mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-06-14 03:01:09 +08:00
Compare commits
2 Commits
dc3062bcee
...
bf3d24e680
| Author | SHA1 | Date | |
|---|---|---|---|
| bf3d24e680 | |||
| b52bfdb78f |
178
docs/entity-density-labels-design.md
Normal file
178
docs/entity-density-labels-design.md
Normal file
@ -0,0 +1,178 @@
|
||||
# 实体密度标签设计文档
|
||||
|
||||
## 背景与问题
|
||||
|
||||
当前三阶段级联生成(stage1 骨架、stage2 功能实体、stage3 资源)在结构可行性上基本稳定,但存在明显的分布偏移:
|
||||
|
||||
- 怪物数量偏多
|
||||
- 资源数量偏多
|
||||
- 门数量在部分样本上也偏高
|
||||
|
||||
已尝试在采样阶段通过“随机抛弃部分新揭开位并重新掩码”的方式抑制过密生成,但效果不稳定,核心原因是该策略属于推理期启发式约束,不能从训练目标层面改变模型对全局密度的先验。
|
||||
|
||||
因此需要引入显式条件:将每张地图中门、怪物、资源的密度离散为三档(低/中/高),并在训练和推理时作为条件输入,让模型学习“在指定密度档位下生成”。
|
||||
|
||||
## 目标
|
||||
|
||||
- 新增 3 个可控标签:`doorDensityLevel`、`monsterDensityLevel`、`resourceDensityLevel`,取值均为 `0 | 1 | 2`。
|
||||
- 标签计算与分档在 Python 端完成,保持与现有 `roomCountLevel`、`branchLevel` 一致的处理方式。
|
||||
- 标签注入模型后,支持在推理时显式控制三类实体密度。
|
||||
- 在不改动数据处理端(TypeScript)的前提下完成接入。
|
||||
|
||||
## 设计原则
|
||||
|
||||
- 统计口径稳定:密度分母采用固定地图面积(13x13),避免受随机掩码影响。
|
||||
- 分档可迁移:使用训练集等频分箱阈值;验证/推理复用同一阈值。
|
||||
- 最小侵入:优先扩展现有 Python 数据集与条件注入链路,不改变数据文件格式。
|
||||
- 可回溯:训练日志与可视化中输出目标密度档位与实际密度,便于诊断。
|
||||
|
||||
## 标签定义
|
||||
|
||||
### 1. 统计对象
|
||||
|
||||
基于原始地图 `item['map']`(未掩码、未降级)统计三类图块数量:
|
||||
|
||||
- `doorCount`: 图块 ID = 2
|
||||
- `resourceCount`: 图块 ID = 3
|
||||
- `monsterCount`: 图块 ID = 4
|
||||
|
||||
### 2. 密度定义
|
||||
|
||||
设地图面积为 `MAP_SIZE = 13 * 13 = 169`,则:
|
||||
|
||||
- `doorDensity = doorCount / 169`
|
||||
- `monsterDensity = monsterCount / 169`
|
||||
- `resourceDensity = resourceCount / 169`
|
||||
|
||||
### 3. 分档定义
|
||||
|
||||
采用等频分箱(三档)并与现有 `to_level` 规则一致:
|
||||
|
||||
- 训练集上收集某一密度指标的全量样本值,升序排序
|
||||
- 取 `n/3` 与 `2n/3` 位置作为阈值 `th1`、`th2`
|
||||
- 分档规则:
|
||||
- `< th1` -> `0`(Low)
|
||||
- `>= th1 且 < th2` -> `1`(Medium)
|
||||
- `>= th2` -> `2`(High)
|
||||
|
||||
阈值退化处理(与现有实现一致):
|
||||
|
||||
- 若 `th1 == th2`,将 `th2 = th1 + eps`
|
||||
- 对密度值建议 `eps = 1e-6`
|
||||
|
||||
## Python 端处理方案
|
||||
|
||||
### 1. 数据集初始化阶段
|
||||
|
||||
在 `GinkaSeperatedDataset.__init__` 中新增一次统计流程:
|
||||
|
||||
- 从 `self.data` 中提取每张图的 `doorDensity`、`monsterDensity`、`resourceDensity`
|
||||
- 分别计算三组阈值:
|
||||
- `self.door_density_th`
|
||||
- `self.monster_density_th`
|
||||
- `self.resource_density_th`
|
||||
- 回填每个样本:
|
||||
- `item['doorDensityLevel']`
|
||||
- `item['monsterDensityLevel']`
|
||||
- `item['resourceDensityLevel']`
|
||||
|
||||
### 2. 样本输出阶段
|
||||
|
||||
在 `__getitem__` 返回字典中新增条件向量(建议独立字段,避免影响旧逻辑):
|
||||
|
||||
- `density_inject = LongTensor([doorLevel, monsterLevel, resourceLevel])`
|
||||
|
||||
不建议直接复用旧 `struct_inject` 覆盖含义。推荐并行保留:
|
||||
|
||||
- `struct_inject`:结构语义(对称/房间/分支/外墙)
|
||||
- `density_inject`:实体密度语义(门/怪物/资源)
|
||||
|
||||
## 模型接入方案
|
||||
|
||||
### 1. 条件输入组织
|
||||
|
||||
密度条件与结构条件在语义上完全不同(结构描述地图拓扑形态,密度描述实体数量先验),不复用 `struct_inject` 的处理路径。
|
||||
|
||||
设计:在 MaskGIT 内新增一个独立的**密度 MLP**:
|
||||
|
||||
- 输入:3 个独立 embedding 表(每档取值 0/1/2)输出相加后的向量
|
||||
- `emb_door_density: Embedding(3, d_embed)`
|
||||
- `emb_monster_density: Embedding(3, d_embed)`
|
||||
- `emb_resource_density: Embedding(3, d_embed)`
|
||||
- 三个 embedding 相加后送入 2 层 MLP(`d_embed -> d_model -> d_model`,激活函数 GELU),输出一个 `d_model` 维向量
|
||||
- 该向量作为独立条件 token 拼接到主序列头部(与 struct token 并列,不替换)
|
||||
|
||||
结构条件(`struct_inject`)保留原有处理方式不变。
|
||||
|
||||
### 2. 训练与推理接口
|
||||
|
||||
- 训练前向:`mgX(inpX, z_q, struct_inject, density_inject)`
|
||||
- 推理采样:允许显式指定密度档位;未指定时可随机采样档位或使用数据先验分布采样
|
||||
|
||||
### 3. 条件 Dropout
|
||||
|
||||
对密度条件增加独立 dropout(例如 0.1):
|
||||
|
||||
- 训练时随机置空部分密度条件,降低过拟合风险
|
||||
- 推理时可在“无密度条件”与“强密度条件”两种模式间切换
|
||||
|
||||
## 训练与验证改造
|
||||
|
||||
### 1. 日志指标
|
||||
|
||||
在验证阶段新增统计输出:
|
||||
|
||||
- 按档位分组的密度 L1 误差:分别统计 door/monster/resource 三类实体在 Low/Medium/High 三档条件下,生成地图实际计数与档位中位期望值之间的 L1 距离(仅用于观察,不参与反向传播)
|
||||
|
||||
无需额外输出目标档位分布或实际密度均值,档位 L1 已足够直观反映控制效果。
|
||||
|
||||
### 2. 可视化对照
|
||||
|
||||
在每张验证生成图上直接标注所有条件标签,分两行显示:
|
||||
|
||||
- 第一行(结构标签):`sym=N room=L/M/H branch=L/M/H outer=0/1`
|
||||
- 第二行(密度标签):`d=L/M/H m=L/M/H r=L/M/H`
|
||||
|
||||
其中 `sym` 取 `cond_sym` 的原始整数值(0–7),`room`/`branch`/`d`/`m`/`r` 均以 `L`/`M`/`H` 表示三档。
|
||||
|
||||
标注位置:图像顶部左上角,两行叠加,与现有 `fix`/`free` 标注并列(可追加到同一 `annotate` 调用后)。
|
||||
|
||||
额外新增一类对照图:固定同一 `z` 和结构条件,仅扫遍密度档位(Low/Medium/High 三档),分别生成地图并排排列,用于直观验证"只改密度条件,生成实体数量随档位单调变化"。该对照图在每个 checkpoint 验证时生成一次,保存到 `result/seperated/eN/density_cmp.png`。
|
||||
|
||||
### 3. 验收标准
|
||||
|
||||
至少满足以下条件后再认为方案有效:
|
||||
|
||||
- 同一结构条件下,密度档位从 Low -> High 时,三类实体计数总体单调上升
|
||||
- 验证集上各档位的目标-实际密度 MAE 明显低于未加标签版本
|
||||
- 地图可玩性不退化(入口可达、关键路径连通性不显著恶化)
|
||||
|
||||
## 与现有流程的兼容性
|
||||
|
||||
- 数据源 JSON 无需新增字段。
|
||||
- 标签在 Python 读取后即时计算,不影响 `data/` 侧脚本。
|
||||
- 旧 checkpoint 不兼容新增输入维度,需要从旧权重迁移或重新训练。
|
||||
|
||||
## 实施步骤建议
|
||||
|
||||
1. 在数据集类中实现三类密度统计、分档和 `density_inject` 返回。
|
||||
2. 扩展 MaskGIT 条件嵌入与前向接口,打通三阶段训练调用。
|
||||
3. 更新训练/验证日志与可视化标注,增加按档位评估。
|
||||
4. 先做小规模过拟合与对照采样验证,再进入完整训练。
|
||||
|
||||
## 风险与应对
|
||||
|
||||
- 风险:档位边界样本噪声大,模型学习不稳定。
|
||||
- 应对:引入软标签邻域采样(可选)或在损失中增加密度一致性正则。
|
||||
|
||||
- 风险:实体密度受结构强约束,条件可控性受限。
|
||||
- 应对:在评估中按结构复杂度分组分析,必要时引入结构-密度联合条件建模。
|
||||
|
||||
- 风险:三阶段相互影响导致 stage2/stage3 条件冲突。
|
||||
- 应对:分别监控阶段内计数与最终合并计数,必要时增加阶段特异性权重。
|
||||
|
||||
## 后续可扩展方向
|
||||
|
||||
- 将三档扩展为五档,提升控制精度。
|
||||
- 在密度标签之外增加“功能实体聚集度/均匀度”标签。
|
||||
- 引入条件一致性判别器,进一步约束生成结果与目标档位一致。
|
||||
212
docs/film-adaln-cond-design.md
Normal file
212
docs/film-adaln-cond-design.md
Normal file
@ -0,0 +1,212 @@
|
||||
# 条件注入方式改进:从 Cross Attention 到 FiLM / AdaLN
|
||||
|
||||
## 问题背景
|
||||
|
||||
### 当前条件注入方式
|
||||
|
||||
`GinkaMaskGIT` 当前使用的条件注入策略如下:
|
||||
|
||||
1. VQ 码字 `z`(形状 `[B, L*3, d_z]`)通过 `z_proj` 投影到 `d_model` 维度
|
||||
2. 结构标签(`sym / room / branch / outer`)各自嵌入后拼接为 `[B, 4, d_model]`
|
||||
3. 密度标签(`door / monster / resource`)三个嵌入相加后经 MLP 得到 `[B, 1, d_model]`
|
||||
4. 上述三部分拼接为 `memory`(`[B, L*3+5, d_model]`),作为 cross-attention 的 key/value
|
||||
5. Transformer decoder 以 map token 作为 query,对 `memory` 做 cross-attention
|
||||
|
||||
### 问题分析
|
||||
|
||||
Cross-attention 的本质是**查询驱动**(query-driven)的检索机制:模型只在需要时才主动去 `memory` 中寻找相关信息,且注意力权重由 query(地图 token)与 key 的相似度决定。
|
||||
|
||||
这一机制对**空间局部条件**(如参考图像特征、空间先验)效果良好,但对**全局标量条件**(如"资源密度为 High")存在以下问题:
|
||||
|
||||
#### 1. 隐式性:条件无法强制生效
|
||||
|
||||
模型可以选择性地"忽视"某个 memory 条目。结构/密度条件只是 memory 序列中的几个 token,与 VQ 码字并列竞争注意力权重。当 VQ 码字已经携带了足够多的生成信息时,模型倾向于将注意力集中在 VQ 码字上,而对结构/密度 token 的注意力权重趋近于零。
|
||||
|
||||
实验现象印证了这一点:即使将密度标签设置为 High,模型生成的怪物/资源数量与 Low 时差异极小,说明密度条件被模型基本忽略。
|
||||
|
||||
#### 2. 语义不匹配:全局信号与局部查询不对齐
|
||||
|
||||
Cross-attention 的设计假设 key/value 携带**空间位置相关**的信息(例如编码器输出的特征图),query 在不同位置关注不同的 key。然而:
|
||||
|
||||
- 密度标签是一个全局标量(表示整张地图的资源密度档位),没有空间维度
|
||||
- 所有地图位置(169 个 token)的 query 若都要接收该全局信号,需要所有 query 一致地高度关注同一个 key,这与 cross-attention 的设计初衷相悖
|
||||
|
||||
#### 3. 与 VQ 码字竞争导致梯度稀释
|
||||
|
||||
结构/密度条件作为 memory token,与 VQ 码字通过同一个 softmax 竞争注意力。当 VQ 码字数量远多于条件 token(当前 L\*3=6 对 5),且 VQ 码字携带了更多"有用信息"时,梯度信号倾向于强化对 VQ 的关注,条件 token 的参数得不到有效更新。
|
||||
|
||||
#### 4. VQ 码字 z 本身也未被充分利用
|
||||
|
||||
即使将结构/密度从 cross-attention 中移出,VQ 码字 `z` 本身也存在相同的问题。训练前期观察到模型倾向于输出高度相似的地图(风格单一、多样性极低),这表明模型并未有效利用随机采样的 `z`。根本原因相同:cross-attention 是 query-driven 的,模型可以在不关注 `z` 的情况下仅靠地图 token 自注意力完成预测,`z` 的梯度信号因此极为稀弱。因此,`z` 同样需要改为全局 AdaLN 注入,而非仅依赖 cross-attention。
|
||||
|
||||
---
|
||||
|
||||
## 改进方案
|
||||
|
||||
### 核心思路
|
||||
|
||||
全局条件(结构标签、密度标签)应当作用于 **每一层的特征变换**,以加法偏移或缩放仿射的形式强制施加到所有 map token 上,使模型**无法绕过**该条件。这正是 FiLM 和 AdaLN 的设计目标。
|
||||
|
||||
### FiLM(Feature-wise Linear Modulation)
|
||||
|
||||
FiLM 对特征向量做逐元素仿射变换:
|
||||
|
||||
$$
|
||||
\text{FiLM}(x, c) = \gamma(c) \odot x + \beta(c)
|
||||
$$
|
||||
|
||||
其中 $\gamma(c)$ 和 $\beta(c)$ 是从条件 $c$ 预测出的缩放和偏移向量(维度均为 `d_model`),$\odot$ 为逐元素乘法。
|
||||
|
||||
FiLM 直接修改特征分布,条件信号强制影响所有 token 的表示,而不依赖 query 主动发起的检索。
|
||||
|
||||
### AdaLN(Adaptive Layer Normalization)
|
||||
|
||||
AdaLN 将 FiLM 与 LayerNorm 结合,用条件向量预测 LayerNorm 的缩放和偏移参数,替代原有的固定参数:
|
||||
|
||||
$$
|
||||
\text{AdaLN}(x, c) = \gamma(c) \odot \frac{x - \mu}{\sigma} + \beta(c)
|
||||
$$
|
||||
|
||||
与标准 LayerNorm 的区别仅在于 $\gamma$ 和 $\beta$ 不是可学习的静态参数,而是由条件 $c$ 动态生成。AdaLN 在 DiT(Diffusion Transformer)和 MaskGIT 的改进版本中已有广泛验证。
|
||||
|
||||
**选用 AdaLN** 作为主要方案,理由:
|
||||
|
||||
- 在 Transformer 架构中,LayerNorm 是特征归一化的核心节点,在此处注入条件效果最稳定
|
||||
- AdaLN 的参数量增加极少(仅新增 `2 * d_model` 的线性层输出)
|
||||
- 与 FiLM 效果等价,但更符合 Transformer 的设计惯例
|
||||
|
||||
---
|
||||
|
||||
## 架构设计
|
||||
|
||||
### 条件向量的构建
|
||||
|
||||
将结构标签、密度标签和 VQ 码字 `z` 全部融合为**单一全局条件向量** `c`(维度 `d_model`),通过 AdaLN 在每一层强制施加到所有 map token 上。
|
||||
|
||||
**结构标签**(4 个离散标量)各自独立嵌入后**拼接**,再经 Linear 投影:
|
||||
|
||||
```
|
||||
struct: [B, 4] → 各自 Embedding(d_cond) → cat → [B, 4*d_cond] → Linear → [B, d_model]
|
||||
```
|
||||
|
||||
**密度标签**(3 个离散标量)各自独立嵌入后**拼接**,再经 Linear 投影(不使用相加,避免各档位嵌入相互抵消):
|
||||
|
||||
```
|
||||
density: [B, 3] → 各自 Embedding(d_cond) → cat → [B, 3*d_cond] → Linear → [B, d_model]
|
||||
```
|
||||
|
||||
**VQ 码字 z**(序列)先做均值池化压缩为单个向量,再经 Linear 投影:
|
||||
|
||||
```
|
||||
z: [B, L*3, d_z] → mean(dim=1) → [B, d_z] → Linear → [B, d_model]
|
||||
```
|
||||
|
||||
三路向量相加得到最终条件向量:
|
||||
|
||||
```
|
||||
c = struct_vec + density_vec + z_vec # [B, d_model]
|
||||
```
|
||||
|
||||
> 说明:`z` 改为全局注入的动机在于,训练前期模型观察到输出地图高度相似、多样性极低,表明 cross-attention 方式下模型未能有效利用随机采样的 `z`。均值池化保留了 `z` 序列的整体语义,同时将其压缩为标量条件,适合 AdaLN 注入。
|
||||
|
||||
### 自定义 Transformer 层
|
||||
|
||||
由于 PyTorch 的 `nn.TransformerEncoderLayer` / `nn.TransformerDecoderLayer` 不支持外部注入 AdaLN 参数,需要自行实现:
|
||||
|
||||
#### AdaLN 模块
|
||||
|
||||
```python
|
||||
class AdaLN(nn.Module):
|
||||
# 自适应 LayerNorm:用条件向量 c 预测 LayerNorm 的 gamma 和 beta
|
||||
def __init__(self, d_model: int, d_cond: int):
|
||||
...
|
||||
self.norm = nn.LayerNorm(d_model, elementwise_affine=False)
|
||||
self.proj = nn.Linear(d_cond, d_model * 2) # 输出 [gamma, beta]
|
||||
|
||||
def forward(self, x, c):
|
||||
# x: [B, S, d_model]
|
||||
# c: [B, d_model] 全局条件向量
|
||||
gamma, beta = self.proj(c).chunk(2, dim=-1) # 各 [B, d_model]
|
||||
return (1 + gamma.unsqueeze(1)) * self.norm(x) + beta.unsqueeze(1)
|
||||
```
|
||||
|
||||
#### 自定义 Transformer 层
|
||||
|
||||
替换标准的 `TransformerEncoderLayer`,在每个 sub-layer 的 LayerNorm 处注入条件:
|
||||
|
||||
```python
|
||||
class CondTransformerLayer(nn.Module):
|
||||
# 带 AdaLN 条件注入的 Transformer Encoder 层
|
||||
# 结构:AdaLN-Self-Attn → AdaLN-FFN
|
||||
def __init__(self, d_model, nhead, dim_ff, d_cond):
|
||||
...
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
|
||||
self.adaln1 = AdaLN(d_model, d_cond) # 自注意力前的归一化
|
||||
self.adaln2 = AdaLN(d_model, d_cond) # FFN 前的归一化
|
||||
self.ffn = nn.Sequential(Linear, GELU, Linear)
|
||||
|
||||
def forward(self, x, c, key_padding_mask=None):
|
||||
# Pre-norm 结构
|
||||
residual = x
|
||||
x = self.adaln1(x, c)
|
||||
x, _ = self.self_attn(x, x, x)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.adaln2(x, c)
|
||||
x = self.ffn(x)
|
||||
x = residual + x
|
||||
return x
|
||||
```
|
||||
|
||||
#### Cross-attention 层(移除)
|
||||
|
||||
`z` 已改为通过均值池化后加入全局条件向量 `c`,由 AdaLN 注入每一层,不再需要单独的 cross-attention 层。整个 Transformer 退化为纯 encoder(自注意力)结构,仅由 `CondTransformerLayer` 堆叠而成,无 decoder。
|
||||
|
||||
### 整体前向流程
|
||||
|
||||
```
|
||||
map → tile_embed + pos_embed → x [B, H*W, d_model]
|
||||
|
||||
struct: [B, 4] → 各自 Embed → cat → Linear → [B, d_model]
|
||||
density: [B, 3] → 各自 Embed → cat → Linear → [B, d_model]
|
||||
z: [B, L*3, d_z] → mean → Linear → [B, d_model]
|
||||
c = struct_vec + density_vec + z_vec # [B, d_model]
|
||||
|
||||
for each layer:
|
||||
x = CondTransformerLayer(x, c) # AdaLN 自注意力,纯 encoder 结构
|
||||
|
||||
logits = output_fc(x) [B, H*W, num_classes]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 参数量对比
|
||||
|
||||
以 `d_model=256, nhead=4, dim_ff=1024, num_layers=6` 为基准估算:
|
||||
|
||||
| 模块 | 当前方案 | 新方案(AdaLN) |
|
||||
| ------------------- | ------------------------ | -------------------------------------------- |
|
||||
| 条件嵌入层 | 小(各 Embedding + MLP) | 小(相似,略有增加) |
|
||||
| 每层 AdaLN 额外参数 | 0 | `2 * d_model * d_model = 131K` × 6 层 ≈ 786K |
|
||||
| cross-attention 层 | 6 层完整 decoder | 0(移除,z 改为 AdaLN 全局注入) |
|
||||
| 总参数量变化 | 基准 | +约 5~10%(可接受) |
|
||||
|
||||
---
|
||||
|
||||
## 实现文件规划
|
||||
|
||||
| 文件 | 改动内容 |
|
||||
| -------------------------- | --------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `ginka/maskGIT/maskGIT.py` | 重写 `Transformer` 为自定义纯 encoder 架构,新增 `AdaLN`、`CondTransformerLayer`;移除 `ZCrossAttentionLayer` |
|
||||
| `ginka/maskGIT/model.py` | 更新 `GinkaMaskGIT`:struct/density/z 三路融合为条件向量 `c`,密度标签改为拼接,z 改为均值池化后注入;移除旧 cross-attention 路径 |
|
||||
| `ginka/train_seperated.py` | 无需修改(接口不变,`forward` 签名保持) |
|
||||
|
||||
---
|
||||
|
||||
## 预期效果
|
||||
|
||||
- 密度标签、结构标签、VQ 码字 `z` 三路均通过 AdaLN 在每一层强制影响特征分布,模型无法绕过任何一路条件
|
||||
- 密度标签改为拼接(而非相加),避免不同档位嵌入线性叠加时相互抵消,使各密度维度保持独立的表示空间
|
||||
- `z` 通过均值池化压缩为全局向量后注入,保留 codebook 多样性的同时消除对 cross-attention 的依赖,预期解决训练前期输出地图高度相似的问题
|
||||
- 架构简化为纯 encoder,去掉 encoder-decoder 分离结构,降低实现复杂度和计算量
|
||||
@ -53,9 +53,33 @@ class GinkaSeperatedDataset(Dataset):
|
||||
item['roomCountLevel'] = self.to_level(item['roomCount'], self.room_th)
|
||||
item['branchLevel'] = self.to_level(item['highDegBranchCount'], self.branch_th)
|
||||
|
||||
# 实体密度等级:统计原始地图中门/怪物/资源的数量,等频三档
|
||||
eps = 1e-6
|
||||
door_counts = sorted(self.count_tile(item['map'], self.DOOR) for item in self.data)
|
||||
monster_counts = sorted(self.count_tile(item['map'], self.MONSTER) for item in self.data)
|
||||
resource_counts = sorted(self.count_tile(item['map'], self.RESOURCE) for item in self.data)
|
||||
th1_d, th2_d = door_counts[n // 3], door_counts[2 * n // 3]
|
||||
th1_m, th2_m = monster_counts[n // 3], monster_counts[2 * n // 3]
|
||||
th1_rc, th2_rc = resource_counts[n // 3], resource_counts[2 * n // 3]
|
||||
if th1_d == th2_d: th2_d = th1_d + eps
|
||||
if th1_m == th2_m: th2_m = th1_m + eps
|
||||
if th1_rc == th2_rc: th2_rc = th1_rc + eps
|
||||
self.door_density_th = (th1_d, th2_d)
|
||||
self.monster_density_th = (th1_m, th2_m)
|
||||
self.resource_density_th = (th1_rc, th2_rc)
|
||||
|
||||
for item in self.data:
|
||||
m = item['map']
|
||||
item['doorDensityLevel'] = self.to_level(self.count_tile(m, self.DOOR), self.door_density_th)
|
||||
item['monsterDensityLevel'] = self.to_level(self.count_tile(m, self.MONSTER), self.monster_density_th)
|
||||
item['resourceDensityLevel'] = self.to_level(self.count_tile(m, self.RESOURCE), self.resource_density_th)
|
||||
|
||||
def to_level(self, v, th):
|
||||
return 0 if v < th[0] else (1 if v < th[1] else 2)
|
||||
|
||||
def count_tile(self, map_data: list, tile_id: int) -> int:
|
||||
return sum(cell == tile_id for row in map_data for cell in row)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
@ -174,6 +198,12 @@ class GinkaSeperatedDataset(Dataset):
|
||||
cond_outer = item['outerWall']
|
||||
struct_inject = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
|
||||
|
||||
density_inject = torch.LongTensor([
|
||||
item['doorDensityLevel'],
|
||||
item['monsterDensityLevel'],
|
||||
item['resourceDensityLevel']
|
||||
])
|
||||
|
||||
return {
|
||||
"input_stage1": torch.LongTensor(out[0]),
|
||||
"target_stage1": torch.LongTensor(out[1]),
|
||||
@ -184,5 +214,6 @@ class GinkaSeperatedDataset(Dataset):
|
||||
"input_stage3": torch.LongTensor(out[6]),
|
||||
"target_stage3": torch.LongTensor(out[7]),
|
||||
"encoder_stage3": torch.LongTensor(out[8]),
|
||||
"struct_inject": struct_inject
|
||||
"struct_inject": struct_inject,
|
||||
"density_inject": density_inject
|
||||
}
|
||||
|
||||
@ -1,29 +1,57 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self, d_model=256, dim_ff=512, nhead=8, num_layers=4,
|
||||
):
|
||||
class AdaLN(nn.Module):
|
||||
# 自适应 LayerNorm:条件向量 c 动态预测 LayerNorm 的 gamma 和 beta
|
||||
def __init__(self, d_model: int, d_cond: int):
|
||||
super().__init__()
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'),
|
||||
num_layers=num_layers
|
||||
self.norm = nn.LayerNorm(d_model, elementwise_affine=False)
|
||||
self.proj = nn.Linear(d_cond, d_model * 2)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
# x: [B, S, d_model] c: [B, d_cond]
|
||||
gamma, beta = self.proj(c).chunk(2, dim=-1) # 各 [B, d_model]
|
||||
return (1 + gamma.unsqueeze(1)) * self.norm(x) + beta.unsqueeze(1)
|
||||
|
||||
class CondTransformerLayer(nn.Module):
|
||||
# 带 AdaLN 条件注入的 Transformer Encoder 层
|
||||
# 结构:AdaLN → Self-Attn → 残差;AdaLN → FFN → 残差(Pre-norm)
|
||||
def __init__(self, d_model: int, nhead: int, dim_ff: int, d_cond: int):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
|
||||
self.adaln1 = AdaLN(d_model, d_cond)
|
||||
self.adaln2 = AdaLN(d_model, d_cond)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(d_model, dim_ff),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim_ff, d_model)
|
||||
)
|
||||
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff, batch_first=True, activation='gelu'),
|
||||
num_layers=num_layers
|
||||
)
|
||||
|
||||
def forward(self, x, memory=None):
|
||||
# x: [B, S, d_model] 地图 token 序列
|
||||
# memory: [B, L, d_model] 可选的 z 投影,用于 cross-attention
|
||||
# 若 memory 为 None,则退化为原始自编解码行为(向后兼容)
|
||||
enc_out = self.encoder(x)
|
||||
if memory is not None:
|
||||
# encoder 输出作为 query,z 作为 key/value
|
||||
out = self.decoder(enc_out, memory)
|
||||
else:
|
||||
out = self.decoder(x, enc_out)
|
||||
return out
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
# x: [B, S, d_model] c: [B, d_cond]
|
||||
residual = x
|
||||
normed = self.adaln1(x, c)
|
||||
x, _ = self.self_attn(normed, normed, normed)
|
||||
x = residual + x
|
||||
residual = x
|
||||
x = self.ffn(self.adaln2(x, c))
|
||||
x = residual + x
|
||||
return x
|
||||
|
||||
class Transformer(nn.Module):
|
||||
# 纯 encoder Transformer,每层使用 AdaLN 注入全局条件向量 c
|
||||
def __init__(
|
||||
self, d_model: int = 256, dim_ff: int = 512,
|
||||
nhead: int = 8, num_layers: int = 4
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([
|
||||
CondTransformerLayer(d_model=d_model, nhead=nhead, dim_ff=dim_ff, d_cond=d_model)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
# x: [B, S, d_model] c: [B, d_model] 全局条件向量
|
||||
for layer in self.layers:
|
||||
x = layer(x, c)
|
||||
return x
|
||||
|
||||
@ -10,10 +10,16 @@ ROOM_VOCAB = 3 # roomCountLevel 0-2
|
||||
BRANCH_VOCAB = 3 # branchLevel 0-2
|
||||
OUTER_VOCAB = 2 # outerWall 0-1
|
||||
|
||||
# 密度标签词表大小(Low/Medium/High 三档)
|
||||
DOOR_DENSITY_VOCAB = 3
|
||||
MONSTER_DENSITY_VOCAB = 3
|
||||
RESOURCE_DENSITY_VOCAB = 3
|
||||
|
||||
class GinkaMaskGIT(nn.Module):
|
||||
def __init__(
|
||||
self, num_classes: int = 16, d_model: int = 192, dim_ff: int = 512,
|
||||
nhead: int = 8, num_layers: int = 4, map_h: int = 13, map_w: int = 13, d_z: int = 64
|
||||
nhead: int = 8, num_layers: int = 4, map_h: int = 13, map_w: int = 13,
|
||||
d_z: int = 64, z_seq_len: int = 6
|
||||
):
|
||||
super().__init__()
|
||||
self.map_h = map_h
|
||||
@ -24,33 +30,25 @@ class GinkaMaskGIT(nn.Module):
|
||||
self.row_embedding = nn.Parameter(torch.randn(1, map_h, d_model) * 0.02)
|
||||
self.col_embedding = nn.Parameter(torch.randn(1, map_w, d_model) * 0.02)
|
||||
|
||||
# z 投影:将 VQ 码字从 d_z 维映射到 d_model 维,供 cross-attention 使用
|
||||
self.z_proj = nn.Sequential(
|
||||
nn.Linear(d_z, d_model * 2),
|
||||
nn.LayerNorm(d_model * 2),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Linear(d_model * 2, d_model),
|
||||
nn.LayerNorm(d_model)
|
||||
)
|
||||
|
||||
# 结构标签嵌入(编码到 d_z 维度)
|
||||
# 注意:结构标签与 VQ 码字语义不同,使用独立投影层避免混用
|
||||
# 结构标签嵌入:各自独立嵌入到 d_z 维度,作为独立 token
|
||||
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
|
||||
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
|
||||
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)
|
||||
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
|
||||
|
||||
self.struct_proj = nn.Sequential(
|
||||
nn.Linear(d_z, d_model * 2),
|
||||
nn.LayerNorm(d_model * 2),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Linear(d_model * 2, d_model),
|
||||
nn.LayerNorm(d_model)
|
||||
)
|
||||
# 密度标签嵌入:各自独立嵌入到 d_z 维度,作为独立 token
|
||||
self.door_density_embed = nn.Embedding(DOOR_DENSITY_VOCAB, d_z)
|
||||
self.monster_density_embed = nn.Embedding(MONSTER_DENSITY_VOCAB, d_z)
|
||||
self.resource_density_embed = nn.Embedding(RESOURCE_DENSITY_VOCAB, d_z)
|
||||
|
||||
# Transformer:encoder 做 map token 自注意力,decoder 做与 z 的 cross-attention
|
||||
# z 投影:逐 token 线性变换,保持序列结构
|
||||
self.z_proj = nn.Linear(d_z, d_z)
|
||||
|
||||
# 条件融合投影:将 (z_seq_len + 4 + 3) 个 d_z 维 token 拼接后降维到 d_model
|
||||
# 拼接顺序:z_seq_len 个 z token + 4 个结构 token + 3 个密度 token
|
||||
self.cond_proj = nn.Linear((z_seq_len + 7) * d_z, d_model)
|
||||
|
||||
# 纯 encoder Transformer,条件向量 c 通过 AdaLN 注入每一层
|
||||
self.transformer = Transformer(
|
||||
d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers
|
||||
)
|
||||
@ -61,29 +59,35 @@ class GinkaMaskGIT(nn.Module):
|
||||
self,
|
||||
map: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
struct: torch.Tensor
|
||||
struct: torch.Tensor,
|
||||
density: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# map: [B, H * W]
|
||||
# z: [B, L * 3, d_z]
|
||||
# struch: [B, 4]
|
||||
# z: [B, z_seq_len, d_z]
|
||||
# struct: [B, 4]
|
||||
# density: [B, 3] — [door_level, monster_level, resource_level]
|
||||
|
||||
sym_idx = struct[:, 0]
|
||||
room_idx = struct[:, 1]
|
||||
branch_idx = struct[:, 2]
|
||||
outer_idx = struct[:, 3]
|
||||
# 结构标签:各自嵌入为独立 token,stack 成序列 [B, 4, d_z]
|
||||
e_struct = torch.stack([
|
||||
self.sym_embed(struct[:, 0]),
|
||||
self.room_embed(struct[:, 1]),
|
||||
self.branch_embed(struct[:, 2]),
|
||||
self.outer_embed(struct[:, 3])
|
||||
], dim=1)
|
||||
|
||||
# 嵌入结构标签到 d_z 维度,拼接到 z 序列末尾
|
||||
e_sym = self.sym_embed(sym_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
e_room = self.room_embed(room_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
e_branch = self.branch_embed(branch_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
e_outer = self.outer_embed(outer_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
# 密度标签:各自嵌入为独立 token,stack 成序列 [B, 3, d_z]
|
||||
e_density = torch.stack([
|
||||
self.door_density_embed(density[:, 0]),
|
||||
self.monster_density_embed(density[:, 1]),
|
||||
self.resource_density_embed(density[:, 2])
|
||||
], dim=1)
|
||||
|
||||
struct_seq = torch.cat([e_sym, e_room, e_branch, e_outer], dim=1) # [B, 4, d_z]
|
||||
# z:逐 token 投影,保留序列结构 [B, z_seq_len, d_z]
|
||||
z_proj = self.z_proj(z)
|
||||
|
||||
# VQ 码字与结构标签语义不同,使用各自独立的投影层后再拼接
|
||||
z_mem_vq = self.z_proj(z) # [B, L, d_model]
|
||||
z_mem_struct = self.struct_proj(struct_seq) # [B, 4, d_model]
|
||||
z_mem = torch.cat([z_mem_vq, z_mem_struct], dim=1) # [B, L * 3 + 4, d_model]
|
||||
# 拼接所有条件 token → [B, z_seq_len+7, d_z],展平后投影到 d_model
|
||||
cond_seq = torch.cat([z_proj, e_struct, e_density], dim=1)
|
||||
c = self.cond_proj(cond_seq.reshape(cond_seq.size(0), -1)) # [B, d_model]
|
||||
|
||||
# tile embedding + 位置编码
|
||||
row_idx = torch.arange(self.map_h, device=map.device).repeat_interleave(self.map_w)
|
||||
@ -91,8 +95,8 @@ class GinkaMaskGIT(nn.Module):
|
||||
pos = self.row_embedding[0, row_idx] + self.col_embedding[0, col_idx] # [H*W, d_model]
|
||||
x = self.tile_embedding(map) + pos # [B, H * W, d_model]
|
||||
|
||||
# Transformer:encoder 做 map 自注意力,decoder cross-attend z+struct
|
||||
x = self.transformer(x, memory=z_mem) # [B, H * W, d_model]
|
||||
# Transformer:纯 encoder,每层通过 AdaLN 接收全局条件向量 c
|
||||
x = self.transformer(x, c) # [B, H * W, d_model]
|
||||
|
||||
logits = self.output_fc(x) # [B, H * W, num_classes]
|
||||
return logits
|
||||
@ -101,29 +105,36 @@ if __name__ == "__main__":
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
map_input = torch.randint(0, 7, (4, 13 * 13)).to(device) # [4, 169]
|
||||
z_input = torch.randn(4, 2, 64).to(device) # [4, 2, 64]
|
||||
z_input = torch.randn(4, 6, 64).to(device) # [4, L*3, 64]
|
||||
struct_input = torch.tensor([
|
||||
[3, 1, 0, 1],
|
||||
[0, 2, 1, 0],
|
||||
[5, 1, 2, 1],
|
||||
[1, 0, 1, 0],
|
||||
], dtype=torch.long).to(device) # [4, 4]
|
||||
density_input = torch.tensor([
|
||||
[0, 1, 2],
|
||||
[2, 0, 1],
|
||||
[1, 2, 0],
|
||||
[0, 0, 1],
|
||||
], dtype=torch.long).to(device) # [4, 3]
|
||||
|
||||
model = GinkaMaskGIT(
|
||||
num_classes=7,
|
||||
d_model=192,
|
||||
d_model=256,
|
||||
d_z=64,
|
||||
dim_ff=2048,
|
||||
nhead=8,
|
||||
dim_ff=1024,
|
||||
nhead=4,
|
||||
num_layers=6,
|
||||
map_h=13,
|
||||
map_w=13
|
||||
map_w=13,
|
||||
z_seq_len=6
|
||||
).to(device)
|
||||
|
||||
print_memory(device, "初始化后")
|
||||
|
||||
start = time.perf_counter()
|
||||
logits = model(map_input, z_input, struct_input)
|
||||
logits = model(map_input, z_input, struct_input, density_input)
|
||||
end = time.perf_counter()
|
||||
|
||||
print_memory(device, "前向传播后")
|
||||
@ -131,8 +142,9 @@ if __name__ == "__main__":
|
||||
print(f"推理耗时: {end - start:.4f}s")
|
||||
print(f"输出形状: logits={logits.shape}")
|
||||
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
|
||||
print(f"Z Projection parameters: {sum(p.numel() for p in model.z_proj.parameters())}")
|
||||
print(f"Struct Projection parameters: {sum(p.numel() for p in model.struct_proj.parameters())}")
|
||||
print(f"Density Projection parameters: {sum(p.numel() for p in model.density_proj.parameters())}")
|
||||
print(f"Z Projection parameters: {sum(p.numel() for p in model.z_proj.parameters())}")
|
||||
print(f"Transformer parameters: {sum(p.numel() for p in model.transformer.parameters())}")
|
||||
print(f"Output FC parameters: {sum(p.numel() for p in model.output_fc.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
@ -44,24 +44,24 @@ VQ_BETA = 0.5 # commit loss 权重(防止编码器输出漂离 codebook)
|
||||
VQ_GAMMA = 0.0 # entropy loss 权重(当前未启用)
|
||||
VQ_LAYERS = 3 # VQ-VAE Transformer 层数
|
||||
VQ_DIM_FF = 512 # VQ-VAE 前馈网络隐层维度
|
||||
VQ_D_MODEL = 64 # VQ-VAE Transformer 模型维度
|
||||
VQ_NHEAD = 8 # VQ-VAE 多头注意力头数
|
||||
VQ_D_MODEL = 128 # VQ-VAE Transformer 模型维度
|
||||
VQ_NHEAD = 4 # VQ-VAE 多头注意力头数
|
||||
|
||||
# 第一阶段 MaskGIT 超参
|
||||
STAGE1_MG_DMODEL = 192
|
||||
STAGE1_MG_NHEAD = 8
|
||||
STAGE1_MG_DMODEL = 256
|
||||
STAGE1_MG_NHEAD = 4
|
||||
STAGE1_MG_NUM_LAYERS = 6
|
||||
STAGE1_MG_DIM_FF = 1024
|
||||
|
||||
# 第二阶段 MaskGIT 超参
|
||||
STAGE2_MG_DMODEL = 192
|
||||
STAGE2_MG_NHEAD = 8
|
||||
STAGE2_MG_DMODEL = 256
|
||||
STAGE2_MG_NHEAD = 4
|
||||
STAGE2_MG_NUM_LAYERS = 6
|
||||
STAGE2_MG_DIM_FF = 1024
|
||||
|
||||
# 第三阶段 MaskGIT 超参
|
||||
STAGE3_MG_DMODEL = 192
|
||||
STAGE3_MG_NHEAD = 8
|
||||
STAGE3_MG_DMODEL = 256
|
||||
STAGE3_MG_NHEAD = 4
|
||||
STAGE3_MG_NUM_LAYERS = 6
|
||||
STAGE3_MG_DIM_FF = 1024
|
||||
|
||||
@ -136,15 +136,18 @@ def build_model(device: torch.device):
|
||||
# 三个独立 MaskGIT 解码器,均接收完整的三阶段 z_q 作为条件
|
||||
mg1 = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES, d_model=STAGE1_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE1_MG_DIM_FF,
|
||||
nhead=STAGE1_MG_NHEAD, num_layers=STAGE1_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W
|
||||
nhead=STAGE1_MG_NHEAD, num_layers=STAGE1_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W,
|
||||
z_seq_len=VQ_L * 3
|
||||
).to(device)
|
||||
mg2 = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES, d_model=STAGE2_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE2_MG_DIM_FF,
|
||||
nhead=STAGE2_MG_NHEAD, num_layers=STAGE2_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W
|
||||
nhead=STAGE2_MG_NHEAD, num_layers=STAGE2_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W,
|
||||
z_seq_len=VQ_L * 3
|
||||
).to(device)
|
||||
mg3 = GinkaMaskGIT(
|
||||
num_classes=NUM_CLASSES, d_model=STAGE3_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE3_MG_DIM_FF,
|
||||
nhead=STAGE3_MG_NHEAD, num_layers=STAGE3_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W
|
||||
nhead=STAGE3_MG_NHEAD, num_layers=STAGE3_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W,
|
||||
z_seq_len=VQ_L * 3
|
||||
).to(device)
|
||||
|
||||
# 六个模型参数合并到同一优化器,端到端联合训练
|
||||
@ -178,10 +181,18 @@ def random_struct(device: torch.device) -> torch.Tensor:
|
||||
cond_outer = random.randint(0, 1) # 是否有外围走廊
|
||||
return torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]).unsqueeze(0).to(device)
|
||||
|
||||
def random_density(device: torch.device) -> torch.Tensor:
|
||||
# 随机采样一组密度参量,用于自由生成
|
||||
# density_inject 格式:[door_level(0-2), monster_level(0-2), resource_level(0-2)]
|
||||
door_lv = random.randint(0, 2)
|
||||
monster_lv = random.randint(0, 2)
|
||||
resource_lv = random.randint(0, 2)
|
||||
return torch.LongTensor([door_lv, monster_lv, resource_lv]).unsqueeze(0).to(device)
|
||||
|
||||
def maskgit_sample(
|
||||
model: torch.nn.Module, inp: torch.Tensor, z: torch.Tensor,
|
||||
struct: torch.Tensor, steps: int, target_tiles: list[int] | None = None,
|
||||
keep_fixed: bool = True
|
||||
struct: torch.Tensor, density: torch.Tensor, steps: int,
|
||||
target_tiles: list[int] | None = None, keep_fixed: bool = True
|
||||
) -> np.ndarray:
|
||||
# target_tiles: 本阶段负责生成的图块 ID 列表;None 表示接受所有类别(stage1)
|
||||
# keep_fixed=True:锁定输入中已有的非掩码/非空地位,使上一阶段结构保持不变
|
||||
@ -198,7 +209,7 @@ def maskgit_sample(
|
||||
|
||||
# 迭代去掩码:每步根据置信度分数重新决定掩码位置
|
||||
for step in range(steps):
|
||||
logits = model(current, z, struct)
|
||||
logits = model(current, z, struct, density)
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
|
||||
dist = torch.distributions.Categorical(probs)
|
||||
@ -264,7 +275,7 @@ def maskgit_sample(
|
||||
# 目标模式下,未被填充的位置视为空地(不属于本阶段负责的图块)
|
||||
current[0, still_masked] = 0
|
||||
else:
|
||||
logits = model(current, z, struct)
|
||||
logits = model(current, z, struct, density)
|
||||
current[0, still_masked] = torch.argmax(logits[0, still_masked], dim=-1)
|
||||
|
||||
return current[0].cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
@ -272,6 +283,7 @@ def maskgit_sample(
|
||||
def full_generate_random_z(
|
||||
input: torch.Tensor,
|
||||
struct: torch.Tensor,
|
||||
density: torch.Tensor,
|
||||
models: list[torch.nn.Module],
|
||||
device: torch.device,
|
||||
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
||||
@ -282,13 +294,13 @@ def full_generate_random_z(
|
||||
z = quantizer.sample(1, VQ_L, device)
|
||||
|
||||
# stage1:生成 floor/wall 骨架
|
||||
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP, keep_fixed=keep_fixed[0])
|
||||
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, density, GENERATE_STEP, keep_fixed=keep_fixed[0])
|
||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
inp2[inp2 == 0] = MASK_TOKEN # 空地位交由 stage2 填充
|
||||
|
||||
# stage2:在骨架上生成 door(2)/monster(4)/entrance(5),非零结果覆盖合并
|
||||
pred2_np = maskgit_sample(
|
||||
mg2, inp2, z, struct, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
||||
mg2, inp2, z, struct, density, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
||||
)
|
||||
merged12 = pred1_np.copy()
|
||||
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
|
||||
@ -297,7 +309,7 @@ def full_generate_random_z(
|
||||
|
||||
# stage3:填充 resource(3)
|
||||
pred3_np = maskgit_sample(
|
||||
mg3, inp3, z, struct, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
||||
mg3, inp3, z, struct, density, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
||||
)
|
||||
merged123 = merged12.copy()
|
||||
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
|
||||
@ -308,6 +320,7 @@ def full_generate_specific_z(
|
||||
input: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
struct: torch.Tensor,
|
||||
density: torch.Tensor,
|
||||
models: list[torch.nn.Module],
|
||||
device: torch.device,
|
||||
keep_fixed: tuple[bool, bool, bool] = (True, True, True)
|
||||
@ -316,12 +329,12 @@ def full_generate_specific_z(
|
||||
|
||||
with torch.no_grad():
|
||||
# 与 full_generate_random_z 相同的三阶段级联,但使用给定的 z
|
||||
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, GENERATE_STEP, keep_fixed=keep_fixed[0])
|
||||
pred1_np = maskgit_sample(mg1, input.clone(), z, struct, density, GENERATE_STEP, keep_fixed=keep_fixed[0])
|
||||
inp2 = torch.tensor(pred1_np.flatten(), dtype=torch.long, device=device).reshape(1, MAP_SIZE)
|
||||
inp2[inp2 == 0] = MASK_TOKEN
|
||||
|
||||
pred2_np = maskgit_sample(
|
||||
mg2, inp2, z, struct, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
||||
mg2, inp2, z, struct, density, GENERATE_STEP, target_tiles=[2, 4, 5], keep_fixed=keep_fixed[1]
|
||||
)
|
||||
merged12 = pred1_np.copy()
|
||||
merged12[pred2_np != 0] = pred2_np[pred2_np != 0]
|
||||
@ -329,7 +342,7 @@ def full_generate_specific_z(
|
||||
inp3[inp3 == 0] = MASK_TOKEN
|
||||
|
||||
pred3_np = maskgit_sample(
|
||||
mg3, inp3, z, struct, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
||||
mg3, inp3, z, struct, density, GENERATE_STEP, target_tiles=[3], keep_fixed=keep_fixed[2]
|
||||
)
|
||||
merged123 = merged12.copy()
|
||||
merged123[pred3_np != 0] = pred3_np[pred3_np != 0]
|
||||
@ -343,6 +356,23 @@ def annotate(img: np.ndarray, text: str) -> np.ndarray:
|
||||
cv2.putText(img, text, (2, 14), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
|
||||
return img
|
||||
|
||||
def annotate_labels(
|
||||
img: np.ndarray,
|
||||
struct: torch.Tensor,
|
||||
density: torch.Tensor
|
||||
) -> np.ndarray:
|
||||
# 两行标注:第一行结构标签,第二行密度标签
|
||||
lv = ['Low', 'Medium', 'High']
|
||||
s = struct.tolist()
|
||||
d = density.tolist()
|
||||
line1 = f"sym:{s[0]} room:{lv[s[1]]} branch:{lv[s[2]]} outer:{s[3]}"
|
||||
line2 = f"door:{lv[d[0]]} enemy:{lv[d[1]]} res:{lv[d[2]]}"
|
||||
img = img.copy()
|
||||
for text, y in [(line1, 12), (line2, 24)]:
|
||||
cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 0), 2)
|
||||
cv2.putText(img, text, (2, y), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1)
|
||||
return img
|
||||
|
||||
def rand_keep() -> tuple[bool, bool, bool]:
|
||||
b = random.choice([True, False])
|
||||
return (b, b, b)
|
||||
@ -404,23 +434,29 @@ def visualize_part2(batch, z_q, models, device, tile_dict):
|
||||
|
||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
||||
struct_t = batch["struct_inject"][0:1].to(device)
|
||||
density_t = batch["density_inject"][0:1].to(device)
|
||||
kf = rand_keep()
|
||||
auto_pred1_np, auto_merged12, auto_merged123 = full_generate_specific_z(
|
||||
inp1_t, z_q[0:1], struct_t, models, device, keep_fixed=kf
|
||||
inp1_t, z_q[0:1], struct_t, density_t, models, device, keep_fixed=kf
|
||||
)
|
||||
kf_label = 'fix' if kf[0] else 'free'
|
||||
label1 = f"s1:{kf_label}"
|
||||
label2 = f"s2:{kf_label}"
|
||||
label3 = f"s3:{kf_label}"
|
||||
|
||||
enc1_np = batch["encoder_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
enc2_np = batch["encoder_stage2"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
enc3_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
|
||||
struct_cpu = batch["struct_inject"][0]
|
||||
density_cpu = batch["density_inject"][0]
|
||||
|
||||
rows = [
|
||||
[to_img(enc1_np), to_img(enc2_np), to_img(enc3_np)],
|
||||
[to_img(inp1_np), annotate(to_img(auto_pred1_np), label1), annotate(to_img(auto_merged12), label2), annotate(to_img(auto_merged123), label3)],
|
||||
[
|
||||
annotate(to_img(inp1_np), kf_label),
|
||||
annotate_labels(to_img(auto_pred1_np), struct_cpu, density_cpu),
|
||||
annotate_labels(to_img(auto_merged12), struct_cpu, density_cpu),
|
||||
annotate_labels(to_img(auto_merged123), struct_cpu, density_cpu)
|
||||
],
|
||||
]
|
||||
grid = np.ones((2 * img_h + 3 * SEP, 4 * img_w + 5 * SEP, 3), dtype=np.uint8) * 255
|
||||
for r, row in enumerate(rows):
|
||||
@ -442,19 +478,24 @@ def visualize_part3(batch, models, device, tile_dict):
|
||||
|
||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
||||
struct_ref = batch["struct_inject"][0:1].to(device)
|
||||
density_ref = batch["density_inject"][0:1].to(device)
|
||||
inp1_np = batch["input_stage1"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
struct_cpu = batch["struct_inject"][0]
|
||||
density_cpu = batch["density_inject"][0]
|
||||
|
||||
row1 = [to_img(inp1_np)]
|
||||
for _ in range(2):
|
||||
kf = rand_keep()
|
||||
_, _, merged123 = full_generate_random_z(inp1_t, struct_ref, models, device, keep_fixed=kf)
|
||||
row1.append(annotate(to_img(merged123), keep_label(kf)))
|
||||
_, _, merged123 = full_generate_random_z(inp1_t, struct_ref, density_ref, models, device, keep_fixed=kf)
|
||||
row1.append(annotate_labels(to_img(merged123), struct_cpu, density_cpu))
|
||||
|
||||
row2 = []
|
||||
for _ in range(3):
|
||||
kf = rand_keep()
|
||||
_, _, merged123 = full_generate_random_z(inp1_t, random_struct(device), models, device, keep_fixed=kf)
|
||||
row2.append(annotate(to_img(merged123), keep_label(kf)))
|
||||
rnd_struct = random_struct(device)
|
||||
rnd_density = random_density(device)
|
||||
_, _, merged123 = full_generate_random_z(inp1_t, rnd_struct, rnd_density, models, device, keep_fixed=kf)
|
||||
row2.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_density[0].cpu()))
|
||||
|
||||
rows = [row1, row2]
|
||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||
@ -484,8 +525,10 @@ def visualize_part4(models, device, tile_dict):
|
||||
results = []
|
||||
for _ in range(5):
|
||||
kf = rand_keep()
|
||||
_, _, merged123 = full_generate_random_z(seed, random_struct(device), models, device, keep_fixed=kf)
|
||||
results.append(annotate(to_img(merged123), keep_label(kf)))
|
||||
rnd_struct = random_struct(device)
|
||||
rnd_density = random_density(device)
|
||||
_, _, merged123 = full_generate_random_z(seed, rnd_struct, rnd_density, models, device, keep_fixed=kf)
|
||||
results.append(annotate_labels(to_img(merged123), rnd_struct[0].cpu(), rnd_density[0].cpu()))
|
||||
|
||||
row1 = [to_img(seed_np)] + results[:2]
|
||||
row2 = results[2:]
|
||||
@ -507,6 +550,74 @@ def visualize_validate(
|
||||
cv2.imwrite(f"{save_dir}/val{batch_idx}.png", visualize_part1(batch, logits1, logits2, logits3, tile_dict))
|
||||
cv2.imwrite(f"{save_dir}/full{batch_idx}.png", visualize_part2(batch, z_q, models, device, tile_dict))
|
||||
cv2.imwrite(f"{save_dir}/rand{batch_idx}.png", visualize_part3(batch, models, device, tile_dict))
|
||||
cv2.imwrite(f"{save_dir}/dvar{batch_idx}.png", visualize_density_var(batch, z_q, models, device, tile_dict))
|
||||
|
||||
# 密度对照图:随机种子+随机结构,5 张随机密度生成,2×3 网格(左上角为种子图)
|
||||
def visualize_density_cmp(models, device, tile_dict):
|
||||
SEP = 3
|
||||
TILE_SIZE = 32
|
||||
img_h = MAP_H * TILE_SIZE
|
||||
img_w = MAP_W * TILE_SIZE
|
||||
|
||||
def to_img(mat):
|
||||
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
|
||||
|
||||
n_walls = random.randint(math.floor(MAP_SIZE * 0.02), math.floor(MAP_SIZE * 0.06))
|
||||
seed = torch.full((1, MAP_SIZE), MASK_TOKEN, dtype=torch.long, device=device)
|
||||
wall_pos = torch.randperm(MAP_SIZE, device=device)[:n_walls]
|
||||
seed[0, wall_pos] = 1
|
||||
seed_np = seed[0].cpu().numpy().reshape(MAP_H, MAP_W)
|
||||
rnd_struct = random_struct(device)
|
||||
struct_cpu = rnd_struct[0].cpu()
|
||||
gen_imgs = []
|
||||
for _ in range(5):
|
||||
rnd_density = random_density(device)
|
||||
density_cpu = rnd_density[0].cpu()
|
||||
_, _, merged123 = full_generate_random_z(seed, rnd_struct, rnd_density, models, device)
|
||||
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, density_cpu))
|
||||
row1 = [to_img(seed_np)] + gen_imgs[:2]
|
||||
row2 = gen_imgs[2:]
|
||||
rows = [row1, row2]
|
||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||
for r, row in enumerate(rows):
|
||||
for c, img in enumerate(row):
|
||||
y = SEP + r * (img_h + SEP)
|
||||
x = SEP + c * (img_w + SEP)
|
||||
grid[y:y + img_h, x:x + img_w] = img
|
||||
return grid
|
||||
|
||||
# 固定 z 和结构条件,使用 5 个随机密度各生成一次,2×3 网格(左上角为参考地图)
|
||||
def visualize_density_var(batch, z_q, models, device, tile_dict):
|
||||
SEP = 3
|
||||
TILE_SIZE = 32
|
||||
img_h = MAP_H * TILE_SIZE
|
||||
img_w = MAP_W * TILE_SIZE
|
||||
|
||||
def to_img(mat):
|
||||
return matrix_to_image_cv(mat, tile_dict, TILE_SIZE)
|
||||
|
||||
inp1_t = batch["input_stage1"][0:1].to(device).reshape(1, MAP_SIZE)
|
||||
struct_t = batch["struct_inject"][0:1].to(device)
|
||||
struct_cpu = batch["struct_inject"][0]
|
||||
ref_np = batch["encoder_stage3"][0].numpy().reshape(MAP_H, MAP_W)
|
||||
gen_imgs = []
|
||||
for _ in range(5):
|
||||
rnd_density = random_density(device)
|
||||
density_cpu = rnd_density[0].cpu()
|
||||
_, _, merged123 = full_generate_specific_z(
|
||||
inp1_t, z_q[0:1], struct_t, rnd_density, models, device
|
||||
)
|
||||
gen_imgs.append(annotate_labels(to_img(merged123), struct_cpu, density_cpu))
|
||||
row1 = [to_img(ref_np)] + gen_imgs[:2]
|
||||
row2 = gen_imgs[2:]
|
||||
rows = [row1, row2]
|
||||
grid = np.ones((2 * img_h + 3 * SEP, 3 * img_w + 4 * SEP, 3), dtype=np.uint8) * 255
|
||||
for r, row in enumerate(rows):
|
||||
for c, img in enumerate(row):
|
||||
y = SEP + r * (img_h + SEP)
|
||||
x = SEP + c * (img_w + SEP)
|
||||
grid[y:y + img_h, x:x + img_w] = img
|
||||
return grid
|
||||
|
||||
def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torch.device, tile_dict, epoch: int):
|
||||
vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models
|
||||
@ -521,10 +632,21 @@ def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torc
|
||||
loss3_total = torch.Tensor([0]).to(device)
|
||||
commit_total = torch.Tensor([0]).to(device)
|
||||
|
||||
# 按档位(0/1/2)累计实体计数差(L1),用于诊断密度条件可控性
|
||||
# 结构:{tile_id: {level: [累计误差, 样本数]}}
|
||||
density_l1 = {
|
||||
2: {0: [0.0, 0], 1: [0.0, 0], 2: [0.0, 0]}, # door
|
||||
4: {0: [0.0, 0], 1: [0.0, 0], 2: [0.0, 0]}, # monster
|
||||
3: {0: [0.0, 0], 1: [0.0, 0], 2: [0.0, 0]}, # resource
|
||||
}
|
||||
# 三类实体对应的 density_inject 索引
|
||||
tile_density_idx = {2: 0, 4: 1, 3: 2}
|
||||
|
||||
idx = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, leave=False, desc="Validate Progress", disable=disable_tqdm):
|
||||
|
||||
# 三阶段各自的掩码输入、预测目标和 VQ 编码器输入
|
||||
inp1 = batch["input_stage1"].to(device).reshape(-1, MAP_SIZE)
|
||||
target1 = batch["target_stage1"].to(device).reshape(-1, MAP_SIZE)
|
||||
@ -539,6 +661,7 @@ def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torc
|
||||
enc3 = batch["encoder_stage3"].to(device).reshape(-1, MAP_SIZE)
|
||||
|
||||
struct = batch["struct_inject"].to(device)
|
||||
density = batch["density_inject"].to(device)
|
||||
|
||||
# VQ 编码:各阶段独立编码后拼接、量化
|
||||
z_e1 = vq1(enc1) # [B, L, d_z]
|
||||
@ -548,24 +671,56 @@ def validate(dataloader: DataLoader, models: list[torch.nn.Module], device: torc
|
||||
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
|
||||
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
|
||||
|
||||
# 三阶段 MaskGIT 推理(均以完整 z_q 和 struct 为条件)
|
||||
logits1 = mg1(inp1, z_q, struct)
|
||||
logits2 = mg2(inp2, z_q, struct)
|
||||
logits3 = mg3(inp3, z_q, struct)
|
||||
# 三阶段 MaskGIT 推理(均以完整 z_q、struct 和 density 为条件)
|
||||
logits1 = mg1(inp1, z_q, struct, density)
|
||||
logits2 = mg2(inp2, z_q, struct, density)
|
||||
logits3 = mg3(inp3, z_q, struct, density)
|
||||
|
||||
loss1_total += focal_loss(logits1, target1)
|
||||
loss2_total += focal_loss(logits2, target2)
|
||||
loss3_total += focal_loss(logits3, target3)
|
||||
commit_total += commit_loss
|
||||
|
||||
# 计算 argmax 预测并统计各档位密度 L1(预测计数与真实计数之差的绝对值)
|
||||
pred2_map = torch.argmax(logits2, dim=-1).cpu() # [B, MAP_SIZE]
|
||||
pred3_map = torch.argmax(logits3, dim=-1).cpu()
|
||||
true2_map = target2.cpu() # [B, MAP_SIZE]
|
||||
true3_map = target3.cpu()
|
||||
density_cpu = batch["density_inject"] # [B, 3]
|
||||
for b in range(pred2_map.size(0)):
|
||||
for tile_id, d_idx in tile_density_idx.items():
|
||||
if tile_id == 3:
|
||||
pred_map = pred3_map[b]
|
||||
true_map = true3_map[b]
|
||||
else:
|
||||
pred_map = pred2_map[b]
|
||||
true_map = true2_map[b]
|
||||
pred_count = float((pred_map == tile_id).sum().item())
|
||||
true_count = float((true_map == tile_id).sum().item())
|
||||
lv = int(density_cpu[b, d_idx].item())
|
||||
density_l1[tile_id][lv][0] += abs(pred_count - true_count)
|
||||
density_l1[tile_id][lv][1] += 1
|
||||
|
||||
# 每个 batch 生成三种可视化图(val/full/rand)
|
||||
visualize_validate(batch, logits1, logits2, logits3, z_q, models, device, tile_dict, epoch, idx)
|
||||
idx += 1
|
||||
|
||||
# 每个 epoch 额外生成一张无条件自由生成图(不依赖任何 batch 样本)
|
||||
# 输出密度 L1 统计(各档位的平均实体计数,供诊断密度条件效果)
|
||||
lv_names = ['Low', 'Medium', 'High']
|
||||
tile_names = {2: 'door', 4: 'enemy', 3: 'resource'}
|
||||
for tile_id in [2, 4, 3]:
|
||||
parts = []
|
||||
for lv in range(3):
|
||||
acc, cnt = density_l1[tile_id][lv]
|
||||
avg = acc / cnt if cnt > 0 else 0.0
|
||||
parts.append(f"{lv_names[lv]}={avg:.2f}")
|
||||
tqdm.write(f" density {tile_names[tile_id]}: {' '.join(parts)}")
|
||||
|
||||
save_dir = f"result/seperated/e{epoch}"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# 每个 epoch 额外生成:无条件自由生成图 + 全局密度对照图
|
||||
cv2.imwrite(f"{save_dir}/free.png", visualize_part4(models, device, tile_dict))
|
||||
cv2.imwrite(f"{save_dir}/density_cmp.png", visualize_density_cmp(models, device, tile_dict))
|
||||
|
||||
# 恢复训练模式
|
||||
for m in [vq1, vq2, vq3, mg1, mg2, mg3]:
|
||||
@ -659,6 +814,7 @@ def train(device: torch.device):
|
||||
|
||||
# 结构条件向量:[cond_sym, cond_room, cond_branch, cond_outer]
|
||||
struct = batch["struct_inject"].to(device)
|
||||
density = batch["density_inject"].to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
@ -671,10 +827,10 @@ def train(device: torch.device):
|
||||
z_e_all = torch.cat([z_e1, z_e2, z_e3], dim=1) # [B, L*3, d_z]
|
||||
z_q, _, commit_loss = quantizer(z_e_all) # [B, L*3, d_z]
|
||||
|
||||
# 三阶段 MaskGIT 前向(均接收完整三阶段 z_q)
|
||||
logits1 = mg1(inp1, z_q, struct)
|
||||
logits2 = mg2(inp2, z_q, struct)
|
||||
logits3 = mg3(inp3, z_q, struct)
|
||||
# 三阶段 MaskGIT 前向(均接收完整三阶段 z_q、struct 和 density 条件)
|
||||
logits1 = mg1(inp1, z_q, struct, density)
|
||||
logits2 = mg2(inp2, z_q, struct, density)
|
||||
logits3 = mg3(inp3, z_q, struct, density)
|
||||
|
||||
# 三阶段 Focal Loss + VQ commit loss 加权求和
|
||||
loss1 = focal_loss(logits1, target1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user