perf: 改进网络结构

This commit is contained in:
unanmed 2025-05-01 22:08:39 +08:00
parent a7d21260e4
commit 53041ab754
8 changed files with 81 additions and 71 deletions

View File

@ -7,28 +7,41 @@ class ConditionEncoder(nn.Module):
super().__init__() super().__init__()
self.tag_embed = nn.Linear(tag_dim, hidden_dim) self.tag_embed = nn.Linear(tag_dim, hidden_dim)
self.val_embed = nn.Linear(val_dim, hidden_dim) self.val_embed = nn.Linear(val_dim, hidden_dim)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
batch_first=True
),
num_layers=6
)
self.fusion = nn.Sequential( self.fusion = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim*2),
nn.LayerNorm(hidden_dim*2), nn.LayerNorm(hidden_dim*2),
nn.ELU(), nn.ELU(),
nn.Linear(hidden_dim*2, hidden_dim*4), nn.Linear(hidden_dim*2, out_dim)
nn.LayerNorm(hidden_dim*4),
nn.ELU(),
nn.Linear(hidden_dim*4, out_dim)
) )
def forward(self, tag, val): def forward(self, tag, val):
tag = self.tag_embed(tag) tag = self.tag_embed(tag)
val = self.val_embed(val) val = self.val_embed(val)
feat = torch.cat([tag, val], dim=1) feat = torch.stack([tag, val], dim=1)
feat = self.encoder(feat)
feat = torch.mean(feat, dim=1)
feat = self.fusion(feat) feat = self.fusion(feat)
return feat return feat
class ConditionInjector(nn.Module): class ConditionInjector(nn.Module):
def __init__(self, cond_dim, out_dim): def __init__(self, cond_dim, out_dim):
super().__init__() super().__init__()
self.fc = nn.Sequential( self.gamma_layer = nn.Sequential(
nn.Linear(cond_dim, cond_dim*2),
nn.LayerNorm(cond_dim*2),
nn.ELU(),
nn.Linear(cond_dim*2, out_dim)
)
self.beta_layer = nn.Sequential(
nn.Linear(cond_dim, cond_dim*2), nn.Linear(cond_dim, cond_dim*2),
nn.LayerNorm(cond_dim*2), nn.LayerNorm(cond_dim*2),
nn.ELU(), nn.ELU(),
@ -37,7 +50,6 @@ class ConditionInjector(nn.Module):
) )
def forward(self, x, cond): def forward(self, x, cond):
cond = self.fc(cond) gamma = self.gamma_layer(cond).unsqueeze(2).unsqueeze(3)
B, D = cond.shape beta = self.beta_layer(cond).unsqueeze(2).unsqueeze(3)
cond = cond.view(B, D, 1, 1) return x * gamma + beta
return x + cond

View File

@ -2,12 +2,12 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.utils import spectral_norm from torch.nn.utils import spectral_norm
from torch_geometric.nn import global_max_pool, GCNConv, global_mean_pool from torch_geometric.nn import global_max_pool, GCNConv
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
from shared.graph import batch_convert_soft_map_to_graph from shared.graph import batch_convert_soft_map_to_graph
from .vision import MinamoVisionModel from .vision import MinamoVisionModel
from .topo import MinamoTopoModel from .topo import MinamoTopoModel
from ..common.cond import ConditionEncoder, ConditionInjector from ..common.cond import ConditionEncoder
def print_memory(tag=""): def print_memory(tag=""):
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
@ -24,7 +24,7 @@ class CNNHead(nn.Module):
self.fc = nn.Sequential( self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_ch*2*2, 1)) spectral_norm(nn.Linear(in_ch*2*2, 1))
) )
self.proj = nn.Linear(256, in_ch*2*2) self.proj = spectral_norm(nn.Linear(256, in_ch*2*2))
def forward(self, x, cond): def forward(self, x, cond):
x = self.cnn(x) x = self.cnn(x)
@ -39,7 +39,7 @@ class GCNHead(nn.Module):
def __init__(self, in_dim): def __init__(self, in_dim):
super().__init__() super().__init__()
self.gcn = GCNConv(in_dim, in_dim) self.gcn = GCNConv(in_dim, in_dim)
self.proj = nn.Linear(256, in_dim) self.proj = spectral_norm(nn.Linear(256, in_dim))
self.fc = nn.Sequential( self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_dim, 1)) spectral_norm(nn.Linear(in_dim, 1))
) )
@ -69,7 +69,7 @@ class MinamoModel(nn.Module):
super().__init__() super().__init__()
self.topo_model = MinamoTopoModel(tile_types) self.topo_model = MinamoTopoModel(tile_types)
self.vision_model = MinamoVisionModel(tile_types) self.vision_model = MinamoVisionModel(tile_types)
self.cond = ConditionEncoder(64, 16, 128, 256) self.cond = ConditionEncoder(64, 16, 256, 256)
# 输出层 # 输出层
self.head1 = MinamoScoreHead(512, 512) self.head1 = MinamoScoreHead(512, 512)
self.head2 = MinamoScoreHead(512, 512) self.head2 = MinamoScoreHead(512, 512)

View File

@ -51,7 +51,6 @@ def apply_curriculum_mask(
mask_ratio: float # 遮挡比例 0~1 mask_ratio: float # 遮挡比例 0~1
) -> torch.Tensor: ) -> torch.Tensor:
C, H, W = maps.shape C, H, W = maps.shape
device = maps.device
masked_maps = maps.clone() masked_maps = maps.clone()
# Step 1: 移除不需要的类别(全设为 0 类) # Step 1: 移除不需要的类别(全设为 0 类)

View File

@ -347,7 +347,7 @@ def immutable_penalty_loss(
return penalty return penalty
class WGANGinkaLoss: class WGANGinkaLoss:
def __init__(self, lambda_gp=100, weight=[1, 0.5, 10, 0.2, 0.2, 0.2]): def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2]):
# weight: 判别器损失CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失 # weight: 判别器损失CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
self.lambda_gp = lambda_gp # 梯度惩罚系数 self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight self.weight = weight

View File

@ -15,7 +15,7 @@ class GinkaModel(nn.Module):
""" """
super().__init__() super().__init__()
self.head = RandomInputHead() self.head = RandomInputHead()
self.cond = ConditionEncoder(64, 16, 128, 256) self.cond = ConditionEncoder(64, 16, 256, 256)
self.input = GinkaInput(32, 32, (13, 13), (32, 32)) self.input = GinkaInput(32, 32, (13, 13), (32, 32))
self.unet = GinkaUNet(32, base_ch, base_ch) self.unet = GinkaUNet(32, base_ch, base_ch)
self.output = GinkaOutput(base_ch, out_ch, (13, 13)) self.output = GinkaOutput(base_ch, out_ch, (13, 13))

View File

@ -10,7 +10,11 @@ class StageHead(nn.Module):
self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32) self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32)
self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch]) self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch])
self.pool = nn.Sequential( self.pool = nn.Sequential(
nn.Conv2d(in_ch, in_ch, 3, padding=1, padding_mode='replicate'), nn.Conv2d(in_ch, in_ch*2, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(in_ch*2),
nn.ELU(),
nn.Conv2d(in_ch*2, in_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(in_ch), nn.InstanceNorm2d(in_ch),
nn.ELU(), nn.ELU(),

View File

@ -167,10 +167,6 @@ class GinkaUNet(nn.Module):
"""Ginka Model UNet 部分 """Ginka Model UNet 部分
""" """
super().__init__() super().__init__()
# self.input = GinkaTransformerEncoder(
# in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size
# token_size=4, ff_dim=feat_dim*2, num_layers=4
# )
self.down1 = ConvBlock(in_ch, base_ch) self.down1 = ConvBlock(in_ch, base_ch)
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16) self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8) self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)

View File

@ -11,7 +11,6 @@ from tqdm import tqdm
from .generator.model import GinkaModel from .generator.model import GinkaModel
from .dataset import GinkaWGANDataset from .dataset import GinkaWGANDataset
from .generator.loss import WGANGinkaLoss from .generator.loss import WGANGinkaLoss
from .generator.input import RandomInputHead
from .critic.model import MinamoModel from .critic.model import MinamoModel
from shared.image import matrix_to_image_cv from shared.image import matrix_to_image_cv
@ -106,13 +105,12 @@ def train():
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程 stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
ginka = GinkaModel().to(device) ginka = GinkaModel().to(device)
ginka_head = RandomInputHead().to(device)
minamo = MinamoModel().to(device) minamo = MinamoModel().to(device)
dataset = GinkaWGANDataset(args.train, device) dataset = GinkaWGANDataset(args.train, device)
dataset_val = GinkaWGANDataset(args.validate, device) dataset_val = GinkaWGANDataset(args.validate, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True) dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9)) optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
@ -270,47 +268,6 @@ def train():
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}" f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}"
) )
if avg_loss_ce < 0.5:
low_loss_epochs += 1
else:
low_loss_epochs = 0
# 训练流程控制
if train_stage >= 2:
train_stage += 1
if train_stage == 5:
train_stage = 2
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
if mask_ratio >= 0.9:
train_stage = 2
mask_ratio += 0.2
mask_ratio = min(mask_ratio, 0.9)
low_loss_epochs = 0
stage_epoch = 0
stage_epoch += 1
dataset.train_stage = train_stage
dataset_val.train_stage = train_stage
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
# scheduler_ginka.step()
# scheduler_minamo.step()
if avg_dis < 0:
g_steps = max(int(-avg_dis * 5), 1)
else:
g_steps = 1
if avg_loss_minamo > 0:
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
else:
c_steps = 5
# 每若干轮输出一次图片,并保存检查点 # 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % args.checkpoint == 0: if (epoch + 1) % args.checkpoint == 0:
# 保存检查点 # 保存检查点
@ -344,8 +301,7 @@ def train():
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True) fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
elif train_stage == 3 or train_stage == 4: elif train_stage == 3 or train_stage == 4:
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1) fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
fake1, fake2, fake3, _ = gen_total(ginka, input, tag_cond, val_cond, True, True, train_stage == 4)
fake1 = torch.argmax(fake1, dim=1).cpu().numpy() fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
fake2 = torch.argmax(fake2, dim=1).cpu().numpy() fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
@ -359,6 +315,49 @@ def train():
idx += 1 idx += 1
# 训练流程控制
if mask_ratio < 0.5 and avg_loss_ce < 0.2:
low_loss_epochs += 1
elif mask_ratio > 0.5 and avg_loss_ce < 0.3:
low_loss_epochs += 1
else:
low_loss_epochs = 0
if train_stage >= 2:
train_stage += 1
if train_stage == 5:
train_stage = 2
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
if mask_ratio >= 0.9:
train_stage = 2
mask_ratio += 0.2
mask_ratio = min(mask_ratio, 0.9)
low_loss_epochs = 0
stage_epoch = 0
stage_epoch += 1
# scheduler_ginka.step()
# scheduler_minamo.step()
if avg_dis < 0:
g_steps = max(int(-avg_dis * 5), 1)
else:
g_steps = 1
if avg_loss_minamo > 0:
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
else:
c_steps = 5
dataset.train_stage = train_stage
dataset_val.train_stage = train_stage
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
print("Train ended.") print("Train ended.")
torch.save({ torch.save({
"model_state": ginka.state_dict(), "model_state": ginka.state_dict(),