mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 20:41:12 +08:00
chore: 删除旧训练代码
This commit is contained in:
parent
bf3d24e680
commit
8f087fe138
@ -1,380 +0,0 @@
|
||||
"""
|
||||
三通道分拆预训练脚本(方案 B)
|
||||
|
||||
三路编码器各自负责一个语义通道:
|
||||
通道 1:空间骨架(floor+wall),损失仅计算 wall(1) 位置
|
||||
通道 2:关卡门控(floor+wall+door+mob+entrance),损失仅计算 {2,9,10} 位置
|
||||
通道 3:收集资源(完整地图),损失仅计算 {3,4,5,6,7,8} 位置
|
||||
|
||||
预训练完成后保存各通道编码器权重(不含解码头),
|
||||
供联合训练脚本 train_vq.py 加载并拼接 z。
|
||||
|
||||
用法示例:
|
||||
python -m ginka.train_pretrain_split
|
||||
python -m ginka.train_pretrain_split --resume True --state result/pretrain_split/split-10.pth
|
||||
# 预训练完成后指定权重路径启动联合训练:
|
||||
python -m ginka.train_vq --pretrain_split result/pretrain_split/split_final.pth
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from .vqvae.model import GinkaVQVAE, VQDecodeHead
|
||||
from .dataset import GinkaSplitDataset
|
||||
from .utils import masked_focal
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 超参数
|
||||
# ---------------------------------------------------------------------------
|
||||
BATCH_SIZE = 64
|
||||
NUM_CLASSES = 7
|
||||
MAP_SIZE = 13 * 13
|
||||
FOCAL_GAMMA = 1.0
|
||||
|
||||
# 通道 1:空间骨架(floor+wall)
|
||||
CH1_KEEP = {0, 1} # 编码器输入保留的 tile
|
||||
CH1_LOSS = {0, 1} # 损失计算范围(仅 wall)
|
||||
CH1_D_MODEL = 64
|
||||
CH1_NHEAD = 8
|
||||
|
||||
# 通道 2:关卡门控
|
||||
CH2_KEEP = {0, 1, 2, 4, 5}
|
||||
CH2_LOSS = {2, 4, 5}
|
||||
CH2_D_MODEL = 64
|
||||
CH2_NHEAD = 8
|
||||
|
||||
# 通道 3:收集资源
|
||||
CH3_KEEP = None # 完整地图,无需切片
|
||||
CH3_LOSS = {3}
|
||||
CH3_D_MODEL = 64
|
||||
CH3_NHEAD = 8
|
||||
|
||||
# 三路共用的 VQ 超参
|
||||
VQ_L = 2
|
||||
VQ_K = 8
|
||||
VQ_D_Z = 64
|
||||
VQ_LAYERS = 3
|
||||
VQ_DIM_FF = 512
|
||||
VQ_BETA = 0.5 # commit loss 权重
|
||||
VQ_GAMMA = 0.0 # entropy loss 权重
|
||||
|
||||
# 解码头超参(三路共用相同规格)
|
||||
DH_NHEAD = 8
|
||||
DH_DIM_FF = 512
|
||||
DH_LAYERS = 3
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 设备
|
||||
# ---------------------------------------------------------------------------
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
else "mps" if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
os.makedirs("result/pretrain_split", exist_ok=True)
|
||||
|
||||
disable_tqdm = not sys.stdout.isatty()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 参数解析
|
||||
# ---------------------------------------------------------------------------
|
||||
def _str2bool(v: str) -> bool:
|
||||
"""argparse 专用:将字符串 'True'/'False' 正确转为 bool。
|
||||
type=bool 会把任何非空字符串(包括 'False')解析为 True,故需此辅助。"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ('true', '1', 'yes'):
|
||||
return True
|
||||
if v.lower() in ('false', '0', 'no'):
|
||||
return False
|
||||
raise argparse.ArgumentTypeError(f"布尔值应为 True/False,收到: {v!r}")
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="三通道分拆 VQ 编码器预训练(方案 B)")
|
||||
parser.add_argument("--resume", type=_str2bool, default=False)
|
||||
parser.add_argument("--state", type=str, default="result/pretrain_split/split-10.pth",
|
||||
help="续训时加载的检查点路径")
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
||||
parser.add_argument("--epochs", type=int, default=60)
|
||||
parser.add_argument("--checkpoint", type=int, default=5,
|
||||
help="每隔多少 epoch 保存检查点并输出验证指标")
|
||||
parser.add_argument("--load_optim", type=_str2bool, default=True)
|
||||
return parser.parse_args()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 验证:各通道专属 tile 召回率 + codebook 使用熵
|
||||
# ---------------------------------------------------------------------------
|
||||
@torch.no_grad()
|
||||
def validate(
|
||||
enc1, enc2, enc3,
|
||||
head1, head2, head3,
|
||||
dataloader_val: DataLoader,
|
||||
) -> dict:
|
||||
for m in [enc1, enc2, enc3, head1, head2, head3]:
|
||||
m.eval()
|
||||
|
||||
# 每类 tile 的 tp / gt 计数
|
||||
ch1_tp, ch1_gt = 0, 0 # wall(1)
|
||||
ch2_tp = {t: 0 for t in CH2_LOSS} # {2,4,5}
|
||||
ch2_gt = {t: 0 for t in CH2_LOSS}
|
||||
ch3_tp = {t: 0 for t in CH3_LOSS} # {3,4,5}
|
||||
ch3_gt = {t: 0 for t in CH3_LOSS}
|
||||
|
||||
# codebook 使用频次(用于熵估算)
|
||||
codebook_counts = [
|
||||
torch.zeros(VQ_K, dtype=torch.long), # 通道 1
|
||||
torch.zeros(VQ_K, dtype=torch.long), # 通道 2
|
||||
torch.zeros(VQ_K, dtype=torch.long), # 通道 3
|
||||
]
|
||||
|
||||
for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
|
||||
raw_map = batch["raw_map"].to(device)
|
||||
s1 = batch["slice1"].to(device)
|
||||
s2 = batch["slice2"].to(device)
|
||||
s3 = batch["slice3"].to(device)
|
||||
|
||||
# 通道 1
|
||||
z_q1, _, idx1, _, _, _ = enc1(s1)
|
||||
logits1 = head1(z_q1, torch.zeros_like(raw_map))
|
||||
pred1 = logits1.argmax(dim=-1) # [B, H*W]
|
||||
wall_m = (raw_map == 1)
|
||||
ch1_tp += (pred1[wall_m] == 1).sum().item()
|
||||
ch1_gt += wall_m.sum().item()
|
||||
for code in idx1.view(-1).cpu():
|
||||
codebook_counts[0][code] += 1
|
||||
|
||||
# 通道 2
|
||||
z_q2, _, idx2, _, _, _ = enc2(s2)
|
||||
logits2 = head2(z_q2, s1)
|
||||
pred2 = logits2.argmax(dim=-1)
|
||||
for t in CH2_LOSS:
|
||||
m = (raw_map == t)
|
||||
ch2_tp[t] += (pred2[m] == t).sum().item()
|
||||
ch2_gt[t] += m.sum().item()
|
||||
for code in idx2.view(-1).cpu():
|
||||
codebook_counts[1][code] += 1
|
||||
|
||||
# 通道 3
|
||||
z_q3, _, idx3, _, _, _ = enc3(s3)
|
||||
logits3 = head3(z_q3, s2)
|
||||
pred3 = logits3.argmax(dim=-1)
|
||||
for t in CH3_LOSS:
|
||||
m = (raw_map == t)
|
||||
ch3_tp[t] += (pred3[m] == t).sum().item()
|
||||
ch3_gt[t] += m.sum().item()
|
||||
for code in idx3.view(-1).cpu():
|
||||
codebook_counts[2][code] += 1
|
||||
|
||||
def _entropy(counts):
|
||||
"""估算 codebook 使用熵(bits)。"""
|
||||
counts = counts.float() + 1e-8
|
||||
p = counts / counts.sum()
|
||||
return float(-(p * torch.log2(p)).sum())
|
||||
|
||||
metrics = {
|
||||
"ch1_wall_recall": ch1_tp / max(ch1_gt, 1),
|
||||
"ch2_recall": {t: ch2_tp[t] / max(ch2_gt[t], 1) for t in CH2_LOSS},
|
||||
"ch3_recall": {t: ch3_tp[t] / max(ch3_gt[t], 1) for t in CH3_LOSS},
|
||||
"codebook_entropy": [_entropy(c) for c in codebook_counts],
|
||||
}
|
||||
return metrics
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 主训练函数
|
||||
# ---------------------------------------------------------------------------
|
||||
def train():
|
||||
print(f"Using device: {device}")
|
||||
args = parse_arguments()
|
||||
|
||||
# ---- 三路编码器 ----
|
||||
enc1 = GinkaVQVAE(
|
||||
num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z,
|
||||
d_model=CH1_D_MODEL, nhead=CH1_NHEAD, num_layers=VQ_LAYERS,
|
||||
dim_ff=VQ_DIM_FF, beta=VQ_BETA, gamma=VQ_GAMMA,
|
||||
).to(device)
|
||||
|
||||
enc2 = GinkaVQVAE(
|
||||
num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z,
|
||||
d_model=CH2_D_MODEL, nhead=CH2_NHEAD, num_layers=VQ_LAYERS,
|
||||
dim_ff=VQ_DIM_FF, beta=VQ_BETA, gamma=VQ_GAMMA,
|
||||
).to(device)
|
||||
|
||||
enc3 = GinkaVQVAE(
|
||||
num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z,
|
||||
d_model=CH3_D_MODEL, nhead=CH3_NHEAD, num_layers=VQ_LAYERS,
|
||||
dim_ff=VQ_DIM_FF, beta=VQ_BETA, gamma=VQ_GAMMA,
|
||||
).to(device)
|
||||
|
||||
# ---- 三路解码头(预训练专用,训练后丢弃)----
|
||||
head1 = VQDecodeHead(
|
||||
num_classes=NUM_CLASSES, d_z=VQ_D_Z, map_size=MAP_SIZE,
|
||||
nhead=DH_NHEAD, dim_ff=DH_DIM_FF, num_layers=DH_LAYERS,
|
||||
).to(device)
|
||||
|
||||
head2 = VQDecodeHead(
|
||||
num_classes=NUM_CLASSES, d_z=VQ_D_Z, map_size=MAP_SIZE,
|
||||
nhead=DH_NHEAD, dim_ff=DH_DIM_FF, num_layers=DH_LAYERS,
|
||||
).to(device)
|
||||
|
||||
head3 = VQDecodeHead(
|
||||
num_classes=NUM_CLASSES, d_z=VQ_D_Z, map_size=MAP_SIZE,
|
||||
nhead=DH_NHEAD, dim_ff=DH_DIM_FF, num_layers=DH_LAYERS,
|
||||
).to(device)
|
||||
|
||||
# ---- 优化器(三路同步训练) ----
|
||||
optimizer = optim.AdamW(
|
||||
list(enc1.parameters()) + list(enc2.parameters()) + list(enc3.parameters()) +
|
||||
list(head1.parameters()) + list(head2.parameters()) + list(head3.parameters()),
|
||||
lr=1e-3,
|
||||
weight_decay=1e-4,
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
|
||||
# ---- 续训 ----
|
||||
if args.resume:
|
||||
ckpt = torch.load(args.state, map_location=device)
|
||||
enc1.load_state_dict(ckpt["enc1"])
|
||||
enc2.load_state_dict(ckpt["enc2"])
|
||||
enc3.load_state_dict(ckpt["enc3"])
|
||||
head1.load_state_dict(ckpt["head1"])
|
||||
head2.load_state_dict(ckpt["head2"])
|
||||
head3.load_state_dict(ckpt["head3"])
|
||||
if args.load_optim and "optimizer" in ckpt:
|
||||
optimizer.load_state_dict(ckpt["optimizer"])
|
||||
start_epoch = ckpt.get("epoch", 0)
|
||||
print(f"Resumed from epoch {start_epoch}: {args.state}")
|
||||
|
||||
# ---- 数据集 ----
|
||||
ds_train = GinkaSplitDataset(args.train)
|
||||
ds_val = GinkaSplitDataset(args.validate)
|
||||
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
|
||||
dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
|
||||
|
||||
print(f"训练集大小: {len(ds_train)},验证集大小: {len(ds_val)}")
|
||||
|
||||
total_params = (
|
||||
sum(p.numel() for p in enc1.parameters()) +
|
||||
sum(p.numel() for p in enc2.parameters()) +
|
||||
sum(p.numel() for p in enc3.parameters())
|
||||
)
|
||||
print(f"编码器总参数量(三路): {total_params:,} ({total_params / 1e6:.3f}M)")
|
||||
total_params = (
|
||||
sum(p.numel() for p in head1.parameters()) +
|
||||
sum(p.numel() for p in head2.parameters()) +
|
||||
sum(p.numel() for p in head3.parameters())
|
||||
)
|
||||
print(f"解码器总参数量(三路): {total_params:,} ({total_params / 1e6:.3f}M)")
|
||||
|
||||
# ---- 训练循环 ----
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
for m in [enc1, enc2, enc3, head1, head2, head3]:
|
||||
m.train()
|
||||
|
||||
total_loss = 0.0
|
||||
ch_losses = [0.0, 0.0, 0.0]
|
||||
|
||||
for batch in tqdm(dl_train, desc=f"Epoch {epoch + 1}/{args.epochs}", disable=disable_tqdm):
|
||||
raw_map = batch["raw_map"].to(device)
|
||||
s1 = batch["slice1"].to(device)
|
||||
s2 = batch["slice2"].to(device)
|
||||
s3 = batch["slice3"].to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# ─── 通道 1 ───
|
||||
z_q1, _, _, vq_loss1, commit_loss1, entropy_loss1 = enc1(s1)
|
||||
logits1 = head1(z_q1, torch.zeros_like(raw_map)) # [B, H*W, C]
|
||||
fl1 = masked_focal(logits1, raw_map, CH1_LOSS, gamma=FOCAL_GAMMA)
|
||||
loss1 = fl1 + VQ_BETA * commit_loss1 + VQ_GAMMA * entropy_loss1
|
||||
|
||||
# ─── 通道 2 ───
|
||||
z_q2, _, _, vq_loss2, commit_loss2, entropy_loss2 = enc2(s2)
|
||||
logits2 = head2(z_q2, s1)
|
||||
fl2 = masked_focal(logits2, raw_map, CH2_LOSS, gamma=FOCAL_GAMMA)
|
||||
loss2 = fl2 + VQ_BETA * commit_loss2 + VQ_GAMMA * entropy_loss2
|
||||
|
||||
# ─── 通道 3 ───
|
||||
z_q3, _, _, vq_loss3, commit_loss3, entropy_loss3 = enc3(s3)
|
||||
logits3 = head3(z_q3, s2)
|
||||
fl3 = masked_focal(logits3, raw_map, CH3_LOSS, gamma=FOCAL_GAMMA)
|
||||
loss3 = fl3 + VQ_BETA * commit_loss3 + VQ_GAMMA * entropy_loss3
|
||||
|
||||
loss = loss1 + loss2 + loss3
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
list(enc1.parameters()) + list(enc2.parameters()) + list(enc3.parameters()) +
|
||||
list(head1.parameters()) + list(head2.parameters()) + list(head3.parameters()),
|
||||
max_norm=1.0,
|
||||
)
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
ch_losses[0] += loss1.item()
|
||||
ch_losses[1] += loss2.item()
|
||||
ch_losses[2] += loss3.item()
|
||||
|
||||
n_batches = len(dl_train)
|
||||
print(
|
||||
f"[{epoch + 1:03d}] total={total_loss / n_batches:.4f} "
|
||||
f"ch1={ch_losses[0] / n_batches:.4f} "
|
||||
f"ch2={ch_losses[1] / n_batches:.4f} "
|
||||
f"ch3={ch_losses[2] / n_batches:.4f}"
|
||||
)
|
||||
|
||||
# ---- 检查点 & 验证 ----
|
||||
if (epoch + 1) % args.checkpoint == 0 or epoch + 1 == args.epochs:
|
||||
metrics = validate(enc1, enc2, enc3, head1, head2, head3, dl_val)
|
||||
print(
|
||||
f" 验证 ch1_wall_recall={metrics['ch1_wall_recall']:.3f} "
|
||||
f"ch2_recall={metrics['ch2_recall']} "
|
||||
f"ch3_recall={metrics['ch3_recall']}"
|
||||
)
|
||||
print(
|
||||
f" codebook_entropy ch1={metrics['codebook_entropy'][0]:.3f} "
|
||||
f"ch2={metrics['codebook_entropy'][1]:.3f} "
|
||||
f"ch3={metrics['codebook_entropy'][2]:.3f}"
|
||||
)
|
||||
|
||||
ts = datetime.now().strftime("%m%d-%H%M")
|
||||
ckpt_path = f"result/pretrain_split/split-{epoch + 1}.pth"
|
||||
torch.save({
|
||||
"epoch": epoch + 1,
|
||||
"enc1": enc1.state_dict(),
|
||||
"enc2": enc2.state_dict(),
|
||||
"enc3": enc3.state_dict(),
|
||||
"head1": head1.state_dict(),
|
||||
"head2": head2.state_dict(),
|
||||
"head3": head3.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"metrics": metrics,
|
||||
"ts": ts,
|
||||
}, ckpt_path)
|
||||
print(f" Saved checkpoint: {ckpt_path}")
|
||||
|
||||
# ---- 保存最终编码器权重(供联合训练加载) ----
|
||||
final_path = "result/pretrain_split/split_final.pth"
|
||||
torch.save({
|
||||
"epoch": args.epochs,
|
||||
"enc1": enc1.state_dict(),
|
||||
"enc2": enc2.state_dict(),
|
||||
"enc3": enc3.state_dict(),
|
||||
# 解码头不迁移,不保存
|
||||
}, final_path)
|
||||
print(f"\n预训练完成,编码器权重已保存至: {final_path}")
|
||||
print("接下来运行联合训练(阶段 1 冻结热身):")
|
||||
print(f" python -m ginka.train_vq --pretrain_split {final_path} --freeze_vq True")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
@ -1,684 +0,0 @@
|
||||
"""
|
||||
三阶段级联训练脚本:各阶段独立训练,使用 GinkaStageDataset。
|
||||
|
||||
总损失 = L_CE(只对本阶段负责的 tile 位置计算)+ beta * L_commit + gamma * L_entropy
|
||||
|
||||
各阶段分工:
|
||||
stage=1 结构骨架:floor(0) + wall(1)
|
||||
stage=2 功能元素:door(2) + monster(4) + entrance(5)
|
||||
stage=3 资源放置:resource(3)
|
||||
|
||||
用法示例:
|
||||
python -m ginka.train_stage --stage 0 # 三阶段联合训练(推荐)
|
||||
python -m ginka.train_stage --stage 1 # 只训练 stage1
|
||||
python -m ginka.train_stage --stage 0 --resume True --state result/joint-50.pth
|
||||
python -m ginka.train_stage --stage 0 --pretrain_vq result/joint/joint-50.pth
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .vqvae.model import GinkaVQVAE
|
||||
from .maskGIT.model import GinkaMaskGIT
|
||||
from .dataset import GinkaStageDataset
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 各阶段配置
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# 共用 VQ-VAE 超参
|
||||
VQ_L = 2 # 码字序列长度
|
||||
VQ_K = 8 # codebook 大小
|
||||
VQ_D_Z = 64 # 码字维度
|
||||
VQ_BETA = 0.5 # commit loss 权重
|
||||
VQ_GAMMA = 0.0 # entropy loss 权重
|
||||
VQ_LAYERS = 3
|
||||
VQ_DIM_FF = 512
|
||||
VQ_D_MODEL = 64
|
||||
VQ_NHEAD = 8
|
||||
|
||||
# 各阶段 MaskGIT 超参(按任务复杂度差异化配置)
|
||||
STAGE_MG_CONFIGS = {
|
||||
1: dict(d_model=256, nhead=8, num_layers=6, dim_ff=2048), # 结构骨架,最重要
|
||||
2: dict(d_model=192, nhead=8, num_layers=4, dim_ff=1024), # 功能元素
|
||||
3: dict(d_model=128, nhead=8, num_layers=3, dim_ff=512), # 资源放置,最简单
|
||||
}
|
||||
|
||||
# 各阶段监控的 tile 集合(用于分类别召回率统计)
|
||||
STAGE_TILE_SETS = {
|
||||
1: {0: "floor", 1: "wall"},
|
||||
2: {2: "door", 4: "monster", 5: "entrance"},
|
||||
3: {3: "resource"},
|
||||
}
|
||||
|
||||
# 各阶段损失权重(可单独调节 CE 与 VQ 损失的平衡)
|
||||
# stage3 的 resource 极稀疏,大幅上调 ce_weight 以补偿类别不均衡
|
||||
STAGE_LOSS_CONFIG = {
|
||||
1: dict(ce_weight=1.0, vq_weight=1.0), # 结构骨架,标准权重
|
||||
2: dict(ce_weight=1.5, vq_weight=0.5), # 功能元素较稀疏,上调 CE
|
||||
3: dict(ce_weight=3.0, vq_weight=0.5), # resource 极稀疏,显著上调 CE
|
||||
}
|
||||
|
||||
NUM_CLASSES = 7
|
||||
MASK_TOKEN = 6
|
||||
MAP_SIZE = 13 * 13
|
||||
MAP_H = MAP_W = 13
|
||||
FOCAL_GAMMA = 2.0
|
||||
GENERATE_STEP = 18
|
||||
BATCH_SIZE = 64
|
||||
WALL_MASK_RATIO = 0.8
|
||||
|
||||
MG_Z_DROPOUT = 0.1
|
||||
MG_STRUCT_DROPOUT = 0.1
|
||||
|
||||
SUBSET_WEIGHTS = (0.5, 0.2, 0.2, 0.1)
|
||||
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
else "mps" if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
disable_tqdm = not sys.stdout.isatty()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 参数解析
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _str2bool(v):
|
||||
if isinstance(v, bool): return v
|
||||
if v.lower() in ('true', '1', 'yes'): return True
|
||||
if v.lower() in ('false', '0', 'no'): return False
|
||||
raise argparse.ArgumentTypeError(f"布尔值应为 True/False,收到: {v!r}")
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="三阶段级联训练")
|
||||
parser.add_argument(
|
||||
"--stage", type=int, required=True, choices=[0, 1, 2, 3],
|
||||
help="训练阶段:1/2/3 单独训练,0 = 依次训练全部三个阶段",
|
||||
)
|
||||
parser.add_argument("--resume", type=_str2bool, default=False)
|
||||
parser.add_argument(
|
||||
"--state", type=str, default="",
|
||||
help="续训时加载的检查点路径(自动推断 stage{N}/stage{N}-*.pth)",
|
||||
)
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--checkpoint", type=int, default=5)
|
||||
parser.add_argument("--load_optim", type=_str2bool, default=True)
|
||||
parser.add_argument(
|
||||
"--freeze_vq", type=_str2bool, default=False,
|
||||
help="冻结 VQ 编码器,只训练 MaskGIT(适合加载预训练编码器后热身)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrain_vq", type=str, default="",
|
||||
help="从 train_vq.py 的联合训练检查点中导入对应通道的 VQ 编码器权重",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Focal Loss(与 train_vq.py 一致)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def focal_loss(logits, targets, gamma=FOCAL_GAMMA, reduction='none'):
|
||||
ce = F.cross_entropy(logits, targets, reduction='none')
|
||||
pt = torch.exp(-ce)
|
||||
fl = (1.0 - pt) ** gamma * ce
|
||||
if reduction == 'mean': return fl.mean()
|
||||
if reduction == 'sum': return fl.sum()
|
||||
return fl
|
||||
|
||||
|
||||
def masked_focal_loss(logits, targets, loss_mask, gamma=FOCAL_GAMMA):
|
||||
"""
|
||||
只对 loss_mask 为 True 的位置计算 focal loss 均值。
|
||||
|
||||
Args:
|
||||
logits: [B, C, H*W]
|
||||
targets: [B, H*W]
|
||||
loss_mask: [B, H*W] bool
|
||||
"""
|
||||
per_token = focal_loss(logits, targets, gamma, reduction='none') # [B, H*W]
|
||||
selected = per_token[loss_mask]
|
||||
if selected.numel() == 0:
|
||||
return per_token.mean()
|
||||
return selected.mean()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MaskGIT 推理(cosine schedule)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def maskgit_generate(
|
||||
model_mg: GinkaMaskGIT,
|
||||
z: torch.Tensor,
|
||||
steps: int = GENERATE_STEP,
|
||||
init_map: torch.Tensor = None,
|
||||
struct_cond: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
迭代生成地图(cosine schedule unmasking)。
|
||||
|
||||
Args:
|
||||
init_map: 可选初始地图;非 MASK 位置在生成中保持不变。
|
||||
|
||||
Returns:
|
||||
[B, MAP_SIZE] LongTensor
|
||||
"""
|
||||
B = z.shape[0]
|
||||
map_seq = (
|
||||
torch.full((B, MAP_SIZE), MASK_TOKEN, device=device)
|
||||
if init_map is None else init_map.clone().to(device)
|
||||
)
|
||||
|
||||
generatable = (map_seq == MASK_TOKEN)
|
||||
|
||||
for step in range(steps):
|
||||
if not generatable.any():
|
||||
break
|
||||
|
||||
logits = model_mg(map_seq, z, struct_cond=struct_cond) # [B, S, C]
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
dist = torch.distributions.Categorical(probs)
|
||||
sampled = dist.sample()
|
||||
confidences = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1)
|
||||
confidences = confidences.masked_fill(~generatable, float('inf'))
|
||||
|
||||
ratio = math.cos(((step + 1) / steps) * math.pi / 2)
|
||||
new_map = map_seq.clone()
|
||||
|
||||
for b in range(B):
|
||||
n_gen = int(generatable[b].sum().item())
|
||||
n_keep = int(ratio * n_gen)
|
||||
if n_keep > 0:
|
||||
_, keep_idx = torch.topk(confidences[b], k=n_keep, largest=False)
|
||||
pred_b = sampled[b].clone()
|
||||
pred_b[keep_idx] = MASK_TOKEN
|
||||
new_map[b] = torch.where(generatable[b], pred_b, map_seq[b])
|
||||
else:
|
||||
new_map[b] = torch.where(generatable[b], sampled[b], map_seq[b])
|
||||
|
||||
map_seq = new_map
|
||||
|
||||
return map_seq
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 可视化工具(与 train_vq.py 保持一致)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_map_image(map_flat, tile_dict):
|
||||
arr = map_flat.cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
return matrix_to_image_cv(arr, tile_dict)
|
||||
|
||||
|
||||
def hstack_images(imgs, gap=4, color=(255, 255, 255)):
|
||||
max_h = max(img.shape[0] for img in imgs)
|
||||
|
||||
def _pad(img):
|
||||
dh = max_h - img.shape[0]
|
||||
return img if dh == 0 else np.concatenate(
|
||||
[img, np.full((dh, img.shape[1], 3), color, dtype=np.uint8)], axis=0)
|
||||
|
||||
vline = np.full((max_h, gap, 3), color, dtype=np.uint8)
|
||||
result = _pad(imgs[0])
|
||||
for img in imgs[1:]:
|
||||
result = np.concatenate([result, vline, _pad(img)], axis=1)
|
||||
return result
|
||||
|
||||
|
||||
def grid_images(imgs, gap=4, bg=(255, 255, 255)):
|
||||
n = len(imgs)
|
||||
if n == 0: return np.zeros((1, 1, 3), dtype=np.uint8)
|
||||
if n == 1: return imgs[0]
|
||||
mid = math.ceil(n / 2)
|
||||
top = hstack_images(imgs[:mid], gap, bg)
|
||||
bot_imgs = imgs[mid:]
|
||||
if not bot_imgs: return top
|
||||
bot = hstack_images(bot_imgs, gap, bg)
|
||||
tw, bw = top.shape[1], bot.shape[1]
|
||||
if tw > bw:
|
||||
bot = np.concatenate(
|
||||
[bot, np.full((bot.shape[0], tw - bw, 3), bg, dtype=np.uint8)], axis=1)
|
||||
elif bw > tw:
|
||||
top = np.concatenate(
|
||||
[top, np.full((top.shape[0], bw - tw, 3), bg, dtype=np.uint8)], axis=1)
|
||||
hline = np.full((gap, top.shape[1], 3), bg, dtype=np.uint8)
|
||||
return np.concatenate([top, hline, bot], axis=0)
|
||||
|
||||
|
||||
def label_image(img, text, font_scale=0.45):
|
||||
bar = np.full((16, img.shape[1], 3), (40, 40, 40), dtype=np.uint8)
|
||||
cv2.putText(
|
||||
bar, text, (2, 13), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
font_scale, (200, 200, 200), 1, cv2.LINE_AA,
|
||||
)
|
||||
return np.concatenate([bar, img], axis=0)
|
||||
|
||||
|
||||
def make_random_struct_cond():
|
||||
from .maskGIT.model import SYM_VOCAB, ROOM_VOCAB, BRANCH_VOCAB, OUTER_VOCAB
|
||||
return torch.tensor([[
|
||||
random.randint(0, SYM_VOCAB - 2),
|
||||
random.randint(0, ROOM_VOCAB - 2),
|
||||
random.randint(0, BRANCH_VOCAB - 2),
|
||||
random.randint(0, OUTER_VOCAB - 2),
|
||||
]], dtype=torch.long, device=device)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 按阶段构造推理初始地图
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_stage_init(stage: int, context_map: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
根据阶段构造 MaskGIT 的推理初始地图(与训练端掩码策略一致)。
|
||||
|
||||
Stage 1: 全 MASK
|
||||
Stage 2: 只保留 wall(1),floor + 功能元素 → MASK
|
||||
Stage 3: 保留 wall(1)/door(2)/monster(4)/entrance(5),floor + resource → MASK
|
||||
"""
|
||||
init = context_map.clone()
|
||||
|
||||
if stage == 1:
|
||||
init = torch.full_like(init, MASK_TOKEN)
|
||||
|
||||
elif stage == 2:
|
||||
# 只保留 wall,其余全部 → MASK
|
||||
keep = torch.isin(init, torch.tensor([1], device=init.device))
|
||||
init[~keep] = MASK_TOKEN
|
||||
|
||||
else: # stage == 3
|
||||
# 保留 wall + 功能元素,floor + resource → MASK
|
||||
keep = torch.isin(init, torch.tensor([1, 2, 4, 5], device=init.device))
|
||||
init[~keep] = MASK_TOKEN
|
||||
|
||||
return init
|
||||
|
||||
|
||||
def make_random_wall_seed(ratio_min=0.02, ratio_max=0.08):
|
||||
ratio = random.uniform(ratio_min, ratio_max)
|
||||
n_wall = max(2, int(MAP_SIZE * ratio))
|
||||
seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device)
|
||||
idx = torch.randperm(MAP_SIZE)[:n_wall]
|
||||
seed[0, idx] = 1
|
||||
return seed
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 验证函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(
|
||||
stage: int,
|
||||
enc: GinkaVQVAE,
|
||||
model_mg: GinkaMaskGIT,
|
||||
dataloader_val: DataLoader,
|
||||
tile_dict: dict,
|
||||
epoch: int,
|
||||
n_rand: int = 3,
|
||||
):
|
||||
enc.eval()
|
||||
model_mg.eval()
|
||||
|
||||
epoch_dir = f"result/stage{stage}_img/e{epoch:04d}"
|
||||
os.makedirs(epoch_dir, exist_ok=True)
|
||||
|
||||
val_loss_total = 0.0
|
||||
val_steps = 0
|
||||
captured = {s: None for s in ('A', 'B', 'C', 'D')}
|
||||
|
||||
for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
|
||||
raw_map = batch["raw_map"].to(device)
|
||||
vq_slice = batch["vq_slice"].to(device)
|
||||
stage_input = batch["stage_input"].to(device)
|
||||
target_map = batch["target_map"].to(device)
|
||||
loss_mask = batch["loss_mask"].to(device)
|
||||
struct_cond = batch["struct_cond"].to(device)
|
||||
subsets = batch["subset"]
|
||||
|
||||
z_q, _, _, vq_loss, _, _ = enc(vq_slice)
|
||||
logits = model_mg(stage_input, z_q, struct_cond=struct_cond)
|
||||
|
||||
ce = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask)
|
||||
val_loss_total += (ce + vq_loss).item()
|
||||
val_steps += 1
|
||||
|
||||
for i in range(raw_map.shape[0]):
|
||||
s = subsets[i]
|
||||
if captured[s] is None:
|
||||
captured[s] = {
|
||||
"raw": raw_map[i:i+1].clone(),
|
||||
"stage_input": stage_input[i:i+1].clone(),
|
||||
"z_q": z_q[i:i+1].clone(),
|
||||
"struct_cond": struct_cond[i:i+1].clone(),
|
||||
}
|
||||
|
||||
if all(v is not None for v in captured.values()):
|
||||
break
|
||||
|
||||
# ---- 可视化:每个子集一张图 ----------------------------------------
|
||||
for sub, cap in captured.items():
|
||||
if cap is None:
|
||||
continue
|
||||
|
||||
raw_img = label_image(make_map_image(cap["raw"][0], tile_dict), "GT")
|
||||
inp_img = label_image(make_map_image(cap["stage_input"][0], tile_dict), f"stage{stage} input")
|
||||
|
||||
# 真实 z 的迭代生成
|
||||
init = make_stage_init(stage, cap["stage_input"][0].unsqueeze(0))
|
||||
gen = maskgit_generate(
|
||||
model_mg, cap["z_q"],
|
||||
init_map=init, struct_cond=cap["struct_cond"],
|
||||
)
|
||||
gen_img = label_image(make_map_image(gen[0], tile_dict), "z_real gen")
|
||||
|
||||
# 随机 z 的生成
|
||||
rand_imgs = []
|
||||
for i in range(n_rand):
|
||||
z_r = enc.sample(1, device)
|
||||
sc_r = make_random_struct_cond()
|
||||
init2 = make_stage_init(stage, cap["raw"][0].unsqueeze(0))
|
||||
gen_r = maskgit_generate(model_mg, z_r, init_map=init2, struct_cond=sc_r)
|
||||
rand_imgs.append(label_image(make_map_image(gen_r[0], tile_dict), f"z_rand_{i+1}"))
|
||||
|
||||
row = [raw_img, inp_img, gen_img] + rand_imgs
|
||||
cv2.imwrite(f"{epoch_dir}/subset_{sub}.png", grid_images(row))
|
||||
|
||||
# ---- 场景:完全自主生成(仅单阶段时执行,多阶段由级联验证统一覆盖)------
|
||||
if True: # 占位,避免缩进塌陷;单阶段验证不做级联,跳过
|
||||
pass
|
||||
|
||||
return val_loss_total / max(val_steps, 1)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 主训练函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_stage(stage: int, args):
|
||||
"""初始化单个阶段的模型、数据集,返回状态字典(不含优化器)。"""
|
||||
result_dir = f"result/stage{stage}"
|
||||
os.makedirs(result_dir, exist_ok=True)
|
||||
os.makedirs(f"result/stage{stage}_img", exist_ok=True)
|
||||
|
||||
mg_cfg = STAGE_MG_CONFIGS[stage]
|
||||
enc = GinkaVQVAE(
|
||||
num_classes=NUM_CLASSES,
|
||||
L=VQ_L,
|
||||
K=VQ_K,
|
||||
d_z=VQ_D_Z,
|
||||
d_model=VQ_D_MODEL,
|
||||
nhead=VQ_NHEAD,
|
||||
num_layers=VQ_LAYERS,
|
||||
dim_ff=VQ_DIM_FF,
|
||||
map_size=MAP_SIZE,
|
||||
beta=VQ_BETA,
|
||||
gamma=VQ_GAMMA,
|
||||
).to(device)
|
||||
model_mg = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES,
|
||||
d_model=mg_cfg["d_model"],
|
||||
d_z=VQ_D_Z,
|
||||
dim_ff=mg_cfg["dim_ff"],
|
||||
nhead=mg_cfg["nhead"],
|
||||
num_layers=mg_cfg["num_layers"],
|
||||
map_size=MAP_SIZE,
|
||||
z_dropout=MG_Z_DROPOUT,
|
||||
struct_dropout=MG_STRUCT_DROPOUT,
|
||||
).to(device)
|
||||
|
||||
enc_params = sum(p.numel() for p in enc.parameters())
|
||||
mg_params = sum(p.numel() for p in model_mg.parameters())
|
||||
print(f"[Stage {stage}] VQ={enc_params/1e6:.2f}M MaskGIT={mg_params/1e6:.2f}M")
|
||||
|
||||
dataset_train = GinkaStageDataset(
|
||||
args.train,
|
||||
stage=stage,
|
||||
subset_weights=SUBSET_WEIGHTS,
|
||||
wall_mask_ratio=WALL_MASK_RATIO,
|
||||
)
|
||||
dataset_val = GinkaStageDataset(
|
||||
args.validate,
|
||||
stage=stage,
|
||||
subset_weights=SUBSET_WEIGHTS,
|
||||
room_thresholds=dataset_train.room_th,
|
||||
branch_thresholds=dataset_train.branch_th,
|
||||
wall_mask_ratio=WALL_MASK_RATIO,
|
||||
)
|
||||
dataloader_train = DataLoader(
|
||||
dataset_train,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
pin_memory=(device.type == "cuda"),
|
||||
)
|
||||
dataloader_val = DataLoader(
|
||||
dataset_val,
|
||||
batch_size=8,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
if args.pretrain_vq:
|
||||
ckpt = torch.load(args.pretrain_vq, map_location=device)
|
||||
enc_key = f"enc{stage}"
|
||||
if enc_key in ckpt:
|
||||
enc.load_state_dict(ckpt[enc_key], strict=False)
|
||||
print(f"[Stage {stage}] 已加载预训练 VQ 权重。")
|
||||
else:
|
||||
print(f"[Stage {stage}] 警告:检查点中未找到 {enc_key}。")
|
||||
|
||||
if args.freeze_vq:
|
||||
for p in enc.parameters():
|
||||
p.requires_grad_(False)
|
||||
print(f"[Stage {stage}] VQ 编码器已冻结。")
|
||||
|
||||
return {
|
||||
"stage": stage,
|
||||
"enc": enc,
|
||||
"model_mg": model_mg,
|
||||
"dataloader_train": dataloader_train,
|
||||
"dataloader_val": dataloader_val,
|
||||
"result_dir": result_dir,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
def train():
|
||||
print(f"Using device: {device}")
|
||||
args = parse_arguments()
|
||||
stages = [1, 2, 3] if args.stage == 0 else [args.stage]
|
||||
|
||||
# ---- tile 贴图(一次性加载,所有阶段共用)----
|
||||
tile_dict = {}
|
||||
for f in os.listdir("tiles"):
|
||||
name = os.path.splitext(f)[0]
|
||||
img = cv2.imread(f"tiles/{f}", cv2.IMREAD_UNCHANGED)
|
||||
if img is not None:
|
||||
tile_dict[name] = img
|
||||
|
||||
# ---- 初始化各阶段 ----
|
||||
states = {stage: _build_stage(stage, args) for stage in stages}
|
||||
|
||||
# ---- 合并优化器(所有阶段参数统一管理)----
|
||||
all_params = []
|
||||
for st in states.values():
|
||||
all_params += list(st["enc"].parameters()) + list(st["model_mg"].parameters())
|
||||
optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=args.epochs, eta_min=1e-6,
|
||||
)
|
||||
|
||||
# ---- 续训 ----
|
||||
start_epoch = 0
|
||||
if args.resume:
|
||||
ckpt = torch.load(args.state, map_location=device)
|
||||
for stage in stages:
|
||||
st = states[stage]
|
||||
st["enc"].load_state_dict(ckpt[f"enc{stage}"], strict=False)
|
||||
st["model_mg"].load_state_dict(ckpt[f"mg{stage}"], strict=False)
|
||||
if args.load_optim and ckpt.get("optim_state") is not None:
|
||||
optimizer.load_state_dict(ckpt["optim_state"])
|
||||
start_epoch = ckpt.get("epoch", 0)
|
||||
print(f"从 epoch {start_epoch} 接续训练。")
|
||||
|
||||
# ---- 数据集对齐:以最短的 dataloader 为准,zip 迭代 ----
|
||||
# 单阶段时直接用该阶段的 dataloader;多阶段时 zip 保证每个 batch 各阶段同步推进
|
||||
def _epoch_iters():
|
||||
loaders = [states[s]["dataloader_train"] for s in stages]
|
||||
return zip(*loaders)
|
||||
|
||||
# ---- 训练循环 ----
|
||||
for epoch in tqdm(
|
||||
range(start_epoch, start_epoch + args.epochs),
|
||||
desc="Training",
|
||||
disable=disable_tqdm,
|
||||
):
|
||||
for st in states.values():
|
||||
st["enc"].train()
|
||||
st["model_mg"].train()
|
||||
|
||||
loss_totals = {s: 0.0 for s in stages}
|
||||
ce_totals = {s: 0.0 for s in stages}
|
||||
vq_totals = {s: 0.0 for s in stages}
|
||||
n_batches = 0
|
||||
|
||||
for batches in tqdm(
|
||||
_epoch_iters(),
|
||||
leave=False,
|
||||
desc="Batch",
|
||||
disable=disable_tqdm,
|
||||
):
|
||||
optimizer.zero_grad()
|
||||
total_loss = 0.0
|
||||
|
||||
for stage, batch in zip(stages, batches):
|
||||
st = states[stage]
|
||||
vq_slice = batch["vq_slice"].to(device)
|
||||
stage_input = batch["stage_input"].to(device)
|
||||
target_map = batch["target_map"].to(device)
|
||||
loss_mask = batch["loss_mask"].to(device)
|
||||
struct_cond = batch["struct_cond"].to(device)
|
||||
|
||||
z_q, _, _, vq_loss, _, _ = st["enc"](vq_slice)
|
||||
logits = st["model_mg"](stage_input, z_q, struct_cond=struct_cond)
|
||||
|
||||
ce_loss = masked_focal_loss(logits.permute(0, 2, 1), target_map, loss_mask)
|
||||
cfg = STAGE_LOSS_CONFIG[stage]
|
||||
loss = cfg["ce_weight"] * ce_loss + cfg["vq_weight"] * vq_loss
|
||||
total_loss = total_loss + loss
|
||||
loss_totals[stage] += loss.detach().item()
|
||||
ce_totals[stage] += ce_loss.detach().item()
|
||||
vq_totals[stage] += vq_loss.detach().item()
|
||||
|
||||
total_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
|
||||
optimizer.step()
|
||||
n_batches += 1
|
||||
|
||||
scheduler.step()
|
||||
|
||||
n = max(n_batches, 1)
|
||||
total_avg = sum(loss_totals.values()) / n
|
||||
stage_loss_str = " ".join(
|
||||
f"S{s}[focal={ce_totals[s]/n:.4f} vq={vq_totals[s]/n:.4f}]" for s in stages
|
||||
)
|
||||
lr_now = scheduler.get_last_lr()[0]
|
||||
tqdm.write(
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||
f"Epoch {epoch + 1:4d} | "
|
||||
f"Total {total_avg:.4f} | {stage_loss_str} | "
|
||||
f"LR {lr_now:.2e}"
|
||||
)
|
||||
|
||||
# ---- 检查点 + 验证 ----
|
||||
if (epoch + 1) % args.checkpoint == 0:
|
||||
# 保存联合检查点
|
||||
ckpt_data = {"epoch": epoch + 1, "optim_state": optimizer.state_dict()}
|
||||
for stage in stages:
|
||||
st = states[stage]
|
||||
ckpt_data[f"enc{stage}"] = st["enc"].state_dict()
|
||||
ckpt_data[f"mg{stage}"] = st["model_mg"].state_dict()
|
||||
ckpt_path = f"result/stage{stages[-1]}/joint-{epoch + 1}.pth"
|
||||
torch.save(ckpt_data, ckpt_path)
|
||||
tqdm.write(f" 检查点已保存: {ckpt_path}")
|
||||
|
||||
# 各阶段验证
|
||||
val_loss_total = 0.0
|
||||
for stage in stages:
|
||||
st = states[stage]
|
||||
vl = validate(
|
||||
stage, st["enc"], st["model_mg"],
|
||||
st["dataloader_val"], tile_dict, epoch + 1,
|
||||
)
|
||||
val_loss_total += vl
|
||||
tqdm.write(f" [Stage {stage}] Val Loss {vl:.5f}")
|
||||
|
||||
# 级联自由生成(stage1→stage2→stage3)
|
||||
if len(stages) == 3:
|
||||
_cascade_free_validate(states, tile_dict, epoch + 1)
|
||||
|
||||
for st in states.values():
|
||||
st["enc"].train()
|
||||
st["model_mg"].train()
|
||||
|
||||
# ---- 最终存档 ----
|
||||
ckpt_data = {"epoch": start_epoch + args.epochs}
|
||||
for stage in stages:
|
||||
st = states[stage]
|
||||
ckpt_data[f"enc{stage}"] = st["enc"].state_dict()
|
||||
ckpt_data[f"mg{stage}"] = st["model_mg"].state_dict()
|
||||
torch.save(ckpt_data, "result/joint_final.pth")
|
||||
print("训练结束。")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _cascade_free_validate(states: dict, tile_dict: dict, epoch: int, n: int = 4):
|
||||
"""
|
||||
三阶段级联自由生成:stage1 生成结果 → stage2 上下文 → stage3 上下文,
|
||||
最终只展示 stage3 的完整地图(已含所有 tile)。
|
||||
"""
|
||||
epoch_dir = f"result/cascade_img/e{epoch:04d}"
|
||||
os.makedirs(epoch_dir, exist_ok=True)
|
||||
|
||||
imgs = []
|
||||
for i in range(n):
|
||||
sc = make_random_struct_cond()
|
||||
|
||||
# Stage 1:全 MASK → 生成 floor/wall
|
||||
z1 = states[1]["enc"].sample(1, device)
|
||||
init1 = make_random_wall_seed()
|
||||
map1 = maskgit_generate(states[1]["model_mg"], z1, init_map=init1, struct_cond=sc)
|
||||
|
||||
# Stage 2:以 stage1 结果为上下文,生成 door/monster/entrance
|
||||
z2 = states[2]["enc"].sample(1, device)
|
||||
init2 = make_stage_init(2, map1)
|
||||
map2 = maskgit_generate(states[2]["model_mg"], z2, init_map=init2, struct_cond=sc)
|
||||
|
||||
# Stage 3:以 stage2 结果为上下文,生成 resource
|
||||
z3 = states[3]["enc"].sample(1, device)
|
||||
init3 = make_stage_init(3, map2)
|
||||
map3 = maskgit_generate(states[3]["model_mg"], z3, init_map=init3, struct_cond=sc)
|
||||
|
||||
imgs.append(label_image(make_map_image(map3[0], tile_dict), f"cascade_{i+1}"))
|
||||
|
||||
cv2.imwrite(f"{epoch_dir}/cascade_free.png", grid_images(imgs))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(4)
|
||||
train()
|
||||
@ -1,812 +0,0 @@
|
||||
"""
|
||||
联合训练脚本:VQ-VAE + MaskGIT
|
||||
|
||||
总损失 = L_CE(MaskGIT 重建损失)+ beta * L_commit + gamma * L_entropy
|
||||
+ lambda * L_consist(z 一致性约束,方案 A)
|
||||
|
||||
验证阶段对四种子集(A/B/C/D)分别输出图片,
|
||||
每条样本额外采样 N_Z_SAMPLES 个随机 z,
|
||||
便于直观对比同条件不同 z 下的生成差异。
|
||||
|
||||
用法示例:
|
||||
python -m ginka.train_vq
|
||||
python -m ginka.train_vq --resume True --state result/joint/joint-10.pth
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .vqvae.model import GinkaVQVAE
|
||||
from .maskGIT.model import GinkaMaskGIT
|
||||
from .dataset import GinkaVQDataset
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 超参数
|
||||
# ---------------------------------------------------------------------------
|
||||
BATCH_SIZE = 64
|
||||
NUM_CLASSES = 7
|
||||
MASK_TOKEN = 6
|
||||
GENERATE_STEP = 18 # 推理时 MaskGIT 迭代步数
|
||||
MAP_SIZE = 13 * 13
|
||||
MAP_H = MAP_W = 13
|
||||
FOCAL_GAMMA = 2.0 # focal loss 聚焦参数(越大越关注难例/稀有类别)
|
||||
WALL_MASK_RATIO = 0.8
|
||||
|
||||
# VQ-VAE 公共超参(三路编码器共用,方案 B 三通道分拆)
|
||||
VQ_L = 2 # 每路码字序列长度(三路合计 L1+L2+L3 = 6)
|
||||
VQ_K = 8 # codebook 大小
|
||||
VQ_D_Z = 64 # codebook 嵌入维度(三路保持一致,便于拼接)
|
||||
VQ_BETA = 0.5 # commit loss 权重
|
||||
VQ_GAMMA = 0.0 # entropy loss 权重
|
||||
|
||||
# 各通道编码器配置
|
||||
CH1_D_MODEL = 64; CH1_NHEAD = 8 # 通道 1:空间骨架(floor+wall)
|
||||
CH2_D_MODEL = 64; CH2_NHEAD = 8 # 通道 2:关卡门控
|
||||
CH3_D_MODEL = 64; CH3_NHEAD = 8 # 通道 3:收集资源
|
||||
VQ_LAYERS = 3
|
||||
VQ_DIM_FF = 512
|
||||
|
||||
# 通道专属损失计算范围(用于监控验证召回率)
|
||||
CH1_LOSS = {1}
|
||||
CH2_LOSS = {2, 4, 5}
|
||||
CH3_LOSS = {3} # 资源已压缩为单一 tile=3
|
||||
|
||||
# MaskGIT 超参
|
||||
MG_D_MODEL = 256
|
||||
MG_NHEAD = 8
|
||||
MG_LAYERS = 6
|
||||
MG_DIM_FF = 2048
|
||||
MG_Z_DROPOUT = 0.1 # 训练时以此概率把 z 替换为随机噪声
|
||||
MG_STRUCT_DROPOUT= 0.1 # 训练时以此概率将结构标签替换为 null(无条件占位)
|
||||
|
||||
# 一致性约束超参(方案 A)
|
||||
CONSIST_LAMBDA = 0.1 # z 一致性损失权重
|
||||
CONSIST_TEMP = 2.0 # 计算软嵌入时对 logits 施加的温度(>1 平滑分布,降低 gap)
|
||||
|
||||
# 验证时对每条样本额外采样的 z 数量(0 = 只用真实 z)
|
||||
N_Z_SAMPLES = 3
|
||||
|
||||
# 四个子集 A/B/C/D 的采样权重(训练集与验证集共用)
|
||||
SUBSET_WEIGHTS = (0.5, 0.2, 0.2, 0.1)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 设备
|
||||
# ---------------------------------------------------------------------------
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
else "mps" if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
os.makedirs("result/joint", exist_ok=True)
|
||||
os.makedirs("result/joint_img", exist_ok=True)
|
||||
|
||||
disable_tqdm = not sys.stdout.isatty()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 参数解析
|
||||
# ---------------------------------------------------------------------------
|
||||
def _str2bool(v: str) -> bool:
|
||||
"""argparse 专用:将字符串 'True'/'False' 正确转为 bool。
|
||||
type=bool 会把任何非空字符串(包括 'False')解析为 True,故需此辅助。"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ('true', '1', 'yes'):
|
||||
return True
|
||||
if v.lower() in ('false', '0', 'no'):
|
||||
return False
|
||||
raise argparse.ArgumentTypeError(f"布尔值应为 True/False,收到: {v!r}")
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="VQ-VAE + MaskGIT 联合训练")
|
||||
parser.add_argument("--resume", type=_str2bool, default=False)
|
||||
parser.add_argument("--state", type=str, default="result/joint/joint-10.pth",
|
||||
help="续训时加载的检查点路径")
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--checkpoint", type=int, default=5,
|
||||
help="每隔多少 epoch 保存检查点并验证")
|
||||
parser.add_argument("--load_optim", type=_str2bool, default=True)
|
||||
parser.add_argument("--freeze_vq", type=_str2bool, default=False,
|
||||
help="(方案 B 阶段 1)冻结三路 VQ 编码器,仅训练 MaskGIT。"
|
||||
"适用于预训练权重加载后的热身阶段。")
|
||||
parser.add_argument("--pretrain_split", type=str, default="",
|
||||
help="(方案 B)三通道分拆预训练检查点路径;"
|
||||
"指定后将从该检查点加载三路编码器初始权重。")
|
||||
return parser.parse_args()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Focal Loss
|
||||
# ---------------------------------------------------------------------------
|
||||
def focal_loss(
|
||||
logits: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
gamma: float = FOCAL_GAMMA,
|
||||
reduction: str = 'none',
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
多分类 Focal Loss:FL = -(1 - p_t)^gamma * log(p_t)
|
||||
|
||||
相比 CE,对已被正确分类的高置信度样本施加更小的权重,
|
||||
迫使模型关注难分类的稀有 tile(门/怪/资源等)。
|
||||
|
||||
Args:
|
||||
logits: [B, C, *] 未经 softmax 的原始预测
|
||||
targets: [B, *] 整数类别标签
|
||||
gamma: 聚焦参数,0 时退化为标准 CE
|
||||
reduction: 'none' | 'mean' | 'sum'
|
||||
"""
|
||||
ce = F.cross_entropy(logits, targets, reduction='none') # [B, *]
|
||||
pt = torch.exp(-ce) # 正确类的预测概率
|
||||
fl = (1.0 - pt) ** gamma * ce
|
||||
if reduction == 'mean':
|
||||
return fl.mean()
|
||||
if reduction == 'sum':
|
||||
return fl.sum()
|
||||
return fl # 'none'
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MaskGIT 推理(cosine schedule 迭代解码)
|
||||
# ---------------------------------------------------------------------------
|
||||
@torch.no_grad()
|
||||
def maskgit_generate(
|
||||
model_mg: GinkaMaskGIT,
|
||||
z: torch.Tensor,
|
||||
steps: int = GENERATE_STEP,
|
||||
init_map: torch.Tensor = None,
|
||||
struct_cond: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
迭代生成地图(cosine schedule unmasking)。
|
||||
|
||||
Args:
|
||||
model_mg: GinkaMaskGIT(eval 模式)
|
||||
z: [B, L, d_z] 条件 z
|
||||
steps: 解码步数
|
||||
init_map: [B, MAP_SIZE] 可选初始地图;非 MASK 位置在生成过程中保持固定。
|
||||
为 None 时从全 MASK 开始(自由生成)。
|
||||
|
||||
Returns:
|
||||
map_out: [B, MAP_SIZE]
|
||||
"""
|
||||
B = z.shape[0]
|
||||
if init_map is None:
|
||||
map_seq = torch.full((B, MAP_SIZE), MASK_TOKEN, device=device)
|
||||
else:
|
||||
map_seq = init_map.clone().to(device)
|
||||
|
||||
# 记录初始 MASK 位置,这些位置才需要生成
|
||||
generatable = (map_seq == MASK_TOKEN) # [B, S] bool
|
||||
|
||||
for step in range(steps):
|
||||
if not generatable.any():
|
||||
break
|
||||
|
||||
logits = model_mg(map_seq, z, struct_cond=struct_cond) # [B, S, C]
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
dist = torch.distributions.Categorical(probs)
|
||||
sampled = dist.sample() # [B, S]
|
||||
|
||||
# 计算置信度;固定位置设为 +inf(确保不会被选为“继续保持 MASK”)
|
||||
confidences = torch.gather(probs, -1, sampled.unsqueeze(-1)).squeeze(-1)
|
||||
confidences = confidences.masked_fill(~generatable, float('inf'))
|
||||
|
||||
ratio = math.cos(((step + 1) / steps) * math.pi / 2)
|
||||
new_map = map_seq.clone()
|
||||
|
||||
for b in range(B):
|
||||
n_gen = int(generatable[b].sum().item())
|
||||
n_keep = int(ratio * n_gen) # 本步仍保持 MASK 的位置数
|
||||
|
||||
if n_keep > 0:
|
||||
_, keep_idx = torch.topk(confidences[b], k=n_keep, largest=False)
|
||||
pred_b = sampled[b].clone()
|
||||
pred_b[keep_idx] = MASK_TOKEN
|
||||
new_map[b] = torch.where(generatable[b], pred_b, map_seq[b])
|
||||
else:
|
||||
new_map[b] = torch.where(generatable[b], sampled[b], map_seq[b])
|
||||
|
||||
map_seq = new_map
|
||||
|
||||
return map_seq
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 可视化工具
|
||||
# ---------------------------------------------------------------------------
|
||||
def make_map_image(map_flat: torch.Tensor, tile_dict: dict) -> np.ndarray:
|
||||
"""将 [MAP_SIZE] 的 tensor 转成 RGB 图片(numpy)。"""
|
||||
arr = map_flat.cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
return matrix_to_image_cv(arr, tile_dict)
|
||||
|
||||
|
||||
def hstack_images(imgs: list, gap: int = 4, color=(255, 255, 255)) -> np.ndarray:
|
||||
"""将多张图片横向拼接,之间插入竖线;高度不一致时底部补齐背景色。"""
|
||||
max_h = max(img.shape[0] for img in imgs)
|
||||
|
||||
def _pad_h(img):
|
||||
dh = max_h - img.shape[0]
|
||||
if dh == 0:
|
||||
return img
|
||||
pad = np.full((dh, img.shape[1], 3), color, dtype=np.uint8)
|
||||
return np.concatenate([img, pad], axis=0)
|
||||
|
||||
vline = np.full((max_h, gap, 3), color, dtype=np.uint8)
|
||||
result = _pad_h(imgs[0])
|
||||
for img in imgs[1:]:
|
||||
result = np.concatenate([result, vline, _pad_h(img)], axis=1)
|
||||
return result
|
||||
|
||||
|
||||
def grid_images(imgs: list, gap: int = 4, bg_color=(255, 255, 255)) -> np.ndarray:
|
||||
"""将图片列表排成两行网格(上行 ceil(N/2),下行 floor(N/2)),方便查看。"""
|
||||
n = len(imgs)
|
||||
if n == 0:
|
||||
return np.zeros((1, 1, 3), dtype=np.uint8)
|
||||
if n == 1:
|
||||
return imgs[0]
|
||||
|
||||
mid = math.ceil(n / 2)
|
||||
top_row = hstack_images(imgs[:mid], gap, bg_color)
|
||||
bot_imgs = imgs[mid:]
|
||||
|
||||
if not bot_imgs:
|
||||
return top_row
|
||||
|
||||
bot_row = hstack_images(bot_imgs, gap, bg_color)
|
||||
|
||||
# 补齐宽度(右侧填充背景色)
|
||||
tw, bw = top_row.shape[1], bot_row.shape[1]
|
||||
if tw > bw:
|
||||
pad = np.full((bot_row.shape[0], tw - bw, 3), bg_color, dtype=np.uint8)
|
||||
bot_row = np.concatenate([bot_row, pad], axis=1)
|
||||
elif bw > tw:
|
||||
pad = np.full((top_row.shape[0], bw - tw, 3), bg_color, dtype=np.uint8)
|
||||
top_row = np.concatenate([top_row, pad], axis=1)
|
||||
|
||||
hline = np.full((gap, top_row.shape[1], 3), bg_color, dtype=np.uint8)
|
||||
return np.concatenate([top_row, hline, bot_row], axis=0)
|
||||
|
||||
|
||||
def label_image(img: np.ndarray, text: str, font_scale: float = 0.45) -> np.ndarray:
|
||||
"""在图片顶部加一行文字标签(就地修改并返回)。"""
|
||||
bar_h = 16
|
||||
bar = np.full((bar_h, img.shape[1], 3), (40, 40, 40), dtype=np.uint8)
|
||||
cv2.putText(
|
||||
bar, text, (2, bar_h - 3),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, font_scale,
|
||||
(200, 200, 200), 1, cv2.LINE_AA
|
||||
)
|
||||
return np.concatenate([bar, img], axis=0)
|
||||
|
||||
|
||||
def struct_cond_to_text(sc: torch.Tensor) -> str:
|
||||
"""
|
||||
将 struct_cond [4] LongTensor 解码为可读字符串。
|
||||
|
||||
sc 顺序:[cond_sym, cond_room, cond_branch, cond_outer]
|
||||
cond_sym : sym_h*4 + sym_v*2 + sym_c,取值 0-6,7=null
|
||||
cond_room : roomCountLevel 0-2,3=null
|
||||
cond_branch: branchLevel 0-2,3=null
|
||||
cond_outer : outerWall 0-1,2=null
|
||||
"""
|
||||
sym_val, room_val, branch_val, outer_val = (int(x) for x in sc.tolist())
|
||||
|
||||
# 对称性
|
||||
if sym_val == 7:
|
||||
sym_str = "sym:-"
|
||||
else:
|
||||
flags = []
|
||||
if sym_val & 4: flags.append("H")
|
||||
if sym_val & 2: flags.append("V")
|
||||
if sym_val & 1: flags.append("C")
|
||||
sym_str = "sym:" + ("".join(flags) if flags else "none")
|
||||
|
||||
# 房间数量等级
|
||||
room_map = {0: "room:lo", 1: "room:mid", 2: "room:hi", 3: "room:-"}
|
||||
room_str = room_map.get(room_val, f"room:{room_val}")
|
||||
|
||||
# 分支等级
|
||||
branch_map = {0: "br:lo", 1: "br:mid", 2: "br:hi", 3: "br:-"}
|
||||
branch_str = branch_map.get(branch_val, f"br:{branch_val}")
|
||||
|
||||
# 外墙
|
||||
outer_map = {0: "wall:N", 1: "wall:Y", 2: "wall:-"}
|
||||
outer_str = outer_map.get(outer_val, f"wall:{outer_val}")
|
||||
|
||||
return f"{sym_str} {room_str} {branch_str} {outer_str}"
|
||||
|
||||
|
||||
def annotate_struct(img: np.ndarray, sc: torch.Tensor) -> np.ndarray:
|
||||
"""在图片底部追加一行结构标签注释(深蓝底白字)。"""
|
||||
text = struct_cond_to_text(sc)
|
||||
bar_h = 14
|
||||
bar = np.full((bar_h, img.shape[1], 3), (60, 30, 10), dtype=np.uint8)
|
||||
cv2.putText(
|
||||
bar, text, (2, bar_h - 3),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.38,
|
||||
(180, 220, 255), 1, cv2.LINE_AA
|
||||
)
|
||||
return np.concatenate([img, bar], axis=0)
|
||||
|
||||
|
||||
def make_random_wall_seed(ratio_min: float = 0.02, ratio_max: float = 0.08) -> torch.Tensor:
|
||||
"""
|
||||
在全 MASK 地图上随机放置少量墙壁作为推理种子,用于完全随机生成场景。
|
||||
|
||||
Returns:
|
||||
[1, MAP_SIZE] MASK=15 背景 + 随机置少量墙壁(tile=1)
|
||||
"""
|
||||
ratio = random.uniform(ratio_min, ratio_max)
|
||||
n_wall = max(2, int(MAP_SIZE * ratio))
|
||||
seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device)
|
||||
idx = torch.randperm(MAP_SIZE)[:n_wall]
|
||||
seed[0, idx] = 1 # wall
|
||||
return seed
|
||||
|
||||
|
||||
def make_random_struct_cond() -> torch.Tensor:
|
||||
"""
|
||||
生成一个随机结构条件,所有标签均取合法非-null 值。
|
||||
|
||||
Returns:
|
||||
[1, 4] LongTensor,顺序 [cond_sym, cond_room, cond_branch, cond_outer]
|
||||
"""
|
||||
from .maskGIT.model import SYM_VOCAB, ROOM_VOCAB, BRANCH_VOCAB, OUTER_VOCAB
|
||||
sym = random.randint(0, SYM_VOCAB - 2) # 0-6
|
||||
room = random.randint(0, ROOM_VOCAB - 2) # 0-2
|
||||
branch = random.randint(0, BRANCH_VOCAB - 2) # 0-2
|
||||
outer = random.randint(0, OUTER_VOCAB - 2) # 0-1
|
||||
return torch.tensor([[sym, room, branch, outer]], dtype=torch.long, device=device)
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(
|
||||
enc1: GinkaVQVAE,
|
||||
enc2: GinkaVQVAE,
|
||||
enc3: GinkaVQVAE,
|
||||
model_mg: GinkaMaskGIT,
|
||||
dataloader_val: DataLoader,
|
||||
tile_dict: dict,
|
||||
epoch: int,
|
||||
):
|
||||
"""
|
||||
验证函数:计算 val loss 并输出 5 类推理场景的对比图。
|
||||
|
||||
场景说明(按 epoch 建立子文件夹,避免图片堆积):
|
||||
场景1 (scene1_completion) : 子集 A,标准随机掩码补全
|
||||
列: ground truth | masked input | z_real pred | z_real gen | z_rand×N
|
||||
场景2 (scene2_wall) : 子集 B,仅墙壁+空地 → 生成完整地图
|
||||
列: ground truth | wall-only input | z_real gen | z_rand×N
|
||||
场景3 (scene3_sparse) : 子集 C,稀疏墙壁条件 → 生成完整地图
|
||||
列: ground truth | sparse wall input | z_real gen | z_rand×N
|
||||
场景4 (scene4_entrance) : 子集 D,墙壁+入口 → 生成完整地图
|
||||
列: ground truth | wall+entrance input | z_real gen | z_rand×N
|
||||
场景5 (scene5_random) : 无数据集参照,随机稀疏墙壁种子 → 完全随机生成
|
||||
列: random seed | z_rand×(N+1)
|
||||
"""
|
||||
for enc in [enc1, enc2, enc3]:
|
||||
enc.eval()
|
||||
model_mg.eval()
|
||||
|
||||
# 按 epoch 建立独立子文件夹,保留每次验证结果方便回溯
|
||||
epoch_dir = f"result/joint_img/e{epoch:04d}"
|
||||
os.makedirs(epoch_dir, exist_ok=True)
|
||||
|
||||
val_loss_total = 0.0
|
||||
val_steps = 0
|
||||
captured = {s: None for s in ('A', 'B', 'C', 'D')}
|
||||
|
||||
# ── 计算 val loss + 捕获各子集样本 ──────────────────────────────────────
|
||||
def _encode_three(s1, s2, s3):
|
||||
"""三路编码并拼接 z_q。"""
|
||||
z_q1, _, _, vq1, _, _ = enc1(s1)
|
||||
z_q2, _, _, vq2, _, _ = enc2(s2)
|
||||
z_q3, _, _, vq3, _, _ = enc3(s3)
|
||||
z_q = torch.cat([z_q1, z_q2, z_q3], dim=1) # [B, L1+L2+L3, d_z]
|
||||
vq_loss = vq1 + vq2 + vq3
|
||||
return z_q, vq_loss
|
||||
|
||||
def _sample_three(B_size):
|
||||
"""三路随机采样并拼接 z。"""
|
||||
z1 = enc1.sample(B_size, device)
|
||||
z2 = enc2.sample(B_size, device)
|
||||
z3 = enc3.sample(B_size, device)
|
||||
return torch.cat([z1, z2, z3], dim=1)
|
||||
|
||||
for batch in tqdm(dataloader_val, desc="Validating", leave=False, disable=disable_tqdm):
|
||||
raw_map = batch["raw_map"].to(device) # [B, 169]
|
||||
masked_map = batch["masked_map"].to(device) # [B, 169]
|
||||
target_map = batch["target_map"].to(device) # [B, 169]
|
||||
s1 = batch["slice1"].to(device)
|
||||
s2 = batch["slice2"].to(device)
|
||||
s3 = batch["slice3"].to(device)
|
||||
subsets = batch["subset"] # list of str
|
||||
B = raw_map.shape[0]
|
||||
|
||||
z_q, vq_loss = _encode_three(s1, s2, s3)
|
||||
struct_cond_b = batch["struct_cond"].to(device) # [B, 4]
|
||||
logits = model_mg(masked_map, z_q, struct_cond=struct_cond_b)
|
||||
mask = (masked_map == MASK_TOKEN)
|
||||
|
||||
ce_loss = focal_loss(logits.permute(0, 2, 1), target_map)
|
||||
masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6)
|
||||
val_loss_total += (masked_ce + vq_loss).item()
|
||||
val_steps += 1
|
||||
|
||||
for i in range(B):
|
||||
s = subsets[i]
|
||||
if captured[s] is None:
|
||||
captured[s] = {
|
||||
"raw": raw_map[i:i+1].clone(),
|
||||
"masked": masked_map[i:i+1].clone(),
|
||||
"z_q": z_q[i:i+1].clone(),
|
||||
"struct_cond": struct_cond_b[i:i+1].clone(),
|
||||
}
|
||||
|
||||
if all(v is not None for v in captured.values()):
|
||||
break
|
||||
|
||||
# ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成(无条件)──────────────
|
||||
def _rand_gens(cond_map, n):
|
||||
imgs = []
|
||||
for i in range(n):
|
||||
z_r = _sample_three(1)
|
||||
gen = maskgit_generate(model_mg, z_r, init_map=cond_map) # struct_cond=None 无条件
|
||||
imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}"))
|
||||
return imgs
|
||||
|
||||
# ── 公共辅助:对给定条件地图随机采样 n 次 z 并迭代生成(随机结构标签)────────
|
||||
def _rand_gens_with_struct(cond_map, n):
|
||||
imgs = []
|
||||
for i in range(n):
|
||||
z_r = _sample_three(1)
|
||||
sc_r = make_random_struct_cond() # [1, 4] 随机合法标签
|
||||
gen = maskgit_generate(model_mg, z_r, init_map=cond_map, struct_cond=sc_r)
|
||||
img = label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}")
|
||||
img = annotate_struct(img, sc_r[0])
|
||||
imgs.append(img)
|
||||
return imgs
|
||||
|
||||
# ── 场景1:标准掩码补全(子集 A)─────────────────────────────────────────
|
||||
if captured['A'] is not None:
|
||||
cap = captured['A']
|
||||
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
|
||||
|
||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "masked input")
|
||||
|
||||
# 单步 argmax 预测(观察模型对掩码位置的瞬时判断)
|
||||
pred = model_mg(cond, z_q, struct_cond=sc).argmax(dim=-1)[0]
|
||||
pred_img = label_image(make_map_image(pred, tile_dict), "z_real pred")
|
||||
|
||||
# 迭代生成(从掩码输入出发,真实 z)
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
# 对使用了真实 struct_cond 的图片追加标签注释
|
||||
sc0 = sc[0]
|
||||
real_img = annotate_struct(real_img, sc0)
|
||||
cond_img = annotate_struct(cond_img, sc0)
|
||||
pred_img = annotate_struct(pred_img, sc0)
|
||||
gen_r_img = annotate_struct(gen_r_img, sc0)
|
||||
|
||||
row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
cv2.imwrite(f"{epoch_dir}/scene1_completion.png", grid_images(row))
|
||||
|
||||
# ── 场景2:墙壁辅助生成(子集 B)─────────────────────────────────────────
|
||||
if captured['B'] is not None:
|
||||
cap = captured['B']
|
||||
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
|
||||
|
||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "wall-only input")
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
sc0 = sc[0]
|
||||
real_img = annotate_struct(real_img, sc0)
|
||||
cond_img = annotate_struct(cond_img, sc0)
|
||||
gen_r_img = annotate_struct(gen_r_img, sc0)
|
||||
|
||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
cv2.imwrite(f"{epoch_dir}/scene2_wall.png", grid_images(row))
|
||||
|
||||
# ── 场景3:稀疏墙壁条件生成(子集 C)────────────────────────────────────
|
||||
if captured['C'] is not None:
|
||||
cap = captured['C']
|
||||
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
|
||||
|
||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "sparse wall input")
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
sc0 = sc[0]
|
||||
real_img = annotate_struct(real_img, sc0)
|
||||
cond_img = annotate_struct(cond_img, sc0)
|
||||
gen_r_img = annotate_struct(gen_r_img, sc0)
|
||||
|
||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
cv2.imwrite(f"{epoch_dir}/scene3_sparse.png", grid_images(row))
|
||||
|
||||
# ── 场景4:墙壁+入口条件生成(子集 D)───────────────────────────────────
|
||||
if captured['D'] is not None:
|
||||
cap = captured['D']
|
||||
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
|
||||
|
||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "wall+entrance input")
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
sc0 = sc[0]
|
||||
real_img = annotate_struct(real_img, sc0)
|
||||
cond_img = annotate_struct(cond_img, sc0)
|
||||
gen_r_img = annotate_struct(gen_r_img, sc0)
|
||||
|
||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
cv2.imwrite(f"{epoch_dir}/scene4_entrance.png", grid_images(row))
|
||||
|
||||
# ── 场景5:完全随机生成(无数据集参照)──────────────────────────────────
|
||||
# 5a:随机结构标签 — 验证结构导向能力
|
||||
rand_seed_a = make_random_wall_seed()
|
||||
seed_img_a = label_image(make_map_image(rand_seed_a[0], tile_dict), "random seed")
|
||||
row_a = [seed_img_a] + _rand_gens_with_struct(rand_seed_a, N_Z_SAMPLES + 1)
|
||||
cv2.imwrite(f"{epoch_dir}/scene5a_random_cond.png", grid_images(row_a))
|
||||
|
||||
# 5b:无条件(struct_cond=None)— 验证基线生成质量
|
||||
rand_seed_b = make_random_wall_seed()
|
||||
seed_img_b = label_image(make_map_image(rand_seed_b[0], tile_dict), "random seed")
|
||||
row_b = [seed_img_b] + _rand_gens(rand_seed_b, N_Z_SAMPLES + 1)
|
||||
cv2.imwrite(f"{epoch_dir}/scene5b_random_uncond.png", grid_images(row_b))
|
||||
|
||||
avg_val_loss = val_loss_total / max(val_steps, 1)
|
||||
return avg_val_loss
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 主训练函数
|
||||
# ---------------------------------------------------------------------------
|
||||
def train():
|
||||
print(f"Using device: {device}")
|
||||
args = parse_arguments()
|
||||
|
||||
# ---- 三路编码器(方案 B 三通道分拆) ----
|
||||
_vq_common = dict(
|
||||
num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_z=VQ_D_Z,
|
||||
num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, map_size=MAP_SIZE,
|
||||
beta=VQ_BETA, gamma=VQ_GAMMA,
|
||||
)
|
||||
enc1 = GinkaVQVAE(d_model=CH1_D_MODEL, nhead=CH1_NHEAD, **_vq_common).to(device)
|
||||
enc2 = GinkaVQVAE(d_model=CH2_D_MODEL, nhead=CH2_NHEAD, **_vq_common).to(device)
|
||||
enc3 = GinkaVQVAE(d_model=CH3_D_MODEL, nhead=CH3_NHEAD, **_vq_common).to(device)
|
||||
|
||||
model_mg = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES,
|
||||
d_model=MG_D_MODEL, d_z=VQ_D_Z,
|
||||
dim_ff=MG_DIM_FF, nhead=MG_NHEAD,
|
||||
num_layers=MG_LAYERS,
|
||||
map_size=MAP_SIZE,
|
||||
z_dropout=MG_Z_DROPOUT,
|
||||
struct_dropout=MG_STRUCT_DROPOUT,
|
||||
).to(device)
|
||||
|
||||
enc_params = sum(p.numel() for m in [enc1, enc2, enc3] for p in m.parameters())
|
||||
mg_params = sum(p.numel() for p in model_mg.parameters())
|
||||
print(f"Encoders 参数量(三路): {enc_params:,} ({enc_params/1e6:.3f}M)")
|
||||
print(f"MaskGIT 参数量: {mg_params:,} ({mg_params/1e6:.3f}M)")
|
||||
print(f"Total 参数量: {enc_params+mg_params:,} ({(enc_params+mg_params)/1e6:.3f}M)")
|
||||
|
||||
# ---- 数据集 ----
|
||||
dataset_train = GinkaVQDataset(
|
||||
args.train,
|
||||
subset_weights=SUBSET_WEIGHTS,
|
||||
wall_mask_ratio=WALL_MASK_RATIO,
|
||||
)
|
||||
dataset_val = GinkaVQDataset(
|
||||
args.validate,
|
||||
subset_weights=SUBSET_WEIGHTS,
|
||||
room_thresholds=dataset_train.room_th,
|
||||
branch_thresholds=dataset_train.branch_th,
|
||||
wall_mask_ratio=WALL_MASK_RATIO,
|
||||
)
|
||||
dataloader_train = DataLoader(
|
||||
dataset_train, batch_size=BATCH_SIZE, shuffle=True,
|
||||
num_workers=0, pin_memory=(device.type == "cuda"),
|
||||
)
|
||||
dataloader_val = DataLoader(
|
||||
dataset_val, batch_size=8, shuffle=True,
|
||||
num_workers=0,
|
||||
)
|
||||
|
||||
# ---- 优化器(联合训练,三路编码器 + MaskGIT 共用)----
|
||||
enc_params_list = list(enc1.parameters()) + list(enc2.parameters()) + list(enc3.parameters())
|
||||
all_params = enc_params_list + list(model_mg.parameters())
|
||||
optimizer = optim.AdamW(all_params, lr=2e-4, weight_decay=1e-2)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=args.epochs, eta_min=1e-6
|
||||
)
|
||||
|
||||
# ---- 权重加载 ----
|
||||
start_epoch = 0
|
||||
if args.pretrain_split:
|
||||
# 从分拆预训练检查点加载三路编码器初始权重(阶段 1 冻结热身前)
|
||||
ckpt = torch.load(args.pretrain_split, map_location=device)
|
||||
enc1.load_state_dict(ckpt["enc1"])
|
||||
enc2.load_state_dict(ckpt["enc2"])
|
||||
enc3.load_state_dict(ckpt["enc3"])
|
||||
print(f"已加载分拆预训练编码器权重: {args.pretrain_split}")
|
||||
elif args.resume:
|
||||
ckpt = torch.load(args.state, map_location=device)
|
||||
enc1.load_state_dict(ckpt["enc1"], strict=False)
|
||||
enc2.load_state_dict(ckpt["enc2"], strict=False)
|
||||
enc3.load_state_dict(ckpt["enc3"], strict=False)
|
||||
model_mg.load_state_dict(ckpt["mg_state"], strict=False)
|
||||
if args.load_optim and ckpt.get("optim_state") is not None:
|
||||
optimizer.load_state_dict(ckpt["optim_state"])
|
||||
start_epoch = ckpt.get("epoch", 0)
|
||||
print(f"从 epoch {start_epoch} 接续训练。")
|
||||
|
||||
# ---- tile 贴图(用于验证可视化)----
|
||||
tile_dict = {}
|
||||
for file in os.listdir("tiles"):
|
||||
name = os.path.splitext(file)[0]
|
||||
img = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||
if img is not None:
|
||||
tile_dict[name] = img
|
||||
|
||||
# ---- 方案 B 阶段 1:冻结三路 VQ 编码器 ----
|
||||
if args.freeze_vq:
|
||||
for enc in [enc1, enc2, enc3]:
|
||||
for p in enc.parameters():
|
||||
p.requires_grad_(False)
|
||||
print("三路 VQ 编码器已冻结(阶段 1:MaskGIT 热身)。")
|
||||
|
||||
# ---- 训练循环 ----
|
||||
for epoch in tqdm(range(start_epoch, start_epoch + args.epochs),
|
||||
desc="Joint Training", disable=disable_tqdm):
|
||||
for enc in [enc1, enc2, enc3]:
|
||||
enc.train()
|
||||
model_mg.train()
|
||||
|
||||
loss_total = 0.0
|
||||
ce_total = 0.0
|
||||
vq_loss_total = 0.0
|
||||
commit_total = 0.0
|
||||
entropy_total = 0.0
|
||||
consist_total = 0.0
|
||||
subset_stats = {'A': 0, 'B': 0, 'C': 0, 'D': 0}
|
||||
|
||||
for batch in tqdm(dataloader_train, leave=False,
|
||||
desc="Epoch Progress", disable=disable_tqdm):
|
||||
raw_map = batch["raw_map"].to(device) # [B, 169]
|
||||
masked_map = batch["masked_map"].to(device) # [B, 169]
|
||||
target_map = batch["target_map"].to(device) # [B, 169]
|
||||
s1 = batch["slice1"].to(device) # 通道 1 切片
|
||||
s2 = batch["slice2"].to(device) # 通道 2 切片
|
||||
s3 = batch["slice3"].to(device) # 通道 3 切片
|
||||
|
||||
for s in batch["subset"]:
|
||||
subset_stats[s] = subset_stats.get(s, 0) + 1
|
||||
|
||||
# ---- 前向传播 ----
|
||||
# 1. 三路 VQ 编码器各自编码对应切片 → 拼接 z
|
||||
z_q1, z_e1, _, vq_loss1, commit_loss1, entropy_loss1 = enc1(s1)
|
||||
z_q2, z_e2, _, vq_loss2, commit_loss2, entropy_loss2 = enc2(s2)
|
||||
z_q3, z_e3, _, vq_loss3, commit_loss3, entropy_loss3 = enc3(s3)
|
||||
z_q = torch.cat([z_q1, z_q2, z_q3], dim=1) # [B, L1+L2+L3, d_z]
|
||||
z_e = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L1+L2+L3, d_z]
|
||||
vq_loss = vq_loss1 + vq_loss2 + vq_loss3
|
||||
commit_loss = commit_loss1 + commit_loss2 + commit_loss3
|
||||
entropy_loss = entropy_loss1 + entropy_loss2 + entropy_loss3
|
||||
|
||||
# 2. MaskGIT 以掩码地图 + z + 结构标签预测原始 tile
|
||||
struct_cond = batch["struct_cond"].to(device) # [B, 4]
|
||||
logits = model_mg(masked_map, z_q, struct_cond=struct_cond) # [B, 169, C]
|
||||
|
||||
# 3. 只对被 mask 的位置计算 focal loss(缓解墙壁/空地主导问题)
|
||||
mask = (masked_map == MASK_TOKEN) # [B, 169] bool
|
||||
ce_loss = focal_loss(logits.permute(0, 2, 1), target_map)
|
||||
masked_ce = (ce_loss * mask).sum() / (mask.sum() + 1e-6)
|
||||
|
||||
# 4. z 一致性约束(方案 A 扩展到三通道):
|
||||
# MaskGIT logits 经温度平滑后与各编码器的 tile embedding 做加权求和,
|
||||
# 得到软嵌入 → 各编码器再次编码 → z_pred_e_k 与真实 z_e_k 对齐。
|
||||
# 编码器权重在此路径上临时冻结,确保梯度仅回传至 MaskGIT。
|
||||
for enc in [enc1, enc2, enc3]:
|
||||
for p in enc.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
soft_probs = F.softmax(logits / CONSIST_TEMP, dim=-1) # [B, H*W, V]
|
||||
z_pred_e1 = enc1.encode_soft(soft_probs @ enc1.tile_embedding.weight)
|
||||
z_pred_e2 = enc2.encode_soft(soft_probs @ enc2.tile_embedding.weight)
|
||||
z_pred_e3 = enc3.encode_soft(soft_probs @ enc3.tile_embedding.weight)
|
||||
z_pred_e = torch.cat([z_pred_e1, z_pred_e2, z_pred_e3], dim=1)
|
||||
consist_loss = F.mse_loss(z_pred_e, z_e.detach())
|
||||
|
||||
if not args.freeze_vq:
|
||||
for enc in [enc1, enc2, enc3]:
|
||||
for p in enc.parameters():
|
||||
p.requires_grad_(True)
|
||||
|
||||
# 5. 联合损失
|
||||
loss = masked_ce + vq_loss + CONSIST_LAMBDA * consist_loss
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
|
||||
optimizer.step()
|
||||
|
||||
loss_total += loss.detach().item()
|
||||
ce_total += masked_ce.detach().item()
|
||||
vq_loss_total += vq_loss.detach().item()
|
||||
commit_total += commit_loss.detach().item()
|
||||
entropy_total += entropy_loss.detach().item()
|
||||
consist_total += consist_loss.detach().item()
|
||||
|
||||
scheduler.step()
|
||||
|
||||
n = len(dataloader_train)
|
||||
tqdm.write(
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||
f"Epoch {epoch + 1:4d} | "
|
||||
f"Loss {loss_total/n:.5f} "
|
||||
f"Focal {ce_total/n:.5f} "
|
||||
f"VQ {vq_loss_total/n:.5f} "
|
||||
f"Commit {commit_total/n:.5f} "
|
||||
f"Entropy {entropy_total/n:.5f} "
|
||||
f"Consist {consist_total/n:.5f} | "
|
||||
f"LR {scheduler.get_last_lr()[0]:.6f} | "
|
||||
f"Subsets {subset_stats}"
|
||||
)
|
||||
|
||||
# ---- 检查点 + 验证 ----
|
||||
if (epoch + 1) % args.checkpoint == 0:
|
||||
ckpt_path = f"result/joint/joint-{epoch + 1}.pth"
|
||||
torch.save({
|
||||
"epoch": epoch + 1,
|
||||
"enc1": enc1.state_dict(),
|
||||
"enc2": enc2.state_dict(),
|
||||
"enc3": enc3.state_dict(),
|
||||
"mg_state": model_mg.state_dict(),
|
||||
"optim_state":optimizer.state_dict(),
|
||||
}, ckpt_path)
|
||||
tqdm.write(f" 检查点已保存: {ckpt_path}")
|
||||
|
||||
val_loss = validate(
|
||||
enc1, enc2, enc3, model_mg, dataloader_val, tile_dict, epoch + 1
|
||||
)
|
||||
tqdm.write(
|
||||
f"[Validate] Epoch {epoch + 1:4d} | Val Loss {val_loss:.5f}"
|
||||
)
|
||||
# 恢复训练模式
|
||||
for enc in [enc1, enc2, enc3]:
|
||||
enc.train()
|
||||
model_mg.train()
|
||||
|
||||
print("训练结束。")
|
||||
torch.save({
|
||||
"epoch": start_epoch + args.epochs,
|
||||
"enc1": enc1.state_dict(),
|
||||
"enc2": enc2.state_dict(),
|
||||
"enc3": enc3.state_dict(),
|
||||
"mg_state": model_mg.state_dict(),
|
||||
}, "result/joint/joint_final.pth")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(4)
|
||||
train()
|
||||
Loading…
Reference in New Issue
Block a user