mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-24 21:31:10 +08:00
perf: 修改损失值计算方式
This commit is contained in:
parent
f6b1ad6ebd
commit
447c28ff5e
@ -31,8 +31,22 @@ def load_minamo_gan_data(data: list):
|
|||||||
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
|
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def apply_curriculum_remove(
|
||||||
|
maps: torch.Tensor,
|
||||||
|
remove_classes: List[int], # 要移除的类别索引
|
||||||
|
):
|
||||||
|
C, H, W = maps.shape
|
||||||
|
device = maps.device
|
||||||
|
removed_maps = maps.clone()
|
||||||
|
|
||||||
|
remove_mask = removed_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0
|
||||||
|
removed_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0
|
||||||
|
removed_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地”
|
||||||
|
|
||||||
|
return removed_maps.to(device)
|
||||||
|
|
||||||
def apply_curriculum_mask(
|
def apply_curriculum_mask(
|
||||||
maps: torch.Tensor, # [B, C, H, W]
|
maps: torch.Tensor, # [C, H, W]
|
||||||
mask_classes: List[int], # 要遮挡的类别索引
|
mask_classes: List[int], # 要遮挡的类别索引
|
||||||
remove_classes: List[int], # 要移除的类别索引
|
remove_classes: List[int], # 要移除的类别索引
|
||||||
mask_ratio: float # 遮挡比例 0~1
|
mask_ratio: float # 遮挡比例 0~1
|
||||||
@ -74,24 +88,58 @@ class GinkaWGANDataset(Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
|
def handle_stage1(self, target):
|
||||||
|
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
|
||||||
|
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
|
||||||
|
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
|
||||||
|
|
||||||
|
return removed1, masked1, removed2, masked2, removed3, masked3
|
||||||
|
|
||||||
|
def handle_stage2(self, target):
|
||||||
|
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||||
|
# 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可
|
||||||
|
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1))
|
||||||
|
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1))
|
||||||
|
|
||||||
|
if self.random_ratio > 0:
|
||||||
|
rd = random.uniform(0, self.random_ratio)
|
||||||
|
masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd)
|
||||||
|
masked2 = random_smooth_onehot(masked2, min_main=1 - rd, max_main=1.0, epsilon=rd)
|
||||||
|
masked3 = random_smooth_onehot(masked3, min_main=1 - rd, max_main=1.0, epsilon=rd)
|
||||||
|
|
||||||
|
return removed1, masked1, removed2, masked2, removed3, masked3
|
||||||
|
|
||||||
|
def handle_stage3(self, target):
|
||||||
|
rd = random.uniform(0, self.random_ratio)
|
||||||
|
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||||
|
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
||||||
|
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
||||||
|
masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd)
|
||||||
|
return removed1, masked1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
|
||||||
|
|
||||||
|
def handle_stage4(self, target):
|
||||||
|
input1 = torch.rand((32, 13, 13))
|
||||||
|
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
|
||||||
|
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
||||||
|
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
||||||
|
return removed1, input1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.data[idx]
|
item = self.data[idx]
|
||||||
|
|
||||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||||
|
|
||||||
if self.train_stage == 1:
|
if self.train_stage == 1:
|
||||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
|
return self.handle_stage1(target)
|
||||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
|
|
||||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
|
|
||||||
elif self.train_stage == 2:
|
elif self.train_stage == 2:
|
||||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
return self.handle_stage2(target)
|
||||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 0.9))
|
|
||||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 0.9))
|
|
||||||
|
|
||||||
if self.random_ratio > 0:
|
elif self.train_stage == 3:
|
||||||
removed1 = random_smooth_onehot(removed1, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
|
return self.handle_stage3(target)
|
||||||
removed2 = random_smooth_onehot(removed2, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
|
|
||||||
removed3 = random_smooth_onehot(removed3, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
|
|
||||||
|
|
||||||
return removed1, masked1, removed2, masked2, removed3, masked3
|
elif self.train_stage == 4:
|
||||||
|
return self.handle_stage4(target)
|
||||||
|
|
||||||
|
raise RuntimeError(f"Invalid train stage: {self.train_stage}")
|
||||||
|
|
||||||
@ -327,7 +327,7 @@ def immutable_penalty_loss(
|
|||||||
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
|
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
|
||||||
|
|
||||||
# 差异区域(模型试图改变的地方)
|
# 差异区域(模型试图改变的地方)
|
||||||
penalty = F.cross_entropy(input_mask, target_mask)
|
penalty = F.l1_loss(input_mask, target_mask)
|
||||||
|
|
||||||
return penalty
|
return penalty
|
||||||
|
|
||||||
@ -405,13 +405,13 @@ class WGANGinkaLoss:
|
|||||||
|
|
||||||
fake_scores, _, _ = critic(fake, fake_graph, stage)
|
fake_scores, _, _ = critic(fake, fake_graph, stage)
|
||||||
minamo_loss = -torch.mean(fake_scores)
|
minamo_loss = -torch.mean(fake_scores)
|
||||||
ce_loss = F.cross_entropy(fake, real)
|
ce_loss = F.l1_loss(fake, real)
|
||||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||||
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
||||||
|
|
||||||
losses = [
|
losses = [
|
||||||
minamo_loss * self.weight[0],
|
minamo_loss * self.weight[0],
|
||||||
ce_loss * self.weight[1] / mask_ratio * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
|
ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
|
||||||
immutable_loss * self.weight[2],
|
immutable_loss * self.weight[2],
|
||||||
constraint_loss * self.weight[3]
|
constraint_loss * self.weight[3]
|
||||||
]
|
]
|
||||||
@ -423,4 +423,25 @@ class WGANGinkaLoss:
|
|||||||
|
|
||||||
# print(losses[2].item())
|
# print(losses[2].item())
|
||||||
|
|
||||||
return sum(losses), minamo_loss, ce_loss / mask_ratio, immutable_loss
|
return sum(losses), minamo_loss, ce_loss, immutable_loss
|
||||||
|
|
||||||
|
def generator_loss_total(self, critic, stage, fake) -> torch.Tensor:
|
||||||
|
fake_graph = batch_convert_soft_map_to_graph(fake)
|
||||||
|
|
||||||
|
fake_scores, _, _ = critic(fake, fake_graph, stage)
|
||||||
|
minamo_loss = -torch.mean(fake_scores)
|
||||||
|
immutable_loss = immutable_penalty_loss(fake, fake, STAGE_ALLOWED[stage])
|
||||||
|
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
||||||
|
|
||||||
|
losses = [
|
||||||
|
minamo_loss * self.weight[0],
|
||||||
|
immutable_loss * self.weight[2],
|
||||||
|
constraint_loss * self.weight[3]
|
||||||
|
]
|
||||||
|
|
||||||
|
if stage == 1:
|
||||||
|
# 第一个阶段检查入口存在性
|
||||||
|
entrance_loss = entrance_constraint_loss(fake)
|
||||||
|
losses.append(entrance_loss * self.weight[4])
|
||||||
|
|
||||||
|
return sum(losses)
|
||||||
|
|||||||
@ -46,14 +46,19 @@ def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.
|
|||||||
else:
|
else:
|
||||||
return fake1, fake2, fake3
|
return fake1, fake2, fake3
|
||||||
|
|
||||||
def gen_total(gen, input, detach=False) -> torch.Tensor:
|
def gen_total(gen, input, progress_detach=True, result_detach=False) -> torch.Tensor:
|
||||||
fake1 = gen(input, 1)
|
if progress_detach:
|
||||||
fake2 = gen(fake1, 2)
|
fake1 = gen(input.detach(), 1)
|
||||||
fake3 = gen(fake2, 3)
|
fake2 = gen(fake1.detach(), 2)
|
||||||
if detach:
|
fake3 = gen(fake2.detach(), 3)
|
||||||
return fake3.detach()
|
|
||||||
else:
|
else:
|
||||||
return fake3
|
fake1 = gen(input, 1)
|
||||||
|
fake2 = gen(fake1, 2)
|
||||||
|
fake3 = gen(fake2, 3)
|
||||||
|
if result_detach:
|
||||||
|
return fake1.detach(), fake2.detach(), fake3.detach()
|
||||||
|
else:
|
||||||
|
return fake1, fake2, fake3
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||||
@ -67,6 +72,7 @@ def train():
|
|||||||
train_stage = 1
|
train_stage = 1
|
||||||
mask_ratio = 0.1 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
|
mask_ratio = 0.1 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
|
||||||
random_ratio = 0
|
random_ratio = 0
|
||||||
|
stage3_epoch = 0 # 第三阶段 epoch 数,100 轮后进入第四阶段
|
||||||
|
|
||||||
ginka = GinkaModel()
|
ginka = GinkaModel()
|
||||||
minamo = MinamoScoreModule()
|
minamo = MinamoScoreModule()
|
||||||
@ -109,6 +115,9 @@ def train():
|
|||||||
if data_ginka.get("random_ratio") is not None:
|
if data_ginka.get("random_ratio") is not None:
|
||||||
random_ratio = data_ginka["random_ratio"]
|
random_ratio = data_ginka["random_ratio"]
|
||||||
|
|
||||||
|
if data_ginka.get("stage_epoch3") is not None:
|
||||||
|
stage3_epoch = data_ginka["stage_epoch3"]
|
||||||
|
|
||||||
if data_ginka.get("stage") is not None:
|
if data_ginka.get("stage") is not None:
|
||||||
train_stage = data_ginka["stage"]
|
train_stage = data_ginka["stage"]
|
||||||
|
|
||||||
@ -151,17 +160,18 @@ def train():
|
|||||||
if train_stage == 1 or train_stage == 2:
|
if train_stage == 1 or train_stage == 2:
|
||||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
||||||
|
|
||||||
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
|
elif train_stage == 3 or train_stage == 4:
|
||||||
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
|
fake1, fake2, fake3 = gen_total(ginka, masked1, True, True)
|
||||||
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
|
|
||||||
|
|
||||||
dis_avg = (dis1 + dis2 + dis3) / 3.0
|
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
|
||||||
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
|
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
|
||||||
|
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
|
||||||
|
|
||||||
# 反向传播
|
dis_avg = (dis1 + dis2 + dis3) / 3.0
|
||||||
loss_d_avg.backward()
|
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
|
||||||
elif train_stage == 3:
|
|
||||||
pass
|
# 反向传播
|
||||||
|
loss_d_avg.backward()
|
||||||
|
|
||||||
optimizer_minamo.step()
|
optimizer_minamo.step()
|
||||||
|
|
||||||
@ -188,8 +198,17 @@ def train():
|
|||||||
loss_total_ginka += loss_g.detach()
|
loss_total_ginka += loss_g.detach()
|
||||||
loss_ce_total += loss_ce.detach()
|
loss_ce_total += loss_ce.detach()
|
||||||
|
|
||||||
elif train_stage == 3:
|
elif train_stage == 3 or train_stage == 4:
|
||||||
pass
|
fake1, fake2, fake3 = gen_total(ginka, masked1, True, False)
|
||||||
|
|
||||||
|
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
|
||||||
|
loss_g2 = criterion.generator_loss_total(minamo, 2, fake2)
|
||||||
|
loss_g3 = criterion.generator_loss_total(minamo, 3, fake3)
|
||||||
|
|
||||||
|
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
|
||||||
|
loss_g.backward()
|
||||||
|
optimizer_ginka.step()
|
||||||
|
loss_total_ginka += loss_g.detach()
|
||||||
|
|
||||||
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
|
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
|
||||||
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
|
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
|
||||||
@ -202,12 +221,14 @@ def train():
|
|||||||
f"CE: {avg_loss_ce:.8f} | Mask: {mask_ratio:.2f}"
|
f"CE: {avg_loss_ce:.8f} | Mask: {mask_ratio:.2f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if avg_loss_ce < 0.5:
|
if avg_loss_ce < 0.1:
|
||||||
low_loss_epochs += 1
|
low_loss_epochs += 1
|
||||||
else:
|
else:
|
||||||
low_loss_epochs = 0
|
low_loss_epochs = 0
|
||||||
|
|
||||||
if low_loss_epochs >= 5 and train_stage == 2:
|
if low_loss_epochs >= 5 and train_stage == 2:
|
||||||
|
if random_ratio >= 0.5:
|
||||||
|
train_stage = 3
|
||||||
random_ratio += 0.1
|
random_ratio += 0.1
|
||||||
random_ratio = min(random_ratio, 0.5)
|
random_ratio = min(random_ratio, 0.5)
|
||||||
low_loss_epochs = 0
|
low_loss_epochs = 0
|
||||||
@ -215,11 +236,20 @@ def train():
|
|||||||
if low_loss_epochs >= 5 and train_stage == 1:
|
if low_loss_epochs >= 5 and train_stage == 1:
|
||||||
if mask_ratio >= 0.9:
|
if mask_ratio >= 0.9:
|
||||||
train_stage = 2
|
train_stage = 2
|
||||||
|
|
||||||
mask_ratio += 0.1
|
mask_ratio += 0.1
|
||||||
mask_ratio = min(mask_ratio, 0.9)
|
mask_ratio = min(mask_ratio, 0.9)
|
||||||
low_loss_epochs = 0
|
low_loss_epochs = 0
|
||||||
|
|
||||||
|
if train_stage == 3:
|
||||||
|
stage3_epoch += 1
|
||||||
|
if stage3_epoch >= 100:
|
||||||
|
train_stage = 4
|
||||||
|
stage3_epoch = 0
|
||||||
|
|
||||||
|
if train_stage >= 2:
|
||||||
|
# 第二阶段后 L1 损失不再应该生效
|
||||||
|
mask_ratio = 1.0
|
||||||
|
|
||||||
dataset.train_stage = 2
|
dataset.train_stage = 2
|
||||||
dataset_val.train_stage = 2
|
dataset_val.train_stage = 2
|
||||||
dataset.random_ratio = random_ratio
|
dataset.random_ratio = random_ratio
|
||||||
@ -235,8 +265,8 @@ def train():
|
|||||||
else:
|
else:
|
||||||
g_steps = 1
|
g_steps = 1
|
||||||
|
|
||||||
if avg_loss_ginka > 0 or avg_loss_minamo > 0:
|
if avg_loss_minamo > 0:
|
||||||
c_steps = int(max(min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15), 1))
|
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
|
||||||
else:
|
else:
|
||||||
c_steps = 5
|
c_steps = 5
|
||||||
|
|
||||||
@ -251,6 +281,7 @@ def train():
|
|||||||
"stage": train_stage,
|
"stage": train_stage,
|
||||||
"mask_ratio": mask_ratio,
|
"mask_ratio": mask_ratio,
|
||||||
"random_ratio": random_ratio,
|
"random_ratio": random_ratio,
|
||||||
|
"stage3_epoch": stage3_epoch,
|
||||||
}, f"result/wgan/ginka-{epoch + 1}.pth")
|
}, f"result/wgan/ginka-{epoch + 1}.pth")
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": minamo.state_dict(),
|
"model_state": minamo.state_dict(),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user