From 8f087fe13840b99ff0c7ef42fbe3ab26ce9b1cfa Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Fri, 15 May 2026 23:49:50 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E5=88=A0=E9=99=A4=E6=97=A7=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_pretrain_split.py | 380 ---------------- ginka/train_stage.py | 684 ---------------------------- ginka/train_vq.py | 812 ---------------------------------- 3 files changed, 1876 deletions(-) delete mode 100644 ginka/train_pretrain_split.py delete mode 100644 ginka/train_stage.py delete mode 100644 ginka/train_vq.py diff --git a/ginka/train_pretrain_split.py b/ginka/train_pretrain_split.py deleted file mode 100644 index 25b3871..0000000 --- a/ginka/train_pretrain_split.py +++ /dev/null @@ -1,380 +0,0 @@ -""" -三通道分拆预训练脚本(方案 B) - -三路编码器各自负责一个语义通道: - 通道 1:空间骨架(floor+wall),损失仅计算 wall(1) 位置 - 通道 2:关卡门控(floor+wall+door+mob+entrance),损失仅计算 {2,9,10} 位置 - 通道 3:收集资源(完整地图),损失仅计算 {3,4,5,6,7,8} 位置 - -预训练完成后保存各通道编码器权重(不含解码头), -供联合训练脚本 train_vq.py 加载并拼接 z。 - -用法示例: - python -m ginka.train_pretrain_split - python -m ginka.train_pretrain_split --resume True --state result/pretrain_split/split-10.pth - # 预训练完成后指定权重路径启动联合训练: - python -m ginka.train_vq --pretrain_split result/pretrain_split/split_final.pth -""" - -import argparse -import os -import sys -from datetime import datetime - -import numpy as np -import torch -import torch.optim as optim -from torch.utils.data import DataLoader -from tqdm import tqdm - -from .vqvae.model import GinkaVQVAE, VQDecodeHead -from .dataset import GinkaSplitDataset -from .utils import masked_focal - -# --------------------------------------------------------------------------- -# 超参数 -# --------------------------------------------------------------------------- -BATCH_SIZE = 64 -NUM_CLASSES = 7 -MAP_SIZE = 13 * 13 -FOCAL_GAMMA = 1.0 - -# 通道 1:空间骨架(floor+wall) -CH1_KEEP = {0, 1} # 编码器输入保留的 tile -CH1_LOSS = {0, 1} # 损失计算范围(仅 wall) -CH1_D_MODEL = 64 -CH1_NHEAD = 8 - -# 通道 2:关卡门控 -CH2_KEEP = {0, 1, 2, 4, 5} -CH2_LOSS = {2, 4, 5} -CH2_D_MODEL = 64 -CH2_NHEAD = 8 - -# 通道 3:收集资源 -CH3_KEEP = None # 完整地图,无需切片 -CH3_LOSS = {3} -CH3_D_MODEL = 64 -CH3_NHEAD = 8 - -# 三路共用的 VQ 超参 -VQ_L = 2 -VQ_K = 8 -VQ_D_Z = 64 -VQ_LAYERS = 3 -VQ_DIM_FF = 512 -VQ_BETA = 0.5 # commit loss 权重 -VQ_GAMMA = 0.0 # entropy loss 权重 - -# 解码头超参(三路共用相同规格) -DH_NHEAD = 8 -DH_DIM_FF = 512 -DH_LAYERS = 3 - -# --------------------------------------------------------------------------- -# 设备 -# --------------------------------------------------------------------------- -device = torch.device( - "cuda:1" if torch.cuda.is_available() - else "mps" if torch.backends.mps.is_available() - else "cpu" -) - -os.makedirs("result/pretrain_split", exist_ok=True) - -disable_tqdm = not sys.stdout.isatty() - -# --------------------------------------------------------------------------- -# 参数解析 -# --------------------------------------------------------------------------- -def _str2bool(v: str) -> bool: - """argparse 专用:将字符串 'True'/'False' 正确转为 bool。 - type=bool 会把任何非空字符串(包括 'False')解析为 True,故需此辅助。""" - if isinstance(v, bool): - return v - if v.lower() in ('true', '1', 'yes'): - return True - if v.lower() in ('false', '0', 'no'): - return False - raise argparse.ArgumentTypeError(f"布尔值应为 True/False,收到: {v!r}") - -def parse_arguments(): - parser = argparse.ArgumentParser(description="三通道分拆 VQ 编码器预训练(方案 B)") - parser.add_argument("--resume", type=_str2bool, default=False) - parser.add_argument("--state", type=str, default="result/pretrain_split/split-10.pth", - help="续训时加载的检查点路径") - parser.add_argument("--train", type=str, default="ginka-dataset.json") - parser.add_argument("--validate", type=str, default="ginka-eval.json") - parser.add_argument("--epochs", type=int, default=60) - parser.add_argument("--checkpoint", type=int, default=5, - help="每隔多少 epoch 保存检查点并输出验证指标") - parser.add_argument("--load_optim", type=_str2bool, default=True) - return parser.parse_args() - -# --------------------------------------------------------------------------- -# 验证:各通道专属 tile 召回率 + codebook 使用熵 -# --------------------------------------------------------------------------- -@torch.no_grad() -def validate( - enc1, enc2, enc3, - head1, head2, head3, - dataloader_val: DataLoader, -) -> dict: - for m in [enc1, enc2, enc3, head1, head2, head3]: - m.eval() - - # 每类 tile 的 tp / gt 计数 - ch1_tp, ch1_gt = 0, 0 # wall(1) - ch2_tp = {t: 0 for t in CH2_LOSS} # {2,4,5} - ch2_gt = {t: 0 for t in CH2_LOSS} - ch3_tp = {t: 0 for t in CH3_LOSS} # {3,4,5} - ch3_gt = {t: 0 for t in CH3_LOSS} - - # codebook 使用频次(用于熵估算) - codebook_counts = [ - torch.zeros(VQ_K, dtype=torch.long), # 通道 1 - torch.zeros(VQ_K, dtype=torch.long), # 通道 2 - torch.zeros(VQ_K, dtype=torch.long), # 通道 3 - ] - - for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm): - raw_map = batch["raw_map"].to(device) - s1 = batch["slice1"].to(device) - s2 = batch["slice2"].to(device) - s3 = batch["slice3"].to(device) - - # 通道 1 - z_q1, _, idx1, _, _, _ = enc1(s1) - logits1 = head1(z_q1, torch.zeros_like(raw_map)) - pred1 = logits1.argmax(dim=-1) # [B, H*W] - wall_m = (raw_map == 1) - ch1_tp += (pred1[wall_m] == 1).sum().item() - ch1_gt += wall_m.sum().item() - for code in idx1.view(-1).cpu(): - codebook_counts[0][code] += 1 - - # 通道 2 - z_q2, _, idx2, _, _, _ = enc2(s2) - logits2 = head2(z_q2, s1) - pred2 = logits2.argmax(dim=-1) - for t in CH2_LOSS: - m = (raw_map == t) - ch2_tp[t] += (pred2[m] == t).sum().item() - ch2_gt[t] += m.sum().item() - for code in idx2.view(-1).cpu(): - codebook_counts[1][code] += 1 - - # 通道 3 - z_q3, _, idx3, _, _, _ = enc3(s3) - logits3 = head3(z_q3, s2) - pred3 = logits3.argmax(dim=-1) - for t in CH3_LOSS: - m = (raw_map == t) - ch3_tp[t] += (pred3[m] == t).sum().item() - ch3_gt[t] += m.sum().item() - for code in idx3.view(-1).cpu(): - codebook_counts[2][code] += 1 - - def _entropy(counts): - """估算 codebook 使用熵(bits)。""" - counts = counts.float() + 1e-8 - p = counts / counts.sum() - return float(-(p * torch.log2(p)).sum()) - - metrics = { - "ch1_wall_recall": ch1_tp / max(ch1_gt, 1), - "ch2_recall": {t: ch2_tp[t] / max(ch2_gt[t], 1) for t in CH2_LOSS}, - "ch3_recall": {t: ch3_tp[t] / max(ch3_gt[t], 1) for t in CH3_LOSS}, - "codebook_entropy": [_entropy(c) for c in codebook_counts], - } - return metrics - -# --------------------------------------------------------------------------- -# 主训练函数 -# --------------------------------------------------------------------------- -def train(): - print(f"Using device: {device}") - args = parse_arguments() - - # ---- 三路编码器 ---- - enc1 = GinkaVQVAE( - num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z, - d_model=CH1_D_MODEL, nhead=CH1_NHEAD, num_layers=VQ_LAYERS, - dim_ff=VQ_DIM_FF, beta=VQ_BETA, gamma=VQ_GAMMA, - ).to(device) - - enc2 = GinkaVQVAE( - num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z, - d_model=CH2_D_MODEL, nhead=CH2_NHEAD, num_layers=VQ_LAYERS, - dim_ff=VQ_DIM_FF, beta=VQ_BETA, gamma=VQ_GAMMA, - ).to(device) - - enc3 = GinkaVQVAE( - num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z, - d_model=CH3_D_MODEL, nhead=CH3_NHEAD, num_layers=VQ_LAYERS, - dim_ff=VQ_DIM_FF, beta=VQ_BETA, gamma=VQ_GAMMA, - ).to(device) - - # ---- 三路解码头(预训练专用,训练后丢弃)---- - head1 = VQDecodeHead( - num_classes=NUM_CLASSES, d_z=VQ_D_Z, map_size=MAP_SIZE, - nhead=DH_NHEAD, dim_ff=DH_DIM_FF, num_layers=DH_LAYERS, - ).to(device) - - head2 = VQDecodeHead( - num_classes=NUM_CLASSES, d_z=VQ_D_Z, map_size=MAP_SIZE, - nhead=DH_NHEAD, dim_ff=DH_DIM_FF, num_layers=DH_LAYERS, - ).to(device) - - head3 = VQDecodeHead( - num_classes=NUM_CLASSES, d_z=VQ_D_Z, map_size=MAP_SIZE, - nhead=DH_NHEAD, dim_ff=DH_DIM_FF, num_layers=DH_LAYERS, - ).to(device) - - # ---- 优化器(三路同步训练) ---- - optimizer = optim.AdamW( - list(enc1.parameters()) + list(enc2.parameters()) + list(enc3.parameters()) + - list(head1.parameters()) + list(head2.parameters()) + list(head3.parameters()), - lr=1e-3, - weight_decay=1e-4, - ) - - start_epoch = 0 - - # ---- 续训 ---- - if args.resume: - ckpt = torch.load(args.state, map_location=device) - enc1.load_state_dict(ckpt["enc1"]) - enc2.load_state_dict(ckpt["enc2"]) - enc3.load_state_dict(ckpt["enc3"]) - head1.load_state_dict(ckpt["head1"]) - head2.load_state_dict(ckpt["head2"]) - head3.load_state_dict(ckpt["head3"]) - if args.load_optim and "optimizer" in ckpt: - optimizer.load_state_dict(ckpt["optimizer"]) - start_epoch = ckpt.get("epoch", 0) - print(f"Resumed from epoch {start_epoch}: {args.state}") - - # ---- 数据集 ---- - ds_train = GinkaSplitDataset(args.train) - ds_val = GinkaSplitDataset(args.validate) - dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True) - dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True) - - print(f"训练集大小: {len(ds_train)},验证集大小: {len(ds_val)}") - - total_params = ( - sum(p.numel() for p in enc1.parameters()) + - sum(p.numel() for p in enc2.parameters()) + - sum(p.numel() for p in enc3.parameters()) - ) - print(f"编码器总参数量(三路): {total_params:,} ({total_params / 1e6:.3f}M)") - total_params = ( - sum(p.numel() for p in head1.parameters()) + - sum(p.numel() for p in head2.parameters()) + - sum(p.numel() for p in head3.parameters()) - ) - print(f"解码器总参数量(三路): {total_params:,} ({total_params / 1e6:.3f}M)") - - # ---- 训练循环 ---- - for epoch in range(start_epoch, args.epochs): - for m in [enc1, enc2, enc3, head1, head2, head3]: - m.train() - - total_loss = 0.0 - ch_losses = [0.0, 0.0, 0.0] - - for batch in tqdm(dl_train, desc=f"Epoch {epoch + 1}/{args.epochs}", disable=disable_tqdm): - raw_map = batch["raw_map"].to(device) - s1 = batch["slice1"].to(device) - s2 = batch["slice2"].to(device) - s3 = batch["slice3"].to(device) - - optimizer.zero_grad() - - # ─── 通道 1 ─── - z_q1, _, _, vq_loss1, commit_loss1, entropy_loss1 = enc1(s1) - logits1 = head1(z_q1, torch.zeros_like(raw_map)) # [B, H*W, C] - fl1 = masked_focal(logits1, raw_map, CH1_LOSS, gamma=FOCAL_GAMMA) - loss1 = fl1 + VQ_BETA * commit_loss1 + VQ_GAMMA * entropy_loss1 - - # ─── 通道 2 ─── - z_q2, _, _, vq_loss2, commit_loss2, entropy_loss2 = enc2(s2) - logits2 = head2(z_q2, s1) - fl2 = masked_focal(logits2, raw_map, CH2_LOSS, gamma=FOCAL_GAMMA) - loss2 = fl2 + VQ_BETA * commit_loss2 + VQ_GAMMA * entropy_loss2 - - # ─── 通道 3 ─── - z_q3, _, _, vq_loss3, commit_loss3, entropy_loss3 = enc3(s3) - logits3 = head3(z_q3, s2) - fl3 = masked_focal(logits3, raw_map, CH3_LOSS, gamma=FOCAL_GAMMA) - loss3 = fl3 + VQ_BETA * commit_loss3 + VQ_GAMMA * entropy_loss3 - - loss = loss1 + loss2 + loss3 - loss.backward() - torch.nn.utils.clip_grad_norm_( - list(enc1.parameters()) + list(enc2.parameters()) + list(enc3.parameters()) + - list(head1.parameters()) + list(head2.parameters()) + list(head3.parameters()), - max_norm=1.0, - ) - optimizer.step() - - total_loss += loss.item() - ch_losses[0] += loss1.item() - ch_losses[1] += loss2.item() - ch_losses[2] += loss3.item() - - n_batches = len(dl_train) - print( - f"[{epoch + 1:03d}] total={total_loss / n_batches:.4f} " - f"ch1={ch_losses[0] / n_batches:.4f} " - f"ch2={ch_losses[1] / n_batches:.4f} " - f"ch3={ch_losses[2] / n_batches:.4f}" - ) - - # ---- 检查点 & 验证 ---- - if (epoch + 1) % args.checkpoint == 0 or epoch + 1 == args.epochs: - metrics = validate(enc1, enc2, enc3, head1, head2, head3, dl_val) - print( - f" 验证 ch1_wall_recall={metrics['ch1_wall_recall']:.3f} " - f"ch2_recall={metrics['ch2_recall']} " - f"ch3_recall={metrics['ch3_recall']}" - ) - print( - f" codebook_entropy ch1={metrics['codebook_entropy'][0]:.3f} " - f"ch2={metrics['codebook_entropy'][1]:.3f} " - f"ch3={metrics['codebook_entropy'][2]:.3f}" - ) - - ts = datetime.now().strftime("%m%d-%H%M") - ckpt_path = f"result/pretrain_split/split-{epoch + 1}.pth" - torch.save({ - "epoch": epoch + 1, - "enc1": enc1.state_dict(), - "enc2": enc2.state_dict(), - "enc3": enc3.state_dict(), - "head1": head1.state_dict(), - "head2": head2.state_dict(), - "head3": head3.state_dict(), - "optimizer": optimizer.state_dict(), - "metrics": metrics, - "ts": ts, - }, ckpt_path) - print(f" Saved checkpoint: {ckpt_path}") - - # ---- 保存最终编码器权重(供联合训练加载) ---- - final_path = "result/pretrain_split/split_final.pth" - torch.save({ - "epoch": args.epochs, - "enc1": enc1.state_dict(), - "enc2": enc2.state_dict(), - "enc3": enc3.state_dict(), - # 解码头不迁移,不保存 - }, final_path) - print(f"\n预训练完成,编码器权重已保存至: {final_path}") - print("接下来运行联合训练(阶段 1 冻结热身):") - print(f" python -m ginka.train_vq --pretrain_split {final_path} --freeze_vq True") - - -if __name__ == "__main__": - train() diff --git a/ginka/train_stage.py b/ginka/train_stage.py deleted file mode 100644 index d1c7e4c..0000000 --- a/ginka/train_stage.py +++ /dev/null @@ -1,684 +0,0 @@ -""" -三阶段级联训练脚本:各阶段独立训练,使用 GinkaStageDataset。 - -总损失 = L_CE(只对本阶段负责的 tile 位置计算)+ beta * L_commit + gamma * L_entropy - -各阶段分工: - stage=1 结构骨架:floor(0) + wall(1) - stage=2 功能元素:door(2) + monster(4) + entrance(5) - stage=3 资源放置:resource(3) - -用法示例: - python -m ginka.train_stage --stage 0 # 三阶段联合训练(推荐) - python -m ginka.train_stage --stage 1 # 只训练 stage1 - python -m ginka.train_stage --stage 0 --resume True --state result/joint-50.pth - python -m ginka.train_stage --stage 0 --pretrain_vq result/joint/joint-50.pth -""" - -import argparse -import math -import os -import sys -import random -from datetime import datetime - -import cv2 -import numpy as np -import torch -import torch.nn.functional as F -import torch.optim as optim -from tqdm import tqdm -from torch.utils.data import DataLoader - -from .vqvae.model import GinkaVQVAE -from .maskGIT.model import GinkaMaskGIT -from .dataset import GinkaStageDataset -from shared.image import matrix_to_image_cv - -# --------------------------------------------------------------------------- -# 各阶段配置 -# --------------------------------------------------------------------------- - -# 共用 VQ-VAE 超参 -VQ_L = 2 # 码字序列长度 -VQ_K = 8 # codebook 大小 -VQ_D_Z = 64 # 码字维度 -VQ_BETA = 0.5 # commit loss 权重 -VQ_GAMMA = 0.0 # entropy loss 权重 -VQ_LAYERS = 3 -VQ_DIM_FF = 512 -VQ_D_MODEL = 64 -VQ_NHEAD = 8 - -# 各阶段 MaskGIT 超参(按任务复杂度差异化配置) -STAGE_MG_CONFIGS = { - 1: dict(d_model=256, nhead=8, num_layers=6, dim_ff=2048), # 结构骨架,最重要 - 2: dict(d_model=192, nhead=8, num_layers=4, dim_ff=1024), # 功能元素 - 3: dict(d_model=128, nhead=8, num_layers=3, dim_ff=512), # 资源放置,最简单 -} - -# 各阶段监控的 tile 集合(用于分类别召回率统计) -STAGE_TILE_SETS = { - 1: {0: "floor", 1: "wall"}, - 2: {2: "door", 4: "monster", 5: "entrance"}, - 3: {3: "resource"}, -} - -# 各阶段损失权重(可单独调节 CE 与 VQ 损失的平衡) -# stage3 的 resource 极稀疏,大幅上调 ce_weight 以补偿类别不均衡 -STAGE_LOSS_CONFIG = { - 1: dict(ce_weight=1.0, vq_weight=1.0), # 结构骨架,标准权重 - 2: dict(ce_weight=1.5, vq_weight=0.5), # 功能元素较稀疏,上调 CE - 3: dict(ce_weight=3.0, vq_weight=0.5), # resource 极稀疏,显著上调 CE -} - -NUM_CLASSES = 7 -MASK_TOKEN = 6 -MAP_SIZE = 13 * 13 -MAP_H = MAP_W = 13 -FOCAL_GAMMA = 2.0 -GENERATE_STEP = 18 -BATCH_SIZE = 64 -WALL_MASK_RATIO = 0.8 - -MG_Z_DROPOUT = 0.1 -MG_STRUCT_DROPOUT = 0.1 - -SUBSET_WEIGHTS = (0.5, 0.2, 0.2, 0.1) - -device = torch.device( - "cuda:1" if torch.cuda.is_available() - else "mps" if torch.backends.mps.is_available() - else "cpu" -) - -disable_tqdm = not sys.stdout.isatty() - -# --------------------------------------------------------------------------- -# 参数解析 -# --------------------------------------------------------------------------- - -def _str2bool(v): - if isinstance(v, bool): return v - if v.lower() in ('true', '1', 'yes'): return True - if v.lower() in ('false', '0', 'no'): return False - raise argparse.ArgumentTypeError(f"布尔值应为 True/False,收到: {v!r}") - - -def parse_arguments(): - parser = argparse.ArgumentParser(description="三阶段级联训练") - parser.add_argument( - "--stage", type=int, required=True, choices=[0, 1, 2, 3], - help="训练阶段:1/2/3 单独训练,0 = 依次训练全部三个阶段", - ) - parser.add_argument("--resume", type=_str2bool, default=False) - parser.add_argument( - "--state", type=str, default="", - help="续训时加载的检查点路径(自动推断 stage{N}/stage{N}-*.pth)", - ) - parser.add_argument("--train", type=str, default="ginka-dataset.json") - parser.add_argument("--validate", type=str, default="ginka-eval.json") - parser.add_argument("--epochs", type=int, default=100) - parser.add_argument("--checkpoint", type=int, default=5) - parser.add_argument("--load_optim", type=_str2bool, default=True) - parser.add_argument( - "--freeze_vq", type=_str2bool, default=False, - help="冻结 VQ 编码器,只训练 MaskGIT(适合加载预训练编码器后热身)", - ) - parser.add_argument( - "--pretrain_vq", type=str, default="", - help="从 train_vq.py 的联合训练检查点中导入对应通道的 VQ 编码器权重", - ) - return parser.parse_args() - -# --------------------------------------------------------------------------- -# Focal Loss(与 train_vq.py 一致) -# --------------------------------------------------------------------------- - -def focal_loss(logits, targets, gamma=FOCAL_GAMMA, reduction='none'): - ce = F.cross_entropy(logits, targets, reduction='none') - pt = torch.exp(-ce) - fl = (1.0 - pt) ** gamma * ce - if reduction == 'mean': return fl.mean() - if reduction == 'sum': return fl.sum() - return fl - - -def masked_focal_loss(logits, targets, loss_mask, gamma=FOCAL_GAMMA): - """ - 只对 loss_mask 为 True 的位置计算 focal loss 均值。 - - Args: - logits: [B, C, H*W] - targets: [B, H*W] - loss_mask: [B, H*W] bool - """ - per_token = focal_loss(logits, targets, gamma, reduction='none') # [B, H*W] - selected = per_token[loss_mask] - if selected.numel() == 0: - return per_token.mean() - return selected.mean() - -# --------------------------------------------------------------------------- -# MaskGIT 推理(cosine schedule) -# --------------------------------------------------------------------------- - -@torch.no_grad() -def maskgit_generate( - model_mg: GinkaMaskGIT, - z: torch.Tensor, - steps: int = GENERATE_STEP, - init_map: torch.Tensor = None, - struct_cond: torch.Tensor = None, -) -> torch.Tensor: - """ - 迭代生成地图(cosine schedule unmasking)。 - - Args: - init_map: 可选初始地图;非 MASK 位置在生成中保持不变。 - - Returns: - [B, MAP_SIZE] LongTensor - """ - B = z.shape[0] - map_seq = ( - torch.full((B, MAP_SIZE), MASK_TOKEN, device=device) - if init_map is None else init_map.clone().to(device) - ) - - generatable = (map_seq == MASK_TOKEN) - - for step in range(steps): - if not generatable.any(): - break - - logits = model_mg(map_seq, z, struct_cond=struct_cond) # [B, S, C] - probs = F.softmax(logits, dim=-1) - dist = torch.distributions.Categorical(probs) - sampled = dist.sample() - confidences = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1) - confidences = confidences.masked_fill(~generatable, float('inf')) - - ratio = math.cos(((step + 1) / steps) * math.pi / 2) - new_map = map_seq.clone() - - for b in range(B): - n_gen = int(generatable[b].sum().item()) - n_keep = int(ratio * n_gen) - if n_keep > 0: - _, keep_idx = torch.topk(confidences[b], k=n_keep, largest=False) - pred_b = sampled[b].clone() - pred_b[keep_idx] = MASK_TOKEN - new_map[b] = torch.where(generatable[b], pred_b, map_seq[b]) - else: - new_map[b] = torch.where(generatable[b], sampled[b], map_seq[b]) - - map_seq = new_map - - return map_seq - -# --------------------------------------------------------------------------- -# 可视化工具(与 train_vq.py 保持一致) -# --------------------------------------------------------------------------- - -def make_map_image(map_flat, tile_dict): - arr = map_flat.cpu().numpy().reshape(MAP_H, MAP_W) - return matrix_to_image_cv(arr, tile_dict) - - -def hstack_images(imgs, gap=4, color=(255, 255, 255)): - max_h = max(img.shape[0] for img in imgs) - - def _pad(img): - dh = max_h - img.shape[0] - return img if dh == 0 else np.concatenate( - [img, np.full((dh, img.shape[1], 3), color, dtype=np.uint8)], axis=0) - - vline = np.full((max_h, gap, 3), color, dtype=np.uint8) - result = _pad(imgs[0]) - for img in imgs[1:]: - result = np.concatenate([result, vline, _pad(img)], axis=1) - return result - - -def grid_images(imgs, gap=4, bg=(255, 255, 255)): - n = len(imgs) - if n == 0: return np.zeros((1, 1, 3), dtype=np.uint8) - if n == 1: return imgs[0] - mid = math.ceil(n / 2) - top = hstack_images(imgs[:mid], gap, bg) - bot_imgs = imgs[mid:] - if not bot_imgs: return top - bot = hstack_images(bot_imgs, gap, bg) - tw, bw = top.shape[1], bot.shape[1] - if tw > bw: - bot = np.concatenate( - [bot, np.full((bot.shape[0], tw - bw, 3), bg, dtype=np.uint8)], axis=1) - elif bw > tw: - top = np.concatenate( - [top, np.full((top.shape[0], bw - tw, 3), bg, dtype=np.uint8)], axis=1) - hline = np.full((gap, top.shape[1], 3), bg, dtype=np.uint8) - return np.concatenate([top, hline, bot], axis=0) - - -def label_image(img, text, font_scale=0.45): - bar = np.full((16, img.shape[1], 3), (40, 40, 40), dtype=np.uint8) - cv2.putText( - bar, text, (2, 13), cv2.FONT_HERSHEY_SIMPLEX, - font_scale, (200, 200, 200), 1, cv2.LINE_AA, - ) - return np.concatenate([bar, img], axis=0) - - -def make_random_struct_cond(): - from .maskGIT.model import SYM_VOCAB, ROOM_VOCAB, BRANCH_VOCAB, OUTER_VOCAB - return torch.tensor([[ - random.randint(0, SYM_VOCAB - 2), - random.randint(0, ROOM_VOCAB - 2), - random.randint(0, BRANCH_VOCAB - 2), - random.randint(0, OUTER_VOCAB - 2), - ]], dtype=torch.long, device=device) - -# --------------------------------------------------------------------------- -# 按阶段构造推理初始地图 -# --------------------------------------------------------------------------- - -def make_stage_init(stage: int, context_map: torch.Tensor) -> torch.Tensor: - """ - 根据阶段构造 MaskGIT 的推理初始地图(与训练端掩码策略一致)。 - - Stage 1: 全 MASK - Stage 2: 只保留 wall(1),floor + 功能元素 → MASK - Stage 3: 保留 wall(1)/door(2)/monster(4)/entrance(5),floor + resource → MASK - """ - init = context_map.clone() - - if stage == 1: - init = torch.full_like(init, MASK_TOKEN) - - elif stage == 2: - # 只保留 wall,其余全部 → MASK - keep = torch.isin(init, torch.tensor([1], device=init.device)) - init[~keep] = MASK_TOKEN - - else: # stage == 3 - # 保留 wall + 功能元素,floor + resource → MASK - keep = torch.isin(init, torch.tensor([1, 2, 4, 5], device=init.device)) - init[~keep] = MASK_TOKEN - - return init - - -def make_random_wall_seed(ratio_min=0.02, ratio_max=0.08): - ratio = random.uniform(ratio_min, ratio_max) - n_wall = max(2, int(MAP_SIZE * ratio)) - seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device) - idx = torch.randperm(MAP_SIZE)[:n_wall] - seed[0, idx] = 1 - return seed - -# --------------------------------------------------------------------------- -# 验证函数 -# --------------------------------------------------------------------------- - -@torch.no_grad() -def validate( - stage: int, - enc: GinkaVQVAE, - model_mg: GinkaMaskGIT, - dataloader_val: DataLoader, - tile_dict: dict, - epoch: int, - n_rand: int = 3, -): - enc.eval() - model_mg.eval() - - epoch_dir = f"result/stage{stage}_img/e{epoch:04d}" - os.makedirs(epoch_dir, exist_ok=True) - - val_loss_total = 0.0 - val_steps = 0 - captured = {s: None for s in ('A', 'B', 'C', 'D')} - - for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm): - raw_map = batch["raw_map"].to(device) - vq_slice = batch["vq_slice"].to(device) - stage_input = batch["stage_input"].to(device) - target_map = batch["target_map"].to(device) - loss_mask = batch["loss_mask"].to(device) - struct_cond = batch["struct_cond"].to(device) - subsets = batch["subset"] - - z_q, _, _, vq_loss, _, _ = enc(vq_slice) - logits = model_mg(stage_input, z_q, struct_cond=struct_cond) - - ce = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask) - val_loss_total += (ce + vq_loss).item() - val_steps += 1 - - for i in range(raw_map.shape[0]): - s = subsets[i] - if captured[s] is None: - captured[s] = { - "raw": raw_map[i:i+1].clone(), - "stage_input": stage_input[i:i+1].clone(), - "z_q": z_q[i:i+1].clone(), - "struct_cond": struct_cond[i:i+1].clone(), - } - - if all(v is not None for v in captured.values()): - break - - # ---- 可视化:每个子集一张图 ---------------------------------------- - for sub, cap in captured.items(): - if cap is None: - continue - - raw_img = label_image(make_map_image(cap["raw"][0], tile_dict), "GT") - inp_img = label_image(make_map_image(cap["stage_input"][0], tile_dict), f"stage{stage} input") - - # 真实 z 的迭代生成 - init = make_stage_init(stage, cap["stage_input"][0].unsqueeze(0)) - gen = maskgit_generate( - model_mg, cap["z_q"], - init_map=init, struct_cond=cap["struct_cond"], - ) - gen_img = label_image(make_map_image(gen[0], tile_dict), "z_real gen") - - # 随机 z 的生成 - rand_imgs = [] - for i in range(n_rand): - z_r = enc.sample(1, device) - sc_r = make_random_struct_cond() - init2 = make_stage_init(stage, cap["raw"][0].unsqueeze(0)) - gen_r = maskgit_generate(model_mg, z_r, init_map=init2, struct_cond=sc_r) - rand_imgs.append(label_image(make_map_image(gen_r[0], tile_dict), f"z_rand_{i+1}")) - - row = [raw_img, inp_img, gen_img] + rand_imgs - cv2.imwrite(f"{epoch_dir}/subset_{sub}.png", grid_images(row)) - - # ---- 场景:完全自主生成(仅单阶段时执行,多阶段由级联验证统一覆盖)------ - if True: # 占位,避免缩进塌陷;单阶段验证不做级联,跳过 - pass - - return val_loss_total / max(val_steps, 1) - -# --------------------------------------------------------------------------- -# 主训练函数 -# --------------------------------------------------------------------------- - -def _build_stage(stage: int, args): - """初始化单个阶段的模型、数据集,返回状态字典(不含优化器)。""" - result_dir = f"result/stage{stage}" - os.makedirs(result_dir, exist_ok=True) - os.makedirs(f"result/stage{stage}_img", exist_ok=True) - - mg_cfg = STAGE_MG_CONFIGS[stage] - enc = GinkaVQVAE( - num_classes=NUM_CLASSES, - L=VQ_L, - K=VQ_K, - d_z=VQ_D_Z, - d_model=VQ_D_MODEL, - nhead=VQ_NHEAD, - num_layers=VQ_LAYERS, - dim_ff=VQ_DIM_FF, - map_size=MAP_SIZE, - beta=VQ_BETA, - gamma=VQ_GAMMA, - ).to(device) - model_mg = GinkaMaskGIT( - num_classes=NUM_CLASSES, - d_model=mg_cfg["d_model"], - d_z=VQ_D_Z, - dim_ff=mg_cfg["dim_ff"], - nhead=mg_cfg["nhead"], - num_layers=mg_cfg["num_layers"], - map_size=MAP_SIZE, - z_dropout=MG_Z_DROPOUT, - struct_dropout=MG_STRUCT_DROPOUT, - ).to(device) - - enc_params = sum(p.numel() for p in enc.parameters()) - mg_params = sum(p.numel() for p in model_mg.parameters()) - print(f"[Stage {stage}] VQ={enc_params/1e6:.2f}M MaskGIT={mg_params/1e6:.2f}M") - - dataset_train = GinkaStageDataset( - args.train, - stage=stage, - subset_weights=SUBSET_WEIGHTS, - wall_mask_ratio=WALL_MASK_RATIO, - ) - dataset_val = GinkaStageDataset( - args.validate, - stage=stage, - subset_weights=SUBSET_WEIGHTS, - room_thresholds=dataset_train.room_th, - branch_thresholds=dataset_train.branch_th, - wall_mask_ratio=WALL_MASK_RATIO, - ) - dataloader_train = DataLoader( - dataset_train, - batch_size=BATCH_SIZE, - shuffle=True, - num_workers=0, - pin_memory=(device.type == "cuda"), - ) - dataloader_val = DataLoader( - dataset_val, - batch_size=8, - shuffle=True, - num_workers=0, - ) - - if args.pretrain_vq: - ckpt = torch.load(args.pretrain_vq, map_location=device) - enc_key = f"enc{stage}" - if enc_key in ckpt: - enc.load_state_dict(ckpt[enc_key], strict=False) - print(f"[Stage {stage}] 已加载预训练 VQ 权重。") - else: - print(f"[Stage {stage}] 警告:检查点中未找到 {enc_key}。") - - if args.freeze_vq: - for p in enc.parameters(): - p.requires_grad_(False) - print(f"[Stage {stage}] VQ 编码器已冻结。") - - return { - "stage": stage, - "enc": enc, - "model_mg": model_mg, - "dataloader_train": dataloader_train, - "dataloader_val": dataloader_val, - "result_dir": result_dir, - } - - -# --------------------------------------------------------------------------- -def train(): - print(f"Using device: {device}") - args = parse_arguments() - stages = [1, 2, 3] if args.stage == 0 else [args.stage] - - # ---- tile 贴图(一次性加载,所有阶段共用)---- - tile_dict = {} - for f in os.listdir("tiles"): - name = os.path.splitext(f)[0] - img = cv2.imread(f"tiles/{f}", cv2.IMREAD_UNCHANGED) - if img is not None: - tile_dict[name] = img - - # ---- 初始化各阶段 ---- - states = {stage: _build_stage(stage, args) for stage in stages} - - # ---- 合并优化器(所有阶段参数统一管理)---- - all_params = [] - for st in states.values(): - all_params += list(st["enc"].parameters()) + list(st["model_mg"].parameters()) - optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2) - scheduler = optim.lr_scheduler.CosineAnnealingLR( - optimizer, T_max=args.epochs, eta_min=1e-6, - ) - - # ---- 续训 ---- - start_epoch = 0 - if args.resume: - ckpt = torch.load(args.state, map_location=device) - for stage in stages: - st = states[stage] - st["enc"].load_state_dict(ckpt[f"enc{stage}"], strict=False) - st["model_mg"].load_state_dict(ckpt[f"mg{stage}"], strict=False) - if args.load_optim and ckpt.get("optim_state") is not None: - optimizer.load_state_dict(ckpt["optim_state"]) - start_epoch = ckpt.get("epoch", 0) - print(f"从 epoch {start_epoch} 接续训练。") - - # ---- 数据集对齐:以最短的 dataloader 为准,zip 迭代 ---- - # 单阶段时直接用该阶段的 dataloader;多阶段时 zip 保证每个 batch 各阶段同步推进 - def _epoch_iters(): - loaders = [states[s]["dataloader_train"] for s in stages] - return zip(*loaders) - - # ---- 训练循环 ---- - for epoch in tqdm( - range(start_epoch, start_epoch + args.epochs), - desc="Training", - disable=disable_tqdm, - ): - for st in states.values(): - st["enc"].train() - st["model_mg"].train() - - loss_totals = {s: 0.0 for s in stages} - ce_totals = {s: 0.0 for s in stages} - vq_totals = {s: 0.0 for s in stages} - n_batches = 0 - - for batches in tqdm( - _epoch_iters(), - leave=False, - desc="Batch", - disable=disable_tqdm, - ): - optimizer.zero_grad() - total_loss = 0.0 - - for stage, batch in zip(stages, batches): - st = states[stage] - vq_slice = batch["vq_slice"].to(device) - stage_input = batch["stage_input"].to(device) - target_map = batch["target_map"].to(device) - loss_mask = batch["loss_mask"].to(device) - struct_cond = batch["struct_cond"].to(device) - - z_q, _, _, vq_loss, _, _ = st["enc"](vq_slice) - logits = st["model_mg"](stage_input, z_q, struct_cond=struct_cond) - - ce_loss = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask) - cfg = STAGE_LOSS_CONFIG[stage] - loss = cfg["ce_weight"] * ce_loss + cfg["vq_weight"] * vq_loss - total_loss = total_loss + loss - loss_totals[stage] += loss.detach().item() - ce_totals[stage] += ce_loss.detach().item() - vq_totals[stage] += vq_loss.detach().item() - - total_loss.backward() - torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0) - optimizer.step() - n_batches += 1 - - scheduler.step() - - n = max(n_batches, 1) - total_avg = sum(loss_totals.values()) / n - stage_loss_str = " ".join( - f"S{s}[focal={ce_totals[s]/n:.4f} vq={vq_totals[s]/n:.4f}]" for s in stages - ) - lr_now = scheduler.get_last_lr()[0] - tqdm.write( - f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " - f"Epoch {epoch + 1:4d} | " - f"Total {total_avg:.4f} | {stage_loss_str} | " - f"LR {lr_now:.2e}" - ) - - # ---- 检查点 + 验证 ---- - if (epoch + 1) % args.checkpoint == 0: - # 保存联合检查点 - ckpt_data = {"epoch": epoch + 1, "optim_state": optimizer.state_dict()} - for stage in stages: - st = states[stage] - ckpt_data[f"enc{stage}"] = st["enc"].state_dict() - ckpt_data[f"mg{stage}"] = st["model_mg"].state_dict() - ckpt_path = f"result/stage{stages[-1]}/joint-{epoch + 1}.pth" - torch.save(ckpt_data, ckpt_path) - tqdm.write(f" 检查点已保存: {ckpt_path}") - - # 各阶段验证 - val_loss_total = 0.0 - for stage in stages: - st = states[stage] - vl = validate( - stage, st["enc"], st["model_mg"], - st["dataloader_val"], tile_dict, epoch + 1, - ) - val_loss_total += vl - tqdm.write(f" [Stage {stage}] Val Loss {vl:.5f}") - - # 级联自由生成(stage1→stage2→stage3) - if len(stages) == 3: - _cascade_free_validate(states, tile_dict, epoch + 1) - - for st in states.values(): - st["enc"].train() - st["model_mg"].train() - - # ---- 最终存档 ---- - ckpt_data = {"epoch": start_epoch + args.epochs} - for stage in stages: - st = states[stage] - ckpt_data[f"enc{stage}"] = st["enc"].state_dict() - ckpt_data[f"mg{stage}"] = st["model_mg"].state_dict() - torch.save(ckpt_data, "result/joint_final.pth") - print("训练结束。") - - -@torch.no_grad() -def _cascade_free_validate(states: dict, tile_dict: dict, epoch: int, n: int = 4): - """ - 三阶段级联自由生成:stage1 生成结果 → stage2 上下文 → stage3 上下文, - 最终只展示 stage3 的完整地图(已含所有 tile)。 - """ - epoch_dir = f"result/cascade_img/e{epoch:04d}" - os.makedirs(epoch_dir, exist_ok=True) - - imgs = [] - for i in range(n): - sc = make_random_struct_cond() - - # Stage 1:全 MASK → 生成 floor/wall - z1 = states[1]["enc"].sample(1, device) - init1 = make_random_wall_seed() - map1 = maskgit_generate(states[1]["model_mg"], z1, init_map=init1, struct_cond=sc) - - # Stage 2:以 stage1 结果为上下文,生成 door/monster/entrance - z2 = states[2]["enc"].sample(1, device) - init2 = make_stage_init(2, map1) - map2 = maskgit_generate(states[2]["model_mg"], z2, init_map=init2, struct_cond=sc) - - # Stage 3:以 stage2 结果为上下文,生成 resource - z3 = states[3]["enc"].sample(1, device) - init3 = make_stage_init(3, map2) - map3 = maskgit_generate(states[3]["model_mg"], z3, init_map=init3, struct_cond=sc) - - imgs.append(label_image(make_map_image(map3[0], tile_dict), f"cascade_{i+1}")) - - cv2.imwrite(f"{epoch_dir}/cascade_free.png", grid_images(imgs)) - - -# --------------------------------------------------------------------------- -if __name__ == "__main__": - torch.set_num_threads(4) - train() diff --git a/ginka/train_vq.py b/ginka/train_vq.py deleted file mode 100644 index f7d3f70..0000000 --- a/ginka/train_vq.py +++ /dev/null @@ -1,812 +0,0 @@ -""" -联合训练脚本:VQ-VAE + MaskGIT - -总损失 = L_CE(MaskGIT 重建损失)+ beta * L_commit + gamma * L_entropy - + lambda * L_consist(z 一致性约束,方案 A) - -验证阶段对四种子集(A/B/C/D)分别输出图片, -每条样本额外采样 N_Z_SAMPLES 个随机 z, -便于直观对比同条件不同 z 下的生成差异。 - -用法示例: - python -m ginka.train_vq - python -m ginka.train_vq --resume True --state result/joint/joint-10.pth -""" - -import argparse -import math -import os -import sys -import random -from datetime import datetime - -import cv2 -import numpy as np -import torch -import torch.nn.functional as F -import torch.optim as optim -from tqdm import tqdm -from torch.utils.data import DataLoader - -from .vqvae.model import GinkaVQVAE -from .maskGIT.model import GinkaMaskGIT -from .dataset import GinkaVQDataset -from shared.image import matrix_to_image_cv - -# --------------------------------------------------------------------------- -# 超参数 -# --------------------------------------------------------------------------- -BATCH_SIZE = 64 -NUM_CLASSES = 7 -MASK_TOKEN = 6 -GENERATE_STEP = 18 # 推理时 MaskGIT 迭代步数 -MAP_SIZE = 13 * 13 -MAP_H = MAP_W = 13 -FOCAL_GAMMA = 2.0 # focal loss 聚焦参数(越大越关注难例/稀有类别) -WALL_MASK_RATIO = 0.8 - -# VQ-VAE 公共超参(三路编码器共用,方案 B 三通道分拆) -VQ_L = 2 # 每路码字序列长度(三路合计 L1+L2+L3 = 6) -VQ_K = 8 # codebook 大小 -VQ_D_Z = 64 # codebook 嵌入维度(三路保持一致,便于拼接) -VQ_BETA = 0.5 # commit loss 权重 -VQ_GAMMA = 0.0 # entropy loss 权重 - -# 各通道编码器配置 -CH1_D_MODEL = 64; CH1_NHEAD = 8 # 通道 1:空间骨架(floor+wall) -CH2_D_MODEL = 64; CH2_NHEAD = 8 # 通道 2:关卡门控 -CH3_D_MODEL = 64; CH3_NHEAD = 8 # 通道 3:收集资源 -VQ_LAYERS = 3 -VQ_DIM_FF = 512 - -# 通道专属损失计算范围(用于监控验证召回率) -CH1_LOSS = {1} -CH2_LOSS = {2, 4, 5} -CH3_LOSS = {3} # 资源已压缩为单一 tile=3 - -# MaskGIT 超参 -MG_D_MODEL = 256 -MG_NHEAD = 8 -MG_LAYERS = 6 -MG_DIM_FF = 2048 -MG_Z_DROPOUT = 0.1 # 训练时以此概率把 z 替换为随机噪声 -MG_STRUCT_DROPOUT= 0.1 # 训练时以此概率将结构标签替换为 null(无条件占位) - -# 一致性约束超参(方案 A) -CONSIST_LAMBDA = 0.1 # z 一致性损失权重 -CONSIST_TEMP = 2.0 # 计算软嵌入时对 logits 施加的温度(>1 平滑分布,降低 gap) - -# 验证时对每条样本额外采样的 z 数量(0 = 只用真实 z) -N_Z_SAMPLES = 3 - -# 四个子集 A/B/C/D 的采样权重(训练集与验证集共用) -SUBSET_WEIGHTS = (0.5, 0.2, 0.2, 0.1) - -# --------------------------------------------------------------------------- -# 设备 -# --------------------------------------------------------------------------- -device = torch.device( - "cuda:1" if torch.cuda.is_available() - else "mps" if torch.backends.mps.is_available() - else "cpu" -) - -os.makedirs("result/joint", exist_ok=True) -os.makedirs("result/joint_img", exist_ok=True) - -disable_tqdm = not sys.stdout.isatty() - -# --------------------------------------------------------------------------- -# 参数解析 -# --------------------------------------------------------------------------- -def _str2bool(v: str) -> bool: - """argparse 专用:将字符串 'True'/'False' 正确转为 bool。 - type=bool 会把任何非空字符串(包括 'False')解析为 True,故需此辅助。""" - if isinstance(v, bool): - return v - if v.lower() in ('true', '1', 'yes'): - return True - if v.lower() in ('false', '0', 'no'): - return False - raise argparse.ArgumentTypeError(f"布尔值应为 True/False,收到: {v!r}") - -def parse_arguments(): - parser = argparse.ArgumentParser(description="VQ-VAE + MaskGIT 联合训练") - parser.add_argument("--resume", type=_str2bool, default=False) - parser.add_argument("--state", type=str, default="result/joint/joint-10.pth", - help="续训时加载的检查点路径") - parser.add_argument("--train", type=str, default="ginka-dataset.json") - parser.add_argument("--validate", type=str, default="ginka-eval.json") - parser.add_argument("--epochs", type=int, default=100) - parser.add_argument("--checkpoint", type=int, default=5, - help="每隔多少 epoch 保存检查点并验证") - parser.add_argument("--load_optim", type=_str2bool, default=True) - parser.add_argument("--freeze_vq", type=_str2bool, default=False, - help="(方案 B 阶段 1)冻结三路 VQ 编码器,仅训练 MaskGIT。" - "适用于预训练权重加载后的热身阶段。") - parser.add_argument("--pretrain_split", type=str, default="", - help="(方案 B)三通道分拆预训练检查点路径;" - "指定后将从该检查点加载三路编码器初始权重。") - return parser.parse_args() - -# --------------------------------------------------------------------------- -# Focal Loss -# --------------------------------------------------------------------------- -def focal_loss( - logits: torch.Tensor, - targets: torch.Tensor, - gamma: float = FOCAL_GAMMA, - reduction: str = 'none', -) -> torch.Tensor: - """ - 多分类 Focal Loss:FL = -(1 - p_t)^gamma * log(p_t) - - 相比 CE,对已被正确分类的高置信度样本施加更小的权重, - 迫使模型关注难分类的稀有 tile(门/怪/资源等)。 - - Args: - logits: [B, C, *] 未经 softmax 的原始预测 - targets: [B, *] 整数类别标签 - gamma: 聚焦参数,0 时退化为标准 CE - reduction: 'none' | 'mean' | 'sum' - """ - ce = F.cross_entropy(logits, targets, reduction='none') # [B, *] - pt = torch.exp(-ce) # 正确类的预测概率 - fl = (1.0 - pt) ** gamma * ce - if reduction == 'mean': - return fl.mean() - if reduction == 'sum': - return fl.sum() - return fl # 'none' - -# --------------------------------------------------------------------------- -# MaskGIT 推理(cosine schedule 迭代解码) -# --------------------------------------------------------------------------- -@torch.no_grad() -def maskgit_generate( - model_mg: GinkaMaskGIT, - z: torch.Tensor, - steps: int = GENERATE_STEP, - init_map: torch.Tensor = None, - struct_cond: torch.Tensor | None = None, -) -> torch.Tensor: - """ - 迭代生成地图(cosine schedule unmasking)。 - - Args: - model_mg: GinkaMaskGIT(eval 模式) - z: [B, L, d_z] 条件 z - steps: 解码步数 - init_map: [B, MAP_SIZE] 可选初始地图;非 MASK 位置在生成过程中保持固定。 - 为 None 时从全 MASK 开始(自由生成)。 - - Returns: - map_out: [B, MAP_SIZE] - """ - B = z.shape[0] - if init_map is None: - map_seq = torch.full((B, MAP_SIZE), MASK_TOKEN, device=device) - else: - map_seq = init_map.clone().to(device) - - # 记录初始 MASK 位置,这些位置才需要生成 - generatable = (map_seq == MASK_TOKEN) # [B, S] bool - - for step in range(steps): - if not generatable.any(): - break - - logits = model_mg(map_seq, z, struct_cond=struct_cond) # [B, S, C] - probs = F.softmax(logits, dim=-1) - dist = torch.distributions.Categorical(probs) - sampled = dist.sample() # [B, S] - - # 计算置信度;固定位置设为 +inf(确保不会被选为“继续保持 MASK”) - confidences = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1) - confidences = confidences.masked_fill(~generatable, float('inf')) - - ratio = math.cos(((step + 1) / steps) * math.pi / 2) - new_map = map_seq.clone() - - for b in range(B): - n_gen = int(generatable[b].sum().item()) - n_keep = int(ratio * n_gen) # 本步仍保持 MASK 的位置数 - - if n_keep > 0: - _, keep_idx = torch.topk(confidences[b], k=n_keep, largest=False) - pred_b = sampled[b].clone() - pred_b[keep_idx] = MASK_TOKEN - new_map[b] = torch.where(generatable[b], pred_b, map_seq[b]) - else: - new_map[b] = torch.where(generatable[b], sampled[b], map_seq[b]) - - map_seq = new_map - - return map_seq - -# --------------------------------------------------------------------------- -# 可视化工具 -# --------------------------------------------------------------------------- -def make_map_image(map_flat: torch.Tensor, tile_dict: dict) -> np.ndarray: - """将 [MAP_SIZE] 的 tensor 转成 RGB 图片(numpy)。""" - arr = map_flat.cpu().numpy().reshape(MAP_H, MAP_W) - return matrix_to_image_cv(arr, tile_dict) - - -def hstack_images(imgs: list, gap: int = 4, color=(255, 255, 255)) -> np.ndarray: - """将多张图片横向拼接,之间插入竖线;高度不一致时底部补齐背景色。""" - max_h = max(img.shape[0] for img in imgs) - - def _pad_h(img): - dh = max_h - img.shape[0] - if dh == 0: - return img - pad = np.full((dh, img.shape[1], 3), color, dtype=np.uint8) - return np.concatenate([img, pad], axis=0) - - vline = np.full((max_h, gap, 3), color, dtype=np.uint8) - result = _pad_h(imgs[0]) - for img in imgs[1:]: - result = np.concatenate([result, vline, _pad_h(img)], axis=1) - return result - - -def grid_images(imgs: list, gap: int = 4, bg_color=(255, 255, 255)) -> np.ndarray: - """将图片列表排成两行网格(上行 ceil(N/2),下行 floor(N/2)),方便查看。""" - n = len(imgs) - if n == 0: - return np.zeros((1, 1, 3), dtype=np.uint8) - if n == 1: - return imgs[0] - - mid = math.ceil(n / 2) - top_row = hstack_images(imgs[:mid], gap, bg_color) - bot_imgs = imgs[mid:] - - if not bot_imgs: - return top_row - - bot_row = hstack_images(bot_imgs, gap, bg_color) - - # 补齐宽度(右侧填充背景色) - tw, bw = top_row.shape[1], bot_row.shape[1] - if tw > bw: - pad = np.full((bot_row.shape[0], tw - bw, 3), bg_color, dtype=np.uint8) - bot_row = np.concatenate([bot_row, pad], axis=1) - elif bw > tw: - pad = np.full((top_row.shape[0], bw - tw, 3), bg_color, dtype=np.uint8) - top_row = np.concatenate([top_row, pad], axis=1) - - hline = np.full((gap, top_row.shape[1], 3), bg_color, dtype=np.uint8) - return np.concatenate([top_row, hline, bot_row], axis=0) - - -def label_image(img: np.ndarray, text: str, font_scale: float = 0.45) -> np.ndarray: - """在图片顶部加一行文字标签(就地修改并返回)。""" - bar_h = 16 - bar = np.full((bar_h, img.shape[1], 3), (40, 40, 40), dtype=np.uint8) - cv2.putText( - bar, text, (2, bar_h - 3), - cv2.FONT_HERSHEY_SIMPLEX, font_scale, - (200, 200, 200), 1, cv2.LINE_AA - ) - return np.concatenate([bar, img], axis=0) - - -def struct_cond_to_text(sc: torch.Tensor) -> str: - """ - 将 struct_cond [4] LongTensor 解码为可读字符串。 - - sc 顺序:[cond_sym, cond_room, cond_branch, cond_outer] - cond_sym : sym_h*4 + sym_v*2 + sym_c,取值 0-6,7=null - cond_room : roomCountLevel 0-2,3=null - cond_branch: branchLevel 0-2,3=null - cond_outer : outerWall 0-1,2=null - """ - sym_val, room_val, branch_val, outer_val = (int(x) for x in sc.tolist()) - - # 对称性 - if sym_val == 7: - sym_str = "sym:-" - else: - flags = [] - if sym_val & 4: flags.append("H") - if sym_val & 2: flags.append("V") - if sym_val & 1: flags.append("C") - sym_str = "sym:" + ("".join(flags) if flags else "none") - - # 房间数量等级 - room_map = {0: "room:lo", 1: "room:mid", 2: "room:hi", 3: "room:-"} - room_str = room_map.get(room_val, f"room:{room_val}") - - # 分支等级 - branch_map = {0: "br:lo", 1: "br:mid", 2: "br:hi", 3: "br:-"} - branch_str = branch_map.get(branch_val, f"br:{branch_val}") - - # 外墙 - outer_map = {0: "wall:N", 1: "wall:Y", 2: "wall:-"} - outer_str = outer_map.get(outer_val, f"wall:{outer_val}") - - return f"{sym_str} {room_str} {branch_str} {outer_str}" - - -def annotate_struct(img: np.ndarray, sc: torch.Tensor) -> np.ndarray: - """在图片底部追加一行结构标签注释(深蓝底白字)。""" - text = struct_cond_to_text(sc) - bar_h = 14 - bar = np.full((bar_h, img.shape[1], 3), (60, 30, 10), dtype=np.uint8) - cv2.putText( - bar, text, (2, bar_h - 3), - cv2.FONT_HERSHEY_SIMPLEX, 0.38, - (180, 220, 255), 1, cv2.LINE_AA - ) - return np.concatenate([img, bar], axis=0) - - -def make_random_wall_seed(ratio_min: float = 0.02, ratio_max: float = 0.08) -> torch.Tensor: - """ - 在全 MASK 地图上随机放置少量墙壁作为推理种子,用于完全随机生成场景。 - - Returns: - [1, MAP_SIZE] MASK=15 背景 + 随机置少量墙壁(tile=1) - """ - ratio = random.uniform(ratio_min, ratio_max) - n_wall = max(2, int(MAP_SIZE * ratio)) - seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device) - idx = torch.randperm(MAP_SIZE)[:n_wall] - seed[0, idx] = 1 # wall - return seed - - -def make_random_struct_cond() -> torch.Tensor: - """ - 生成一个随机结构条件,所有标签均取合法非-null 值。 - - Returns: - [1, 4] LongTensor,顺序 [cond_sym, cond_room, cond_branch, cond_outer] - """ - from .maskGIT.model import SYM_VOCAB, ROOM_VOCAB, BRANCH_VOCAB, OUTER_VOCAB - sym = random.randint(0, SYM_VOCAB - 2) # 0-6 - room = random.randint(0, ROOM_VOCAB - 2) # 0-2 - branch = random.randint(0, BRANCH_VOCAB - 2) # 0-2 - outer = random.randint(0, OUTER_VOCAB - 2) # 0-1 - return torch.tensor([[sym, room, branch, outer]], dtype=torch.long, device=device) - -@torch.no_grad() -def validate( - enc1: GinkaVQVAE, - enc2: GinkaVQVAE, - enc3: GinkaVQVAE, - model_mg: GinkaMaskGIT, - dataloader_val: DataLoader, - tile_dict: dict, - epoch: int, -): - """ - 验证函数:计算 val loss 并输出 5 类推理场景的对比图。 - - 场景说明(按 epoch 建立子文件夹,避免图片堆积): - 场景1 (scene1_completion) : 子集 A,标准随机掩码补全 - 列: ground truth | masked input | z_real pred | z_real gen | z_rand×N - 场景2 (scene2_wall) : 子集 B,仅墙壁+空地 → 生成完整地图 - 列: ground truth | wall-only input | z_real gen | z_rand×N - 场景3 (scene3_sparse) : 子集 C,稀疏墙壁条件 → 生成完整地图 - 列: ground truth | sparse wall input | z_real gen | z_rand×N - 场景4 (scene4_entrance) : 子集 D,墙壁+入口 → 生成完整地图 - 列: ground truth | wall+entrance input | z_real gen | z_rand×N - 场景5 (scene5_random) : 无数据集参照,随机稀疏墙壁种子 → 完全随机生成 - 列: random seed | z_rand×(N+1) - """ - for enc in [enc1, enc2, enc3]: - enc.eval() - model_mg.eval() - - # 按 epoch 建立独立子文件夹,保留每次验证结果方便回溯 - epoch_dir = f"result/joint_img/e{epoch:04d}" - os.makedirs(epoch_dir, exist_ok=True) - - val_loss_total = 0.0 - val_steps = 0 - captured = {s: None for s in ('A', 'B', 'C', 'D')} - - # ── 计算 val loss + 捕获各子集样本 ────────────────────────────────────── - def _encode_three(s1, s2, s3): - """三路编码并拼接 z_q。""" - z_q1, _, _, vq1, _, _ = enc1(s1) - z_q2, _, _, vq2, _, _ = enc2(s2) - z_q3, _, _, vq3, _, _ = enc3(s3) - z_q = torch.cat([z_q1, z_q2, z_q3], dim=1) # [B, L1+L2+L3, d_z] - vq_loss = vq1 + vq2 + vq3 - return z_q, vq_loss - - def _sample_three(B_size): - """三路随机采样并拼接 z。""" - z1 = enc1.sample(B_size, device) - z2 = enc2.sample(B_size, device) - z3 = enc3.sample(B_size, device) - return torch.cat([z1, z2, z3], dim=1) - - for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm): - raw_map = batch["raw_map"].to(device) # [B, 169] - masked_map = batch["masked_map"].to(device) # [B, 169] - target_map = batch["target_map"].to(device) # [B, 169] - s1 = batch["slice1"].to(device) - s2 = batch["slice2"].to(device) - s3 = batch["slice3"].to(device) - subsets = batch["subset"] # list of str - B = raw_map.shape[0] - - z_q, vq_loss = _encode_three(s1, s2, s3) - struct_cond_b = batch["struct_cond"].to(device) # [B, 4] - logits = model_mg(masked_map, z_q, struct_cond=struct_cond_b) - mask = (masked_map == MASK_TOKEN) - - ce_loss = focal_loss(logits.permute(0, 2, 1), target_map) - masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6) - val_loss_total += (masked_ce + vq_loss).item() - val_steps += 1 - - for i in range(B): - s = subsets[i] - if captured[s] is None: - captured[s] = { - "raw": raw_map[i:i+1].clone(), - "masked": masked_map[i:i+1].clone(), - "z_q": z_q[i:i+1].clone(), - "struct_cond": struct_cond_b[i:i+1].clone(), - } - - if all(v is not None for v in captured.values()): - break - - # ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成(无条件)────────────── - def _rand_gens(cond_map, n): - imgs = [] - for i in range(n): - z_r = _sample_three(1) - gen = maskgit_generate(model_mg, z_r, init_map=cond_map) # struct_cond=None 无条件 - imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}")) - return imgs - - # ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成(随机结构标签)──────── - def _rand_gens_with_struct(cond_map, n): - imgs = [] - for i in range(n): - z_r = _sample_three(1) - sc_r = make_random_struct_cond() # [1, 4] 随机合法标签 - gen = maskgit_generate(model_mg, z_r, init_map=cond_map, struct_cond=sc_r) - img = label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}") - img = annotate_struct(img, sc_r[0]) - imgs.append(img) - return imgs - - # ── 场景1:标准掩码补全(子集 A)───────────────────────────────────────── - if captured['A'] is not None: - cap = captured['A'] - raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond'] - - real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth") - cond_img = label_image(make_map_image(cond[0], tile_dict), "masked input") - - # 单步 argmax 预测(观察模型对掩码位置的瞬时判断) - pred = model_mg(cond, z_q, struct_cond=sc).argmax(dim=-1)[0] - pred_img = label_image(make_map_image(pred, tile_dict), "z_real pred") - - # 迭代生成(从掩码输入出发,真实 z) - gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) - gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") - - # 对使用了真实 struct_cond 的图片追加标签注释 - sc0 = sc[0] - real_img = annotate_struct(real_img, sc0) - cond_img = annotate_struct(cond_img, sc0) - pred_img = annotate_struct(pred_img, sc0) - gen_r_img = annotate_struct(gen_r_img, sc0) - - row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) - cv2.imwrite(f"{epoch_dir}/scene1_completion.png", grid_images(row)) - - # ── 场景2:墙壁辅助生成(子集 B)───────────────────────────────────────── - if captured['B'] is not None: - cap = captured['B'] - raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond'] - - real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth") - cond_img = label_image(make_map_image(cond[0], tile_dict), "wall-only input") - gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) - gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") - - sc0 = sc[0] - real_img = annotate_struct(real_img, sc0) - cond_img = annotate_struct(cond_img, sc0) - gen_r_img = annotate_struct(gen_r_img, sc0) - - row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) - cv2.imwrite(f"{epoch_dir}/scene2_wall.png", grid_images(row)) - - # ── 场景3:稀疏墙壁条件生成(子集 C)──────────────────────────────────── - if captured['C'] is not None: - cap = captured['C'] - raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond'] - - real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth") - cond_img = label_image(make_map_image(cond[0], tile_dict), "sparse wall input") - gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) - gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") - - sc0 = sc[0] - real_img = annotate_struct(real_img, sc0) - cond_img = annotate_struct(cond_img, sc0) - gen_r_img = annotate_struct(gen_r_img, sc0) - - row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) - cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", grid_images(row)) - - # ── 场景4:墙壁+入口条件生成(子集 D)─────────────────────────────────── - if captured['D'] is not None: - cap = captured['D'] - raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond'] - - real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth") - cond_img = label_image(make_map_image(cond[0], tile_dict), "wall+entrance input") - gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc) - gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen") - - sc0 = sc[0] - real_img = annotate_struct(real_img, sc0) - cond_img = annotate_struct(cond_img, sc0) - gen_r_img = annotate_struct(gen_r_img, sc0) - - row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES) - cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", grid_images(row)) - - # ── 场景5:完全随机生成(无数据集参照)────────────────────────────────── - # 5a:随机结构标签 — 验证结构导向能力 - rand_seed_a = make_random_wall_seed() - seed_img_a = label_image(make_map_image(rand_seed_a[0], tile_dict), "random seed") - row_a = [seed_img_a] + _rand_gens_with_struct(rand_seed_a, N_Z_SAMPLES + 1) - cv2.imwrite(f"{epoch_dir}/scene5a_random_cond.png", grid_images(row_a)) - - # 5b:无条件(struct_cond=None)— 验证基线生成质量 - rand_seed_b = make_random_wall_seed() - seed_img_b = label_image(make_map_image(rand_seed_b[0], tile_dict), "random seed") - row_b = [seed_img_b] + _rand_gens(rand_seed_b, N_Z_SAMPLES + 1) - cv2.imwrite(f"{epoch_dir}/scene5b_random_uncond.png", grid_images(row_b)) - - avg_val_loss = val_loss_total / max(val_steps, 1) - return avg_val_loss - -# --------------------------------------------------------------------------- -# 主训练函数 -# --------------------------------------------------------------------------- -def train(): - print(f"Using device: {device}") - args = parse_arguments() - - # ---- 三路编码器(方案 B 三通道分拆) ---- - _vq_common = dict( - num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z, - num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, map_size=MAP_SIZE, - beta=VQ_BETA, gamma=VQ_GAMMA, - ) - enc1 = GinkaVQVAE(d_model=CH1_D_MODEL, nhead=CH1_NHEAD, **_vq_common).to(device) - enc2 = GinkaVQVAE(d_model=CH2_D_MODEL, nhead=CH2_NHEAD, **_vq_common).to(device) - enc3 = GinkaVQVAE(d_model=CH3_D_MODEL, nhead=CH3_NHEAD, **_vq_common).to(device) - - model_mg = GinkaMaskGIT( - num_classes=NUM_CLASSES, - d_model=MG_D_MODEL, d_z=VQ_D_Z, - dim_ff=MG_DIM_FF, nhead=MG_NHEAD, - num_layers=MG_LAYERS, - map_size=MAP_SIZE, - z_dropout=MG_Z_DROPOUT, - struct_dropout=MG_STRUCT_DROPOUT, - ).to(device) - - enc_params = sum(p.numel() for m in [enc1, enc2, enc3] for p in m.parameters()) - mg_params = sum(p.numel() for p in model_mg.parameters()) - print(f"Encoders 参数量(三路): {enc_params:,} ({enc_params/1e6:.3f}M)") - print(f"MaskGIT 参数量: {mg_params:,} ({mg_params/1e6:.3f}M)") - print(f"Total 参数量: {enc_params+mg_params:,} ({(enc_params+mg_params)/1e6:.3f}M)") - - # ---- 数据集 ---- - dataset_train = GinkaVQDataset( - args.train, - subset_weights=SUBSET_WEIGHTS, - wall_mask_ratio=WALL_MASK_RATIO, - ) - dataset_val = GinkaVQDataset( - args.validate, - subset_weights=SUBSET_WEIGHTS, - room_thresholds=dataset_train.room_th, - branch_thresholds=dataset_train.branch_th, - wall_mask_ratio=WALL_MASK_RATIO, - ) - dataloader_train = DataLoader( - dataset_train, batch_size=BATCH_SIZE, shuffle=True, - num_workers=0, pin_memory=(device.type == "cuda"), - ) - dataloader_val = DataLoader( - dataset_val, batch_size=8, shuffle=True, - num_workers=0, - ) - - # ---- 优化器(联合训练,三路编码器 + MaskGIT 共用)---- - enc_params_list = list(enc1.parameters()) + list(enc2.parameters()) + list(enc3.parameters()) - all_params = enc_params_list + list(model_mg.parameters()) - optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2) - scheduler = optim.lr_scheduler.CosineAnnealingLR( - optimizer, T_max=args.epochs, eta_min=1e-6 - ) - - # ---- 权重加载 ---- - start_epoch = 0 - if args.pretrain_split: - # 从分拆预训练检查点加载三路编码器初始权重(阶段 1 冻结热身前) - ckpt = torch.load(args.pretrain_split, map_location=device) - enc1.load_state_dict(ckpt["enc1"]) - enc2.load_state_dict(ckpt["enc2"]) - enc3.load_state_dict(ckpt["enc3"]) - print(f"已加载分拆预训练编码器权重: {args.pretrain_split}") - elif args.resume: - ckpt = torch.load(args.state, map_location=device) - enc1.load_state_dict(ckpt["enc1"], strict=False) - enc2.load_state_dict(ckpt["enc2"], strict=False) - enc3.load_state_dict(ckpt["enc3"], strict=False) - model_mg.load_state_dict(ckpt["mg_state"], strict=False) - if args.load_optim and ckpt.get("optim_state") is not None: - optimizer.load_state_dict(ckpt["optim_state"]) - start_epoch = ckpt.get("epoch", 0) - print(f"从 epoch {start_epoch} 接续训练。") - - # ---- tile 贴图(用于验证可视化)---- - tile_dict = {} - for file in os.listdir("tiles"): - name = os.path.splitext(file)[0] - img = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) - if img is not None: - tile_dict[name] = img - - # ---- 方案 B 阶段 1:冻结三路 VQ 编码器 ---- - if args.freeze_vq: - for enc in [enc1, enc2, enc3]: - for p in enc.parameters(): - p.requires_grad_(False) - print("三路 VQ 编码器已冻结(阶段 1:MaskGIT 热身)。") - - # ---- 训练循环 ---- - for epoch in tqdm(range(start_epoch, start_epoch + args.epochs), - desc="Joint Training", disable=disable_tqdm): - for enc in [enc1, enc2, enc3]: - enc.train() - model_mg.train() - - loss_total = 0.0 - ce_total = 0.0 - vq_loss_total = 0.0 - commit_total = 0.0 - entropy_total = 0.0 - consist_total = 0.0 - subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0} - - for batch in tqdm(dataloader_train, leave=False, - desc="Epoch Progress", disable=disable_tqdm): - raw_map = batch["raw_map"].to(device) # [B, 169] - masked_map = batch["masked_map"].to(device) # [B, 169] - target_map = batch["target_map"].to(device) # [B, 169] - s1 = batch["slice1"].to(device) # 通道 1 切片 - s2 = batch["slice2"].to(device) # 通道 2 切片 - s3 = batch["slice3"].to(device) # 通道 3 切片 - - for s in batch["subset"]: - subset_stats[s] = subset_stats.get(s, 0) + 1 - - # ---- 前向传播 ---- - # 1. 三路 VQ 编码器各自编码对应切片 → 拼接 z - z_q1, z_e1, _, vq_loss1, commit_loss1, entropy_loss1 = enc1(s1) - z_q2, z_e2, _, vq_loss2, commit_loss2, entropy_loss2 = enc2(s2) - z_q3, z_e3, _, vq_loss3, commit_loss3, entropy_loss3 = enc3(s3) - z_q = torch.cat([z_q1, z_q2, z_q3], dim=1) # [B, L1+L2+L3, d_z] - z_e = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L1+L2+L3, d_z] - vq_loss = vq_loss1 + vq_loss2 + vq_loss3 - commit_loss = commit_loss1 + commit_loss2 + commit_loss3 - entropy_loss = entropy_loss1 + entropy_loss2 + entropy_loss3 - - # 2. MaskGIT 以掩码地图 + z + 结构标签预测原始 tile - struct_cond = batch["struct_cond"].to(device) # [B, 4] - logits = model_mg(masked_map, z_q, struct_cond=struct_cond) # [B, 169, C] - - # 3. 只对被 mask 的位置计算 focal loss(缓解墙壁/空地主导问题) - mask = (masked_map == MASK_TOKEN) # [B, 169] bool - ce_loss = focal_loss(logits.permute(0, 2, 1), target_map) - masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6) - - # 4. z 一致性约束(方案 A 扩展到三通道): - # MaskGIT logits 经温度平滑后与各编码器的 tile embedding 做加权求和, - # 得到软嵌入 → 各编码器再次编码 → z_pred_e_k 与真实 z_e_k 对齐。 - # 编码器权重在此路径上临时冻结,确保梯度仅回传至 MaskGIT。 - for enc in [enc1, enc2, enc3]: - for p in enc.parameters(): - p.requires_grad_(False) - - soft_probs = F.softmax(logits / CONSIST_TEMP, dim=-1) # [B, H*W, V] - z_pred_e1 = enc1.encode_soft(soft_probs @ enc1.tile_embedding.weight) - z_pred_e2 = enc2.encode_soft(soft_probs @ enc2.tile_embedding.weight) - z_pred_e3 = enc3.encode_soft(soft_probs @ enc3.tile_embedding.weight) - z_pred_e = torch.cat([z_pred_e1, z_pred_e2, z_pred_e3], dim=1) - consist_loss = F.mse_loss(z_pred_e, z_e.detach()) - - if not args.freeze_vq: - for enc in [enc1, enc2, enc3]: - for p in enc.parameters(): - p.requires_grad_(True) - - # 5. 联合损失 - loss = masked_ce + vq_loss + CONSIST_LAMBDA * consist_loss - - optimizer.zero_grad() - loss.backward() - torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0) - optimizer.step() - - loss_total += loss.detach().item() - ce_total += masked_ce.detach().item() - vq_loss_total += vq_loss.detach().item() - commit_total += commit_loss.detach().item() - entropy_total += entropy_loss.detach().item() - consist_total += consist_loss.detach().item() - - scheduler.step() - - n = len(dataloader_train) - tqdm.write( - f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " - f"Epoch {epoch + 1:4d} | " - f"Loss {loss_total/n:.5f} " - f"Focal {ce_total/n:.5f} " - f"VQ {vq_loss_total/n:.5f} " - f"Commit {commit_total/n:.5f} " - f"Entropy {entropy_total/n:.5f} " - f"Consist {consist_total/n:.5f} | " - f"LR {scheduler.get_last_lr()[0]:.6f} | " - f"Subsets {subset_stats}" - ) - - # ---- 检查点 + 验证 ---- - if (epoch + 1) % args.checkpoint == 0: - ckpt_path = f"result/joint/joint-{epoch + 1}.pth" - torch.save({ - "epoch": epoch + 1, - "enc1": enc1.state_dict(), - "enc2": enc2.state_dict(), - "enc3": enc3.state_dict(), - "mg_state": model_mg.state_dict(), - "optim_state":optimizer.state_dict(), - }, ckpt_path) - tqdm.write(f" 检查点已保存: {ckpt_path}") - - val_loss = validate( - enc1, enc2, enc3, model_mg, dataloader_val, tile_dict, epoch + 1 - ) - tqdm.write( - f"[Validate] Epoch {epoch + 1:4d} | Val Loss {val_loss:.5f}" - ) - # 恢复训练模式 - for enc in [enc1, enc2, enc3]: - enc.train() - model_mg.train() - - print("训练结束。") - torch.save({ - "epoch": start_epoch + args.epochs, - "enc1": enc1.state_dict(), - "enc2": enc2.state_dict(), - "enc3": enc3.state_dict(), - "mg_state": model_mg.state_dict(), - }, "result/joint/joint_final.pth") - - -# --------------------------------------------------------------------------- -if __name__ == "__main__": - torch.set_num_threads(4) - train()