refactor: 仅保留 SOTA

This commit is contained in:
unanmed 2026-03-12 20:36:49 +08:00
parent 4608a601be
commit 513f27c7ac
96 changed files with 7 additions and 5921 deletions

View File

@ -1,73 +1,3 @@
# GINKA 地图生成器 # GINKA 地图生成器
GINKA Model 是一个用于生成网格状魔塔地图的模型,采用 UNet 网络。 Ginka 地图生成器是专门训练用来生成魔塔地图的工具,采用 `MaskGIT` 模型及生成方法。
GINKA Model 内部集成了 Minamo Model 用做判别器,与 Ginka Model 对抗训练,训练使用 Wasserstein GAN 训练方式。
## 贡献 GINKA Model 数据集
对于 HTML5 魔塔,如果你想要贡献数据集,需要对你的魔塔进行手动数据处理,流程如下:
1. 在 `project` 文件夹下创建 `ginka-config.json` 文件,双击进入编辑,粘贴如下模板:
```json
{
"clip": {
"defaults": [0, 0, 13, 13],
"special": {}
},
"mapping": {
"redGem": {
"27": 1
},
"blueGem": {
"28": 1
},
"greenGem": {
"29": 1
},
"yellowGem": {
"30": 1
},
"item": {
"47": 1,
"49": 1,
"50": 0,
"51": 1,
"52": 1,
"53": 2
},
"potion": {
"31": 100,
"32": 200,
"33": 400,
"34": 800
},
"key": {
"21": 0,
"22": 1,
"23": 2,
"24": 2,
"25": 2
},
"door": {
"81": 0,
"82": 1,
"83": 2,
"84": 2,
"85": 3,
"86": 2
},
"wall": [1, 17],
"decoration": [],
"floor": [87, 88],
"arrow": [91, 92, 93, 94]
},
"data": {}
}
```
其中,`clip` 属性表示你的每张地图的那一部分会被当成数据集,例如填写 `[0, 0, 13, 13]` 就会让坐标为 `(0, 0)`,长宽为 `(13, 13)` 的矩形内容作为数据集。`special` 不用管。注意装饰所使用的贴图是白墙,如果白墙是墙壁的话,需要将白墙设置为墙壁。注意不要忘记保存
2. 使用 [在线工具](https://unanmed.github.io/ginka-process) 处理数据,需要给每个地图添加标签,为每个图块分配种类,有一些图块包含多种等级,需要填写正确。
3. 将 `project` 文件夹打包发给我

View File

@ -1,295 +0,0 @@
import { GinkaConfig } from './types';
const numMap: Record<number, number> = {
0: 0,
1: 1,
2: 2,
91: 30,
92: 30,
93: 30,
94: 30,
87: 29,
88: 29
};
export interface Enemy {
num: number;
hp: number;
atk: number;
def: number;
}
function convert(
map: number[][],
[x, y, w, h]: [number, number, number, number],
config: GinkaConfig,
enemyMap: Record<number, Enemy>
) {
const clipped: number[][] = [];
// 1. 裁剪
for (let ny = y; ny < y + w; ny++) {
const row: number[] = [];
for (let nx = y; nx < x + h; nx++) {
row.push(map[ny][nx]);
}
clipped.push(row);
}
const res: number[][] = Array.from({ length: clipped.length }, () =>
Array.from({ length: clipped[0].length }, () => 0)
);
// 2. 初步映射
for (let nx = 0; nx < w; nx++) {
for (let ny = 0; ny < h; ny++) {
const tile = clipped[ny][nx];
if (numMap[tile] !== void 0) {
res[ny][nx] = numMap[tile];
}
}
}
// 3. 转换一般图块
const mapping: Record<number, number> = {};
const dict = config.mapping;
dict.wall.forEach(v => (mapping[v] = 1));
dict.decoration.forEach(v => (mapping[v] = 2));
dict.floor.forEach(v => (mapping[v] = 29));
dict.arrow.forEach(v => (mapping[v] = 30));
for (let nx = 0; nx < w; nx++) {
for (let ny = 0; ny < h; ny++) {
const tile = clipped[ny][nx];
if (mapping[tile] !== void 0) res[ny][nx] = mapping[tile];
}
}
// 4. 转换含等级图块
const redGemSet = new Set<number>();
const blueGemSet = new Set<number>();
const greenGemSet = new Set<number>();
const potionSet = new Set<number>();
for (let nx = 0; nx < w; nx++) {
for (let ny = 0; ny < h; ny++) {
const tile = clipped[ny][nx];
if (dict.redGem[tile] !== void 0) {
redGemSet.add(dict.redGem[tile]);
} else if (dict.blueGem[tile] !== void 0) {
blueGemSet.add(dict.blueGem[tile]);
} else if (dict.greenGem[tile] !== void 0) {
greenGemSet.add(dict.greenGem[tile]);
} else if (dict.yellowGem[tile] !== void 0) {
redGemSet.add(dict.yellowGem[tile]);
blueGemSet.add(dict.yellowGem[tile]);
greenGemSet.add(dict.yellowGem[tile]);
} else if (dict.potion[tile] !== void 0) {
potionSet.add(dict.potion[tile]);
}
}
}
const minRedGem = Math.min(...redGemSet);
const maxRedGem = Math.max(...redGemSet);
const minBlueGem = Math.min(...blueGemSet);
const maxBlueGem = Math.max(...blueGemSet);
const minGreenGem = Math.min(...greenGemSet);
const maxGreenGem = Math.max(...greenGemSet);
const minPotion = Math.min(...potionSet);
const maxPotion = Math.max(...potionSet);
for (let nx = 0; nx < w; nx++) {
for (let ny = 0; ny < h; ny++) {
const tile = clipped[ny][nx];
if (dict.redGem[tile] !== void 0) {
const value = dict.redGem[tile];
if (maxRedGem - minRedGem < 1e-8) {
res[ny][nx] = 10;
} else {
const level = Math.min(
Math.floor(
((value - minRedGem) / (maxRedGem - minRedGem)) * 3
),
2
);
res[ny][nx] = 10 + level;
}
} else if (dict.blueGem[tile] !== void 0) {
const value = dict.blueGem[tile];
if (maxBlueGem - minBlueGem < 1e-8) {
res[ny][nx] = 13;
} else {
const level = Math.min(
Math.floor(
((value - minBlueGem) / (maxBlueGem - minBlueGem)) *
3
),
2
);
res[ny][nx] = 13 + level;
}
} else if (dict.greenGem[tile] !== void 0) {
const value = dict.greenGem[tile];
if (maxGreenGem - minGreenGem < 1e-8) {
res[ny][nx] = 16;
} else {
const level = Math.min(
Math.floor(
((value - minGreenGem) /
(maxGreenGem - minGreenGem)) *
3
),
2
);
res[ny][nx] = 16 + level;
}
} else if (dict.yellowGem[tile] !== void 0) {
const rand = Math.random();
const value = dict.yellowGem[tile];
if (rand < 2 / 5) {
if (maxRedGem - minRedGem < 1e-8) {
res[ny][nx] = 10;
} else {
const level = Math.min(
Math.floor(
((value - minRedGem) /
(maxRedGem - minRedGem)) *
3
),
2
);
res[ny][nx] = 10 + level;
}
} else if (rand < 4 / 5) {
if (maxBlueGem - minBlueGem < 1e-8) {
res[ny][nx] = 13;
} else {
const level = Math.min(
Math.floor(
((value - minBlueGem) /
(maxBlueGem - minBlueGem)) *
3
),
2
);
res[ny][nx] = 13 + level;
}
} else {
if (maxGreenGem - minGreenGem < 1e-8) {
res[ny][nx] = 16;
} else {
const level = Math.min(
Math.floor(
((value - minGreenGem) /
(maxGreenGem - minGreenGem)) *
3
),
2
);
res[ny][nx] = 16 + level;
}
}
} else if (dict.potion[tile] !== void 0) {
const value = dict.potion[tile];
if (maxPotion - minPotion < 1e-8) {
res[ny][nx] = 19;
} else {
const level = Math.min(
Math.floor(
((value - minPotion) / (maxPotion - minPotion)) * 4
),
3
);
res[ny][nx] = 19 + level;
}
} else if (dict.door[tile] !== void 0) {
const level = dict.door[tile];
res[ny][nx] = 3 + level;
} else if (dict.key[tile] !== void 0) {
const level = dict.key[tile];
res[ny][nx] = 7 + level;
} else if (dict.item[tile] !== void 0) {
const level = dict.item[tile];
res[ny][nx] = 22 + level;
}
}
}
// 5. 转换怪物
const enemySet = new Set<Enemy>();
for (let nx = 0; nx < w; nx++) {
for (let ny = 0; ny < h; ny++) {
const tile = clipped[ny][nx];
const enemy = enemyMap[tile];
if (!enemy) continue;
enemySet.add({ ...enemy, num: tile });
}
}
const enemyArr = [...enemySet];
enemyArr.sort((a, b) => a.num - b.num);
const attrs = [...enemySet].map(v => (v.atk + v.def) * v.hp);
const maxAttr = Math.max(...attrs);
const minAttr = Math.min(...attrs);
const delta = maxAttr - minAttr;
for (let ny = 0; ny < w; ny++) {
for (let nx = 0; nx < h; nx++) {
const tile = clipped[ny][nx];
const enemy = enemyMap[tile];
if (!enemy) continue;
// 替换为弱怪/中怪/强怪
const attr = (enemy.atk + enemy.def) * enemy.hp;
const ad = attr - minAttr;
if (ad < delta / 3 || delta === 0) {
res[ny][nx] = 26;
} else if (ad < (delta * 2) / 3) {
res[ny][nx] = 27;
} else {
res[ny][nx] = 28;
}
}
}
return res;
}
export function convertFloor(
map: number[][],
clip: [number, number, number, number],
config: GinkaConfig,
enemyMap: Record<number, Enemy>
) {
return convert(map, clip, config, enemyMap);
}
export function getCount(map: number[][], tiles: number[]) {
let n = 0;
map.flat().forEach(v => {
if (tiles.includes(v)) n++;
});
return n;
}
export function getRatio(map: number[][], tiles: number[]) {
const area = map.length * map[0].length;
return getCount(map, tiles) / area;
}
function range(from: number, to: number) {
const length = to - from;
return Array.from({ length }, (_, i) => i + from);
}
export function getGinkaRatio(map: number[][]): number[] {
const arr: number[] = Array(16).fill(0);
arr[0] = getRatio(map, [1, ...range(3, 32)]);
arr[1] = getRatio(map, [1]);
arr[2] = getRatio(map, [2]);
arr[3] = getRatio(map, [3, 4, 5, 6]);
arr[4] = getRatio(map, [26, 27, 28]);
arr[5] = getRatio(map, range(7, 26));
arr[6] = getRatio(map, range(10, 19));
arr[7] = getRatio(map, range(19, 23));
arr[8] = getRatio(map, [7, 8, 9]);
arr[9] = getCount(map, [23, 24, 25]);
arr[10] = getCount(map, [29, 30]);
return arr;
}

View File

@ -1,177 +0,0 @@
import { createConnection, Socket } from 'net';
import { chooseFrom, FloorData, readOne } from './utils';
import { MinamoTrainData } from './types';
import { generateTrainData } from './process/minamo';
const SOCKET_FILE = '../tmp/ginka_uds';
const [refer, replayPath = '../datasets/replay.bin'] = process.argv.slice(2);
let id = 0;
function readMap(count: number, arr: number[], h: number, w: number) {
const area = w * h;
const maps: number[][][] = Array.from<number[][]>({
length: count
}).map(() => {
return Array.from<number[]>({ length: h }).map(() => {
return Array.from<number>({ length: w }).fill(0);
});
});
arr.forEach((v, i) => {
const n = Math.floor(i / area);
const y = Math.floor((i % area) / w);
const x = i % w;
maps[n][y][x] = v;
});
return maps;
}
function generateGANData(
keys: string[],
refer: Map<string, FloorData>,
map: number[][]
) {
const id2 = `$${id++}`;
const toTrain = chooseFrom(keys, 4);
const data = toTrain.map<MinamoTrainData[]>(v => {
const floor = refer.get(v);
if (!floor) return [];
const size1: [number, number] = [floor.map[0].length, floor.map.length];
const size2: [number, number] = [map[0].length, map.length];
if (size1[0] !== size2[0] || size1[1] !== size2[1]) return [];
return generateTrainData(v, id2, floor.map, map, size1, false, false, false);
});
return data.flat();
}
const enum ReceiverStatus {
Header,
Content
}
class DataReceiver {
static active?: DataReceiver
/** 接收状态 */
private status: ReceiverStatus = ReceiverStatus.Header;
private received: number[] = []
private count: number = 0;
private h: number = 0;
private w: number = 0;
receive(buf: Buffer): [number[][][], number, number, number] | null {
// 数据通讯 node 输入协议,单位字节:
// 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
switch (this.status) {
case ReceiverStatus.Header:
this.count = buf.readInt16BE();
this.h = buf.readInt8(2);
this.w = buf.readInt8(3);
this.received.push(...buf.subarray(4));
this.status = ReceiverStatus.Content;
break;
case ReceiverStatus.Content:
this.received.push(...buf);
break
}
if (this.received.length === this.count * this.h * this.w) {
delete DataReceiver.active;
return [readMap(this.count, this.received, this.h, this.w), this.count, this.h, this.w];
} else {
return null;
}
}
static check(buf: Buffer) {
if (this.active) {
return this.active.receive(buf);
} else {
this.active = new DataReceiver();
return this.active.receive(buf);
}
}
}
(async () => {
const referTower = await readOne(refer);
const keys = [...referTower.keys()];
const client = createConnection(SOCKET_FILE, () => {
console.log(`UDS IPC connected successfully.`);
});
client.on('data', async buffer => {
const data = DataReceiver.check(buffer);
if (!data) return;
const [map, count, h, w] = data;
const simData = map.map(v => generateGANData(keys, referTower, v));
const rc = 0;
const compareData = simData.flat();
// 数据通讯 node 输出协议,单位字节:
// 2 - Tensor count; 2 - Replay count. Replay is right behind train data;
// 1*tc - Compare count for every map tensor delivered.
// 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo;
// N*1*H*W - Compare map for every map tensor. rc*2*H*W - Replay map tensor.
const toSend = Buffer.alloc(
2 + // Tensor count
2 + // Replay count
1 * count + // Compare count
2 * 4 * (compareData.length + rc) + // Similarity data
compareData.length * 1 * h * w + // Compare map
rc * 2 * h * w, // Replay map
0
);
console.log(
2,
2,
count,
2 * 4 * (compareData.length + rc),
compareData.length * 1 * h * w,
rc * 2 * h * w,
compareData.length,
rc
);
let offset = 0;
toSend.writeInt16BE(count); // Tensor count
toSend.writeInt16BE(0, 2); // Replay count
offset += 2 + 2;
// Compare count
toSend.set(
simData.map(v => v.length),
offset
);
offset += 1 * count;
// Similarity data
compareData.forEach(v => {
// console.log(v.visionSimilarity, v.topoSimilarity);
toSend.writeFloatBE(v.visionSimilarity, offset);
offset += 4;
toSend.writeFloatBE(v.topoSimilarity, offset);
offset += 4;
});
// Compare map
toSend.set(
new Uint8Array(compareData.map(v => v.map1).flat(3)),
offset // Set from Compare map
);
offset += compareData.length * 1 * h * w;
client.write(toSend);
});
client.on('end', () => {
console.log(`Connection lose.`);
});
client.on('error', () => {
client.end();
});
})();

View File

@ -1,16 +0,0 @@
import { writeFile } from 'fs-extra';
import { getAllFloors, parseTowerInfo } from './utils';
import { parseGinka } from './process/ginka';
const [output, ...list] = process.argv.slice(2);
(async () => {
const towers = await Promise.all(
list.map(v => parseTowerInfo(v, 'ginka-config.json'))
);
const floors = await getAllFloors(...towers);
const results = parseGinka(floors);
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
const size = Object.keys(results.data).length;
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个地图`);
})();

View File

@ -1,40 +0,0 @@
import { writeFile } from 'fs-extra';
import { readOne, getAllFloors, parseTowerInfo } from './utils';
import { generateAssignedData, parseMinamo } from './process/minamo';
const [output, ...list] = process.argv.slice(2);
// 判断 assigned 模式,此模式下只会对前两个塔处理,会在这两个塔之间对比,而单个塔的地图不会对比
const assigned = list.at(-1)?.startsWith('assigned');
const assignedCount = parseAssigned(list.at(-1)!);
if (assigned) list.pop();
function parseAssigned(arg: string): [number, number] {
const p = arg.slice(9);
const [a, b] = p.split(':');
return [parseInt(a) || 100, parseInt(b) || 100];
}
(async () => {
if (!assigned) {
const towers = await Promise.all(
list.map(v => parseTowerInfo(v, 'minamo-config.json'))
);
const floors = await getAllFloors(...towers);
const results = parseMinamo(floors);
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
const size = Object.keys(results.data).length;
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`);
} else {
const [tower1, tower2] = list;
if (!tower1 || !tower2) {
console.log(`⚠️ assigned 模式下必须传入两个塔!`);
return;
}
const data1 = await readOne(tower1);
const data2 = await readOne(tower2);
const results = generateAssignedData(data1, data2, assignedCount);
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
const size = Object.keys(results.data).length;
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`);
}
})();

View File

@ -1,36 +0,0 @@
import { SingleBar, Presets } from 'cli-progress';
import { getGinkaRatio } from 'src/floor';
import { GinkaTrainData, GinkaConfig, GinkaDataset } from 'src/types';
import { FloorData } from 'src/utils';
export function parseGinka(data: Map<string, FloorData>) {
const resolved: Record<string, GinkaTrainData> = {};
const progress = new SingleBar({}, Presets.shades_classic);
progress.start(data.size, 0);
let i = 0;
data.forEach((floor, key) => {
const config = floor.config as GinkaConfig;
const data = config.data[floor.id] ?? {
tag: Array(64).fill(0)
};
resolved[key] = {
map: floor.map,
size: [floor.map[0].length, floor.map.length],
tag: data.tag,
val: getGinkaRatio(floor.map)
};
i++;
progress.update(i);
});
const dataset: GinkaDataset = {
datasetId: Math.floor(Math.random() * 1e12),
data: resolved
};
progress.stop();
return dataset;
}

View File

@ -1,406 +0,0 @@
import { SingleBar, Presets } from 'cli-progress';
import { compareMap } from 'src/topology/compare';
import { directions, tileType } from 'src/topology/graph';
import { rotateMap, mirrorMapX, mirrorMapY } from 'src/topology/transform';
import { MinamoTrainData, MinamoDataset } from 'src/types';
import { chooseFrom, FloorData } from 'src/utils';
import { calculateVisualSimilarity } from 'src/vision/similarity';
function chooseN(maxCount: number, n: number) {
return chooseFrom(
Array(maxCount)
.fill(0)
.map((_, i) => i),
n
);
}
function choosePair(n: number, max: number = 1000) {
const totalCount = Math.round((n * (n - 1)) / 2);
const count = Math.min(totalCount, max);
const pairs: number[] = [];
for (let i = 0; i < n; i++) {
for (let j = i + 1; j < n; j++) {
pairs.push(i * n + j);
}
}
// 直接打乱后取前 count 个
for (let i = pairs.length - 1; i > 0; i--) {
let randIndex = Math.floor(Math.random() * (i + 1));
[pairs[i], pairs[randIndex]] = [pairs[randIndex], pairs[i]];
}
return pairs.slice(0, count);
}
function transform(map: number[][], rot: number, flip: number) {
let res = map;
for (let i = 0; i < rot; i++) {
res = rotateMap(res);
}
if (flip & 0b01) {
res = mirrorMapX(res);
}
if (flip & 0b10) {
res = mirrorMapY(res);
}
return res;
}
function generateTransformData(
id1: string,
id2: string,
map1: number[][],
map2: number[][],
simi: number
) {
const types: [rot: number, flip: number][] = [];
for (const rot of [0, 1, 2, 3]) {
for (const flip of [0b00, 0b01, 0b10, 0b11]) {
if (rot === 0 && flip === 0) continue;
types.push([rot, flip]);
}
}
// 随机抽取最多一个
const trans = chooseFrom(types, Math.floor(Math.random() * 1));
return trans
.map(([rot, flip]) => {
const com1 = `${id1}.${rot}.${flip}:${id1}`;
const com2 = `${id1}.${rot}.${flip}:${id2}`;
const com3 = `${id2}.${rot}.${flip}:${id1}`;
const com4 = `${id2}.${rot}.${flip}:${id2}`;
const choose = chooseFrom(
[com1, com2, com3, com4],
Math.floor(Math.random() * 2)
);
const res: [id: string, data: MinamoTrainData][] = [];
if (choose.includes(com1)) {
const t = transform(map1, rot, flip);
res.push([
com1,
{
map1: t,
map2: map1,
topoSimilarity: 1,
visionSimilarity: calculateVisualSimilarity(map1, t),
size: [map1[0].length, map1.length]
}
]);
}
if (choose.includes(com2)) {
const t = transform(map1, rot, flip);
res.push([
com2,
{
map1: t,
map2: map2,
topoSimilarity: simi,
visionSimilarity: calculateVisualSimilarity(t, map2),
size: [map1[0].length, map1.length]
}
]);
}
if (choose.includes(com3)) {
const t = transform(map2, rot, flip);
res.push([
com3,
{
map1: t,
map2: map1,
topoSimilarity: simi,
visionSimilarity: calculateVisualSimilarity(t, map1),
size: [map1[0].length, map1.length]
}
]);
}
if (choose.includes(com4)) {
const t = transform(map2, rot, flip);
res.push([
com4,
{
map1: t,
map2: map2,
topoSimilarity: 1,
visionSimilarity: calculateVisualSimilarity(t, map2),
size: [map1[0].length, map1.length]
}
]);
}
return res;
})
.flat();
}
function generateSimilarData(id: string, map: number[][]) {
// 生成最多两个微调地图
const width = map[0].length;
const height = map.length;
const num = Math.floor(Math.random() * 2);
const res: [id: string, data: MinamoTrainData][] = [];
for (let i = 0; i < num; i++) {
const clone = map.map(v => v.slice());
const prob = Math.random() * 0.3;
for (let ny = 0; ny < height; ny++) {
for (let nx = 0; nx < width; nx++) {
if (Math.random() > prob) {
// 有一定的概率进行微调
continue;
}
if (Math.random() < 0.2) {
// 20% 概率与旁边图块互换位置
const [dx, dy] =
directions[
Math.floor(Math.random() * directions.length)
];
const px = nx + dx;
const py = ny + dy;
if (px < 0 || px >= width || py < 0 || py >= height) {
continue;
}
[clone[ny][nx], clone[py][px]] = [
clone[py][px],
clone[ny][nx]
];
} else {
// 80% 概率替换当前图块
clone[ny][nx] = Math.floor(Math.random() * tileType.size);
}
}
}
const id2 = `${id}.S${i}`;
const sid = `${id}:${id2}`;
const simi = compareMap(id, id2, map, clone);
res.push([
sid,
{
map1: map,
map2: clone,
size: [width, height],
topoSimilarity: simi,
visionSimilarity: calculateVisualSimilarity(map, clone)
}
]);
}
return res;
}
export function generateTrainData(
id1: string,
id2: string,
map1: number[][],
map2: number[][],
size: [number, number],
hasSelf: boolean = true,
hasTransform: boolean = true,
hasSimilar: boolean = true
) {
const topoSimilarity = compareMap(id1, id2, map1, map2);
const visionSimilarity = calculateVisualSimilarity(map1, map2);
const train: MinamoTrainData = {
map1,
map2,
topoSimilarity,
visionSimilarity,
size: size
};
const data: MinamoTrainData[] = [];
data.push(train);
if (hasSelf) {
// 自身与自身对比的训练集,保证模型对相同地图输出 1
const self1 = `${id1}:${id1}`;
const self2 = `${id2}:${id2}`;
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
if (selfTrain.includes(self1)) {
const selfTrain1: MinamoTrainData = {
map1: map1,
map2: map1,
topoSimilarity: 1,
visionSimilarity: 1,
size: size
};
data.push(selfTrain1);
}
if (selfTrain.includes(self2)) {
const selfTrain2: MinamoTrainData = {
map1: map2,
map2: map2,
topoSimilarity: 1,
visionSimilarity: 1,
size: size
};
data.push(selfTrain2);
}
}
if (hasTransform) {
const transform = generateTransformData(
id1,
id2,
map1,
map2,
topoSimilarity
);
data.push(...transform.map(v => v[1]))
}
if (hasSimilar) {
const similar = generateSimilarData(id1, map1);
data.push(...similar.map(v => v[1]))
}
return data;
}
export function generatePair(
data: Record<string, MinamoTrainData>,
id1: string,
id2: string,
map1: number[][],
map2: number[][],
size: [number, number]
) {
const topoSimilarity = compareMap(id1, id2, map1, map2);
const visionSimilarity = calculateVisualSimilarity(map1, map2);
const train: MinamoTrainData = {
map1,
map2,
topoSimilarity,
visionSimilarity,
size: size
};
data[`${id1}:${id2}`] = train;
// 自身与自身对比的训练集,保证模型对相同地图输出 1
const self1 = `${id1}:${id1}`;
const self2 = `${id2}:${id2}`;
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) {
const selfTrain1: MinamoTrainData = {
map1: map1,
map2: map1,
topoSimilarity: 1,
visionSimilarity: 1,
size: size
};
data[`${id1}:${id1}`] = selfTrain1;
}
if (selfTrain.includes(self2) && !data[`${id2}:${id2}`]) {
const selfTrain2: MinamoTrainData = {
map1: map2,
map2: map2,
topoSimilarity: 1,
visionSimilarity: 1,
size: size
};
data[`${id2}:${id2}`] = selfTrain2;
}
// 翻转、旋转训练集
Object.assign(
data,
Object.fromEntries(
generateTransformData(id1, id2, map1, map2, topoSimilarity)
)
);
// 地图微调训练集
Object.assign(data, Object.fromEntries(generateSimilarData(id1, map1)));
}
function generateDataset(
floors: Map<string, FloorData>,
pairs: number[],
floorIds: string[]
): Record<string, MinamoTrainData> {
const data: Record<string, MinamoTrainData> = {};
const progress = new SingleBar({}, Presets.shades_classic);
progress.start(pairs.length, 0);
pairs.forEach((v, i) => {
const num1 = Math.floor(v / floorIds.length);
const num2 = v % floorIds.length;
const id1 = floorIds[num1];
const id2 = floorIds[num2];
const map1 = floors.get(id1)?.map;
const map2 = floors.get(id2)?.map;
if (!map1 || !map2) return;
const [w1, h1] = [map1[0].length, map1.length];
const [w2, h2] = [map2[0].length, map2.length];
if (w1 !== w2 || h1 !== h2) return;
generatePair(data, id1, id2, map1, map2, [w1, h1]);
progress.update(i + 1);
});
progress.stop();
return data;
}
export function parseMinamo(data: Map<string, FloorData>): MinamoDataset {
const length = data.size;
const totalCount = Math.round((length * (length - 1)) / 2);
const pairs = choosePair(length, 10000);
console.log(
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${pairs.length} 个组合`
);
const trainData = generateDataset(data, pairs, [...data.keys()]);
const dataset: MinamoDataset = {
datasetId: Math.floor(Math.random() * 1e12),
data: trainData
};
return dataset;
}
export function generateAssignedData(
data1: Map<string, FloorData>,
data2: Map<string, FloorData>,
count: [number, number]
): MinamoDataset {
const length = data1.size + data2.size;
const totalCount = data1.size * data2.size;
const count1 = Math.min(count[0], data1.size);
const count2 = Math.min(count[1], data2.size);
const keys1 = [...data1.keys()];
const keys2 = [...data2.keys()];
const choose1 = chooseFrom(keys1, count1);
const trainData: Record<string, MinamoTrainData> = {};
console.log(
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${
count1 * count2
} `
);
const progress = new SingleBar({}, Presets.shades_classic);
progress.start(count1 * count2, 0);
let n = 0;
for (const key1 of choose1) {
const choose2 = chooseFrom(keys2, count2);
for (const key2 of choose2) {
const { map: map1 } = data1.get(key1)!;
const { map: map2 } = data2.get(key2)!;
if (!map1 || !map2) continue;
const [w1, h1] = [map1[0].length, map1.length];
const [w2, h2] = [map2[0].length, map2.length];
if (w1 !== w2 || h1 !== h2) continue;
generatePair(trainData, key1, key2, map1, map2, [w1, h1]);
n++;
progress.update(n);
}
}
progress.stop();
const dataset: MinamoDataset = {
datasetId: Math.floor(Math.random() * 1e12),
data: trainData
};
return dataset;
}

View File

@ -1,38 +0,0 @@
import { readFile, writeFile } from 'fs-extra';
import { chooseFrom, DatasetMergable, mergeDataset } from './utils';
const [target, ...review] = process.argv.slice(2);
const n = getNum();
function getNum() {
const last = review.at(-1);
if (!last) return 1000;
else {
const n = parseInt(last);
if (!n) return 1000;
else {
review.pop();
return n;
}
}
}
(async () => {
const datas = await Promise.all(
review.map(async v => {
const file = await readFile(v, 'utf-8');
return JSON.parse(file) as DatasetMergable<any>;
})
);
const targetFile = await readFile(target, 'utf-8');
const targetData = JSON.parse(targetFile) as DatasetMergable<any>;
const merged = mergeDataset(true, ...datas);
const keys = Object.keys(merged.data);
const toReview = chooseFrom(keys, n);
const reviewData: DatasetMergable<any> = {
datasetId: Math.floor(Math.random() * 1e12),
data: Object.fromEntries(toReview.map(v => [v, merged.data[v]]))
};
const reviewed = mergeDataset(false, targetData, reviewData);
await writeFile(target, JSON.stringify(reviewed), 'utf-8');
})();

View File

@ -1,95 +0,0 @@
import { buildTopologicalGraph } from './graph';
import { mirrorMapX, mirrorMapY, rotateMap } from './transform';
import { overallSimilarity } from './similarity';
(() => {
// MT3
const map1 = [
[1, 1, 1, 1, 1, 1, 1],
[1, 3, 1, 10, 7, 3, 1],
[1, 6, 1, 5, 1, 7, 1],
[1, 9, 1, 8, 1, 6, 1],
[1, 5, 8, 0, 1, 7, 1],
[1, 2, 1, 5, 7, 10, 1],
[1, 1, 1, 1, 1, 1, 1]
];
// MT6
const map2 = [
[1, 1, 1, 1, 1, 1, 1],
[1, 5, 6, 4, 1, 10, 1],
[1, 6, 1, 9, 1, 5, 1],
[1, 8, 0, 6, 0, 8, 1],
[1, 5, 1, 10, 1, 2, 1],
[1, 9, 3, 1, 4, 9, 1],
[1, 1, 1, 1, 1, 1, 1]
];
// MT8
// const map2 = [
// [1, 1, 1, 1, 1, 1, 1],
// [1, 5, 8, 10, 7, 2, 1],
// [1, 2, 1, 5, 1, 7, 1],
// [1, 3, 1, 3, 6, 4, 1],
// [1, 6, 1, 6, 1, 8, 1],
// [1, 10, 7, 5, 1, 5, 1],
// [1, 1, 1, 1, 1, 1, 1]
// ];
// MT3 微调
// const map2 = [
// [1, 1, 1, 1, 1, 1, 1],
// [1, 3, 1, 10, 7, 3, 1],
// [1, 6, 1, 5, 1, 7, 1],
// [1, 9, 1, 8, 1, 6, 1],
// [1, 5, 8, 0, 1, 7, 1],
// [1, 2, 1, 4, 7, 10, 1],
// [1, 1, 1, 1, 1, 1, 1]
// ];
// 1. 两张图与自身对比
const graph1 = buildTopologicalGraph(map1);
const graph2 = buildTopologicalGraph(map2);
console.log(`map1 vs map1: ${overallSimilarity(graph1, graph1)}`);
console.log(`map2 vs map2: ${overallSimilarity(graph2, graph2)}`);
// 2. 两张图相互对比
console.log(`map1 vs map2: ${overallSimilarity(graph1, graph2)}`);
console.log(`map2 vs map1: ${overallSimilarity(graph2, graph1)}`);
// 3. x镜像对比
const xFlipped1 = mirrorMapX(map1);
const xFlipped2 = mirrorMapX(map2);
const graphX1 = buildTopologicalGraph(xFlipped1);
const graphX2 = buildTopologicalGraph(xFlipped2);
console.log(`map1:x vs map1: ${overallSimilarity(graphX1, graph1)}`);
console.log(`map1:x vs map2: ${overallSimilarity(graphX1, graph2)}`);
console.log(`map1 vs map2:x: ${overallSimilarity(graph1, graphX2)}`);
console.log(`map2:x vs map2: ${overallSimilarity(graphX2, graph2)}`);
console.log(`map2:x vs map1: ${overallSimilarity(graphX2, graph2)}`);
console.log(`map2 vs map1:x: ${overallSimilarity(graph2, graphX1)}`);
// 4. y镜像对比
const yFlipped1 = mirrorMapY(map1);
const yFlipped2 = mirrorMapY(map2);
const graphY1 = buildTopologicalGraph(yFlipped1);
const graphY2 = buildTopologicalGraph(yFlipped2);
console.log(`map1:y vs map1: ${overallSimilarity(graphY1, graph1)}`);
console.log(`map1:y vs map2: ${overallSimilarity(graphY1, graph2)}`);
console.log(`map1 vs map2:y: ${overallSimilarity(graph1, graphY2)}`);
console.log(`map2:y vs map2: ${overallSimilarity(graphY2, graph2)}`);
console.log(`map2:y vs map1: ${overallSimilarity(graphY2, graph1)}`);
console.log(`map2 vs map1:y: ${overallSimilarity(graph2, graphY1)}`);
// 5. xy 镜像混合对比
console.log(`map1:x vs map1:y: ${overallSimilarity(graphX1, graphY1)}`);
console.log(`map1:y vs map2:x: ${overallSimilarity(graphY1, graphX2)}`);
console.log(`map1:x vs map2:x: ${overallSimilarity(graphX1, graphX2)}`);
console.log(`map1:x vs map2:y: ${overallSimilarity(graphX1, graphY2)}`);
// 6. 旋转对比
const rot901 = rotateMap(map1);
const rot902 = rotateMap(map2);
const graph901 = buildTopologicalGraph(rot901);
const graph902 = buildTopologicalGraph(rot902);
console.log(`map1:90 vs map1: ${overallSimilarity(graph1, graph901)}`);
console.log(`map2:90 vs map2: ${overallSimilarity(graph2, graph902)}`);
})();

View File

@ -1,13 +0,0 @@
export function mirrorMapX(map: number[][]) {
return map.map(v => [...v].reverse());
}
export function mirrorMapY(map: number[][]) {
return [...map].reverse();
}
export function rotateMap(map: number[][]) {
return [
...map[0].map((_, colIndex) => map.map(row => row[colIndex]))
].reverse();
}

View File

@ -1,7 +1,6 @@
import { readFile } from 'fs-extra'; import { readFile } from 'fs-extra';
import { join } from 'path'; import { join } from 'path';
import { BaseConfig, GinkaConfig, TowerInfo } from './types'; import { BaseConfig, GinkaConfig, TowerInfo } from './types';
import { convertFloor } from './floor';
export interface DatasetMergable<T> { export interface DatasetMergable<T> {
datasetId: number; datasetId: number;
@ -77,81 +76,6 @@ export async function parseTowerInfo(
}; };
} }
export async function getAllFloors(...info: TowerInfo[]) {
const floorData = await Promise.all(
info.map(async tower => {
// 获取必要信息
const enemyFile = await readFile(
join(tower.path, 'enemys.js'),
'utf-8'
);
const mapFile = await readFile(
join(tower.path, 'maps.js'),
'utf-8'
);
const enemyMap = JSON.parse(
enemyFile.split('\n').slice(1).join('\n')
) as Record<string, any>;
const mapData = JSON.parse(
mapFile.split('\n').slice(1).join('\n')
) as Record<number, any>;
const enemyNumMap: Record<number, any> = {};
// 将怪物转化为数字映射
for (const [key, value] of Object.entries(mapData)) {
if (value.cls === 'enemys') {
enemyNumMap[parseInt(key)] = enemyMap[value.id];
}
}
return Promise.all(
tower.floorIds.map(async id => {
const floorFile = await readFile(
join(tower.path, 'floors', `${id}.js`),
'utf-8'
);
try {
const data = JSON.parse(
floorFile
// .replaceAll("'", '"')
.slice(floorFile.indexOf('=') + 1)
);
const map = data.map as number[][];
// 裁剪地图
const { clip } = tower.config;
const area = clip.special[id] ?? clip.defaults;
return convertFloor(
map,
area,
tower.config as GinkaConfig,
enemyNumMap
);
} catch (e) {
console.log(
`Error when processing '${tower.name}' '${id}'`
);
throw e;
}
})
);
})
);
const maps: Map<string, FloorData> = new Map();
floorData.forEach((tower, tid) => {
const name = info[tid].name;
tower.forEach((map, mid) => {
const floorId = info[tid].floorIds[mid];
maps.set(`${name}::${floorId}`, {
map,
id: floorId,
config: info[tid].config
});
});
});
return maps;
}
export function mergeFloorIds(...info: TowerInfo[]) { export function mergeFloorIds(...info: TowerInfo[]) {
const ids: string[] = []; const ids: string[] = [];
info.forEach(v => { info.forEach(v => {
@ -160,14 +84,6 @@ export function mergeFloorIds(...info: TowerInfo[]) {
return ids; return ids;
} }
export async function readOne(path: string) {
if (path.endsWith('.json')) {
return fromJSON(path);
} else {
return getAllFloors(await parseTowerInfo(path, 'minamo-config.json'));
}
}
export async function fromJSON(path: string) { export async function fromJSON(path: string) {
const file = await readFile(path, 'utf-8'); const file = await readFile(path, 'utf-8');
const data = JSON.parse(file) as Record<string, number[][]>; const data = JSON.parse(file) as Record<string, number[][]>;

View File

@ -1,145 +0,0 @@
interface VisualSimilarityConfig {
// 类型重要性权重(需根据游戏设定调整)
typeWeights: { [key: number]: number };
// 是否启用视觉焦点增强
enableVisualFocus: boolean;
// 是否启用密度感知
enableDensityAwareness: boolean;
}
const DEFAULT_CONFIG: VisualSimilarityConfig = {
typeWeights: {
0: 0.2, // 空地
1: 0.3, // 墙壁
2: 0.6, // 钥匙
3: 0.7, // 红宝石
4: 0.7, // 蓝宝石
5: 0.5, // 血瓶
6: 0.4, // 门
7: 0.5, // 弱怪
8: 0.6, // 中怪
9: 0.6, // 强怪
10: 0.4, // 楼梯
11: 0.4, // 箭头
12: 0.7 // 道具
},
enableVisualFocus: true,
enableDensityAwareness: true
};
export function calculateVisualSimilarity(
map1: number[][],
map2: number[][],
config = DEFAULT_CONFIG
): number {
// 尺寸校验
if (map1.length !== map2.length || map1[0]?.length !== map2[0]?.length) {
return 0; // 或抛出异常
}
const rows = map1.length;
const cols = map1[0].length;
let totalScore = 0;
let maxPossibleScore = 0;
// 视觉焦点权重图
const focusWeights = config.enableVisualFocus
? generateFocusWeights(rows, cols)
: Array(rows)
.fill(1)
.map(() => Array(cols).fill(1));
// 类型密度分布计算
const densityMap = config.enableDensityAwareness
? calculateDensityImpact(map1, map2, config.typeWeights)
: Array(rows)
.fill(1)
.map(() => Array(cols).fill(1));
for (let i = 0; i < rows; i++) {
for (let j = 0; j < cols; j++) {
const type1 = map1[i][j];
const type2 = map2[i][j];
// 基础类型权重
const baseWeight = Math.max(
config.typeWeights[type1] || 0.5,
config.typeWeights[type2] || 0.5
);
// 空间权重组合
const spatialWeight = focusWeights[i][j] * densityMap[i][j];
// 类型匹配得分
const typeScore = type1 === type2 ? 1 : 0;
totalScore += typeScore * baseWeight * spatialWeight;
maxPossibleScore += baseWeight * spatialWeight;
}
}
return maxPossibleScore > 0 ? totalScore / maxPossibleScore : 0;
}
/**
*
*/
function generateFocusWeights(rows: number, cols: number): number[][] {
const weights = [];
const centerX = cols / 2;
const centerY = rows / 2;
const maxDist = Math.sqrt(centerX ** 2 + centerY ** 2) * 0.7;
for (let i = 0; i < rows; i++) {
const rowWeights = [];
for (let j = 0; j < cols; j++) {
// 使用高斯分布模拟视觉焦点
const dx = (j - centerX) / cols;
const dy = (i - centerY) / rows;
const distance = Math.sqrt(dx ** 2 + dy ** 2);
const gaussian = Math.exp(-(distance ** 2) / (2 * 0.3 ** 2));
rowWeights.push(1.0 + 0.6 * gaussian); // 中心区域最高1.6倍权重
}
weights.push(rowWeights);
}
return weights;
}
/**
*
*/
function calculateDensityImpact(
map1: number[][],
map2: number[][],
typeWeights: { [key: number]: number }
): number[][] {
const rows = map1.length;
const cols = map1[0].length;
const densityMap = Array(rows)
.fill(0)
.map(() => Array(cols).fill(0));
// 滑动窗口分析局部密度
const windowSize = 3;
const halfWindow = Math.floor(windowSize / 2);
for (let i = 0; i < rows; i++) {
for (let j = 0; j < cols; j++) {
let density = 0;
for (let di = -halfWindow; di <= halfWindow; di++) {
for (let dj = -halfWindow; dj <= halfWindow; dj++) {
const ni = i + di;
const nj = j + dj;
if (ni >= 0 && ni < rows && nj >= 0 && nj < cols) {
const weight1 = typeWeights[map1[ni][nj]] || 0.5;
const weight2 = typeWeights[map2[ni][nj]] || 0.5;
density += (weight1 + weight2) / 2;
}
}
}
// 密度权重:高密度区域增强对比度
densityMap[i][j] = 1.0 + 0.4 * (density / windowSize ** 2);
}
}
return densityMap;
}

View File

@ -1,74 +0,0 @@
import { calculateVisualSimilarity } from './similarity';
(() => {
// MT3
const map1 = [
[1, 1, 1, 1, 1, 1, 1],
[1, 3, 1, 10, 7, 3, 1],
[1, 6, 1, 5, 1, 7, 1],
[1, 9, 1, 8, 1, 6, 1],
[1, 5, 8, 0, 1, 7, 1],
[1, 2, 1, 5, 7, 10, 1],
[1, 1, 1, 1, 1, 1, 1]
];
// MT6
const map2 = [
[1, 1, 1, 1, 1, 1, 1],
[1, 5, 6, 4, 1, 10, 1],
[1, 6, 1, 9, 1, 5, 1],
[1, 8, 0, 6, 0, 8, 1],
[1, 5, 1, 10, 1, 2, 1],
[1, 9, 3, 1, 4, 9, 1],
[1, 1, 1, 1, 1, 1, 1]
];
// MT8
const map3 = [
[1, 1, 1, 1, 1, 1, 1],
[1, 5, 8, 10, 7, 2, 1],
[1, 2, 1, 5, 1, 7, 1],
[1, 3, 1, 3, 6, 4, 1],
[1, 6, 1, 6, 1, 8, 1],
[1, 10, 7, 5, 1, 5, 1],
[1, 1, 1, 1, 1, 1, 1]
];
// MT3 微调
const map4 = [
[1, 1, 1, 1, 1, 1, 1],
[1, 3, 1, 10, 7, 3, 1],
[1, 6, 1, 5, 1, 7, 1],
[1, 9, 1, 8, 1, 6, 1],
[1, 5, 8, 0, 1, 7, 1],
[1, 2, 1, 4, 7, 10, 1],
[1, 1, 1, 1, 1, 1, 1]
];
// MT10
const map5 = [
[1, 1, 1, 1, 1, 1, 1],
[1, 5, 1, 10, 1, 5, 1],
[1, 6, 7, 7, 7, 6, 1],
[1, 1, 6, 5, 6, 1, 1],
[1, 4, 5, 9, 5, 4, 1],
[1, 3, 1, 1, 1, 3, 1],
[1, 1, 1, 1, 1, 1, 1]
];
// 测试自我对比
console.log(`map1 vs map1: ${calculateVisualSimilarity(map1, map1)}`);
console.log(`map2 vs map2: ${calculateVisualSimilarity(map2, map2)}`);
console.log(`map3 vs map3: ${calculateVisualSimilarity(map3, map3)}`);
console.log(`map4 vs map4: ${calculateVisualSimilarity(map4, map4)}`);
// 两两测试
console.log(`map1 vs map2: ${calculateVisualSimilarity(map1, map2)}`);
console.log(`map1 vs map3: ${calculateVisualSimilarity(map1, map3)}`);
console.log(`map1 vs map4: ${calculateVisualSimilarity(map1, map4)}`);
console.log(`map1 vs map5: ${calculateVisualSimilarity(map1, map5)}`);
console.log(`map2 vs map3: ${calculateVisualSimilarity(map2, map3)}`);
console.log(`map2 vs map4: ${calculateVisualSimilarity(map2, map4)}`);
console.log(`map2 vs map5: ${calculateVisualSimilarity(map2, map5)}`);
console.log(`map3 vs map4: ${calculateVisualSimilarity(map3, map4)}`);
console.log(`map3 vs map5: ${calculateVisualSimilarity(map3, map5)}`);
console.log(`map4 vs map5: ${calculateVisualSimilarity(map4, map5)}`);
// 测试交换性
console.log(`map2 vs map1: ${calculateVisualSimilarity(map2, map1)}`);
console.log(`map4 vs map2: ${calculateVisualSimilarity(map4, map2)}`);
})();

View File

@ -1,133 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TransformerConv
from torch_geometric.utils import grid
def batch_edge_index(B, edge_index, num_nodes_per_batch):
# 批次偏移 edge_index
edge_index = edge_index.clone() # [2, E]
batch_edge_index = []
for i in range(B):
offset = i * num_nodes_per_batch
batch_edge_index.append(edge_index + offset)
return torch.cat(batch_edge_index, dim=1)
class DoubleConvBlock(nn.Module):
def __init__(self, feats: tuple[int, int, int]):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(feats[1]),
nn.GELU(),
nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(feats[2]),
nn.GELU(),
)
def forward(self, x):
x = self.cnn(x)
return x
class GCNBlock(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
super().__init__()
self.conv1 = GCNConv(in_ch, hidden_ch)
self.conv2 = GCNConv(hidden_ch, hidden_ch)
self.conv3 = GCNConv(hidden_ch, out_ch)
self.norm1 = nn.LayerNorm(hidden_ch)
self.norm2 = nn.LayerNorm(hidden_ch)
self.norm3 = nn.LayerNorm(out_ch)
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
def forward(self, x):
# x: [B, C, H, W]
B, C, H, W = x.shape
# Reshape to [B * H * W, C]
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
# Construct batched edge index
device = x.device
edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W)
# Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling)
# batch = torch.arange(B, device=device).repeat_interleave(H * W)
# GCN forward
x = self.conv1(x, edge_index)
x = F.gelu(self.norm1(x))
x = self.conv2(x, edge_index)
x = F.gelu(self.norm2(x))
x = self.conv3(x, edge_index)
x = F.gelu(self.norm3(x))
# Reshape back to [B, C, H, W]
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
return x
class TransformerGCNBlock(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
super().__init__()
self.conv1 = TransformerConv(in_ch, hidden_ch // 8, heads=8, concat=True)
self.conv2 = TransformerConv(hidden_ch, out_ch, heads=1)
self.norm1 = nn.LayerNorm(hidden_ch)
self.norm2 = nn.LayerNorm(out_ch)
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
def forward(self, x):
# x: [B, C, H, W]
B, C, H, W = x.shape
# Reshape to [B * H * W, C]
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
# Construct batched edge index
device = x.device
edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W)
# Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling)
# batch = torch.arange(B, device=device).repeat_interleave(H * W)
# GCN forward
x = self.conv1(x, edge_index)
x = F.gelu(self.norm1(x))
x = self.conv2(x, edge_index)
x = F.gelu(self.norm2(x))
# Reshape back to [B, C, H, W]
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
return x
class ConvFusionModule(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int):
super().__init__()
self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch])
self.gcn = TransformerGCNBlock(in_ch, hidden_ch, in_ch, w, h)
self.fusion = DoubleConvBlock([in_ch*2, hidden_ch, out_ch])
def forward(self, x):
x1 = self.cnn(x)
x2 = self.gcn(x)
x = torch.cat([x1, x2], dim=1)
x = self.fusion(x)
return x
class DoubleFCModule(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, out_dim),
nn.LayerNorm(out_dim),
nn.GELU()
)
def forward(self, x):
x = self.fc(x)
return x

View File

@ -1,50 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .common import DoubleFCModule
class ConditionEncoder(nn.Module):
def __init__(self, tag_dim=64, val_dim=16, hidden_dim=256, out_dim=256):
super().__init__()
self.tag_embed = DoubleFCModule(tag_dim, hidden_dim, hidden_dim)
self.val_embed = DoubleFCModule(val_dim, hidden_dim, hidden_dim)
self.stage_embed = DoubleFCModule(1, hidden_dim, hidden_dim)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
batch_first=True
),
num_layers=4
)
self.fusion = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, out_dim)
)
def forward(self, tag, val, stage):
# tag = self.tag_embed(tag)
val = self.val_embed(val)
stage = self.stage_embed(stage)
feat = torch.stack([val, stage], dim=1)
feat = self.encoder(feat)
feat = torch.mean(feat, dim=1)
feat = self.fusion(feat)
return feat
class ConditionInjector(nn.Module):
def __init__(self, cond_dim, out_dim):
super().__init__()
self.gamma_layer = nn.Sequential(
nn.Linear(cond_dim, out_dim)
)
self.beta_layer = nn.Sequential(
nn.Linear(cond_dim, out_dim)
)
def forward(self, x, cond):
gamma = self.gamma_layer(cond).unsqueeze(2).unsqueeze(3)
beta = self.beta_layer(cond).unsqueeze(2).unsqueeze(3)
return x * gamma + beta

View File

@ -1,292 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch_geometric.nn import global_max_pool, GCNConv, TransformerConv
from torch_geometric.utils import grid
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
from .vision import MinamoVisionModel
from .topo import MinamoTopoModel
def print_memory(tag=""):
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
def batch_edge_index(B, edge_index, num_nodes_per_batch):
# 批次偏移 edge_index
edge_index = edge_index.clone() # [2, E]
batch_edge_index = []
for i in range(B):
offset = i * num_nodes_per_batch
batch_edge_index.append(edge_index + offset)
return torch.cat(batch_edge_index, dim=1)
class DoubleConvBlock(nn.Module):
def __init__(self, feats: tuple[int, int, int]):
super().__init__()
self.cnn = nn.Sequential(
spectral_norm(nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate')),
nn.GELU(),
spectral_norm(nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate')),
nn.GELU(),
)
def forward(self, x):
x = self.cnn(x)
return x
class TransformerGCNBlock(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
super().__init__()
self.conv1 = TransformerConv(in_ch, hidden_ch // 8, heads=8, concat=True)
self.conv2 = TransformerConv(hidden_ch, out_ch, heads=1)
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
def forward(self, x):
B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
device = x.device
edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W)
x = self.conv1(x, edge_index)
x = F.gelu(x)
x = self.conv2(x, edge_index)
x = F.gelu(x)
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
return x
class ConvFusionModule(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int):
super().__init__()
self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch])
self.gcn = TransformerGCNBlock(in_ch, hidden_ch, in_ch, w, h)
self.fusion = DoubleConvBlock([in_ch*2, hidden_ch, out_ch])
def forward(self, x):
x1 = self.cnn(x)
x2 = self.gcn(x)
x = torch.cat([x1, x2], dim=1)
x = self.fusion(x)
return x
class DoubleFCModule(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_dim, hidden_dim)),
nn.GELU(),
spectral_norm(nn.Linear(hidden_dim, out_dim)),
nn.GELU()
)
def forward(self, x):
x = self.fc(x)
return x
class ConditionEncoder(nn.Module):
def __init__(self, tag_dim, val_dim, hidden_dim, out_dim):
super().__init__()
self.tag_embed = DoubleFCModule(tag_dim, hidden_dim, hidden_dim)
self.val_embed = DoubleFCModule(val_dim, hidden_dim, hidden_dim)
self.stage_embed = DoubleFCModule(1, hidden_dim, hidden_dim)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
batch_first=True
),
num_layers=4
)
self.fusion = nn.Sequential(
spectral_norm(nn.Linear(hidden_dim, hidden_dim)),
nn.GELU(),
spectral_norm(nn.Linear(hidden_dim, out_dim))
)
def forward(self, tag, val, stage):
tag = self.tag_embed(tag)
val = self.val_embed(val)
stage = self.stage_embed(stage)
feat = torch.stack([tag, val, stage], dim=1)
feat = self.encoder(feat)
feat = torch.mean(feat, dim=1)
feat = self.fusion(feat)
return feat
class ConditionInjector(nn.Module):
def __init__(self, cond_dim, out_dim):
super().__init__()
self.gamma_layer = nn.Sequential(
spectral_norm(nn.Linear(cond_dim, out_dim))
)
self.beta_layer = nn.Sequential(
spectral_norm(nn.Linear(cond_dim, out_dim))
)
def forward(self, x, cond):
gamma = self.gamma_layer(cond).unsqueeze(2).unsqueeze(3)
beta = self.beta_layer(cond).unsqueeze(2).unsqueeze(3)
return x * gamma + beta
class CNNHead(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.cnn = nn.Sequential(
spectral_norm(nn.Conv2d(in_ch, in_ch, 3)),
nn.GELU(),
nn.AdaptiveMaxPool2d((2, 2))
)
self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_ch*2*2, 1))
)
self.proj = spectral_norm(nn.Linear(256, in_ch*2*2))
def forward(self, x, cond):
x = self.cnn(x)
B, C, H, W = x.shape
x = x.view(B, -1)
cond = self.proj(cond)
proj = torch.sum(x * cond, dim=1, keepdim=True)
x = self.fc(x) + proj
return x
class GCNHead(nn.Module):
def __init__(self, in_dim):
super().__init__()
self.gcn = GCNConv(in_dim, in_dim)
self.proj = spectral_norm(nn.Linear(256, in_dim))
self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_dim, 1))
)
def forward(self, x, graph, cond):
x = self.gcn(x, graph.edge_index)
x = F.gelu(x)
x = global_max_pool(x, graph.batch)
cond = self.proj(cond)
proj = torch.sum(x * cond, dim=1, keepdim=True)
x = self.fc(x) + proj
return x
class MinamoScoreHead(nn.Module):
def __init__(self, vision_dim, topo_dim):
super().__init__()
self.vision_head = CNNHead(vision_dim)
self.topo_head = GCNHead(topo_dim)
def forward(self, vis, topo, graph, cond):
vis_score = self.vision_head(vis, cond)
topo_score = self.topo_head(topo, graph, cond)
return vis_score, topo_score
class MinamoModel(nn.Module):
def __init__(self, tile_types=32):
super().__init__()
self.topo_model = MinamoTopoModel(tile_types)
self.vision_model = MinamoVisionModel(tile_types)
self.cond = ConditionEncoder(64, 16, 256, 256)
# 输出层
self.head1 = MinamoScoreHead(512, 512)
self.head2 = MinamoScoreHead(512, 512)
self.head3 = MinamoScoreHead(512, 512)
def forward(self, map, graph, stage, tag_cond, val_cond):
B, D = tag_cond.shape
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(map.device)
vision = self.vision_model(map)
topo = self.topo_model(graph)
cond = self.cond(tag_cond, val_cond, stage_tensor)
if stage == 1:
vision_score, topo_score = self.head1(vision, topo, graph, cond)
elif stage == 2:
vision_score, topo_score = self.head2(vision, topo, graph, cond)
elif stage == 3:
vision_score, topo_score = self.head3(vision, topo, graph, cond)
else:
raise RuntimeError("Unknown critic stage.")
score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score
return score, vision_score, topo_score
class MinamoHead2(nn.Module):
def __init__(self, in_ch, hidden_ch):
super().__init__()
self.conv = ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13)
self.pool = nn.AdaptiveMaxPool2d(1)
self.proj = spectral_norm(nn.Linear(256, hidden_ch))
self.fc = spectral_norm(nn.Linear(hidden_ch, 1))
def forward(self, x, cond):
x = self.conv(x)
x = self.pool(x)
x = x.squeeze(3).squeeze(2)
cond = self.proj(cond)
proj = torch.sum(x * cond, dim=1, keepdim=True)
x = self.fc(x) + proj
return x
class MinamoModel2(nn.Module):
def __init__(self, tile_types=32):
super().__init__()
self.cond = ConditionEncoder(64, 16, 256, 256)
self.conv1 = ConvFusionModule(tile_types, 256, 256, 13, 13)
self.conv2 = ConvFusionModule(256, 512, 256, 13, 13)
self.conv3 = ConvFusionModule(256, 512, 256, 13, 13)
self.head0 = MinamoHead2(256, 256) # 随机头的判别头
self.head1 = MinamoHead2(256, 256)
self.head2 = MinamoHead2(256, 256)
self.head3 = MinamoHead2(256, 256)
# self.inject1 = ConditionInjector(256, 256)
# self.inject2 = ConditionInjector(256, 256)
self.inject3 = ConditionInjector(256, 256)
def forward(self, x, stage, tag_cond, val_cond):
B, D = tag_cond.shape
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device)
cond = self.cond(tag_cond, val_cond, stage_tensor)
x = self.conv1(x)
# x = self.inject1(x, cond)
x = self.conv2(x)
# x = self.inject2(x, cond)
x = self.conv3(x)
x = self.inject3(x, cond)
if stage == 0:
score = self.head0(x, cond)
elif stage == 1:
score = self.head1(x, cond)
elif stage == 2:
score = self.head2(x, cond)
elif stage == 3:
score = self.head3(x, cond)
else:
raise RuntimeError("Unknown critic stage.")
return score
# 检查显存占用
if __name__ == "__main__":
input = torch.randn((1, 32, 13, 13)).cuda()
tag = torch.rand(1, 64).cuda()
val = torch.rand(1, 16).cuda()
# 初始化模型
model = MinamoModel2().cuda()
print_memory("初始化后")
# 前向传播
output = model(input, 1, tag, val)
print_memory("前向传播后")
print(f"输入形状: feat={input.shape}")
print(f"输出形状: output={output.shape}")
# print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}")
# print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}")
print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}")
print(f"Head parameters: {sum(p.numel() for p in model.head1.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -1,36 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch_geometric.nn import GATConv, TransformerConv
from torch_geometric.data import Data
class MinamoTopoModel(nn.Module):
def __init__(
self, tile_types=32, emb_dim=128, hidden_dim=128, out_dim=512
):
super().__init__()
# 传入 softmax 概率值,直接映射
self.input_proj = nn.Sequential(
spectral_norm(nn.Linear(tile_types, emb_dim)),
nn.LeakyReLU(0.2)
)
# 图卷积层
self.conv1 = TransformerConv(emb_dim, hidden_dim, heads=8)
self.conv2 = TransformerConv(hidden_dim*8, hidden_dim, heads=8)
self.conv3 = TransformerConv(hidden_dim*8, out_dim, heads=1)
def forward(self, graph: Data):
x = self.input_proj(graph.x)
x = self.conv1(x, graph.edge_index)
x = F.leaky_relu(x, 0.2)
x = self.conv2(x, graph.edge_index)
x = F.leaky_relu(x, 0.2)
x = self.conv3(x, graph.edge_index)
x = F.leaky_relu(x, 0.2)
return x

View File

@ -1,22 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
class MinamoVisionModel(nn.Module):
def __init__(self, in_ch=32, out_ch=512):
super().__init__()
self.conv = nn.Sequential(
spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3)), # 11*11
nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(in_ch*2, in_ch*8, 3)), #9*9
nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(in_ch*8, out_ch, 3)), # 7*7
nn.LeakyReLU(0.2),
)
def forward(self, x):
x = self.conv(x)
return x

View File

@ -8,13 +8,6 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from typing import List from typing import List
STAGE1_MASK = [0, 1, 2, 29, 30]
STAGE1_REMOVE = list(range(3, 29))
STAGE2_MASK = [3, 4, 5, 6, 26, 27, 28]
STAGE2_REMOVE = list(range(7, 26))
STAGE3_MASK = list(range(7, 26))
STAGE3_REMOVE = []
def load_data(path: str): def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f: with open(path, 'r', encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
@ -24,220 +17,6 @@ def load_data(path: str):
data_list.append(value) data_list.append(value)
return data_list return data_list
def load_minamo_gan_data(data: list):
res = list()
for one in data:
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
return res
def apply_curriculum_remove(
maps: torch.Tensor,
remove_classes: List[int], # 要移除的类别索引
):
C, H, W = maps.shape
device = maps.device
removed_maps = maps.clone()
remove_mask = removed_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0
removed_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0
removed_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地”
return removed_maps.to(device)
def apply_curriculum_mask(
maps: torch.Tensor, # [C, H, W]
mask_classes: List[int], # 要遮挡的类别索引
remove_classes: List[int], # 要移除的类别索引
mask_ratio: float # 遮挡比例 0~1
) -> torch.Tensor:
C, H, W = maps.shape
masked_maps = maps.clone()
# Step 1: 移除不需要的类别(全设为 0 类)
if remove_classes:
remove_mask = masked_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0
masked_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0
masked_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地”
removed_maps = masked_maps.clone()
# Step 2: 对指定类别随机遮挡
for cls in mask_classes:
cls_mask = masked_maps[cls] > 0 # 目标类别的像素布尔掩码 [H, W]
indices = cls_mask.nonzero(as_tuple=False) # 所有该类像素坐标
num_mask = int(len(indices) * mask_ratio)
if num_mask > 0:
selected = indices[torch.randperm(len(indices))[:num_mask]]
masked_maps[cls, selected[:, 0], selected[:, 1]] = 0
masked_maps[0, selected[:, 0], selected[:, 1]] = 1 # 置为“空地”
return removed_maps, masked_maps
def apply_curriculum_wall_mask(
maps: torch.Tensor, # [C, H, W]
mask_classes: List[int], # 要遮挡的类别索引
remove_classes: List[int], # 要移除的类别索引
mask_ratio: float # 遮挡比例 0~1
) -> torch.Tensor:
C, H, W = maps.shape
masked_maps = maps.clone()
# Step 1: 移除不需要的类别(全设为 0 类)
if remove_classes:
remove_mask = masked_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0
masked_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0
masked_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地”
removed_maps = masked_maps.clone()
area = H * W * mask_ratio
l = math.floor(math.sqrt(area))
nx = random.randint(0, W - l)
ny = random.randint(0, H - l)
masked_maps[mask_classes, nx:nx+l, ny:ny+l] = 0
masked_maps[0, nx:nx+l, ny:ny+l] = 1
return removed_maps, masked_maps
class GinkaWGANDataset(Dataset):
def __init__(self, data_path: str, device):
self.data = load_data(data_path) # 自定义数据加载函数
self.device = device
self.train_stage = 1
self.mask_ratio1 = 0.1
self.mask_ratio2 = 0.1
self.mask_ratio3 = 0.1
def __len__(self):
return len(self.data)
def handle_stage1(self, target, tag_cond, val_cond):
# 课程学习第一阶段,蒙版填充
removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
rand = torch.rand(32, 32, 32, device=target.device)
return {
"rand": rand,
"real0": removed1,
"real1": removed1,
"masked1": masked1,
"real2": removed2,
"masked2": masked2,
"real3": removed3,
"masked3": masked3,
"tag_cond": tag_cond,
"val_cond": val_cond
}
def handle_stage2(self, target, tag_cond, val_cond):
# 课程学习第二阶段,完全随机蒙版
removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
# 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1))
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1))
rand = torch.rand(32, 32, 32, device=target.device)
return {
"rand": rand,
"real0": removed1,
"real1": removed1,
"masked1": masked1,
"real2": removed2,
"masked2": masked2,
"real3": removed3,
"masked3": masked3,
"tag_cond": tag_cond,
"val_cond": val_cond
}
def handle_stage3(self, target, tag_cond, val_cond):
# 第三阶段,联合生成,输入随机蒙版
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
rand = torch.rand(32, 32, 32, device=target.device)
return {
"rand": rand,
"real0": removed1,
"real1": removed1,
"masked1": removed1,
"real2": removed2,
"masked2": torch.zeros_like(target),
"real3": removed3,
"masked3": torch.zeros_like(target),
"tag_cond": tag_cond,
"val_cond": val_cond
}
def handle_stage4(self, target, tag_cond, val_cond):
# 第四阶段,完全随机输入
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
rand = torch.rand(32, 32, 32, device=target.device)
return {
"rand": rand,
"real0": removed1,
"real1": removed1,
"masked1": rand,
"real2": removed2,
"masked2": torch.zeros_like(target),
"real3": removed3,
"masked3": torch.zeros_like(target),
"tag_cond": tag_cond,
"val_cond": val_cond
}
def __getitem__(self, idx):
item = self.data[idx]
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
C, H, W = target.shape
tag_cond = torch.FloatTensor(item['tag'])
val_cond = torch.FloatTensor(item['val'])
val_cond[9] = val_cond[9] / H / W
val_cond[10] = val_cond[10] / H / W
if self.train_stage == 1:
return self.handle_stage1(target, tag_cond, val_cond)
elif self.train_stage == 2:
return self.handle_stage2(target, tag_cond, val_cond)
elif self.train_stage == 3:
return self.handle_stage3(target, tag_cond, val_cond)
elif self.train_stage == 4:
return self.handle_stage4(target, tag_cond, val_cond)
raise RuntimeError(f"Invalid train stage: {self.train_stage}")
class GinkaRNNDataset(Dataset):
def __init__(self, data_path: str, device):
self.data = load_data(data_path) # 自定义数据加载函数
self.device = device
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
target = torch.LongTensor(item['map']) # [H, W]
H, W = target.shape
tag_cond = torch.FloatTensor(item['tag'])
val_cond = torch.FloatTensor(item['val'])
return {
"tag_cond": tag_cond,
"val_cond": val_cond,
"target_map": target
}
class GinkaMaskGITDataset(Dataset): class GinkaMaskGITDataset(Dataset):
def __init__(self, data_path: str, device): def __init__(self, data_path: str, device):

View File

@ -1,37 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TransformerConv
from torch_geometric.utils import grid
from ..common.cond import ConditionInjector
# 考虑使用 GCN 作为生成器主路径,暂时先留着
class GCNBlock(nn.Module):
def __init__(self, feats: tuple[int, int, int]):
super().__init__()
self.conv1 = GCNConv(feats[0], feats[1])
self.conv2 = GCNConv(feats[1], feats[2])
self.norm1 = nn.LayerNorm(feats[1])
self.norm2 = nn.LayerNorm(feats[2])
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.elu(self.norm1(x))
x = self.conv2(x, edge_index)
x = F.elu(self.norm2(x))
return x
class GinkaGCNEncoder(nn.Module):
def __init__(self):
super().__init__()
class GinkaGCNDecoder(nn.Module):
def __init__(self):
super().__init__()
class GinkaGCNModel(nn.Module):
def __init__(self):
super().__init__()

View File

@ -1,58 +0,0 @@
import torch
import torch.nn as nn
from ..common.common import ConvFusionModule
from ..common.cond import ConditionInjector
from .unet import GinkaEncoderPath, GinkaDecoderPath
class RandomInputHead(nn.Module):
def __init__(self):
super().__init__()
self.enc = GinkaEncoderPath(32, 32)
self.dec = GinkaDecoderPath(32)
self.out_conv = nn.Sequential(
nn.AdaptiveMaxPool2d((15, 15)),
nn.Conv2d(32, 64, 3, padding=0),
nn.InstanceNorm2d(64),
nn.GELU(),
nn.Conv2d(64, 32, 1),
)
def forward(self, x, cond):
x1, x2, x3, x4 = self.enc(x, cond)
x = self.dec(x1, x2, x3, x4, cond)
x = self.out_conv(x)
return x
class InputUpsample(nn.Module):
def __init__(self, in_ch, hidden_ch=64, out_ch=64):
super().__init__()
self.net = nn.Sequential(
ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13),
nn.Upsample(scale_factor=2, mode='nearest'), # 13x13 → 26x26
ConvFusionModule(hidden_ch, hidden_ch, hidden_ch, 26, 26),
nn.Upsample(size=(32, 32), mode='nearest'), # 26x26 → 32x32
ConvFusionModule(hidden_ch, hidden_ch, out_ch, 32, 32),
)
def forward(self, x): # [B, C, 13, 13]
x = self.net(x) # [B, C, 32, 32]
return x
class GinkaInput(nn.Module):
def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)):
super().__init__()
self.out_size = out_size
self.upsample = InputUpsample(in_ch, in_ch*2, out_ch)
self.enc = ConvFusionModule(out_ch, out_ch*2, out_ch, out_size[0], out_size[1])
self.inject1 = ConditionInjector(256, out_ch)
self.inject2 = ConditionInjector(256, out_ch)
def forward(self, x, cond):
x = self.upsample(x)
x = self.inject1(x, cond)
x = self.enc(x)
x = self.inject2(x, cond)
return x

View File

@ -1,420 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
CLASS_NUM = 32
ILLEGAL_MAX_NUM = 30
STAGE_CHANGEABLE = [
[],
[0, 1, 2, 29, 30],
[3, 4, 5, 6, 26, 27, 28],
list(range(7, 26))
]
STAGE_ALLOWED = [
[],
STAGE_CHANGEABLE[1],
[*STAGE_CHANGEABLE[1], *STAGE_CHANGEABLE[2]],
[*STAGE_CHANGEABLE[1], *STAGE_CHANGEABLE[2], *STAGE_CHANGEABLE[3]]
]
DENSITY_MAP = [
[1, *list(range(3, 30))],
[1],
[2],
[3, 4, 5, 6],
[26, 27, 28],
list(range(7, 26)),
list(range(10, 19)),
[19, 20, 21, 22],
[7, 8, 9],
[23, 24, 25],
[29, 30]
]
DENSITY_WEIGHTS = [
1,
1.5,
0.5,
5,
4,
3,
3,
3,
5,
10,
20
]
DENSITY_STAGE = [
[],
[1, 2],
[1, 2, 3, 4],
list(range(0, 10))
]
def get_not_allowed(classes: list[int], include_illegal=False):
res = list()
for num in range(0, CLASS_NUM):
if not num in classes:
if num > ILLEGAL_MAX_NUM:
if include_illegal:
res.append(num)
else:
res.append(num)
return res
def inner_constraint_loss(pred: torch.Tensor, allowed=list(range(0, 30))):
"""限定内部允许出现的图块种类
Args:
pred (torch.Tensor): 模型输出的概率分布 [B, C, H, W]
allowed (list, optional): 在地图中部除最外圈允许出现的图块种类
"""
B, C, H, W = pred.shape
# 创建内部 mask [H, W]
mask = torch.ones((H, W), dtype=torch.bool, device=pred.device)
mask[0, :] = False # 第一行
mask[-1, :] = False # 最后一行
mask[:, 0] = False # 第一列
mask[:, -1] = False # 最后一列
# 提取所有允许和不允许类别的概率和 [B, H, W]
unallowed_probs = pred[:, get_not_allowed(allowed, include_illegal=True), :, :].sum(dim=1)
# 获取外圈区域允许类别的概率 [B, N_pixels]
inner_unallowed = unallowed_probs[:, mask]
target = torch.zeros_like(inner_unallowed)
loss_unallowed = F.mse_loss(inner_unallowed, target)
return loss_unallowed
def _create_distance_kernel(size):
"""生成一个环状衰减核"""
y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij')
center = size // 2
dist = torch.sqrt((x - center)**2 + (y - center)**2)
kernel = 1 / (dist + 1)
kernel /= kernel.sum() # 归一化
return kernel.unsqueeze(0).unsqueeze(0) # [1,1,H,W]
def entrance_constraint_loss(
pred: torch.Tensor,
entrance_classes=[29, 30],
min_distance=9,
presence_threshold=0.8,
lambda_presence=1.0,
lambda_spacing=0.5
):
"""
入口约束损失函数
参数:
pred: 模型输出的概率分布 [B, C, H, W]
entrance_classes: 入口类别列表
min_distance: 最小间隔距离对应卷积核尺寸
presence_threshold: 存在性概率阈值
lambda_presence: 存在性损失权重
lambda_spacing: 间距约束权重
返回:
total_loss: 综合损失值
"""
B, C, H, W = pred.shape
entrance_probs = pred[:, entrance_classes, :, :].sum(dim=1) # [B, H, W]
# 计算存在性损失:鼓励至少有一个高置信度入口
max_per_sample = entrance_probs.view(B, -1).max(dim=1)[0] # [B, H*W] -> [B, 1]
presence_loss = F.relu(presence_threshold - max_per_sample).mean()
# 生成空间权重掩码(中心衰减)
y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
center_weight = 1 - torch.sqrt(((x-W//2)/W*2)**2 + ((y-H//2)/H*2)**2)
center_weight = center_weight.clamp(0,1).to(pred.device) # [H,W]
# 概率密度感知的间距计算
kernel = _create_distance_kernel(min_distance) # 自定义函数生成权重核
kernel = kernel.to(pred.device)
density_map = F.conv2d(entrance_probs.unsqueeze(1), kernel, padding=min_distance-1)
spacing_loss = density_map.mean()
# 区域加权综合损失
total_loss = (
lambda_presence * presence_loss +
lambda_spacing * (spacing_loss * center_weight).mean()
)
return total_loss
def input_head_illegal_loss(input_map, allowed_classes=[0, 1, 2]):
C = input_map.shape[1]
unallowed = get_not_allowed(allowed_classes, include_illegal=True)
illegal = input_map[:, unallowed, :, :]
penalty = F.l1_loss(illegal, torch.zeros_like(illegal, device=illegal.device))
return penalty
def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=[1, 2]):
wall_prob = input_map[:, wall_class] # [B, H, W]
wall_ratio = wall_prob.mean() # 计算平均墙体占比
wall_penalty = torch.clamp(wall_ratio - max_wall_ratio, min=0.0) # 超过则惩罚
return wall_penalty
def compute_multi_density_loss(probs, target_densities, tile_list):
"""
pred: [B, C, H, W]
target_densities: [B, N] - N 个目标类别密度
class_indices: [N] - 对应类别通道索引
"""
losses = []
for i, classes in enumerate(DENSITY_MAP):
class_map = probs[:, classes, :, :]
pred_density = torch.mean(class_map, dim=(1, 2, 3))
if i in tile_list:
loss = F.mse_loss(pred_density, target_densities[:, i])
losses.append(loss * DENSITY_WEIGHTS[i])
return sum(losses)
# 对图像数据进行插值
def interpolate_data(real_data, fake_data, epsilon):
return epsilon * real_data + (1 - epsilon) * fake_data
# 对节点特征进行插值,但保持边连接关系不变
def interpolate_graph_features(real_graph, fake_graph, epsilon=0.5):
# 插值节点特征
x_real, x_fake = real_graph.x, fake_graph.x
x_interp = epsilon * x_real + (1 - epsilon) * x_fake
# 保持边连接关系和边特征不变
edge_index_interp = real_graph.edge_index # 保持边连接关系
edge_attr_interp = real_graph.edge_attr # 如果有边特征,保持不变
return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp)
def js_divergence(p, q, eps=1e-6, softmax=False):
if softmax:
p = F.softmax(p, dim=1)
q = F.softmax(q, dim=1)
# softmax 后变成概率分布
m = 0.5 * (p + q)
# log_softmax 以供 kl_div 使用
log_p = torch.log(p + eps)
log_q = torch.log(q + eps)
log_m = torch.log(m + eps)
kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m)
kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m)
return torch.log1p(0.5 * (kl_pm + kl_qm))
def immutable_penalty_loss(
pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
) -> torch.Tensor:
"""
惩罚模型修改不可更改区域的损失
Args:
input: 模型输出 [B, C, H, W]概率分布 (softmax )
target: 原始输入图 [B, C, H, W]概率分布 (softmax )
modifiable_classes: 允许被修改的类别列表
"""
not_allowed = get_not_allowed(modifiable_classes, include_illegal=True)
input_mask = pred[:, not_allowed, :, :]
with torch.no_grad():
target_mask = torch.argmax(input[:, not_allowed, :, :], dim=1)
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
# 差异区域(模型试图改变的地方)
penalty = torch.clamp(F.cross_entropy(input_mask, target_mask) - 0.2, min=0)
return penalty
def modifiable_penalty_loss(
probs: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
) -> torch.Tensor:
target_modifiable = input[:, modifiable_classes, :, :]
pred_modifiable = probs[:, modifiable_classes, :, :]
existed = torch.clamp(target_modifiable - pred_modifiable, min=0.0, max=1.0)
penalty = F.mse_loss(existed, torch.zeros_like(existed, device=existed.device))
return penalty
def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]):
not_allowed = get_not_allowed(legal_classes, include_illegal=True)
input_mask = pred[:, not_allowed, :, :]
target = torch.zeros_like(input_mask)
penalty = F.cross_entropy(input_mask, target)
return penalty
class WGANGinkaLoss:
def __init__(self, lambda_gp=100, weight=[1, 0.4, 20, 0.2, 0.2, 0.05, 0.4]):
# weight:
# 1. 判别器损失及图块维持损失(可修改部分的已有内容不可修改)
# 2. CE 损失
# 3. 不可修改类型损失和非法图块损失
# 4. 图块类型损失
# 5. 入口存在性损失
# 6. 多样性损失
# 7. 密度损失
self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight
def compute_gradient_penalty(self, critic, stage, real_data, fake_data, tag_cond, val_cond):
# 进行插值
batch_size = real_data.size(0)
epsilon_data = torch.rand(batch_size, 1, 1, 1, device=real_data.device)
interp_data = interpolate_data(real_data, fake_data, epsilon_data).to(real_data.device)
# 对图像进行反向传播并计算梯度
interp_data.requires_grad_()
d_score = critic(interp_data, stage, tag_cond, val_cond)
# 计算梯度
grad = torch.autograd.grad(
outputs=d_score, inputs=interp_data,
grad_outputs=torch.ones_like(d_score),
create_graph=True, retain_graph=True, only_inputs=True
)[0]
# 计算梯度的 L2 范数
grad_norm = grad.reshape(batch_size, -1).norm(2, dim=1)
# 计算梯度惩罚项
gp_loss = ((grad_norm - 1.0) ** 2).mean()
# print(grad_norm_topo.mean().item(), grad_norm_vis.mean().item())
return gp_loss
def discriminator_loss(
self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor,
tag_cond: torch.Tensor, val_cond: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
""" 判别器损失函数 """
fake_data = F.softmax(fake_data, dim=1)
real_scores = critic(real_data, stage, tag_cond, val_cond)
fake_scores = critic(fake_data, stage, tag_cond, val_cond)
# Wasserstein 距离
d_loss = fake_scores.mean() - real_scores.mean()
grad_loss = self.compute_gradient_penalty(critic, stage, real_data, fake_data, tag_cond, val_cond)
total_loss = d_loss + self.lambda_gp * grad_loss
return total_loss, d_loss
def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input, tag_cond, val_cond) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" 生成器损失函数 """
probs_fake = F.softmax(fake, dim=1)
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores)
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
constraint_loss = inner_constraint_loss(probs_fake)
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
fake_a, fake_b = fake.chunk(2, dim=0)
losses = [
minamo_loss * self.weight[0],
ce_loss * self.weight[1],
immutable_loss * self.weight[2],
constraint_loss * self.weight[3],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
density_loss * self.weight[6],
]
if stage == 1:
# 第一个阶段检查入口存在性
entrance_loss = entrance_constraint_loss(probs_fake)
losses.append(entrance_loss * self.weight[4])
return sum(losses), ce_loss
def generator_loss_total(self, critic, stage, fake, tag_cond, val_cond) -> torch.Tensor:
probs_fake = F.softmax(fake, dim=1)
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores)
illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage])
constraint_loss = inner_constraint_loss(probs_fake)
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
fake_a, fake_b = fake.chunk(2, dim=0)
losses = [
minamo_loss * self.weight[0],
illegal_loss * self.weight[2],
constraint_loss * self.weight[3],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
density_loss * self.weight[6],
]
if stage == 1:
# 第一个阶段检查入口存在性
entrance_loss = entrance_constraint_loss(probs_fake)
losses.append(entrance_loss * self.weight[4])
return sum(losses)
def generator_loss_total_with_input(self, critic, stage, fake, input, tag_cond, val_cond) -> torch.Tensor:
probs_fake = F.softmax(fake, dim=1)
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores)
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
constraint_loss = inner_constraint_loss(probs_fake)
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
fake_a, fake_b = fake.chunk(2, dim=0)
losses = [
minamo_loss * self.weight[0],
immutable_loss * self.weight[2],
constraint_loss * self.weight[3],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
density_loss * self.weight[6],
]
if stage == 1:
# 第一个阶段检查入口存在性
entrance_loss = entrance_constraint_loss(probs_fake)
losses.append(entrance_loss * self.weight[4])
return sum(losses)
def generator_input_head_loss(self, critic, probs: torch.Tensor, tag_cond, val_cond) -> torch.Tensor:
head_scores = -torch.mean(critic(probs, 0, tag_cond, val_cond))
probs_a, probs_b = probs.chunk(2, dim=0)
losses = [
head_scores,
input_head_illegal_loss(probs) * 50,
-js_divergence(probs_a, probs_b, softmax=False) * 0.5
]
return sum(losses)
class RNNGinkaLoss:
def __init__(self, num_classes, device):
self.num_classes = num_classes
weight = torch.ones(self.num_classes)
weight[0] = 0.3
weight[1] = 0.5
self.weight = weight.to(device)
pass
def rnn_loss(self, fake, target):
"""
fake: [B, C, H, W]
target: [B, H, W]
"""
target = F.one_hot(target, num_classes=self.num_classes).float().permute(0, 3, 1, 2)
return F.cross_entropy(fake, target, label_smoothing=0.1)

View File

@ -1,66 +0,0 @@
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet import GinkaUNet
from .output import GinkaOutput
from .input import GinkaInput, RandomInputHead
from ..common.cond import ConditionEncoder
def print_memory(tag=""):
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
class GinkaModel(nn.Module):
def __init__(self, base_ch=64, out_ch=32):
"""Ginka Model 模型定义部分
"""
super().__init__()
self.head = RandomInputHead()
self.cond = ConditionEncoder(64, 16, 256, 256)
self.input = GinkaInput(32, 64, (13, 13), (32, 32))
self.unet = GinkaUNet(64, base_ch, base_ch)
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
def forward(self, x, stage, tag_cond, val_cond):
B, D = tag_cond.shape
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device)
cond = self.cond(tag_cond, val_cond, stage_tensor)
if stage == 0:
x = self.head(x, cond)
else:
x = self.input(x, cond)
x = self.unet(x, cond)
x = self.output(x, stage, cond)
return x
# 检查显存占用
if __name__ == "__main__":
input = torch.rand(1, 32, 32, 32).cuda()
tag = torch.rand(1, 64).cuda()
val = torch.rand(1, 16).cuda()
# 初始化模型
model = GinkaModel().cuda()
print_memory("初始化后")
# 前向传播
start = time.perf_counter()
fake0 = model(input, 0, tag, val)
fake1 = model(F.softmax(fake0, dim=1), 1, tag, val)
fake2 = model(F.softmax(fake1, dim=1), 1, tag, val)
fake3 = model(F.softmax(fake2, dim=1), 1, tag, val)
end = time.perf_counter()
print_memory("前向传播后")
print(f"推理耗时: {end - start}")
print(f"输入形状: feat={input.shape}")
print(f"输出形状: output={fake3.shape}")
print(f"Random parameters: {sum(p.numel() for p in model.head.parameters())}")
print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}")
print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -1,45 +0,0 @@
import torch
import torch.nn as nn
from ..common.common import ConvFusionModule
from ..common.cond import ConditionInjector
class StageHead(nn.Module):
def __init__(self, in_ch, out_ch, out_size=(13, 13)):
super().__init__()
self.dec1 = ConvFusionModule(in_ch, in_ch*2, in_ch*2, 32, 32)
self.dec2 = ConvFusionModule(in_ch*2, in_ch*2, in_ch*2, 32, 32)
self.pool = nn.Sequential(
ConvFusionModule(in_ch*2, in_ch*2, in_ch*2, 32, 32),
ConvFusionModule(in_ch*2, in_ch*2, in_ch, 32, 32),
nn.AdaptiveMaxPool2d(out_size),
nn.Conv2d(in_ch, out_ch, 1)
)
self.inject1 = ConditionInjector(256, in_ch*2)
self.inject2 = ConditionInjector(256, in_ch*2)
def forward(self, x, cond):
x = self.dec1(x)
x = self.inject1(x, cond)
x = self.dec2(x)
x = self.inject2(x, cond)
x = self.pool(x)
return x
class GinkaOutput(nn.Module):
def __init__(self, in_ch=64, out_ch=32, out_size=(13, 13)):
super().__init__()
self.head1 = StageHead(in_ch, out_ch, out_size)
self.head2 = StageHead(in_ch, out_ch, out_size)
self.head3 = StageHead(in_ch, out_ch, out_size)
def forward(self, x, stage, cond):
if stage == 1:
x = self.head1(x, cond)
elif stage == 2:
x = self.head2(x, cond)
elif stage == 3:
x = self.head3(x, cond)
else:
raise RuntimeError("Unknown generate stage.")
return x

View File

@ -1,258 +0,0 @@
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
class RNNConditionEncoder(nn.Module):
def __init__(self, val_dim=16, output_dim=256, width=13, height=13):
super().__init__()
# 条件编码
self.val_fc = nn.Sequential(
nn.Linear(val_dim, output_dim * 4),
nn.LayerNorm(output_dim * 4),
nn.GELU(),
)
self.fusion = nn.Sequential(
nn.Linear(output_dim * 4, output_dim)
)
def forward(self, val_cond: torch.Tensor):
val_hidden = self.val_fc(val_cond)
return self.fusion(val_hidden)
class GinkaMapPatch(nn.Module):
def __init__(self, tile_classes=32, width=13, height=13):
super().__init__()
# 地图局部卷积,用于捕获局部结构信息
self.width = width
self.height = height
self.tile_classes = 32
self.patch_cnn = nn.Sequential(
nn.Conv2d(tile_classes + 1, 256, 3, padding=1),
nn.Dropout(0.2),
nn.BatchNorm2d(256),
nn.GELU(),
nn.Conv2d(256, 512, 3),
nn.Dropout(0.2),
nn.BatchNorm2d(512),
nn.GELU(),
nn.Flatten()
)
self.fc = nn.Linear(512 * 3 * 3, 256)
def forward(self, map: torch.Tensor, x: int, y: int):
"""
map: [B, H, W]
"""
B, H, W = map.shape
mask = torch.zeros([B, 5, 5]).to(map.device)
result = torch.zeros([B, 5, 5], dtype=torch.long).to(map.device)
left = x - 2 if x >= 2 else 0
right = x + 3 if x < self.width - 2 else self.width
top = y - 4 if y >= 4 else 0
bottom = y + 1
res_left = left - (x - 2)
res_right = right - (x + 3) + 5
res_top = top - (y - 4)
res_bottom = 5
result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right]
# 没画到的地方要置为 0
result[:, 4, 2] = 0
result[:, 4, 3] = 0
result[:, 4, 4] = 0
mask[:, res_top:res_bottom, res_left:res_right] = 1
mask[:, 4, 2] = 0
mask[:, 4, 3] = 0
mask[:, 4, 4] = 0
masked_result = torch.zeros([B, self.tile_classes + 1, 5, 5]).to(map.device)
masked_result[:, 0:32] = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
masked_result[:, 32] = mask
feat = self.patch_cnn(masked_result)
feat = self.fc(feat)
return feat
class GinkaTileEmbedding(nn.Module):
def __init__(self, tile_classes=32, embed_dim=256):
super().__init__()
# 图块编码,上一次画的图块
self.embedding = nn.Embedding(tile_classes, embed_dim)
def forward(self, tile: torch.Tensor):
return self.embedding(tile)
class GinkaPosEmbedding(nn.Module):
def __init__(self, width=13, height=13, embed_dim=256):
super().__init__()
# 位置编码
self.width = width
self.height = height
self.row_embedding = nn.Embedding(width, embed_dim)
self.col_embedding = nn.Embedding(height, embed_dim)
def forward(self, x: torch.Tensor, y: torch.Tensor):
row = self.row_embedding(x).squeeze(1)
col = self.col_embedding(y).squeeze(1)
return row, col
class GinkaInputFusion(nn.Module):
def __init__(self, d_model=256):
super().__init__()
# 使用 Transformer 进行信息整合
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=d_model, nhead=2, dim_feedforward=d_model*2, batch_first=True,
dropout=0.2
),
num_layers=4
)
def forward(
self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,
row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor
):
"""
tile_embed: [B, 256]
cond_vec: [B, 256]
row_embed: [B, 256]
col_embed: [B, 256]
patch_vec: [B, 256]
"""
vec = torch.stack([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1)
feat = self.transformer(vec)
return feat[:, 0]
class GinkaRNN(nn.Module):
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512):
super().__init__()
# GRU
self.gru = nn.GRUCell(input_dim, hidden_dim)
self.drop = nn.Dropout(0.2)
self.fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, tile_classes)
)
def forward(self, feat_fusion: torch.Tensor, hidden: torch.Tensor):
"""
feat_fusion: [B, input_dim]
hidden: [B, hidden_dim]
"""
hidden = self.drop(self.gru(feat_fusion, hidden))
logits = self.fc(hidden)
return logits, hidden
class GinkaRNNModel(nn.Module):
def __init__(self, device: torch.device, start_tile=31, width=13, height=13):
super().__init__()
self.device = device
self.width = width
self.height = height
self.start_tile = start_tile
self.rnn_hidden = 512
self.tile_classes = 32
# 模型结构
self.cond = RNNConditionEncoder()
self.tile_embedding = GinkaTileEmbedding(tile_classes=self.tile_classes)
self.pos_embedding = GinkaPosEmbedding()
self.map_patch = GinkaMapPatch(tile_classes=self.tile_classes)
self.feat_fusion = GinkaInputFusion()
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self_probility=0):
"""
val_cond: [B, val_dim]
target_map: [B, H, W]
use_self: 是否使用自己生成的上一步结果执行下一步
"""
B, C = val_cond.shape
# 张量声明
now_tile = torch.LongTensor([self.start_tile]).to(self.device).expand(B)
map = torch.zeros([B, self.height, self.width], dtype=torch.int32).to(self.device)
output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device)
hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device)
# 条件编码,全局,所以只用一次
cond = self.cond(val_cond)
for y in range(0, self.height):
for x in range(0, self.width):
x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1)
y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1)
# 位置编码、图块编码、地图局部编码
tile_embed = self.tile_embedding(now_tile)
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
use_self = random.random() < use_self_probility
map_patch = self.map_patch(map if use_self else target_map, x, y)
# 编码特征融合
feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch)
# RNN 输出
logits, h = self.rnn(feat, hidden)
# 处理输出
output_logits[:, y, x] = logits[:]
hidden = h
tile_id = torch.argmax(logits, dim=1).detach()
map[:, y, x] = tile_id[:]
now_tile = tile_id if use_self else target_map[:, y, x].detach()
return output_logits.permute(0, 3, 1, 2), map
def print_memory(device, tag=""):
if torch.cuda.is_available():
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated(device) / 1024**2:.2f} MB")
else:
print("当前设备不支持 cuda.")
if __name__ == "__main__":
device = torch.device("cpu")
input = torch.randint(0, 32, [1, 13, 13]).to(device)
cond = torch.rand(1, 16).to(device)
# 初始化模型
model = GinkaRNNModel("cpu").to(device)
print_memory("初始化后")
# 前向传播
start = time.perf_counter()
fake_logits, fake_map = model(cond, input, False)
end = time.perf_counter()
print_memory("前向传播后")
print(f"推理耗时: {end - start}")
print(f"输出形状: fake_logits={fake_logits.shape}, fake_map={fake_map.shape}")
print(f"Condition Encoder parameters: {sum(p.numel() for p in model.cond.parameters())}")
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}")
print(f"Map Patch parameters: {sum(p.numel() for p in model.map_patch.parameters())}")
print(f"Feature Fusion parameters: {sum(p.numel() for p in model.feat_fusion.parameters())}")
print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -1,188 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from shared.attention import ChannelAttention
from ..common.common import GCNBlock, TransformerGCNBlock, DoubleConvBlock, ConvFusionModule
from ..common.cond import ConditionInjector
class GinkaTransformerEncoder(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6):
super().__init__()
in_dim = in_dim // token_size
hidden_dim = hidden_dim // token_size
out_dim = out_dim // token_size
self.embedding = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.LayerNorm(hidden_dim)
)
self.pos_embedding = nn.Parameter(torch.randn(1, token_size, hidden_dim))
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=ff_dim, batch_first=True),
num_layers=num_layers
)
self.fc = nn.Sequential(
nn.Linear(hidden_dim, out_dim),
nn.LayerNorm(out_dim)
)
def forward(self, x):
# 输入 [B, L, in_dim]
# 输出 [B, L, out_dim]
x = self.embedding(x) # [B, L, hidden_dim]
x = x + self.pos_embedding # [B, L, hidden_dim]
x = self.transformer(x) # [B, L, hidden_dim]
x = self.fc(x) # [B, L, out_dim]
return x
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, attn=True):
super().__init__()
self.conv = DoubleConvBlock([in_ch, out_ch, out_ch])
# self.conv = nn.Sequential(
# nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
# nn.InstanceNorm2d(out_ch),
# nn.ELU(),
# nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
# nn.InstanceNorm2d(out_ch),
# )
# if attn:
# self.conv.append(ChannelAttention(out_ch))
# self.conv.append(nn.ELU())
def forward(self, x):
return self.conv(x)
class FusionModule(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate')
def forward(self, x1, x2):
x = torch.cat([x1, x2], dim=1)
x = self.conv(x)
return x
class GinkaUNetInput(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, cond):
x = self.conv(x)
x = self.inject(x, cond)
return x
class GinkaEncoder(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.pool = nn.MaxPool2d(2)
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, cond):
x = self.pool(x)
x = self.conv(x)
x = self.inject(x, cond)
return x
class GinkaUpSample(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
nn.InstanceNorm2d(out_ch),
nn.GELU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(out_ch),
nn.GELU()
)
def forward(self, x):
return self.conv(x)
class GinkaDecoder(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
self.fusion = nn.Conv2d(in_ch, in_ch, 1)
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, feat, cond):
x = self.upsample(x)
x = torch.cat([x, feat], dim=1)
x = self.fusion(x)
x = self.conv(x)
x = self.inject(x, cond)
return x
class GinkaBottleneck(nn.Module):
def __init__(self, module_ch, w, h):
super().__init__()
# self.transformer = GinkaTransformerEncoder(
# in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
# token_size=16, ff_dim=1024, num_layers=4
# )
# self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
# self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
self.conv = ConvFusionModule(module_ch, module_ch, module_ch, w, h)
self.inject = ConditionInjector(256, module_ch)
def forward(self, x, cond):
# x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
# x1 = self.transformer(x1)
# x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
x = self.conv(x)
x = self.inject(x, cond)
return x
class GinkaEncoderPath(nn.Module):
def __init__(self, in_ch, base_ch):
super().__init__()
self.down1 = GinkaUNetInput(in_ch, base_ch, 32, 32)
self.down2 = GinkaEncoder(base_ch, base_ch*2, 16, 16)
self.down3 = GinkaEncoder(base_ch*2, base_ch*4, 8, 8)
self.down4 = GinkaEncoder(base_ch*4, base_ch*8, 4, 4)
def forward(self, x, cond):
x1 = self.down1(x, cond) # [B, 64, 32, 32]
x2 = self.down2(x1, cond) # [B, 128, 16, 16]
x3 = self.down3(x2, cond) # [B, 256, 8, 8]
x4 = self.down4(x3, cond) # [B, 512, 4, 4]
return x1, x2, x3, x4
class GinkaDecoderPath(nn.Module):
def __init__(self, base_ch):
super().__init__()
self.up1 = GinkaDecoder(base_ch*8, base_ch*4, 8, 8)
self.up2 = GinkaDecoder(base_ch*4, base_ch*2, 16, 16)
self.up3 = GinkaDecoder(base_ch*2, base_ch, 32, 32)
def forward(self, x1, x2, x3, x4, cond):
x = self.up1(x4, x3, cond) # [B, 256, 8, 8]
x = self.up2(x, x2, cond) # [B, 128, 16, 16]
x = self.up3(x, x1, cond) # [B, 64, 32, 32]
return x
class GinkaUNet(nn.Module):
def __init__(self, in_ch=32, base_ch=32, out_ch=32):
"""Ginka Model UNet 部分
"""
super().__init__()
self.enc = GinkaEncoderPath(in_ch, base_ch)
self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
self.dec = GinkaDecoderPath(base_ch)
self.final = ConvFusionModule(base_ch, base_ch, out_ch, 32, 32)
def forward(self, x, cond):
x1, x2, x3, x4 = self.enc(x, cond)
x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4]
x = self.dec(x1, x2, x3, x4, cond)
x = self.final(x) # [B, 32, 32, 32]
return x

View File

@ -1,16 +0,0 @@
import argparse
import torch
def to_deployment(path: str, output: str):
state = torch.load(path)
torch.save({
"model_state": state["model_state"]
}, output)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, default="result/ginka.pth")
parser.add_argument("--output", type=str, default="result/ginka_deploy.pth")
args = parser.parse_args()
to_deployment(args.input, args.output)

View File

@ -89,9 +89,9 @@ def train():
# 用于生成图片 # 用于生成图片
tile_dict = dict() tile_dict = dict()
for file in os.listdir('tiles2'): for file in os.listdir('tiles'):
name = os.path.splitext(file)[0] name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED) tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
# 接续训练 # 接续训练
if args.resume: if args.resume:

View File

@ -1,188 +0,0 @@
import argparse
import os
import sys
from datetime import datetime
import torch
import torch.optim as optim
import cv2
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .generator.rnn import GinkaRNNModel
from .dataset import GinkaRNNDataset
from .generator.loss import RNNGinkaLoss
from shared.image import matrix_to_image_cv
# 手工标注标签定义(暂时不用):
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风
# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门
# 24. 中心对称, 25. 部分对称, 26. 鱼骨
# 自动标注标签定义(暂时不用):
# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊
# 32. 平面塔, 33. 转换塔, 34. 道具塔
# 标量值定义:
# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块
# 1. 墙体密度,墙壁/地图面积
# 2. 装饰密度,装饰数量/地图面积
# 3. 门密度,门数量/地图面积
# 4. 怪物密度,怪物数量/地图面积
# 5. 资源密度,资源数量/地图面积
# 6. 宝石密度,宝石数量/地图面积
# 7. 血瓶密度,血瓶数量/地图面积
# 8. 钥匙密度,钥匙数量/地图面积
# 9. 道具密度,道具数量/地图面积
# 10. 入口数量
# 11. 机关门数量
# 12. 咸鱼门数量(多层咸鱼门只算一个)
# 图块定义:
# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地),
# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门
# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启
# 10-12. 三种等级的红宝石
# 13-15. 三种等级的蓝宝石
# 16-18. 三种等级的绿宝石
# 19-22. 四种等级的血瓶
# 23-25. 三种等级的道具
# 26-28. 三种等级的怪物
# 29. 入口,不区分楼梯和箭头
BATCH_SIZE = 96
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/rnn", exist_ok=True)
os.makedirs("result/ginka_rnn_img", exist_ok=True)
disable_tqdm = not sys.stdout.isatty()
def gt_prob(epoch: int, max_epoch: int) -> float:
progress = epoch / max_epoch
return 0.1 + 0.9 * progress
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state_ginka", type=str, default="result/rnn/ginka-100.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default="ginka-eval.json")
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--checkpoint", type=int, default=5)
parser.add_argument("--load_optim", type=bool, default=True)
args = parser.parse_args()
return args
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments()
ginka_rnn = GinkaRNNModel(device).to(device)
dataset = GinkaRNNDataset(args.train, device)
dataset_val = GinkaRNNDataset(args.validate, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 8)
optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
criterion = RNNGinkaLoss(32, device)
# 用于生成图片
tile_dict = dict()
for file in os.listdir('tiles'):
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
if args.resume:
data_ginka = torch.load(args.state_ginka, map_location=device)
ginka_rnn.load_state_dict(data_ginka["model_state"], strict=False)
if args.load_optim:
if data_ginka.get("optim_state") is not None:
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
print("Train from loaded state.")
for epoch in tqdm(range(args.epochs), desc="RNN Training", disable=disable_tqdm):
loss_total_ginka = torch.Tensor([0]).to(device)
iters = 0
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
val_cond = batch["val_cond"].to(device)
target_map = batch["target_map"].to(device)
fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epochs))
loss = criterion.rnn_loss(fake_logits, target_map)
loss.backward()
torch.nn.utils.clip_grad_norm_(ginka_rnn.parameters(), max_norm=1.0)
optimizer_ginka.step()
loss_total_ginka += loss.detach()
iters += 1
# if iters % 50 == 0:
# avg_loss_ginka = loss_total_ginka.item() / iters
# tqdm.write(
# f"[Iters {iters} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
# f"E: {epoch + 1} | Loss: {avg_loss_ginka:.6f} | " +
# f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
# )
avg_loss_ginka = loss_total_ginka.item() / len(dataloader)
tqdm.write(
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"E: {epoch + 1} | Loss: {avg_loss_ginka:.6f} | " +
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
)
scheduler_ginka.step()
# 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % args.checkpoint == 0:
# 保存检查点
torch.save({
"model_state": ginka_rnn.state_dict(),
"optim_state": optimizer_ginka.state_dict(),
}, f"result/rnn/ginka-{epoch + 1}.pth")
val_loss_total = torch.Tensor([0]).to(device)
with torch.no_grad():
idx = 0
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
val_cond = batch["val_cond"].to(device)
target_map = batch["target_map"].to(device)
fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epochs))
val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()
fake_map = fake_map.cpu().numpy()
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
cv2.imwrite(f"result/ginka_rnn_img/{idx}.png", fake_img)
idx += 1
avg_loss_val = val_loss_total.item() / len(dataloader_val)
tqdm.write(
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " +
f"Loss: {avg_loss_val:.6f}"
)
print("Train ended.")
torch.save({
"model_state": ginka_rnn.state_dict(),
}, f"result/ginka_rnn.pth")
if __name__ == "__main__":
torch.set_num_threads(4)
train()

View File

@ -1,271 +0,0 @@
import argparse
import os
import sys
import random
from datetime import datetime
import torch
import torch.nn.functional as F
import torch.optim as optim
import cv2
import numpy as np
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .transformer.vae import GinkaTransformerVAE
from .vae_rnn.loss import VAELoss
from .vae_rnn.scheduler import VAEScheduler
from .dataset import GinkaRNNDataset
from shared.image import matrix_to_image_cv
# 手工标注标签定义(暂时不用):
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风
# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门
# 24. 中心对称, 25. 部分对称, 26. 鱼骨
# 自动标注标签定义(暂时不用):
# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊
# 32. 平面塔, 33. 转换塔, 34. 道具塔
# 标量值定义:
# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块
# 1. 墙体密度,墙壁/地图面积
# 2. 装饰密度,装饰数量/地图面积
# 3. 门密度,门数量/地图面积
# 4. 怪物密度,怪物数量/地图面积
# 5. 资源密度,资源数量/地图面积
# 6. 宝石密度,宝石数量/地图面积
# 7. 血瓶密度,血瓶数量/地图面积
# 8. 钥匙密度,钥匙数量/地图面积
# 9. 道具密度,道具数量/地图面积
# 10. 入口数量
# 11. 机关门数量
# 12. 咸鱼门数量(多层咸鱼门只算一个)
# 图块定义:
# 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶
# 8. 道具, 9. 怪物, 10. 入口, 14. 起始 token, 15. 终止 token
BATCH_SIZE = 128
LATENT_DIM = 32
KL_BETA = 0.01
SELF_GATE = 0.5
GATE_EPOCH = 5
VAL_BATCH_DIVIDER = 128
PROB_STEP = 0.05
NUM_CLASSES = 16
device = torch.device(
"cuda:1" if torch.cuda.is_available()
else "mps" if torch.mps.is_available()
else "cpu"
)
os.makedirs("result", exist_ok=True)
os.makedirs("result/vae", exist_ok=True)
os.makedirs("result/ginka_vae_img", exist_ok=True)
disable_tqdm = not sys.stdout.isatty()
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state_ginka", type=str, default="result/vae/ginka-100.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default="ginka-eval.json")
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--checkpoint", type=int, default=5)
parser.add_argument("--load_optim", type=bool, default=True)
args = parser.parse_args()
return args
def train():
print(f"Using {device.type} to train model.")
args = parse_arguments()
vae = GinkaTransformerVAE(num_classes=NUM_CLASSES, latent_dim=LATENT_DIM).to(device)
dataset = GinkaRNNDataset(args.train, device)
dataset_val = GinkaRNNDataset(args.validate, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-2, betas=(0.9, 0.95))
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
scheduler_ginka = VAEScheduler(
optimizer_ginka, factor=0.9, increase_factor=2, patience=10, max_lr=1e-4, min_lr=1e-6
)
criterion = VAELoss()
self_prob = 0
prob_epochs = 0
# 用于生成图片
tile_dict = dict()
for file in os.listdir('tiles2'):
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED)
if args.resume:
data_ginka = torch.load(args.state_ginka, map_location=device)
vae.load_state_dict(data_ginka["model_state"], strict=False)
if args.load_optim:
if data_ginka.get("optim_state") is not None:
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
print("Train from loaded state.")
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
loss_total = torch.Tensor([0]).to(device)
reco_loss_total = torch.Tensor([0]).to(device)
kl_loss_total = torch.Tensor([0]).to(device)
vae.teacher_forcing()
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
target_map = batch["target_map"].to(device)
B, H, W = target_map.shape
input = target_map.view(B, H * W)
optimizer_ginka.zero_grad()
fake_logits, mu, logvar = vae(input, self_prob)
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, input, mu, logvar, KL_BETA)
loss.backward()
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
optimizer_ginka.step()
loss_total += loss.detach()
reco_loss_total += reco_loss.detach()
kl_loss_total += kl_loss.detach()
avg_loss = loss_total.item() / len(dataloader)
avg_reco_loss = reco_loss_total.item() / len(dataloader)
avg_kl_loss = kl_loss_total.item() / len(dataloader)
tqdm.write(
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"E: {epoch + 1} | Loss: {avg_loss:.6f} | Reco: {avg_reco_loss:.6f} | " +
f"KL: {avg_kl_loss:.6f} | Prob: {self_prob:.2f} | LR: {scheduler_ginka.get_last_lr()[0]:.6f}"
)
# 验证集
# with torch.no_grad():
# val_loss_total = torch.Tensor([0]).to(device)
# for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
# target_map = batch["target_map"].to(device)
# fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
# loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
# val_loss_total += loss.detach()
# avg_loss_val = val_loss_total.item() / len(dataloader_val)
# 先使用训练集的损失值,因为过拟合比较严重,后续再想办法
if avg_loss < SELF_GATE:
prob_epochs += 1
else:
prob_epochs = 0
if prob_epochs >= GATE_EPOCH and self_prob < 1:
self_prob += PROB_STEP
prob_epochs = 0
if self_prob > 1:
self_prob = 1
self_prob = 1
scheduler_ginka.step(avg_loss, self_prob)
# 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % args.checkpoint == 0:
# 保存检查点
torch.save({
"model_state": vae.state_dict(),
"optim_state": optimizer_ginka.state_dict(),
}, f"result/rnn/ginka-{epoch + 1}.pth")
val_loss_total = torch.Tensor([0]).to(device)
val_reco_loss_total = torch.Tensor([0]).to(device)
val_kl_loss_total = torch.Tensor([0]).to(device)
vae.eval()
with torch.no_grad():
idx = 0
gap = 5
color = (255, 255, 255) # 白色
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
# 地图重建展示
vae.teacher_forcing()
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
target_map = batch["target_map"].to(device)
B, H, W = target_map.shape
input = target_map.view(B, H * W)
fake_logits, mu, logvar = vae(input, self_prob)
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, input, mu, logvar, KL_BETA)
val_loss_total += loss.detach()
val_reco_loss_total += reco_loss.detach()
val_kl_loss_total += kl_loss.detach()
fake_map = torch.argmax(fake_logits, dim=2)[:,0:169].view(B, H, W).cpu().numpy()
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
real_map = target_map.cpu().numpy()
real_img = matrix_to_image_cv(real_map[0], tile_dict)
img = np.block([[real_img], [vline], [fake_img]])
cv2.imwrite(f"result/ginka_vae_img/{idx}.png", img)
idx += 1
# 随机采样
vae.autoregressive()
for i in range(0, 8):
z = torch.randn(1, LATENT_DIM).to(device)
fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device))
fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy()
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img)
# 插值
val_length = len(dataset_val.data)
index1 = random.randint(0, val_length - 1)
index2 = random.randint(0, val_length - 1)
map1 = torch.LongTensor(dataset_val.data[index1]["map"]).to(device).view(1, 169)
map2 = torch.LongTensor(dataset_val.data[index2]["map"]).to(device).view(1, 169)
mu1, logvar1 = vae.encoder(map1)
mu2, logvar2 = vae.encoder(map2)
z1 = vae.reparameterize(mu1, logvar1)
z2 = vae.reparameterize(mu2, logvar2)
real_img1 = matrix_to_image_cv(map1[0].view(13, 13).cpu().numpy(), tile_dict)
real_img2 = matrix_to_image_cv(map2[0].view(13, 13).cpu().numpy(), tile_dict)
i = 0
for t in torch.linspace(0, 1, 8):
z = z1 * (1 - t / 8) + z2 * t / 8
fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device))
fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy()
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
img = np.block([[real_img1], [vline], [fake_img], [vline], [real_img2]])
cv2.imwrite(f"result/ginka_vae_img/{i}_linspace.png", img)
i += 1
avg_loss_val = val_loss_total.item() / len(dataloader_val)
avg_reco_loss_val = val_reco_loss_total.item() / len(dataloader_val)
avg_kl_loss_val = val_kl_loss_total.item() / len(dataloader_val)
tqdm.write(
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " +
f"Loss: {avg_loss_val:.6f} | Reco: {avg_reco_loss_val:.6f} | KL: {avg_kl_loss_val:.6f}"
)
print("Train ended.")
torch.save({
"model_state": vae.state_dict(),
}, f"result/ginka_transformer.pth")
if __name__ == "__main__":
torch.set_num_threads(4)
train()

View File

@ -1,270 +0,0 @@
import argparse
import os
import sys
import random
from datetime import datetime
import torch
import torch.nn.functional as F
import torch.optim as optim
import cv2
import numpy as np
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .vae_rnn.vae import GinkaVAE
from .vae_rnn.loss import VAELoss
from .vae_rnn.scheduler import VAEScheduler
from .dataset import GinkaRNNDataset
from shared.image import matrix_to_image_cv
# 手工标注标签定义(暂时不用):
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风
# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门
# 24. 中心对称, 25. 部分对称, 26. 鱼骨
# 自动标注标签定义(暂时不用):
# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊
# 32. 平面塔, 33. 转换塔, 34. 道具塔
# 标量值定义:
# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块
# 1. 墙体密度,墙壁/地图面积
# 2. 装饰密度,装饰数量/地图面积
# 3. 门密度,门数量/地图面积
# 4. 怪物密度,怪物数量/地图面积
# 5. 资源密度,资源数量/地图面积
# 6. 宝石密度,宝石数量/地图面积
# 7. 血瓶密度,血瓶数量/地图面积
# 8. 钥匙密度,钥匙数量/地图面积
# 9. 道具密度,道具数量/地图面积
# 10. 入口数量
# 11. 机关门数量
# 12. 咸鱼门数量(多层咸鱼门只算一个)
# 图块定义:
# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地),
# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门
# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启
# 10-12. 三种等级的红宝石
# 13-15. 三种等级的蓝宝石
# 16-18. 三种等级的绿宝石
# 19-22. 四种等级的血瓶
# 23-25. 三种等级的道具
# 26-28. 三种等级的怪物
# 29. 入口,不区分楼梯和箭头
BATCH_SIZE = 128
LATENT_DIM = 64
KL_BETA = 0.01
SELF_GATE = 0.3
GATE_EPOCH = 10
VAL_BATCH_DIVIDER = 128
PROB_STEP = 0.05
device = torch.device(
"cuda:1" if torch.cuda.is_available()
else "mps" if torch.mps.is_available()
else "cpu"
)
os.makedirs("result", exist_ok=True)
os.makedirs("result/vae", exist_ok=True)
os.makedirs("result/ginka_vae_img", exist_ok=True)
disable_tqdm = not sys.stdout.isatty()
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state_ginka", type=str, default="result/vae/ginka-100.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default="ginka-eval.json")
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--checkpoint", type=int, default=5)
parser.add_argument("--load_optim", type=bool, default=True)
args = parser.parse_args()
return args
def train():
print(f"Using {device.type} to train model.")
args = parse_arguments()
vae = GinkaVAE(device, latent_dim=LATENT_DIM).to(device)
dataset = GinkaRNNDataset(args.train, device)
dataset_val = GinkaRNNDataset(args.validate, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-4)
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
scheduler_ginka = VAEScheduler(
optimizer_ginka, factor=0.9, increase_factor=1.5, patience=20, max_lr=3e-4, min_lr=1e-6
)
criterion = VAELoss()
self_prob = 0
prob_epochs = 0
# 用于生成图片
tile_dict = dict()
for file in os.listdir('tiles'):
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
if args.resume:
data_ginka = torch.load(args.state_ginka, map_location=device)
vae.load_state_dict(data_ginka["model_state"], strict=False)
if args.load_optim:
if data_ginka.get("optim_state") is not None:
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
print("Train from loaded state.")
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
loss_total = torch.Tensor([0]).to(device)
reco_loss_total = torch.Tensor([0]).to(device)
kl_loss_total = torch.Tensor([0]).to(device)
vae.train()
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
target_map = batch["target_map"].to(device)
optimizer_ginka.zero_grad()
fake_logits, mu, logvar = vae(target_map, self_prob)
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
loss.backward()
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
optimizer_ginka.step()
loss_total += loss.detach()
reco_loss_total += reco_loss.detach()
kl_loss_total += kl_loss.detach()
avg_loss = loss_total.item() / len(dataloader)
avg_reco_loss = reco_loss_total.item() / len(dataloader)
avg_kl_loss = kl_loss_total.item() / len(dataloader)
tqdm.write(
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"E: {epoch + 1} | Loss: {avg_loss:.6f} | Reco: {avg_reco_loss:.6f} | " +
f"KL: {avg_kl_loss:.6f} | Prob: {self_prob:.2f} | LR: {scheduler_ginka.get_last_lr()[0]:.6f}"
)
# 验证集
# with torch.no_grad():
# val_loss_total = torch.Tensor([0]).to(device)
# for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
# target_map = batch["target_map"].to(device)
# fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
# loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
# val_loss_total += loss.detach()
# avg_loss_val = val_loss_total.item() / len(dataloader_val)
# 先使用训练集的损失值,因为过拟合比较严重,后续再想办法
if avg_loss < SELF_GATE:
prob_epochs += 1
else:
prob_epochs = 0
if prob_epochs >= GATE_EPOCH and self_prob < 1:
self_prob += PROB_STEP
prob_epochs = 0
if self_prob > 1:
self_prob = 1
scheduler_ginka.step(avg_loss, self_prob)
# 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % args.checkpoint == 0:
vae.eval()
# 保存检查点
torch.save({
"model_state": vae.state_dict(),
"optim_state": optimizer_ginka.state_dict(),
}, f"result/rnn/ginka-{epoch + 1}.pth")
val_loss_total = torch.Tensor([0]).to(device)
val_reco_loss_total = torch.Tensor([0]).to(device)
val_kl_loss_total = torch.Tensor([0]).to(device)
with torch.no_grad():
idx = 0
gap = 5
color = (255, 255, 255) # 白色
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
# 地图重建展示
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
target_map = batch["target_map"].to(device)
fake_logits, mu, logvar = vae(target_map, self_prob)
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
val_loss_total += loss.detach()
val_reco_loss_total += reco_loss.detach()
val_kl_loss_total += kl_loss.detach()
fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy()
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
real_map = target_map.cpu().numpy()
real_img = matrix_to_image_cv(real_map[0], tile_dict)
img = np.block([[real_img], [vline], [fake_img]])
cv2.imwrite(f"result/ginka_vae_img/{idx}.png", img)
idx += 1
# 随机采样
for i in range(0, 8):
z = torch.randn(1, LATENT_DIM).to(device)
fake_logits = vae.decoder(z, torch.zeros(1, 13, 13).to(device), 1)
fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy()
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img)
# 插值
val_length = len(dataset_val.data)
index1 = random.randint(0, val_length - 1)
index2 = random.randint(0, val_length - 1)
map1 = torch.LongTensor(dataset_val.data[index1]["map"]).to(device).reshape(1, 13, 13)
map2 = torch.LongTensor(dataset_val.data[index2]["map"]).to(device).reshape(1, 13, 13)
mu1, logvar1 = vae.encoder(map1)
mu2, logvar2 = vae.encoder(map2)
z1 = vae.reparameterize(mu1, logvar1)
z2 = vae.reparameterize(mu2, logvar2)
real_img1 = matrix_to_image_cv(map1[0].cpu().numpy(), tile_dict)
real_img2 = matrix_to_image_cv(map2[0].cpu().numpy(), tile_dict)
i = 0
for t in torch.linspace(0, 1, 8):
z = z1 * (1 - t / 8) + z2 * t / 8
fake_logits = vae.decoder(z, torch.zeros(1, 13, 13).to(device), 1)
fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy()
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
img = np.block([[real_img1], [vline], [fake_img], [vline], [real_img2]])
cv2.imwrite(f"result/ginka_vae_img/{i}_linspace.png", img)
i += 1
avg_loss_val = val_loss_total.item() / len(dataloader_val)
avg_reco_loss_val = val_reco_loss_total.item() / len(dataloader_val)
avg_kl_loss_val = val_kl_loss_total.item() / len(dataloader_val)
tqdm.write(
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " +
f"Loss: {avg_loss_val:.6f} | Reco: {avg_reco_loss_val:.6f} | KL: {avg_kl_loss_val:.6f}"
)
print("Train ended.")
torch.save({
"model_state": vae.state_dict(),
}, f"result/ginka_rnn.pth")
if __name__ == "__main__":
torch.set_num_threads(4)
train()

View File

@ -1,428 +0,0 @@
import argparse
import os
import sys
from datetime import datetime
import torch
import torch.optim as optim
import torch.nn.functional as F
import cv2
import numpy as np
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .generator.model import GinkaModel
from .dataset import GinkaWGANDataset
from .generator.loss import WGANGinkaLoss
from .critic.model import MinamoModel2
from shared.image import matrix_to_image_cv
# 手工标注标签定义:
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风
# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门
# 24. 中心对称, 25. 部分对称, 26. 鱼骨
# 自动标注标签定义:
# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊
# 32. 平面塔, 33. 转换塔, 34. 道具塔
# 标量值定义:
# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块
# 1. 墙体密度,墙壁/地图面积
# 2. 装饰密度,装饰数量/地图面积
# 3. 门密度,门数量/地图面积
# 4. 怪物密度,怪物数量/地图面积
# 5. 资源密度,资源数量/地图面积
# 6. 宝石密度,宝石数量/地图面积
# 7. 血瓶密度,血瓶数量/地图面积
# 8. 钥匙密度,钥匙数量/地图面积
# 9. 道具密度,道具数量/地图面积
# 10. 入口数量
# 11. 机关门数量
# 12. 咸鱼门数量(多层咸鱼门只算一个)
# 图块定义:
# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地),
# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门
# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启
# 10-12. 三种等级的红宝石
# 13-15. 三种等级的蓝宝石
# 16-18. 三种等级的绿宝石
# 19-22. 四种等级的血瓶
# 23-25. 三种等级的道具
# 26-28. 三种等级的怪物
# 29. 入口,不区分楼梯和箭头
BATCH_SIZE = 6
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/wgan", exist_ok=True)
disable_tqdm = not sys.stdout.isatty()
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth")
parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default="ginka-eval.json")
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--checkpoint", type=int, default=5)
parser.add_argument("--load_optim", type=bool, default=True)
parser.add_argument("--curr_epoch", type=int, default=20) # 课程学习至少多少 epoch
parser.add_argument("--tuning", type=bool, default=False)
args = parser.parse_args()
return args
def gen_curriculum(gen, masked1, masked2, masked3, tag, val, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
fake1 = gen(masked1, 1, tag, val)
fake2 = gen(masked2, 2, tag, val)
fake3 = gen(masked3, 3, tag, val)
if detach:
return fake1.detach(), fake2.detach(), fake3.detach()
else:
return fake1, fake2, fake3
def gen_total(gen, input, tag, val, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
if random:
fake0 = gen(input, 0, tag, val)
x_in = F.softmax(fake0, dim=1)
else:
fake0 = input
x_in = input
if progress_detach:
fake1 = gen(x_in.detach(), 1, tag, val)
fake2 = gen(F.softmax(fake1.detach(), dim=1), 2, tag, val)
fake3 = gen(F.softmax(fake2.detach(), dim=1), 3, tag, val)
else:
fake1 = gen(x_in, 1, tag, val)
fake2 = gen(F.softmax(fake1, dim=1), 2, tag, val)
fake3 = gen(F.softmax(fake2, dim=1), 3, tag, val)
if result_detach:
return fake1.detach(), fake2.detach(), fake3.detach(), fake0.detach()
else:
return fake1, fake2, fake3, fake0
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments()
c_steps = 2
g_steps = 1
# 训练阶段
train_stage = 1
mask_ratio = 0.2 # 蒙版区域大小
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
total_epoch = 0
ginka = GinkaModel().to(device)
minamo = MinamoModel2().to(device)
dataset = GinkaWGANDataset(args.train, device)
dataset_val = GinkaWGANDataset(args.validate, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9))
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2)
criterion = WGANGinkaLoss()
# 用于生成图片
tile_dict = dict()
for file in os.listdir('tiles'):
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
if args.resume:
data_ginka = torch.load(args.state_ginka, map_location=device)
data_minamo = torch.load(args.state_minamo, map_location=device)
ginka.load_state_dict(data_ginka["model_state"], strict=False)
minamo.load_state_dict(data_minamo["model_state"], strict=False)
# if data_ginka.get("c_steps") is not None and data_ginka.get("g_steps") is not None:
# c_steps = data_ginka["c_steps"]
# g_steps = data_ginka["g_steps"]
if data_ginka.get("mask_ratio") is not None:
mask_ratio = data_ginka["mask_ratio"]
if data_ginka.get("stage_epoch") is not None:
stage_epoch = data_ginka["stage_epoch"]
if data_ginka.get("stage") is not None:
train_stage = data_ginka["stage"]
if data_ginka.get("total_epoch") is not None:
total_epoch = data_ginka["data_ginka"]
if args.load_optim:
if data_ginka.get("optim_state") is not None:
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
if data_minamo.get("optim_state") is not None:
optimizer_minamo.load_state_dict(data_minamo["optim_state"])
print("Train from loaded state.")
curr_epoch = args.curr_epoch
first_curr = curr_epoch * 3
if args.tuning:
train_stage = 1
curr_epoch = curr_epoch // 4
first_curr = first_curr // 4
stage_epoch = 0
mask_ratio = 0.2
dataset.train_stage = train_stage
dataset.mask_ratio1 = mask_ratio
dataset.mask_ratio2 = mask_ratio
dataset.mask_ratio3 = mask_ratio
dataset_val.train_stage = train_stage
dataset_val.mask_ratio1 = mask_ratio
dataset_val.mask_ratio2 = mask_ratio
dataset_val.mask_ratio3 = mask_ratio
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
loss_total_minamo = torch.Tensor([0]).to(device)
loss_total_ginka = torch.Tensor([0]).to(device)
dis_total = torch.Tensor([0]).to(device)
loss_ce_total = torch.Tensor([0]).to(device)
iters = 0
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
rand = batch["rand"].to(device)
real0 = batch["real0"].to(device)
real1 = batch["real1"].to(device)
masked1 = batch["masked1"].to(device)
real2 = batch["real2"].to(device)
masked2 = batch["masked2"].to(device)
real3 = batch["real3"].to(device)
masked3 = batch["masked3"].to(device)
tag_cond = batch["tag_cond"].to(device)
val_cond = batch["val_cond"].to(device)
# ---------- 训练判别器
for _ in range(c_steps):
# 生成假样本
optimizer_minamo.zero_grad()
optimizer_ginka.zero_grad()
with torch.no_grad():
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
if train_stage < 4:
fake0 = ginka(rand, 0, tag_cond, val_cond)
loss_d0, dis0 = criterion.discriminator_loss(minamo, 0, real0, fake0, tag_cond, val_cond)
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1, tag_cond, val_cond)
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2, tag_cond, val_cond)
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3, tag_cond, val_cond)
dis = [dis0, dis1, dis2, dis3]
loss_d = [loss_d0, loss_d1, loss_d2, loss_d3]
dis_avg = sum(dis) / len(dis)
loss_d_avg = sum(loss_d) / len(loss_d)
# 反向传播
loss_d_avg.backward()
optimizer_minamo.step()
loss_total_minamo += loss_d_avg.detach()
dis_total += dis_avg.detach()
# ---------- 训练生成器
for _ in range(g_steps):
optimizer_minamo.zero_grad()
optimizer_ginka.zero_grad()
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, False)
loss_g1, loss_ce_g1 = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1, tag_cond, val_cond)
loss_g2, loss_ce_g2 = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond)
loss_g3, loss_ce_g3 = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond)
loss_g = (loss_g1 * 3.0 + loss_g2 + loss_g3) / 5.0
loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3)
loss_ce_total += loss_ce.detach()
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4)
if train_stage == 4:
fake0 = F.softmax(fake0, dim=1)
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, fake0, tag_cond, val_cond)
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1, tag_cond, val_cond)
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond)
loss_g = (loss_g1 * 3.0 + loss_g2 + loss_g3) / 5.0
if train_stage < 4:
fake0 = F.softmax(ginka(rand, 0, tag_cond, val_cond), dim=1)
loss_g0 = criterion.generator_input_head_loss(minamo, fake0, tag_cond, val_cond)
loss_g += loss_g0
loss_g.backward()
optimizer_ginka.step()
loss_total_ginka += loss_g.detach()
iters += 1
if iters % 50 == 0:
avg_loss_ginka = loss_total_ginka.item() / iters / g_steps
avg_loss_minamo = loss_total_minamo.item() / iters / c_steps
avg_loss_ce = loss_ce_total.item() / iters / g_steps
avg_dis = dis_total.item() / iters / c_steps
tqdm.write(
f"[Iters {iters} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " +
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
)
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
avg_loss_ce = loss_ce_total.item() / len(dataloader) / g_steps
avg_dis = dis_total.item() / len(dataloader) / c_steps
tqdm.write(
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " +
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
)
# 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % args.checkpoint == 0:
# 保存检查点
torch.save({
"model_state": ginka.state_dict(),
"optim_state": optimizer_ginka.state_dict(),
"c_steps": c_steps,
"g_steps": g_steps,
"stage": train_stage,
"mask_ratio": mask_ratio,
"stage_epoch": stage_epoch,
}, f"result/wgan/ginka-{epoch + 1}.pth")
torch.save({
"model_state": minamo.state_dict(),
"optim_state": optimizer_minamo.state_dict()
}, f"result/wgan/minamo-{epoch + 1}.pth")
idx = 0
gap = 5
color = (255, 255, 255) # 白色
with torch.no_grad():
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
real1 = batch["real1"].to(device)
masked1 = batch["masked1"].to(device)
real2 = batch["real2"].to(device)
masked2 = batch["masked2"].to(device)
real3 = batch["real3"].to(device)
masked3 = batch["masked3"].to(device)
tag_cond = batch["tag_cond"].to(device)
val_cond = batch["val_cond"].to(device)
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
fake0 = torch.argmax(fake0, dim=1).cpu().numpy()
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
fake3 = torch.argmax(fake3, dim=1).cpu().numpy()
masked1 = torch.argmax(masked1, dim=1).cpu().numpy()
masked2 = torch.argmax(masked2, dim=1).cpu().numpy()
masked3 = torch.argmax(masked3, dim=1).cpu().numpy()
for i in range(fake1.shape[0]):
fake1_img = matrix_to_image_cv(fake1[i], tile_dict)
fake2_img = matrix_to_image_cv(fake2[i], tile_dict)
fake3_img = matrix_to_image_cv(fake3[i], tile_dict)
if train_stage == 1 or train_stage == 2:
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
hline = np.full((gap, 3 * 416 + gap * 2, 3), color, dtype=np.uint8) # 水平分割线
in1_img = matrix_to_image_cv(masked1[i], tile_dict)
in2_img = matrix_to_image_cv(masked2[i], tile_dict)
in3_img = matrix_to_image_cv(masked3[i], tile_dict)
img = np.block([
[[in1_img], [vline], [in2_img], [vline], [in3_img]],
[[hline]],
[[fake1_img], [vline], [fake2_img], [vline], [fake3_img]]
])
elif train_stage == 3 or train_stage == 4:
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
hline = np.full((gap, 2 * 416 + gap, 3), color, dtype=np.uint8) # 水平分割线
in_img = matrix_to_image_cv(fake0[i], tile_dict)
img = np.block([
[[in_img], [vline], [fake1_img]],
[[hline]],
[[fake2_img], [vline], [fake3_img]]
])
cv2.imwrite(f"result/ginka_img/{idx}.png", img)
idx += 1
# 训练流程控制
# if train_stage >= 2:
# # train_stage = 4
# if (epoch + 1) % 100 == 5:
# train_stage = 3
# elif (epoch + 1) % 100 == 20:
# train_stage = 4
# elif (epoch + 1) % 100 == 0:
# train_stage = 2
if train_stage == 1:
if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \
(mask_ratio > 0.3 and stage_epoch >= curr_epoch):
mask_ratio += 0.2
mask_ratio = min(mask_ratio, 0.8)
stage_epoch = 0
if mask_ratio >= 0.8:
train_stage = 4
stage_epoch += 1
total_epoch += 1
dataset.train_stage = train_stage
dataset_val.train_stage = train_stage
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
scheduler_ginka.step()
scheduler_minamo.step()
print("Train ended.")
torch.save({
"model_state": ginka.state_dict(),
}, f"result/ginka.pth")
torch.save({
"model_state": minamo.state_dict(),
}, f"result/minamo.pth")
if __name__ == "__main__":
torch.set_num_threads(4)
train()

View File

@ -1,106 +0,0 @@
import time
import torch
import torch.nn as nn
from ..utils import print_memory
class GinkaTransformerDecoder(nn.Module):
def __init__(self, num_classes=32, dim_ff=256, nhead=4, num_layers=4, map_size=13*13):
super().__init__()
self.autoregressive = False
self.dim_ff = dim_ff
self.map_size = map_size
self.embedding = nn.Embedding(num_classes, dim_ff)
self.pos_embedding = nn.Embedding(map_size + 1, dim_ff)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True),
num_layers=max(num_layers // 2, 1)
)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True),
num_layers=num_layers
)
self.fc = nn.Sequential(
nn.Linear(dim_ff, num_classes)
)
def forward(self, z: torch.Tensor, target_map: torch.Tensor):
# z: [B, dim_ff]
# target_map: [B, H * W]
# training output: [B, H * W, dim_ff]
# evaling output: [B, H * W]
B, L = target_map.shape
memory = self.encoder(z.unsqueeze(1)) # [B, 1, dim_ff]
mask = torch.triu(torch.ones(L + 1, L + 1, dtype=torch.bool)).to(z.device) # [B, H * W, H * W]
# when training, use teacher forcing
if not self.autoregressive:
first_token = torch.tensor([31], dtype=torch.long).to(z.device).repeat(B, 1)
with_first = torch.cat([first_token, target_map], dim=1)
map = self.embedding(with_first)
pos_embed = self.pos_embedding(torch.arange(L + 1, dtype=torch.long).to(z.device))
map = map + pos_embed # [B, H * W, dim_ff]
decoded = self.decoder(map, memory, tgt_mask=mask) # [B, H * W, dim_ff]
output = self.fc(decoded)
return output
# when evaling, use autoregressive generation
else:
output = torch.zeros([B, L + 1], dtype=torch.int).to(z.device)
for idx in range(0, self.map_size):
embed = self.embedding(output)
pos_embed = self.pos_embedding(torch.IntTensor([idx]).repeat(B, 1).to(z.device))
map = embed + pos_embed # [B, H * W, dim_ff]
decoded = self.decoder(map, memory, tgt_mask=mask)
decoded = self.fc(decoded) # [B, H * W, dim_ff]
output[:, idx] = torch.argmax(decoded[:, idx, :], dim=1)
return output
class GinkaTransformerVAEDecoder(nn.Module):
def __init__(
self, latent_dim=32, num_classes=32, dim_ff=256, nhead=4, num_layers=4,
map_size=13*13
):
super().__init__()
self.map_size = map_size
self.input = nn.Sequential(
nn.Linear(latent_dim, dim_ff),
nn.Dropout(0.3),
nn.LayerNorm(dim_ff),
nn.ReLU(),
nn.Linear(dim_ff, dim_ff)
)
self.decoder = GinkaTransformerDecoder(
num_classes=num_classes, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers, map_size=map_size
)
def forward(self, z: torch.Tensor, map: torch.Tensor):
hidden = self.input(z)
output = self.decoder(hidden, map)
return output
if __name__ == "__main__":
device = torch.device("cpu")
input = torch.randn(1, 32).to(device)
map = torch.randint(0, 32, [1, 169]).to(device)
# 初始化模型
model = GinkaTransformerVAEDecoder().to(device)
print_memory("初始化后")
# 前向传播
start = time.perf_counter()
output = model(input, map)
end = time.perf_counter()
print_memory("前向传播后")
print(f"推理耗时: {end - start}")
print(f"输出形状: output={output.shape}")
print(f"Input Embedding parameters: {sum(p.numel() for p in model.input.parameters())}")
print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -1,97 +0,0 @@
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import print_memory
class GinkaTransformerEncoder(nn.Module):
def __init__(self, dim_ff=256, nhead=4, num_layers=4):
super().__init__()
self.dim_ff = dim_ff
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True, activation=F.gelu),
num_layers=num_layers
)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True, activation=F.gelu),
num_layers=max(num_layers // 2, 1)
)
def forward(self, x: torch.Tensor):
# x: [B, H * W, S]
B, L, S = x.shape
first_token = torch.randn(B, 1, self.dim_ff).to(x.device)
x = self.encoder(x)
x = self.decoder(first_token, x)
return x.squeeze(1)
class GinkaTransformerBottleneck(nn.Module):
def __init__(self, dim_ff=256, hidden_dim=512, latent_dim=32):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(dim_ff, hidden_dim),
nn.Dropout(0.3),
nn.LayerNorm(hidden_dim),
nn.GELU(),
)
self.fc_mu = nn.Sequential(
nn.Linear(hidden_dim, latent_dim)
)
self.fc_logvar = nn.Sequential(
nn.Linear(hidden_dim, latent_dim)
)
def forward(self, x):
# x: [B, dim_ff]
hidden = self.fc(x)
mu = self.fc_mu(hidden)
logvar = self.fc_logvar(hidden)
return mu, logvar
class GinkaTransformerVAEEncoder(nn.Module):
def __init__(
self, num_classes=32, latent_dim=32, bottleneck_dim=512, dim_ff=256,
nhead=4, num_layers=4, map_size=13*13
):
super().__init__()
self.map_size = map_size
self.embedding = nn.Embedding(num_classes, dim_ff)
self.pos_embedding = nn.Embedding(map_size, dim_ff)
self.encoder = GinkaTransformerEncoder(dim_ff=dim_ff, nhead=nhead, num_layers=num_layers)
self.bottleneck = GinkaTransformerBottleneck(
dim_ff=dim_ff, hidden_dim=bottleneck_dim, latent_dim=latent_dim
)
def forward(self, x: torch.Tensor):
# x: [B, map_size]
pos = self.pos_embedding(torch.arange(self.map_size, dtype=torch.long).to(x.device))
x = self.embedding(x) + pos
x = self.encoder(x)
mu, logvar = self.bottleneck(x)
return mu, logvar
if __name__ == "__main__":
device = torch.device("cpu")
input = torch.randint(0, 32, [1, 169]).to(device)
# 初始化模型
model = GinkaTransformerVAEEncoder().to(device)
print_memory("初始化后")
# 前向传播
start = time.perf_counter()
mu, logvar = model(input)
end = time.perf_counter()
print_memory("前向传播后")
print(f"推理耗时: {end - start}")
print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}")
print(f"Embedding parameters: {sum(p.numel() for p in model.embedding.parameters())}")
print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}")
print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}")
print(f"bottleneck parameters: {sum(p.numel() for p in model.bottleneck.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -1,22 +0,0 @@
import torch
import torch.nn as nn
class FSQ(nn.Module):
def __init__(self, levels=7):
super().__init__()
self.levels = levels
self.scale = (levels - 1) / 2
def forward(self, z):
# 限制范围
z = torch.tanh(z)
# 量化
z_q = torch.round(z * self.scale) / self.scale
# Straight-through estimator
z_q = z + (z_q - z).detach()
return z_q

View File

@ -1,54 +0,0 @@
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from .encoder import GinkaTransformerVAEEncoder
from .decoder import GinkaTransformerVAEDecoder
from ..utils import print_memory
class GinkaTransformerVAE(nn.Module):
def __init__(self, num_classes=32, latent_dim=32):
super().__init__()
self.encoder = GinkaTransformerVAEEncoder(num_classes=num_classes, latent_dim=latent_dim)
self.decoder = GinkaTransformerVAEDecoder(latent_dim=latent_dim)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def autoregressive(self):
self.decoder.decoder.autoregressive = True
def teacher_forcing(self):
self.decoder.decoder.autoregressive = False
def forward(self, target_map: torch.Tensor, use_self_probility=0):
# target_map: [B, H * W]
mu, logvar = self.encoder(target_map) # [B, latent_dim]
z = self.reparameterize(mu, logvar)
logits = self.decoder(z, target_map) # [B, H * W, num_classes] | [B, H * W]
return logits, mu, logvar
if __name__ == "__main__":
device = torch.device("cpu")
input = torch.randint(0, 32, [1, 169]).to(device)
# 初始化模型
model = GinkaTransformerVAE().to(device)
print_memory("初始化后")
# 前向传播
start = time.perf_counter()
logits, mu, logvar = model(input)
end = time.perf_counter()
print_memory("前向传播后")
print(f"推理耗时: {end - start}")
print(f"输出形状: logits= {logits.shape}, mu={mu.shape}, logvar={logvar.shape}")
print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}")
print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -1,264 +0,0 @@
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import print_memory
class DecoderMapPatch(nn.Module):
def __init__(self, tile_classes=32, width=13, height=13):
super().__init__()
# 地图局部卷积,用于捕获局部结构信息
self.width = width
self.height = height
self.tile_classes = 32
self.patch_cnn = nn.Sequential(
nn.Conv2d(tile_classes + 1, 64, 3, padding=1),
nn.Dropout(0.2),
nn.BatchNorm2d(64),
nn.GELU(),
nn.Conv2d(64, 128, 3),
nn.Dropout(0.2),
nn.BatchNorm2d(128),
nn.GELU(),
nn.Flatten()
)
self.fc = nn.Linear(128 * 3 * 3, 256)
def forward(self, map: torch.Tensor, x: int, y: int):
"""
map: [B, H, W]
"""
B, H, W = map.shape
mask = torch.zeros([B, 5, 5]).to(map.device)
result = torch.zeros([B, 5, 5], dtype=torch.long).to(map.device)
left = x - 2 if x >= 2 else 0
right = x + 3 if x < self.width - 2 else self.width
top = y - 4 if y >= 4 else 0
bottom = y + 1
res_left = left - (x - 2)
res_right = right - (x + 3) + 5
res_top = top - (y - 4)
res_bottom = 5
result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right]
# 没画到的地方要置为 0
result[:, 4, 2] = 0
result[:, 4, 3] = 0
result[:, 4, 4] = 0
mask[:, res_top:res_bottom, res_left:res_right] = 1
mask[:, 4, 2] = 0
mask[:, 4, 3] = 0
mask[:, 4, 4] = 0
masked_result = torch.zeros([B, self.tile_classes + 1, 5, 5]).to(map.device)
masked_result[:, 0:32] = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
masked_result[:, 32] = mask
feat = self.patch_cnn(masked_result)
feat = self.fc(feat)
return feat
class DecoderTileEmbedding(nn.Module):
def __init__(self, tile_classes=32, embed_dim=256):
super().__init__()
# 图块编码,上一次画的图块
self.embedding = nn.Embedding(tile_classes, embed_dim)
def forward(self, tile: torch.Tensor):
return self.embedding(tile)
class DecoderPosEmbedding(nn.Module):
def __init__(self, width=13, height=13, embed_dim=256):
super().__init__()
# 位置编码
self.width = width
self.height = height
self.row_embedding = nn.Embedding(height, embed_dim)
self.col_embedding = nn.Embedding(width, embed_dim)
def forward(self, x: torch.Tensor, y: torch.Tensor):
row = self.row_embedding(y)
col = self.col_embedding(x)
return row, col
class DecoderInputFusion(nn.Module):
def __init__(self, d_model=256):
super().__init__()
# 使用 Transformer 进行信息整合
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=d_model, nhead=2, dim_feedforward=d_model, batch_first=True,
dropout=0.2
),
num_layers=2
)
self.norm = nn.LayerNorm(d_model)
self.fusion = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.Dropout(0.2),
nn.LayerNorm(d_model * 2),
nn.GELU(),
nn.Linear(d_model * 2, d_model),
nn.Dropout(0.1),
nn.LayerNorm(d_model),
nn.GELU()
)
def forward(
self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,
col_embed: torch.Tensor, row_embed: torch.Tensor, patch_vec: torch.Tensor
):
"""
tile_embed: [B, 256]
cond_vec: [B, 256]
col_embed: [B, 256]
row_embed: [B, 256]
patch_vec: [B, 256]
"""
vec = torch.stack([tile_embed, cond_vec, col_embed, row_embed, patch_vec], dim=1)
feat = self.norm(self.transformer(vec))
mean = torch.mean(feat, dim=1)
max = torch.max(feat, dim=1).values
hidden = torch.cat([mean, max], dim=1)
fused = self.fusion(hidden)
return fused
class DecoderRNN(nn.Module):
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512):
super().__init__()
# GRU
self.gru = nn.GRUCell(input_dim, hidden_dim)
self.drop = nn.Dropout(0.2)
self.fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Dropout(0.1),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, tile_classes)
)
def forward(self, feat_fusion: torch.Tensor, hidden: torch.Tensor):
"""
feat_fusion: [B, input_dim]
hidden: [B, hidden_dim]
"""
hidden = self.drop(self.gru(feat_fusion, hidden))
logits = self.fc(hidden)
return logits, hidden
class VAEDecoder(nn.Module):
def __init__(self, device: torch.device, start_tile=31, map_vec_dim=32, width=13, height=13):
super().__init__()
self.device = device
self.width = width
self.height = height
self.start_tile = start_tile
self.rnn_hidden = 512
self.tile_classes = 32
# 模型结构
self.map_vec_fc = nn.Sequential(
nn.Linear(map_vec_dim, 128),
nn.Dropout(0.1),
nn.LayerNorm(128),
nn.GELU(),
nn.Linear(128, 256)
)
self.tile_embedding = DecoderTileEmbedding(tile_classes=self.tile_classes)
self.pos_embedding = DecoderPosEmbedding()
self.map_patch = DecoderMapPatch(tile_classes=self.tile_classes)
self.feat_fusion = DecoderInputFusion()
self.rnn = DecoderRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
self.col_list = []
self.row_list = []
for y in range(0, height):
for x in range(0, width):
self.col_list.append(x)
self.row_list.append(y)
def forward(self, map_vec: torch.Tensor, target_map: torch.Tensor, use_self_probility=0):
"""
map_vec: [B, vec_dim]
target_map: [B, H, W]
use_self_probility: 使用自己生成的上一步结果执行下一步的概率
"""
B, C = map_vec.shape
# 张量声明
now_tile = torch.LongTensor([self.start_tile]).to(self.device).expand(B)
map = torch.zeros([B, self.height, self.width], dtype=torch.int32).to(self.device)
output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device)
hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device)
col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1)
row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1)
col_embed, row_embed = self.pos_embedding(col_list, row_list)
map_vec = self.map_vec_fc(map_vec)
for y in range(0, self.height):
for x in range(0, self.width):
idx = y * self.width + x
# 图块编码、地图局部编码
tile_embed = self.tile_embedding(now_tile)
use_self = random.random() < use_self_probility
map_patch = self.map_patch(map if use_self else target_map, x, y)
# 编码特征融合
feat = self.feat_fusion(tile_embed, map_vec, col_embed[:, idx], row_embed[:, idx], map_patch)
# RNN 输出
logits, h = self.rnn(feat, hidden)
# 处理输出
output_logits[:, y, x] = logits[:]
hidden = h
tile_id = torch.argmax(logits, dim=1).detach()
map[:, y, x] = tile_id[:]
now_tile = tile_id if use_self else target_map[:, y, x].detach()
return output_logits.permute(0, 3, 1, 2)
if __name__ == "__main__":
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
input = torch.randint(0, 32, [1, 13, 13]).to(device)
map_vec = torch.rand(1, 32).to(device)
# 初始化模型
model = VAEDecoder(device).to(device)
print_memory(device, "初始化后")
# 前向传播
start = time.perf_counter()
fake_logits = model(map_vec, input, 0)
end = time.perf_counter()
print_memory(device, "前向传播后")
print(f"推理耗时: {end - start}")
print(f"输出形状: fake_logits={fake_logits.shape}")
print(f"Map Vector FC parameters: {sum(p.numel() for p in model.map_vec_fc.parameters())}")
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}")
print(f"Map Patch parameters: {sum(p.numel() for p in model.map_patch.parameters())}")
print(f"Feature Fusion parameters: {sum(p.numel() for p in model.feat_fusion.parameters())}")
print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -1,161 +0,0 @@
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import print_memory
class EncoderEmbedding(nn.Module):
def __init__(self, tile_classes=32, width=13, height=13, hidden_dim=128, output_dim=256):
super().__init__()
self.tile_embedding = nn.Sequential(
nn.Embedding(tile_classes, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU()
)
self.col_embedding = nn.Sequential(
nn.Embedding(width, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU()
)
self.row_embedding = nn.Sequential(
nn.Embedding(height, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU()
)
self.fusion = nn.Linear(hidden_dim * 3, output_dim)
def forward(self, tile, x, y):
tile_embed = self.tile_embedding(tile)
col_embed = self.col_embedding(x)
row_embed = self.row_embedding(y)
embed = torch.cat([tile_embed, col_embed, row_embed], dim=2)
fused = self.fusion(embed)
return fused
class EncoderGRU(nn.Module):
def __init__(self, input_dim=256, hidden_dim=512, output_dim=256):
super().__init__()
# GRU
self.gru = nn.GRUCell(input_dim, hidden_dim)
self.drop = nn.Dropout(0.2)
self.fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Dropout(0.1),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, feat: torch.Tensor, hidden: torch.Tensor):
"""
feat: [B, input_dim]
hidden: [B, hidden_dim]
"""
hidden = self.drop(self.gru(feat, hidden))
logits = self.fc(hidden)
return logits, hidden
class EncoderFusion(nn.Module):
def __init__(self, d_model=256):
super().__init__()
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=d_model, dim_feedforward=d_model*2, nhead=2, batch_first=True
),
num_layers=3
)
self.norm = nn.LayerNorm(d_model)
self.fc = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.Dropout(0.1),
nn.LayerNorm(d_model * 2),
nn.GELU()
)
def forward(self, logits):
x = self.norm(self.transformer(logits))
h_mean = torch.mean(x, dim=1)
h_max = torch.max(x, dim=1).values
h = torch.cat([h_mean, h_max], dim=1)
return self.fc(h)
class VAEEncoder(nn.Module):
def __init__(self, device, tile_classes=32, latent_dim=32, width=13, height=13):
super().__init__()
self.device = device
self.rnn_hidden = 512
self.logits_dim = 256
self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256)
self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim)
self.fusion = EncoderFusion(256)
self.fc_mu = nn.Sequential(
nn.Linear(512, 512),
nn.Dropout(0.1),
nn.LayerNorm(512),
nn.GELU(),
nn.Linear(512, latent_dim)
)
self.fc_logvar = nn.Sequential(
nn.Linear(512, 512),
nn.Dropout(0.1),
nn.LayerNorm(512),
nn.GELU(),
nn.Linear(512, latent_dim)
)
self.col_list = []
self.row_list = []
for y in range(0, height):
for x in range(0, width):
self.col_list.append(x)
self.row_list.append(y)
def forward(self, x: torch.Tensor):
B, H, W = x.shape
map = torch.flatten(x, start_dim=1)
hidden = torch.zeros(B, self.rnn_hidden).to(self.device)
output = torch.zeros(B, H * W, self.logits_dim).to(self.device)
col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1)
row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1)
embed = self.embedding(map, col_list, row_list)
for idx in range(0, len(self.col_list)):
logits, h = self.rnn(embed[:, idx], hidden)
hidden = h
output[:, idx] = logits
h = self.fusion(output)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
if __name__ == "__main__":
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
input = torch.randint(0, 32, [1, 13, 13]).to(device)
# 初始化模型
model = VAEEncoder(device).to(device)
print_memory(device, "初始化后")
# 前向传播
start = time.perf_counter()
mu, logvar = model(input)
end = time.perf_counter()
print_memory(device, "前向传播后")
print(f"推理耗时: {end - start}")
print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}")
print(f"Embedding parameters: {sum(p.numel() for p in model.embedding.parameters())}")
print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}")
print(f"Fusion parameters: {sum(p.numel() for p in model.fusion.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -1,20 +0,0 @@
import torch
import torch.nn.functional as F
class VAELoss:
def __init__(self):
self.num_classes = 32
def vae_loss(self, logits, target, mu, logvar, beta=0.1):
# logits: [B, 169, 16]
# target: [B, 169]
B, L = target.shape
end_token = torch.tensor([15], dtype=torch.long).to(logits.device).repeat(B, 1)
target = torch.cat([target, end_token], dim=1)
recon_loss = F.cross_entropy(logits.permute(0, 2, 1), target)
kl_loss = -0.5 * torch.mean(
1 + logvar - mu.pow(2) - logvar.exp()
)
return recon_loss + beta * kl_loss, recon_loss, kl_loss

View File

@ -1,43 +0,0 @@
import torch
class VAEScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau):
def __init__(
self, optimizer, mode="min", factor=0.1, patience=10, threshold=0.0001,
threshold_mode="rel", cooldown=0, min_lr=0, eps=1e-8, verbose="deprecated",
max_lr=1e-2, increase_factor=2, start_prob=0
):
super().__init__(
optimizer, mode, factor, patience, threshold,
threshold_mode, cooldown, min_lr, eps, verbose
)
self.max_lr = max_lr
self.increase_factor = increase_factor
self.last_prob = start_prob
if isinstance(max_lr, (list, tuple)):
if len(max_lr) != len(optimizer.param_groups):
raise ValueError(
f"expected {len(optimizer.param_groups)} max_lrs, got {len(max_lr)}"
)
self.default_max_lr = None
self.max_lrs = list(max_lr)
else:
self.default_max_lr = max_lr
self.max_lrs = [max_lr] * len(optimizer.param_groups)
def step(self, metrics, prob: float, epoch=None):
if prob > self.last_prob:
self.best = metrics
self.num_bad_epochs = 0
self.last_prob = prob
self._increase_lr()
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
else:
return super().step(metrics, epoch)
def _increase_lr(self):
for i, param_group in enumerate(self.optimizer.param_groups):
old_lr = float(param_group["lr"])
new_lr = min(old_lr * self.increase_factor, self.max_lrs[i])
if new_lr - old_lr > self.eps:
param_group["lr"] = new_lr

View File

@ -1,47 +0,0 @@
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from .encoder import VAEEncoder
from .decoder import VAEDecoder
from ..utils import print_memory
class GinkaVAE(nn.Module):
def __init__(self, device, tile_classes=32, latent_dim=32):
super().__init__()
self.encoder = VAEEncoder(device, tile_classes, latent_dim)
self.decoder = VAEDecoder(device, map_vec_dim=latent_dim)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, target_map: torch.Tensor, use_self_probility=0):
mu, logvar = self.encoder(target_map)
z = self.reparameterize(mu, logvar)
logits = self.decoder(z, target_map, use_self_probility)
return logits, mu, logvar
if __name__ == "__main__":
device = torch.device("cpu")
input = torch.randint(0, 32, [1, 13, 13]).to(device)
# 初始化模型
model = GinkaVAE(device).to(device)
print_memory("初始化后")
# 前向传播
start = time.perf_counter()
logits, mu, logvar = model(input)
end = time.perf_counter()
print_memory("前向传播后")
print(f"推理耗时: {end - start}")
print(f"输出形状: logits= {logits.shape}, mu={mu.shape}, logvar={logvar.shape}")
print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}")
print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -1,85 +0,0 @@
import os
import json
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .critic.model import MinamoModel
from .dataset import GinkaDataset
from .generator.loss import GinkaLoss
from .generator.model import GinkaModel
from shared.image import matrix_to_image_cv
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs('result/ginka_img', exist_ok=True)
def validate():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
model = GinkaModel()
state = torch.load("result/ginka.pth", map_location=device)["model_state"]
model.load_state_dict(state)
model.to(device)
for name, param in model.named_parameters():
print(f"Layer: {name}, Params: {param.numel()}")
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")
minamo = MinamoModel(32)
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
minamo.to(device)
# 准备数据集
val_dataset = GinkaDataset("ginka-eval.json", device, minamo)
val_loader = DataLoader(
val_dataset,
batch_size=32,
shuffle=True
)
criterion = GinkaLoss(minamo)
tile_dict = dict()
val_output = dict()
for file in os.listdir('tiles'):
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
minamo.eval()
model.eval()
val_loss = 0
idx = 0
with torch.no_grad():
for batch in tqdm(val_loader):
# 数据迁移到设备
target = batch["target"].to(device)
target_vision_feat = batch["target_vision_feat"].to(device)
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
# 前向传播
output, output_softmax = model(feat_vec)
map_matrix = torch.argmax(output, dim=1)
for matrix in map_matrix[:].cpu():
image = matrix_to_image_cv(matrix.numpy(), tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
val_output[f"val_{idx}"] = matrix.tolist()
idx += 1
# 计算损失
_, loss = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
tqdm.write(f"Validation::loss: {avg_val_loss:.6f}")
with open('result/ginka_val.json', 'w') as f:
json.dump(val_output, f)
if __name__ == "__main__":
torch.set_num_threads(2)
validate()

View File

@ -1,12 +0,0 @@
import argparse
def parse_arguments(from_default: str, train_default: str, val_default: str):
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--from_state", type=str, default=from_default)
parser.add_argument("--load_optim", type=bool, default=False)
parser.add_argument("--train", type=str, default=train_default)
parser.add_argument("--validate", type=str, default=val_default)
parser.add_argument("--epochs", type=int, default=150)
args = parser.parse_args()
return args

View File

@ -1,73 +0,0 @@
import torch
import torch.nn as nn
class ChannelAttention(nn.Module):
"""通道注意力模块"""
def __init__(self, channels, reduction=8):
super().__init__()
# 通道注意力
self.channel_att = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels//reduction, 1),
nn.ELU(),
nn.Conv2d(channels//reduction, channels, 1),
nn.Sigmoid()
)
def forward(self, x):
# 通道注意力
c_att = self.channel_att(x)
x = x * c_att
return x
class SpatialAttention(nn.Module):
"""空间注意力模块"""
def __init__(self):
super().__init__()
# 空间注意力
self.spatial_att = nn.Sequential(
nn.Conv2d(2, 1, 7, padding=3),
nn.Sigmoid()
)
def forward(self, x):
# 空间注意力
max_pool = torch.max(x, dim=1, keepdim=True)[0]
avg_pool = torch.mean(x, dim=1, keepdim=True)
s_att = self.spatial_att(torch.cat([max_pool, avg_pool], dim=1))
return x * s_att
class CBAM(nn.Module):
"""通道与空间注意力结合"""
def __init__(self, channels, reduction=8):
super().__init__()
# 通道注意力
self.channel_att = ChannelAttention(channels, reduction)
# 空间注意力
self.spatial_att = SpatialAttention()
def forward(self, x):
# 通道注意力
c_att = self.channel_att(x)
x = x * c_att
# 空间注意力
s_att = self.spatial_att(x)
return x * s_att
class SEBlock(nn.Module):
def __init__(self, channel, reduction=4):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.GELU(),
nn.Linear(channel // reduction, channel),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y

View File

@ -1,6 +0,0 @@
VIS_DIM = 512
TOPO_DIM = 512
FEAT_DIM = 1024
VISION_WEIGHT = 0.2
TOPO_WEIGHT = 0.8

View File

@ -1,66 +0,0 @@
import torch
from torch_geometric.data import Data, Batch
from torch_geometric.utils import add_self_loops, grid
def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
"""
可导的图结构转换返回 PyG Data 对象
map_probs: [C, H, W]
返回
Data(x=[N, C], edge_index=[2, E], edge_attr=[E, C])
"""
C, H, W = map_probs.shape
device = map_probs.device
N = H * W
# 1. 节点特征
node_features = map_probs.view(C, H * W).T # [N, C]
edge_index, _ = grid(H, W)
edge_index = edge_index.to(device)
# 2. 构建所有可能的边连接
# node_indices = torch.arange(N, device=device).view(H, W)
# # 水平连接(右邻居)
# right_src = node_indices[:, :-1].flatten()
# right_dst = node_indices[:, 1:].flatten()
# # 垂直连接(下邻居)
# down_src = node_indices[:-1, :].flatten()
# down_dst = node_indices[1:, :].flatten()
# # 合并边列表(双向)
# edge_src = torch.cat([right_src, down_src])
# edge_dst = torch.cat([right_dst, down_dst])
# edge_index = torch.cat([
# torch.stack([edge_src, edge_dst], dim=0),
# torch.stack([edge_dst, edge_src], dim=0) # 反向连接
# ], dim=1).to(device, dtype=torch.long)
# # 3. 计算边特征
# src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C]
# dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C]
# edge_attr = (src_feat + dst_feat) / 2 # [E, C]
# edge_index, edge_attr = add_self_loops(edge_index, edge_attr)
return Data(
x=node_features,
edge_index=edge_index,
# edge_attr=edge_attr,
num_nodes=N
)
def batch_convert_soft_map_to_graph(batch_map_probs):
"""
处理 batch 维度 [B, C, H, W] 转换为批量图结构 Batch
"""
B, C, H, W = batch_map_probs.shape # 获取 batch 维度
batch_graphs = []
for i in range(B):
graph = differentiable_convert_to_data(batch_map_probs[i]) # 处理单个样本
batch_graphs.append(graph)
# 合并所有图为批量 Batch
return Batch.from_data_list(batch_graphs)

View File

@ -1,273 +0,0 @@
# Converted Python version of the JS code
import math
from typing import Dict, Set, List, Tuple, Union
from collections import deque, defaultdict
# 拓扑相似度,由 ChatGPT-4o 从 ts 转译而来
class ResourceArea:
def __init__(self):
self.type = 'resource'
self.resources: Dict[int, int] = {}
self.members: Set[int] = set()
self.neighbor: Set[int] = set()
class BranchNode:
def __init__(self, tile: int):
self.type = 'branch'
self.tile = tile
self.neighbor: Set[int] = set()
class ResourceNode:
def __init__(self, resource_type: int, area: ResourceArea):
self.type = 'resource'
self.resourceType = resource_type
self.neighbor = area.neighbor
self.resourceArea = area
GinkaNode = Union[BranchNode, ResourceNode]
class GinkaGraph:
def __init__(self):
self.graph: Dict[int, GinkaNode] = {}
self.resourceMap: Dict[int, int] = {}
self.areaMap: List[ResourceArea] = []
self.visitedEntrance: Set[int] = set()
self.visited: Set[int] = set()
class GinkaTopologicalGraphs:
def __init__(self):
self.graphs: List[GinkaGraph] = []
self.entranceMap: Dict[int, GinkaGraph] = {}
self.unreachable: Set[int] = set()
TILE_TYPE = set(range(13))
BRANCH_TYPE = {6, 7, 8, 9}
ENTRANCE_TYPE = {10, 11}
RESOURCE_TYPE = {0, 2, 3, 4, 5, 10, 11, 12, 13}
directions: List[Tuple[int, int]] = [
(-1, 0), (1, 0), (0, -1), (0, 1)
]
def find_resource_nodes(map_: List[List[int]]):
width, height = len(map_[0]), len(map_)
visited = set()
areas = []
resource_map = {}
for ny in range(height):
for nx in range(width):
tile = map_[ny][nx]
index = ny * width + nx
if index in visited or tile not in RESOURCE_TYPE:
continue
queue = deque([(nx, ny)])
area = ResourceArea()
area.resources[tile] = 1
area.members.add(index)
while queue:
cx, cy = queue.popleft()
cindex = cy * width + cx
if cindex in visited:
continue
ctile = map_[cy][cx]
if ctile not in RESOURCE_TYPE:
continue
visited.add(cindex)
area.resources[ctile] = area.resources.get(ctile, 0) + 1
area.members.add(cindex)
resource_map[cindex] = len(areas)
for dx, dy in directions:
px, py = cx + dx, cy + dy
if 0 <= px < width and 0 <= py < height:
queue.append((px, py))
areas.append(area)
return areas, resource_map
def build_graph_from_entrance(map_: List[List[int]], entrance: int, resource_map: Dict[int, int], area_map: List[ResourceArea]) -> GinkaGraph:
width, height = len(map_[0]), len(map_)
graph = GinkaGraph()
graph.resourceMap = resource_map
graph.areaMap = area_map
visited = graph.visited
visited_entrance = graph.visitedEntrance
visited_entrance.add(entrance)
branch_nodes = set()
queue = deque([(entrance % width, entrance // width)])
while queue:
x, y = queue.popleft()
index = y * width + x
if index in visited:
continue
tile = map_[y][x]
if tile in ENTRANCE_TYPE:
visited_entrance.add(index)
if tile in BRANCH_TYPE:
branch_nodes.add(index)
visited.add(index)
for dx, dy in directions:
px, py = x + dx, y + dy
if 0 <= px < width and 0 <= py < height and map_[py][px] != 1:
queue.append((px, py))
for v in branch_nodes:
x, y = v % width, v // width
if v not in graph.graph:
graph.graph[v] = BranchNode(map_[y][x])
node = graph.graph[v]
for dx, dy in directions:
px, py = x + dx, y + dy
if 0 <= px < width and 0 <= py < height:
index = py * width + px
if index in branch_nodes:
node.neighbor.add(index)
elif index in resource_map:
area = area_map[resource_map[index]]
area.neighbor.add(v)
for m in area.members:
node.neighbor.add(m)
for area in area_map:
for index in area.members:
x, y = index % width, index // width
tile = map_[y][x]
if tile == 0:
continue
node = ResourceNode(tile, area)
graph.graph[index] = node
return graph
def build_topological_graph(map_: List[List[int]]) -> GinkaTopologicalGraphs:
width, height = len(map_[0]), len(map_)
entrances = set()
entrances = {y * width + x for y in range(height) for x in range(width) if map_[y][x] in ENTRANCE_TYPE}
area_map, resource_map = find_resource_nodes(map_)
top_graph = GinkaTopologicalGraphs()
used_entrance = set()
total_visited = set()
for entrance in entrances:
if entrance in used_entrance:
continue
graph = build_graph_from_entrance(map_, entrance, resource_map, area_map)
top_graph.graphs.append(graph)
for ent in graph.visitedEntrance:
used_entrance.add(ent)
top_graph.entranceMap[ent] = graph
total_visited.update(graph.visited)
for y in range(height):
for x in range(width):
index = y * width + x
if index not in total_visited and map_[y][x] != 1:
top_graph.unreachable.add(index)
return top_graph
class WLNode:
def __init__(self, pos: int, label: str):
self.originalPos = pos
self.originalLabel = label
self.currentLabel = label
self.neighbors: List['WLNode'] = []
def encode_node_labels(graph: GinkaGraph) -> List[WLNode]:
node_map = {}
nodes = []
for pos, node in graph.graph.items():
label = f"B:{node.tile}" if node.type == 'branch' else f"R:{node.resourceType}"
wl_node = WLNode(pos, label)
node_map[pos] = wl_node
nodes.append(wl_node)
for node in nodes:
g_node = graph.graph[node.originalPos]
for neighbor in g_node.neighbor:
if neighbor in node_map:
node.neighbors.append(node_map[neighbor])
return nodes
def weisfeiler_lehman_iteration(nodes: List[WLNode], iterations: int, decay: float = 0.6) -> Dict[str, float]:
label_history = []
for _ in range(iterations):
new_labels = []
for node in nodes:
neighbor_labels = sorted(n.currentLabel for n in node.neighbors)
composite = f"{node.currentLabel}|{','.join(neighbor_labels)}"[:8192]
new_labels.append(composite)
for node, new_label in zip(nodes, new_labels):
node.currentLabel = new_label
label_history.append(new_labels[:])
weight = 1.0
label_counts = defaultdict(float)
for layer in label_history:
for label in layer:
label_counts[label] += weight
weight *= decay
for node in nodes:
label_counts[node.originalLabel] += weight
return dict(label_counts)
def vectorize_features(features: Dict[str, float], vocab: List[str]) -> List[float]:
vec = [0.0] * len(vocab)
for label, count in features.items():
if label in vocab:
idx = vocab.index(label)
vec[idx] += count
return vec
def cosine_similarity(a: List[float], b: List[float]) -> float:
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(y * y for y in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
def wl_kernel(graph_a: GinkaGraph, graph_b: GinkaGraph, iterations: int = 3) -> float:
nodes_a = encode_node_labels(graph_a)
nodes_b = encode_node_labels(graph_b)
features_a = weisfeiler_lehman_iteration(nodes_a, iterations)
features_b = weisfeiler_lehman_iteration(nodes_b, iterations)
vocab = list(set(features_a.keys()) | set(features_b.keys()))
vec_a = vectorize_features(features_a, vocab)
vec_b = vectorize_features(features_b, vocab)
return cosine_similarity(vec_a, vec_b)
def overall_similarity(a: GinkaTopologicalGraphs, b: GinkaTopologicalGraphs) -> float:
graphs_a = a.graphs
graphs_b = b.graphs
total_similarity = 0.0
compared_graphs: Set[GinkaGraph] = set()
for ga in graphs_a:
max_similarity = 0.0
max_graph = None
for gb in graphs_b:
if gb in compared_graphs:
continue
min_nodes = min(len(ga.graph), len(gb.graph))
iterations = max(1, math.ceil(math.log(min_nodes)))
similarity = wl_kernel(ga, gb, iterations)
if similarity > max_similarity and not math.isnan(similarity):
max_similarity = similarity
max_graph = gb
if similarity == 1:
break
total_similarity += max_similarity
if max_graph:
compared_graphs.add(max_graph)
reduction = 1 / (1 + abs(len(a.unreachable) - len(b.unreachable)))
if not graphs_a:
return 0.0
return math.sqrt(total_similarity / len(graphs_a)) * reduction

View File

@ -1,75 +0,0 @@
from typing import List, Dict
import math
import numpy as np
# 视觉相似度,由 ChatGPT-4o 从 ts 转译而来
class VisualSimilarityConfig:
def __init__(self):
self.type_weights: Dict[int, float] = {
0: 0.2, 1: 0.3, 2: 0.6, 3: 0.7, 4: 0.7, 5: 0.5,
6: 0.4, 7: 0.5, 8: 0.6, 9: 0.6, 10: 0.4, 11: 0.4, 12: 0.7
}
self.enable_visual_focus: bool = True
self.enable_density_awareness: bool = True
def generate_focus_weights(rows: int, cols: int) -> List[List[float]]:
weights = []
center_x = cols / 2
center_y = rows / 2
for i in range(rows):
row_weights = []
for j in range(cols):
dx = (j - center_x) / cols
dy = (i - center_y) / rows
distance = math.sqrt(dx ** 2 + dy ** 2)
gaussian = math.exp(-(distance ** 2) / (2 * 0.3 ** 2))
row_weights.append(1.0 + 0.6 * gaussian)
weights.append(row_weights)
return weights
def calculate_density_impact(map1: List[List[int]], map2: List[List[int]], type_weights: Dict[int, float]) -> List[List[float]]:
rows, cols = len(map1), len(map1[0])
density_map = [[0.0 for _ in range(cols)] for _ in range(rows)]
window_size = 3
half_window = window_size // 2
for i in range(rows):
for j in range(cols):
density = 0
for di in range(-half_window, half_window + 1):
for dj in range(-half_window, half_window + 1):
ni, nj = i + di, j + dj
if 0 <= ni < rows and 0 <= nj < cols:
weight1 = type_weights.get(map1[ni][nj], 0.5)
weight2 = type_weights.get(map2[ni][nj], 0.5)
density += (weight1 + weight2) / 2
density_map[i][j] = 1.0 + 0.4 * (density / (window_size ** 2))
return density_map
def calculate_visual_similarity(map1: List[List[int]], map2: List[List[int]], config: VisualSimilarityConfig = None) -> float:
if config is None:
config = VisualSimilarityConfig()
if len(map1) != len(map2) or len(map1[0]) != len(map2[0]):
return 0.0
rows, cols = len(map1), len(map1[0])
total_score = 0.0
max_possible_score = 0.0
focus_weights = generate_focus_weights(rows, cols) if config.enable_visual_focus else [[1.0 for _ in range(cols)] for _ in range(rows)]
density_map = calculate_density_impact(map1, map2, config.type_weights) if config.enable_density_awareness else [[1.0 for _ in range(cols)] for _ in range(rows)]
for i in range(rows):
for j in range(cols):
type1 = map1[i][j]
type2 = map2[i][j]
base_weight = max(config.type_weights.get(type1, 0.5), config.type_weights.get(type2, 0.5))
spatial_weight = focus_weights[i][j] * density_map[i][j]
type_score = 1.0 if type1 == type2 else 0.0
total_score += type_score * base_weight * spatial_weight
max_possible_score += base_weight * spatial_weight
return total_score / max_possible_score if max_possible_score > 0 else 0.0

View File

@ -1,17 +0,0 @@
import torch
def random_smooth_onehot(onehot_map, min_main=0.75, max_main=1.0, epsilon=0.25):
"""
生成随机平滑的 one-hot 编码使主类别概率不再固定而是随机波动
"""
C, H, W = onehot_map.shape
# 生成主类别的随机概率 (min_main, max_main)
main_prob = torch.rand(H, W) * (max_main - min_main) + min_main
# 计算剩余概率并随机分配到其他类别
noise = torch.rand(C, H, W) * epsilon # 随机噪声
noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon
# 计算最终平滑 one-hot 结果
smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise
return smooth_onehot

View File

@ -63,7 +63,7 @@ def convert_dataset_to_images(
if __name__ == "__main__": if __name__ == "__main__":
convert_dataset_to_images( convert_dataset_to_images(
json_path="data/result.json", # 数据集文件 json_path="data/result.json", # 数据集文件
tile_folder="tiles2", # 贴图文件夹 tile_folder="tiles", # 贴图文件夹
output_folder="map_images", # 输出文件夹 output_folder="map_images", # 输出文件夹
tile_size=32 # tile 尺寸 tile_size=32 # tile 尺寸
) )

Binary file not shown.

Before

Width:  |  Height:  |  Size: 406 B

After

Width:  |  Height:  |  Size: 699 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 291 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 341 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 396 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 289 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 344 B

After

Width:  |  Height:  |  Size: 414 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 419 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 305 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 358 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 441 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 847 B

After

Width:  |  Height:  |  Size: 426 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 442 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 448 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 436 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 448 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 643 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 389 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 353 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 382 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 453 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 699 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 426 B

After

Width:  |  Height:  |  Size: 368 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 323 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 311 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 312 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 308 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 420 B

After

Width:  |  Height:  |  Size: 406 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 422 B

After

Width:  |  Height:  |  Size: 396 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 678 B

After

Width:  |  Height:  |  Size: 419 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 368 B

After

Width:  |  Height:  |  Size: 441 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 365 B

After

Width:  |  Height:  |  Size: 448 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 377 B

After

Width:  |  Height:  |  Size: 353 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 414 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 576 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 699 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 414 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 426 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 368 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 406 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 396 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 419 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 441 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 448 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 353 B

View File

@ -1,8 +1,3 @@
# 从头训练 # MaskGIT
python3 -u -m ginka.train_wgan --epochs 20 --curr_epoch 1 --checkpoint 1 >> output.log python3 -u -m ginka.train_maskGIT --epochs 150 --checkpoint 10 >> output_maskGIT.log
# 接续训练 python3 -u -m ginka.train_maskGIT --resume true --epochs 150 --checkpoint 10 --state_ginka "result/transformer/ginka-100.pth" >> output_maskGIT.log
python3 -u -m ginka.train_wgan --resume true --epochs 300 --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log
# rnn
python3 -u -m ginka.train_rnn --epochs 150 --checkpoint 10 >> output_rnn.log
python3 -u -m ginka.train_rnn --resume true --epochs 150 --checkpoint 10 --state_ginka "result/rnn/ginka-100.pth" >> output_rnn.log