mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-24 05:01:41 +08:00
feat: 训练时的 heatmap
This commit is contained in:
parent
22a2db464f
commit
c000b90794
@ -95,7 +95,7 @@ def train():
|
|||||||
for file in os.listdir('tiles2'):
|
for file in os.listdir('tiles2'):
|
||||||
name = os.path.splitext(file)[0]
|
name = os.path.splitext(file)[0]
|
||||||
tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED)
|
tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
# 接续训练
|
# 接续训练
|
||||||
if args.resume:
|
if args.resume:
|
||||||
data_ginka = torch.load(args.state_ginka, map_location=device)
|
data_ginka = torch.load(args.state_ginka, map_location=device)
|
||||||
@ -105,7 +105,7 @@ def train():
|
|||||||
if args.load_optim:
|
if args.load_optim:
|
||||||
if data_ginka.get("optim_state") is not None:
|
if data_ginka.get("optim_state") is not None:
|
||||||
optimizer.load_state_dict(data_ginka["optim_state"])
|
optimizer.load_state_dict(data_ginka["optim_state"])
|
||||||
|
|
||||||
print("Train from loaded state.")
|
print("Train from loaded state.")
|
||||||
|
|
||||||
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
|
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
|
||||||
@ -113,9 +113,14 @@ def train():
|
|||||||
|
|
||||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
cond = batch["val_cond"].to(device)
|
cond = batch["cond"].to(device)
|
||||||
|
heatmap = batch["heatmap"].to(device)
|
||||||
B, H, W = target_map.shape
|
B, H, W = target_map.shape
|
||||||
|
|
||||||
target_map = target_map.view(B, H * W)
|
target_map = target_map.view(B, H * W)
|
||||||
|
rand = torch.randn_like(heatmap).to(device) * 0.05
|
||||||
|
if random.random() > 0.5:
|
||||||
|
heatmap = heatmap + rand
|
||||||
|
|
||||||
mask = np.zeros((B, H * W))
|
mask = np.zeros((B, H * W))
|
||||||
for i in range(B):
|
for i in range(B):
|
||||||
@ -127,7 +132,7 @@ def train():
|
|||||||
masked_input = target_map.clone()
|
masked_input = target_map.clone()
|
||||||
masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记
|
masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记
|
||||||
|
|
||||||
logits = model(masked_input, cond)
|
logits = model(masked_input, cond, heatmap)
|
||||||
|
|
||||||
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
|
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
|
||||||
loss = (loss * mask).sum() / (mask.sum() + 1e-6)
|
loss = (loss * mask).sum() / (mask.sum() + 1e-6)
|
||||||
@ -164,7 +169,8 @@ def train():
|
|||||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||||
# 1. 常规生成
|
# 1. 常规生成
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
cond = batch["val_cond"].to(device)
|
cond = batch["cond"].to(device)
|
||||||
|
heatmap = batch["heatmap"].to(device)
|
||||||
B, H, W = target_map.shape
|
B, H, W = target_map.shape
|
||||||
target_map = target_map.view(B, H * W)
|
target_map = target_map.view(B, H * W)
|
||||||
|
|
||||||
@ -178,7 +184,7 @@ def train():
|
|||||||
masked_input = target_map.clone()
|
masked_input = target_map.clone()
|
||||||
masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记
|
masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记
|
||||||
|
|
||||||
logits = model(masked_input, cond)
|
logits = model(masked_input, cond, heatmap)
|
||||||
|
|
||||||
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
|
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
|
||||||
loss = (loss * mask.view(-1)).sum() / (mask.sum() + 1e-6)
|
loss = (loss * mask.view(-1)).sum() / (mask.sum() + 1e-6)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user