feat: 训练时的 heatmap

This commit is contained in:
unanmed 2026-03-11 16:33:15 +08:00
parent 22a2db464f
commit c000b90794

View File

@ -113,9 +113,14 @@ def train():
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
target_map = batch["target_map"].to(device)
cond = batch["val_cond"].to(device)
cond = batch["cond"].to(device)
heatmap = batch["heatmap"].to(device)
B, H, W = target_map.shape
target_map = target_map.view(B, H * W)
rand = torch.randn_like(heatmap).to(device) * 0.05
if random.random() > 0.5:
heatmap = heatmap + rand
mask = np.zeros((B, H * W))
for i in range(B):
@ -127,7 +132,7 @@ def train():
masked_input = target_map.clone()
masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记
logits = model(masked_input, cond)
logits = model(masked_input, cond, heatmap)
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
loss = (loss * mask).sum() / (mask.sum() + 1e-6)
@ -164,7 +169,8 @@ def train():
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
# 1. 常规生成
target_map = batch["target_map"].to(device)
cond = batch["val_cond"].to(device)
cond = batch["cond"].to(device)
heatmap = batch["heatmap"].to(device)
B, H, W = target_map.shape
target_map = target_map.view(B, H * W)
@ -178,7 +184,7 @@ def train():
masked_input = target_map.clone()
masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记
logits = model(masked_input, cond)
logits = model(masked_input, cond, heatmap)
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
loss = (loss * mask.view(-1)).sum() / (mask.sum() + 1e-6)