refactor: 仅保留 SOTA
72
README.md
@ -1,73 +1,3 @@
|
|||||||
# GINKA 地图生成器
|
# GINKA 地图生成器
|
||||||
|
|
||||||
GINKA Model 是一个用于生成网格状魔塔地图的模型,采用 UNet 网络。
|
Ginka 地图生成器是专门训练用来生成魔塔地图的工具,采用 `MaskGIT` 模型及生成方法。
|
||||||
|
|
||||||
GINKA Model 内部集成了 Minamo Model 用做判别器,与 Ginka Model 对抗训练,训练使用 Wasserstein GAN 训练方式。
|
|
||||||
|
|
||||||
## 贡献 GINKA Model 数据集
|
|
||||||
|
|
||||||
对于 HTML5 魔塔,如果你想要贡献数据集,需要对你的魔塔进行手动数据处理,流程如下:
|
|
||||||
|
|
||||||
1. 在 `project` 文件夹下创建 `ginka-config.json` 文件,双击进入编辑,粘贴如下模板:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"clip": {
|
|
||||||
"defaults": [0, 0, 13, 13],
|
|
||||||
"special": {}
|
|
||||||
},
|
|
||||||
"mapping": {
|
|
||||||
"redGem": {
|
|
||||||
"27": 1
|
|
||||||
},
|
|
||||||
"blueGem": {
|
|
||||||
"28": 1
|
|
||||||
},
|
|
||||||
"greenGem": {
|
|
||||||
"29": 1
|
|
||||||
},
|
|
||||||
"yellowGem": {
|
|
||||||
"30": 1
|
|
||||||
},
|
|
||||||
"item": {
|
|
||||||
"47": 1,
|
|
||||||
"49": 1,
|
|
||||||
"50": 0,
|
|
||||||
"51": 1,
|
|
||||||
"52": 1,
|
|
||||||
"53": 2
|
|
||||||
},
|
|
||||||
"potion": {
|
|
||||||
"31": 100,
|
|
||||||
"32": 200,
|
|
||||||
"33": 400,
|
|
||||||
"34": 800
|
|
||||||
},
|
|
||||||
"key": {
|
|
||||||
"21": 0,
|
|
||||||
"22": 1,
|
|
||||||
"23": 2,
|
|
||||||
"24": 2,
|
|
||||||
"25": 2
|
|
||||||
},
|
|
||||||
"door": {
|
|
||||||
"81": 0,
|
|
||||||
"82": 1,
|
|
||||||
"83": 2,
|
|
||||||
"84": 2,
|
|
||||||
"85": 3,
|
|
||||||
"86": 2
|
|
||||||
},
|
|
||||||
"wall": [1, 17],
|
|
||||||
"decoration": [],
|
|
||||||
"floor": [87, 88],
|
|
||||||
"arrow": [91, 92, 93, 94]
|
|
||||||
},
|
|
||||||
"data": {}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
其中,`clip` 属性表示你的每张地图的那一部分会被当成数据集,例如填写 `[0, 0, 13, 13]` 就会让坐标为 `(0, 0)`,长宽为 `(13, 13)` 的矩形内容作为数据集。`special` 不用管。注意装饰所使用的贴图是白墙,如果白墙是墙壁的话,需要将白墙设置为墙壁。注意不要忘记保存
|
|
||||||
|
|
||||||
2. 使用 [在线工具](https://unanmed.github.io/ginka-process) 处理数据,需要给每个地图添加标签,为每个图块分配种类,有一些图块包含多种等级,需要填写正确。
|
|
||||||
3. 将 `project` 文件夹打包发给我
|
|
||||||
|
|||||||
@ -1,295 +0,0 @@
|
|||||||
import { GinkaConfig } from './types';
|
|
||||||
|
|
||||||
const numMap: Record<number, number> = {
|
|
||||||
0: 0,
|
|
||||||
1: 1,
|
|
||||||
2: 2,
|
|
||||||
91: 30,
|
|
||||||
92: 30,
|
|
||||||
93: 30,
|
|
||||||
94: 30,
|
|
||||||
87: 29,
|
|
||||||
88: 29
|
|
||||||
};
|
|
||||||
|
|
||||||
export interface Enemy {
|
|
||||||
num: number;
|
|
||||||
hp: number;
|
|
||||||
atk: number;
|
|
||||||
def: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
function convert(
|
|
||||||
map: number[][],
|
|
||||||
[x, y, w, h]: [number, number, number, number],
|
|
||||||
config: GinkaConfig,
|
|
||||||
enemyMap: Record<number, Enemy>
|
|
||||||
) {
|
|
||||||
const clipped: number[][] = [];
|
|
||||||
|
|
||||||
// 1. 裁剪
|
|
||||||
for (let ny = y; ny < y + w; ny++) {
|
|
||||||
const row: number[] = [];
|
|
||||||
for (let nx = y; nx < x + h; nx++) {
|
|
||||||
row.push(map[ny][nx]);
|
|
||||||
}
|
|
||||||
clipped.push(row);
|
|
||||||
}
|
|
||||||
|
|
||||||
const res: number[][] = Array.from({ length: clipped.length }, () =>
|
|
||||||
Array.from({ length: clipped[0].length }, () => 0)
|
|
||||||
);
|
|
||||||
|
|
||||||
// 2. 初步映射
|
|
||||||
for (let nx = 0; nx < w; nx++) {
|
|
||||||
for (let ny = 0; ny < h; ny++) {
|
|
||||||
const tile = clipped[ny][nx];
|
|
||||||
if (numMap[tile] !== void 0) {
|
|
||||||
res[ny][nx] = numMap[tile];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 转换一般图块
|
|
||||||
const mapping: Record<number, number> = {};
|
|
||||||
const dict = config.mapping;
|
|
||||||
dict.wall.forEach(v => (mapping[v] = 1));
|
|
||||||
dict.decoration.forEach(v => (mapping[v] = 2));
|
|
||||||
dict.floor.forEach(v => (mapping[v] = 29));
|
|
||||||
dict.arrow.forEach(v => (mapping[v] = 30));
|
|
||||||
for (let nx = 0; nx < w; nx++) {
|
|
||||||
for (let ny = 0; ny < h; ny++) {
|
|
||||||
const tile = clipped[ny][nx];
|
|
||||||
if (mapping[tile] !== void 0) res[ny][nx] = mapping[tile];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. 转换含等级图块
|
|
||||||
const redGemSet = new Set<number>();
|
|
||||||
const blueGemSet = new Set<number>();
|
|
||||||
const greenGemSet = new Set<number>();
|
|
||||||
const potionSet = new Set<number>();
|
|
||||||
for (let nx = 0; nx < w; nx++) {
|
|
||||||
for (let ny = 0; ny < h; ny++) {
|
|
||||||
const tile = clipped[ny][nx];
|
|
||||||
if (dict.redGem[tile] !== void 0) {
|
|
||||||
redGemSet.add(dict.redGem[tile]);
|
|
||||||
} else if (dict.blueGem[tile] !== void 0) {
|
|
||||||
blueGemSet.add(dict.blueGem[tile]);
|
|
||||||
} else if (dict.greenGem[tile] !== void 0) {
|
|
||||||
greenGemSet.add(dict.greenGem[tile]);
|
|
||||||
} else if (dict.yellowGem[tile] !== void 0) {
|
|
||||||
redGemSet.add(dict.yellowGem[tile]);
|
|
||||||
blueGemSet.add(dict.yellowGem[tile]);
|
|
||||||
greenGemSet.add(dict.yellowGem[tile]);
|
|
||||||
} else if (dict.potion[tile] !== void 0) {
|
|
||||||
potionSet.add(dict.potion[tile]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const minRedGem = Math.min(...redGemSet);
|
|
||||||
const maxRedGem = Math.max(...redGemSet);
|
|
||||||
const minBlueGem = Math.min(...blueGemSet);
|
|
||||||
const maxBlueGem = Math.max(...blueGemSet);
|
|
||||||
const minGreenGem = Math.min(...greenGemSet);
|
|
||||||
const maxGreenGem = Math.max(...greenGemSet);
|
|
||||||
const minPotion = Math.min(...potionSet);
|
|
||||||
const maxPotion = Math.max(...potionSet);
|
|
||||||
|
|
||||||
for (let nx = 0; nx < w; nx++) {
|
|
||||||
for (let ny = 0; ny < h; ny++) {
|
|
||||||
const tile = clipped[ny][nx];
|
|
||||||
if (dict.redGem[tile] !== void 0) {
|
|
||||||
const value = dict.redGem[tile];
|
|
||||||
if (maxRedGem - minRedGem < 1e-8) {
|
|
||||||
res[ny][nx] = 10;
|
|
||||||
} else {
|
|
||||||
const level = Math.min(
|
|
||||||
Math.floor(
|
|
||||||
((value - minRedGem) / (maxRedGem - minRedGem)) * 3
|
|
||||||
),
|
|
||||||
2
|
|
||||||
);
|
|
||||||
res[ny][nx] = 10 + level;
|
|
||||||
}
|
|
||||||
} else if (dict.blueGem[tile] !== void 0) {
|
|
||||||
const value = dict.blueGem[tile];
|
|
||||||
if (maxBlueGem - minBlueGem < 1e-8) {
|
|
||||||
res[ny][nx] = 13;
|
|
||||||
} else {
|
|
||||||
const level = Math.min(
|
|
||||||
Math.floor(
|
|
||||||
((value - minBlueGem) / (maxBlueGem - minBlueGem)) *
|
|
||||||
3
|
|
||||||
),
|
|
||||||
2
|
|
||||||
);
|
|
||||||
res[ny][nx] = 13 + level;
|
|
||||||
}
|
|
||||||
} else if (dict.greenGem[tile] !== void 0) {
|
|
||||||
const value = dict.greenGem[tile];
|
|
||||||
if (maxGreenGem - minGreenGem < 1e-8) {
|
|
||||||
res[ny][nx] = 16;
|
|
||||||
} else {
|
|
||||||
const level = Math.min(
|
|
||||||
Math.floor(
|
|
||||||
((value - minGreenGem) /
|
|
||||||
(maxGreenGem - minGreenGem)) *
|
|
||||||
3
|
|
||||||
),
|
|
||||||
2
|
|
||||||
);
|
|
||||||
res[ny][nx] = 16 + level;
|
|
||||||
}
|
|
||||||
} else if (dict.yellowGem[tile] !== void 0) {
|
|
||||||
const rand = Math.random();
|
|
||||||
const value = dict.yellowGem[tile];
|
|
||||||
if (rand < 2 / 5) {
|
|
||||||
if (maxRedGem - minRedGem < 1e-8) {
|
|
||||||
res[ny][nx] = 10;
|
|
||||||
} else {
|
|
||||||
const level = Math.min(
|
|
||||||
Math.floor(
|
|
||||||
((value - minRedGem) /
|
|
||||||
(maxRedGem - minRedGem)) *
|
|
||||||
3
|
|
||||||
),
|
|
||||||
2
|
|
||||||
);
|
|
||||||
res[ny][nx] = 10 + level;
|
|
||||||
}
|
|
||||||
} else if (rand < 4 / 5) {
|
|
||||||
if (maxBlueGem - minBlueGem < 1e-8) {
|
|
||||||
res[ny][nx] = 13;
|
|
||||||
} else {
|
|
||||||
const level = Math.min(
|
|
||||||
Math.floor(
|
|
||||||
((value - minBlueGem) /
|
|
||||||
(maxBlueGem - minBlueGem)) *
|
|
||||||
3
|
|
||||||
),
|
|
||||||
2
|
|
||||||
);
|
|
||||||
res[ny][nx] = 13 + level;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (maxGreenGem - minGreenGem < 1e-8) {
|
|
||||||
res[ny][nx] = 16;
|
|
||||||
} else {
|
|
||||||
const level = Math.min(
|
|
||||||
Math.floor(
|
|
||||||
((value - minGreenGem) /
|
|
||||||
(maxGreenGem - minGreenGem)) *
|
|
||||||
3
|
|
||||||
),
|
|
||||||
2
|
|
||||||
);
|
|
||||||
res[ny][nx] = 16 + level;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (dict.potion[tile] !== void 0) {
|
|
||||||
const value = dict.potion[tile];
|
|
||||||
if (maxPotion - minPotion < 1e-8) {
|
|
||||||
res[ny][nx] = 19;
|
|
||||||
} else {
|
|
||||||
const level = Math.min(
|
|
||||||
Math.floor(
|
|
||||||
((value - minPotion) / (maxPotion - minPotion)) * 4
|
|
||||||
),
|
|
||||||
3
|
|
||||||
);
|
|
||||||
res[ny][nx] = 19 + level;
|
|
||||||
}
|
|
||||||
} else if (dict.door[tile] !== void 0) {
|
|
||||||
const level = dict.door[tile];
|
|
||||||
res[ny][nx] = 3 + level;
|
|
||||||
} else if (dict.key[tile] !== void 0) {
|
|
||||||
const level = dict.key[tile];
|
|
||||||
res[ny][nx] = 7 + level;
|
|
||||||
} else if (dict.item[tile] !== void 0) {
|
|
||||||
const level = dict.item[tile];
|
|
||||||
res[ny][nx] = 22 + level;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 5. 转换怪物
|
|
||||||
const enemySet = new Set<Enemy>();
|
|
||||||
for (let nx = 0; nx < w; nx++) {
|
|
||||||
for (let ny = 0; ny < h; ny++) {
|
|
||||||
const tile = clipped[ny][nx];
|
|
||||||
const enemy = enemyMap[tile];
|
|
||||||
if (!enemy) continue;
|
|
||||||
enemySet.add({ ...enemy, num: tile });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const enemyArr = [...enemySet];
|
|
||||||
enemyArr.sort((a, b) => a.num - b.num);
|
|
||||||
|
|
||||||
const attrs = [...enemySet].map(v => (v.atk + v.def) * v.hp);
|
|
||||||
const maxAttr = Math.max(...attrs);
|
|
||||||
const minAttr = Math.min(...attrs);
|
|
||||||
const delta = maxAttr - minAttr;
|
|
||||||
for (let ny = 0; ny < w; ny++) {
|
|
||||||
for (let nx = 0; nx < h; nx++) {
|
|
||||||
const tile = clipped[ny][nx];
|
|
||||||
const enemy = enemyMap[tile];
|
|
||||||
if (!enemy) continue;
|
|
||||||
// 替换为弱怪/中怪/强怪
|
|
||||||
const attr = (enemy.atk + enemy.def) * enemy.hp;
|
|
||||||
const ad = attr - minAttr;
|
|
||||||
if (ad < delta / 3 || delta === 0) {
|
|
||||||
res[ny][nx] = 26;
|
|
||||||
} else if (ad < (delta * 2) / 3) {
|
|
||||||
res[ny][nx] = 27;
|
|
||||||
} else {
|
|
||||||
res[ny][nx] = 28;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function convertFloor(
|
|
||||||
map: number[][],
|
|
||||||
clip: [number, number, number, number],
|
|
||||||
config: GinkaConfig,
|
|
||||||
enemyMap: Record<number, Enemy>
|
|
||||||
) {
|
|
||||||
return convert(map, clip, config, enemyMap);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getCount(map: number[][], tiles: number[]) {
|
|
||||||
let n = 0;
|
|
||||||
map.flat().forEach(v => {
|
|
||||||
if (tiles.includes(v)) n++;
|
|
||||||
});
|
|
||||||
return n;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getRatio(map: number[][], tiles: number[]) {
|
|
||||||
const area = map.length * map[0].length;
|
|
||||||
return getCount(map, tiles) / area;
|
|
||||||
}
|
|
||||||
|
|
||||||
function range(from: number, to: number) {
|
|
||||||
const length = to - from;
|
|
||||||
return Array.from({ length }, (_, i) => i + from);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function getGinkaRatio(map: number[][]): number[] {
|
|
||||||
const arr: number[] = Array(16).fill(0);
|
|
||||||
arr[0] = getRatio(map, [1, ...range(3, 32)]);
|
|
||||||
arr[1] = getRatio(map, [1]);
|
|
||||||
arr[2] = getRatio(map, [2]);
|
|
||||||
arr[3] = getRatio(map, [3, 4, 5, 6]);
|
|
||||||
arr[4] = getRatio(map, [26, 27, 28]);
|
|
||||||
arr[5] = getRatio(map, range(7, 26));
|
|
||||||
arr[6] = getRatio(map, range(10, 19));
|
|
||||||
arr[7] = getRatio(map, range(19, 23));
|
|
||||||
arr[8] = getRatio(map, [7, 8, 9]);
|
|
||||||
arr[9] = getCount(map, [23, 24, 25]);
|
|
||||||
arr[10] = getCount(map, [29, 30]);
|
|
||||||
return arr;
|
|
||||||
}
|
|
||||||
177
data/src/gan.ts
@ -1,177 +0,0 @@
|
|||||||
import { createConnection, Socket } from 'net';
|
|
||||||
import { chooseFrom, FloorData, readOne } from './utils';
|
|
||||||
import { MinamoTrainData } from './types';
|
|
||||||
import { generateTrainData } from './process/minamo';
|
|
||||||
|
|
||||||
const SOCKET_FILE = '../tmp/ginka_uds';
|
|
||||||
const [refer, replayPath = '../datasets/replay.bin'] = process.argv.slice(2);
|
|
||||||
|
|
||||||
let id = 0;
|
|
||||||
|
|
||||||
function readMap(count: number, arr: number[], h: number, w: number) {
|
|
||||||
const area = w * h;
|
|
||||||
|
|
||||||
const maps: number[][][] = Array.from<number[][]>({
|
|
||||||
length: count
|
|
||||||
}).map(() => {
|
|
||||||
return Array.from<number[]>({ length: h }).map(() => {
|
|
||||||
return Array.from<number>({ length: w }).fill(0);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
arr.forEach((v, i) => {
|
|
||||||
const n = Math.floor(i / area);
|
|
||||||
const y = Math.floor((i % area) / w);
|
|
||||||
const x = i % w;
|
|
||||||
maps[n][y][x] = v;
|
|
||||||
});
|
|
||||||
|
|
||||||
return maps;
|
|
||||||
}
|
|
||||||
|
|
||||||
function generateGANData(
|
|
||||||
keys: string[],
|
|
||||||
refer: Map<string, FloorData>,
|
|
||||||
map: number[][]
|
|
||||||
) {
|
|
||||||
const id2 = `$${id++}`;
|
|
||||||
const toTrain = chooseFrom(keys, 4);
|
|
||||||
const data = toTrain.map<MinamoTrainData[]>(v => {
|
|
||||||
const floor = refer.get(v);
|
|
||||||
if (!floor) return [];
|
|
||||||
const size1: [number, number] = [floor.map[0].length, floor.map.length];
|
|
||||||
const size2: [number, number] = [map[0].length, map.length];
|
|
||||||
if (size1[0] !== size2[0] || size1[1] !== size2[1]) return [];
|
|
||||||
|
|
||||||
return generateTrainData(v, id2, floor.map, map, size1, false, false, false);
|
|
||||||
});
|
|
||||||
return data.flat();
|
|
||||||
}
|
|
||||||
|
|
||||||
const enum ReceiverStatus {
|
|
||||||
Header,
|
|
||||||
Content
|
|
||||||
}
|
|
||||||
|
|
||||||
class DataReceiver {
|
|
||||||
static active?: DataReceiver
|
|
||||||
/** 接收状态 */
|
|
||||||
private status: ReceiverStatus = ReceiverStatus.Header;
|
|
||||||
|
|
||||||
private received: number[] = []
|
|
||||||
private count: number = 0;
|
|
||||||
private h: number = 0;
|
|
||||||
private w: number = 0;
|
|
||||||
|
|
||||||
receive(buf: Buffer): [number[][][], number, number, number] | null {
|
|
||||||
// 数据通讯 node 输入协议,单位字节:
|
|
||||||
// 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
|
|
||||||
switch (this.status) {
|
|
||||||
case ReceiverStatus.Header:
|
|
||||||
this.count = buf.readInt16BE();
|
|
||||||
this.h = buf.readInt8(2);
|
|
||||||
this.w = buf.readInt8(3);
|
|
||||||
this.received.push(...buf.subarray(4));
|
|
||||||
this.status = ReceiverStatus.Content;
|
|
||||||
break;
|
|
||||||
case ReceiverStatus.Content:
|
|
||||||
this.received.push(...buf);
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if (this.received.length === this.count * this.h * this.w) {
|
|
||||||
delete DataReceiver.active;
|
|
||||||
return [readMap(this.count, this.received, this.h, this.w), this.count, this.h, this.w];
|
|
||||||
} else {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static check(buf: Buffer) {
|
|
||||||
if (this.active) {
|
|
||||||
return this.active.receive(buf);
|
|
||||||
} else {
|
|
||||||
this.active = new DataReceiver();
|
|
||||||
return this.active.receive(buf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
(async () => {
|
|
||||||
const referTower = await readOne(refer);
|
|
||||||
const keys = [...referTower.keys()];
|
|
||||||
|
|
||||||
const client = createConnection(SOCKET_FILE, () => {
|
|
||||||
console.log(`UDS IPC connected successfully.`);
|
|
||||||
});
|
|
||||||
|
|
||||||
client.on('data', async buffer => {
|
|
||||||
const data = DataReceiver.check(buffer);
|
|
||||||
if (!data) return;
|
|
||||||
|
|
||||||
const [map, count, h, w] = data;
|
|
||||||
const simData = map.map(v => generateGANData(keys, referTower, v));
|
|
||||||
const rc = 0;
|
|
||||||
const compareData = simData.flat();
|
|
||||||
|
|
||||||
// 数据通讯 node 输出协议,单位字节:
|
|
||||||
// 2 - Tensor count; 2 - Replay count. Replay is right behind train data;
|
|
||||||
// 1*tc - Compare count for every map tensor delivered.
|
|
||||||
// 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo;
|
|
||||||
// N*1*H*W - Compare map for every map tensor. rc*2*H*W - Replay map tensor.
|
|
||||||
const toSend = Buffer.alloc(
|
|
||||||
2 + // Tensor count
|
|
||||||
2 + // Replay count
|
|
||||||
1 * count + // Compare count
|
|
||||||
2 * 4 * (compareData.length + rc) + // Similarity data
|
|
||||||
compareData.length * 1 * h * w + // Compare map
|
|
||||||
rc * 2 * h * w, // Replay map
|
|
||||||
0
|
|
||||||
);
|
|
||||||
console.log(
|
|
||||||
2,
|
|
||||||
2,
|
|
||||||
count,
|
|
||||||
2 * 4 * (compareData.length + rc),
|
|
||||||
compareData.length * 1 * h * w,
|
|
||||||
rc * 2 * h * w,
|
|
||||||
compareData.length,
|
|
||||||
rc
|
|
||||||
);
|
|
||||||
|
|
||||||
let offset = 0;
|
|
||||||
toSend.writeInt16BE(count); // Tensor count
|
|
||||||
toSend.writeInt16BE(0, 2); // Replay count
|
|
||||||
offset += 2 + 2;
|
|
||||||
// Compare count
|
|
||||||
toSend.set(
|
|
||||||
simData.map(v => v.length),
|
|
||||||
offset
|
|
||||||
);
|
|
||||||
offset += 1 * count;
|
|
||||||
// Similarity data
|
|
||||||
compareData.forEach(v => {
|
|
||||||
// console.log(v.visionSimilarity, v.topoSimilarity);
|
|
||||||
|
|
||||||
toSend.writeFloatBE(v.visionSimilarity, offset);
|
|
||||||
offset += 4;
|
|
||||||
toSend.writeFloatBE(v.topoSimilarity, offset);
|
|
||||||
offset += 4;
|
|
||||||
});
|
|
||||||
// Compare map
|
|
||||||
toSend.set(
|
|
||||||
new Uint8Array(compareData.map(v => v.map1).flat(3)),
|
|
||||||
offset // Set from Compare map
|
|
||||||
);
|
|
||||||
offset += compareData.length * 1 * h * w;
|
|
||||||
|
|
||||||
client.write(toSend);
|
|
||||||
});
|
|
||||||
|
|
||||||
client.on('end', () => {
|
|
||||||
console.log(`Connection lose.`);
|
|
||||||
});
|
|
||||||
|
|
||||||
client.on('error', () => {
|
|
||||||
client.end();
|
|
||||||
});
|
|
||||||
})();
|
|
||||||
@ -1,16 +0,0 @@
|
|||||||
import { writeFile } from 'fs-extra';
|
|
||||||
import { getAllFloors, parseTowerInfo } from './utils';
|
|
||||||
import { parseGinka } from './process/ginka';
|
|
||||||
|
|
||||||
const [output, ...list] = process.argv.slice(2);
|
|
||||||
|
|
||||||
(async () => {
|
|
||||||
const towers = await Promise.all(
|
|
||||||
list.map(v => parseTowerInfo(v, 'ginka-config.json'))
|
|
||||||
);
|
|
||||||
const floors = await getAllFloors(...towers);
|
|
||||||
const results = parseGinka(floors);
|
|
||||||
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
|
|
||||||
const size = Object.keys(results.data).length;
|
|
||||||
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个地图`);
|
|
||||||
})();
|
|
||||||
@ -1,40 +0,0 @@
|
|||||||
import { writeFile } from 'fs-extra';
|
|
||||||
import { readOne, getAllFloors, parseTowerInfo } from './utils';
|
|
||||||
import { generateAssignedData, parseMinamo } from './process/minamo';
|
|
||||||
|
|
||||||
const [output, ...list] = process.argv.slice(2);
|
|
||||||
// 判断 assigned 模式,此模式下只会对前两个塔处理,会在这两个塔之间对比,而单个塔的地图不会对比
|
|
||||||
const assigned = list.at(-1)?.startsWith('assigned');
|
|
||||||
const assignedCount = parseAssigned(list.at(-1)!);
|
|
||||||
if (assigned) list.pop();
|
|
||||||
|
|
||||||
function parseAssigned(arg: string): [number, number] {
|
|
||||||
const p = arg.slice(9);
|
|
||||||
const [a, b] = p.split(':');
|
|
||||||
return [parseInt(a) || 100, parseInt(b) || 100];
|
|
||||||
}
|
|
||||||
|
|
||||||
(async () => {
|
|
||||||
if (!assigned) {
|
|
||||||
const towers = await Promise.all(
|
|
||||||
list.map(v => parseTowerInfo(v, 'minamo-config.json'))
|
|
||||||
);
|
|
||||||
const floors = await getAllFloors(...towers);
|
|
||||||
const results = parseMinamo(floors);
|
|
||||||
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
|
|
||||||
const size = Object.keys(results.data).length;
|
|
||||||
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`);
|
|
||||||
} else {
|
|
||||||
const [tower1, tower2] = list;
|
|
||||||
if (!tower1 || !tower2) {
|
|
||||||
console.log(`⚠️ assigned 模式下必须传入两个塔!`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const data1 = await readOne(tower1);
|
|
||||||
const data2 = await readOne(tower2);
|
|
||||||
const results = generateAssignedData(data1, data2, assignedCount);
|
|
||||||
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
|
|
||||||
const size = Object.keys(results.data).length;
|
|
||||||
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`);
|
|
||||||
}
|
|
||||||
})();
|
|
||||||
@ -1,36 +0,0 @@
|
|||||||
import { SingleBar, Presets } from 'cli-progress';
|
|
||||||
import { getGinkaRatio } from 'src/floor';
|
|
||||||
import { GinkaTrainData, GinkaConfig, GinkaDataset } from 'src/types';
|
|
||||||
import { FloorData } from 'src/utils';
|
|
||||||
|
|
||||||
export function parseGinka(data: Map<string, FloorData>) {
|
|
||||||
const resolved: Record<string, GinkaTrainData> = {};
|
|
||||||
|
|
||||||
const progress = new SingleBar({}, Presets.shades_classic);
|
|
||||||
progress.start(data.size, 0);
|
|
||||||
let i = 0;
|
|
||||||
|
|
||||||
data.forEach((floor, key) => {
|
|
||||||
const config = floor.config as GinkaConfig;
|
|
||||||
const data = config.data[floor.id] ?? {
|
|
||||||
tag: Array(64).fill(0)
|
|
||||||
};
|
|
||||||
resolved[key] = {
|
|
||||||
map: floor.map,
|
|
||||||
size: [floor.map[0].length, floor.map.length],
|
|
||||||
tag: data.tag,
|
|
||||||
val: getGinkaRatio(floor.map)
|
|
||||||
};
|
|
||||||
i++;
|
|
||||||
progress.update(i);
|
|
||||||
});
|
|
||||||
|
|
||||||
const dataset: GinkaDataset = {
|
|
||||||
datasetId: Math.floor(Math.random() * 1e12),
|
|
||||||
data: resolved
|
|
||||||
};
|
|
||||||
|
|
||||||
progress.stop();
|
|
||||||
|
|
||||||
return dataset;
|
|
||||||
}
|
|
||||||
@ -1,406 +0,0 @@
|
|||||||
import { SingleBar, Presets } from 'cli-progress';
|
|
||||||
import { compareMap } from 'src/topology/compare';
|
|
||||||
import { directions, tileType } from 'src/topology/graph';
|
|
||||||
import { rotateMap, mirrorMapX, mirrorMapY } from 'src/topology/transform';
|
|
||||||
import { MinamoTrainData, MinamoDataset } from 'src/types';
|
|
||||||
import { chooseFrom, FloorData } from 'src/utils';
|
|
||||||
import { calculateVisualSimilarity } from 'src/vision/similarity';
|
|
||||||
|
|
||||||
function chooseN(maxCount: number, n: number) {
|
|
||||||
return chooseFrom(
|
|
||||||
Array(maxCount)
|
|
||||||
.fill(0)
|
|
||||||
.map((_, i) => i),
|
|
||||||
n
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
function choosePair(n: number, max: number = 1000) {
|
|
||||||
const totalCount = Math.round((n * (n - 1)) / 2);
|
|
||||||
const count = Math.min(totalCount, max);
|
|
||||||
const pairs: number[] = [];
|
|
||||||
for (let i = 0; i < n; i++) {
|
|
||||||
for (let j = i + 1; j < n; j++) {
|
|
||||||
pairs.push(i * n + j);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 直接打乱后取前 count 个
|
|
||||||
for (let i = pairs.length - 1; i > 0; i--) {
|
|
||||||
let randIndex = Math.floor(Math.random() * (i + 1));
|
|
||||||
[pairs[i], pairs[randIndex]] = [pairs[randIndex], pairs[i]];
|
|
||||||
}
|
|
||||||
|
|
||||||
return pairs.slice(0, count);
|
|
||||||
}
|
|
||||||
|
|
||||||
function transform(map: number[][], rot: number, flip: number) {
|
|
||||||
let res = map;
|
|
||||||
for (let i = 0; i < rot; i++) {
|
|
||||||
res = rotateMap(res);
|
|
||||||
}
|
|
||||||
if (flip & 0b01) {
|
|
||||||
res = mirrorMapX(res);
|
|
||||||
}
|
|
||||||
if (flip & 0b10) {
|
|
||||||
res = mirrorMapY(res);
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
function generateTransformData(
|
|
||||||
id1: string,
|
|
||||||
id2: string,
|
|
||||||
map1: number[][],
|
|
||||||
map2: number[][],
|
|
||||||
simi: number
|
|
||||||
) {
|
|
||||||
const types: [rot: number, flip: number][] = [];
|
|
||||||
for (const rot of [0, 1, 2, 3]) {
|
|
||||||
for (const flip of [0b00, 0b01, 0b10, 0b11]) {
|
|
||||||
if (rot === 0 && flip === 0) continue;
|
|
||||||
types.push([rot, flip]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 随机抽取最多一个
|
|
||||||
const trans = chooseFrom(types, Math.floor(Math.random() * 1));
|
|
||||||
return trans
|
|
||||||
.map(([rot, flip]) => {
|
|
||||||
const com1 = `${id1}.${rot}.${flip}:${id1}`;
|
|
||||||
const com2 = `${id1}.${rot}.${flip}:${id2}`;
|
|
||||||
const com3 = `${id2}.${rot}.${flip}:${id1}`;
|
|
||||||
const com4 = `${id2}.${rot}.${flip}:${id2}`;
|
|
||||||
const choose = chooseFrom(
|
|
||||||
[com1, com2, com3, com4],
|
|
||||||
Math.floor(Math.random() * 2)
|
|
||||||
);
|
|
||||||
const res: [id: string, data: MinamoTrainData][] = [];
|
|
||||||
if (choose.includes(com1)) {
|
|
||||||
const t = transform(map1, rot, flip);
|
|
||||||
res.push([
|
|
||||||
com1,
|
|
||||||
{
|
|
||||||
map1: t,
|
|
||||||
map2: map1,
|
|
||||||
topoSimilarity: 1,
|
|
||||||
visionSimilarity: calculateVisualSimilarity(map1, t),
|
|
||||||
size: [map1[0].length, map1.length]
|
|
||||||
}
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
if (choose.includes(com2)) {
|
|
||||||
const t = transform(map1, rot, flip);
|
|
||||||
res.push([
|
|
||||||
com2,
|
|
||||||
{
|
|
||||||
map1: t,
|
|
||||||
map2: map2,
|
|
||||||
topoSimilarity: simi,
|
|
||||||
visionSimilarity: calculateVisualSimilarity(t, map2),
|
|
||||||
size: [map1[0].length, map1.length]
|
|
||||||
}
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
if (choose.includes(com3)) {
|
|
||||||
const t = transform(map2, rot, flip);
|
|
||||||
res.push([
|
|
||||||
com3,
|
|
||||||
{
|
|
||||||
map1: t,
|
|
||||||
map2: map1,
|
|
||||||
topoSimilarity: simi,
|
|
||||||
visionSimilarity: calculateVisualSimilarity(t, map1),
|
|
||||||
size: [map1[0].length, map1.length]
|
|
||||||
}
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
if (choose.includes(com4)) {
|
|
||||||
const t = transform(map2, rot, flip);
|
|
||||||
res.push([
|
|
||||||
com4,
|
|
||||||
{
|
|
||||||
map1: t,
|
|
||||||
map2: map2,
|
|
||||||
topoSimilarity: 1,
|
|
||||||
visionSimilarity: calculateVisualSimilarity(t, map2),
|
|
||||||
size: [map1[0].length, map1.length]
|
|
||||||
}
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
})
|
|
||||||
.flat();
|
|
||||||
}
|
|
||||||
|
|
||||||
function generateSimilarData(id: string, map: number[][]) {
|
|
||||||
// 生成最多两个微调地图
|
|
||||||
const width = map[0].length;
|
|
||||||
const height = map.length;
|
|
||||||
const num = Math.floor(Math.random() * 2);
|
|
||||||
const res: [id: string, data: MinamoTrainData][] = [];
|
|
||||||
|
|
||||||
for (let i = 0; i < num; i++) {
|
|
||||||
const clone = map.map(v => v.slice());
|
|
||||||
const prob = Math.random() * 0.3;
|
|
||||||
for (let ny = 0; ny < height; ny++) {
|
|
||||||
for (let nx = 0; nx < width; nx++) {
|
|
||||||
if (Math.random() > prob) {
|
|
||||||
// 有一定的概率进行微调
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (Math.random() < 0.2) {
|
|
||||||
// 20% 概率与旁边图块互换位置
|
|
||||||
const [dx, dy] =
|
|
||||||
directions[
|
|
||||||
Math.floor(Math.random() * directions.length)
|
|
||||||
];
|
|
||||||
const px = nx + dx;
|
|
||||||
const py = ny + dy;
|
|
||||||
if (px < 0 || px >= width || py < 0 || py >= height) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
[clone[ny][nx], clone[py][px]] = [
|
|
||||||
clone[py][px],
|
|
||||||
clone[ny][nx]
|
|
||||||
];
|
|
||||||
} else {
|
|
||||||
// 80% 概率替换当前图块
|
|
||||||
clone[ny][nx] = Math.floor(Math.random() * tileType.size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const id2 = `${id}.S${i}`;
|
|
||||||
const sid = `${id}:${id2}`;
|
|
||||||
const simi = compareMap(id, id2, map, clone);
|
|
||||||
|
|
||||||
res.push([
|
|
||||||
sid,
|
|
||||||
{
|
|
||||||
map1: map,
|
|
||||||
map2: clone,
|
|
||||||
size: [width, height],
|
|
||||||
topoSimilarity: simi,
|
|
||||||
visionSimilarity: calculateVisualSimilarity(map, clone)
|
|
||||||
}
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function generateTrainData(
|
|
||||||
id1: string,
|
|
||||||
id2: string,
|
|
||||||
map1: number[][],
|
|
||||||
map2: number[][],
|
|
||||||
size: [number, number],
|
|
||||||
hasSelf: boolean = true,
|
|
||||||
hasTransform: boolean = true,
|
|
||||||
hasSimilar: boolean = true
|
|
||||||
) {
|
|
||||||
const topoSimilarity = compareMap(id1, id2, map1, map2);
|
|
||||||
const visionSimilarity = calculateVisualSimilarity(map1, map2);
|
|
||||||
const train: MinamoTrainData = {
|
|
||||||
map1,
|
|
||||||
map2,
|
|
||||||
topoSimilarity,
|
|
||||||
visionSimilarity,
|
|
||||||
size: size
|
|
||||||
};
|
|
||||||
const data: MinamoTrainData[] = [];
|
|
||||||
data.push(train);
|
|
||||||
if (hasSelf) {
|
|
||||||
// 自身与自身对比的训练集,保证模型对相同地图输出 1
|
|
||||||
const self1 = `${id1}:${id1}`;
|
|
||||||
const self2 = `${id2}:${id2}`;
|
|
||||||
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
|
|
||||||
if (selfTrain.includes(self1)) {
|
|
||||||
const selfTrain1: MinamoTrainData = {
|
|
||||||
map1: map1,
|
|
||||||
map2: map1,
|
|
||||||
topoSimilarity: 1,
|
|
||||||
visionSimilarity: 1,
|
|
||||||
size: size
|
|
||||||
};
|
|
||||||
data.push(selfTrain1);
|
|
||||||
}
|
|
||||||
if (selfTrain.includes(self2)) {
|
|
||||||
const selfTrain2: MinamoTrainData = {
|
|
||||||
map1: map2,
|
|
||||||
map2: map2,
|
|
||||||
topoSimilarity: 1,
|
|
||||||
visionSimilarity: 1,
|
|
||||||
size: size
|
|
||||||
};
|
|
||||||
data.push(selfTrain2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (hasTransform) {
|
|
||||||
const transform = generateTransformData(
|
|
||||||
id1,
|
|
||||||
id2,
|
|
||||||
map1,
|
|
||||||
map2,
|
|
||||||
topoSimilarity
|
|
||||||
);
|
|
||||||
data.push(...transform.map(v => v[1]))
|
|
||||||
}
|
|
||||||
if (hasSimilar) {
|
|
||||||
const similar = generateSimilarData(id1, map1);
|
|
||||||
data.push(...similar.map(v => v[1]))
|
|
||||||
}
|
|
||||||
return data;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function generatePair(
|
|
||||||
data: Record<string, MinamoTrainData>,
|
|
||||||
id1: string,
|
|
||||||
id2: string,
|
|
||||||
map1: number[][],
|
|
||||||
map2: number[][],
|
|
||||||
size: [number, number]
|
|
||||||
) {
|
|
||||||
const topoSimilarity = compareMap(id1, id2, map1, map2);
|
|
||||||
const visionSimilarity = calculateVisualSimilarity(map1, map2);
|
|
||||||
const train: MinamoTrainData = {
|
|
||||||
map1,
|
|
||||||
map2,
|
|
||||||
topoSimilarity,
|
|
||||||
visionSimilarity,
|
|
||||||
size: size
|
|
||||||
};
|
|
||||||
data[`${id1}:${id2}`] = train;
|
|
||||||
// 自身与自身对比的训练集,保证模型对相同地图输出 1
|
|
||||||
const self1 = `${id1}:${id1}`;
|
|
||||||
const self2 = `${id2}:${id2}`;
|
|
||||||
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
|
|
||||||
if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) {
|
|
||||||
const selfTrain1: MinamoTrainData = {
|
|
||||||
map1: map1,
|
|
||||||
map2: map1,
|
|
||||||
topoSimilarity: 1,
|
|
||||||
visionSimilarity: 1,
|
|
||||||
size: size
|
|
||||||
};
|
|
||||||
data[`${id1}:${id1}`] = selfTrain1;
|
|
||||||
}
|
|
||||||
if (selfTrain.includes(self2) && !data[`${id2}:${id2}`]) {
|
|
||||||
const selfTrain2: MinamoTrainData = {
|
|
||||||
map1: map2,
|
|
||||||
map2: map2,
|
|
||||||
topoSimilarity: 1,
|
|
||||||
visionSimilarity: 1,
|
|
||||||
size: size
|
|
||||||
};
|
|
||||||
data[`${id2}:${id2}`] = selfTrain2;
|
|
||||||
}
|
|
||||||
// 翻转、旋转训练集
|
|
||||||
Object.assign(
|
|
||||||
data,
|
|
||||||
Object.fromEntries(
|
|
||||||
generateTransformData(id1, id2, map1, map2, topoSimilarity)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
// 地图微调训练集
|
|
||||||
Object.assign(data, Object.fromEntries(generateSimilarData(id1, map1)));
|
|
||||||
}
|
|
||||||
|
|
||||||
function generateDataset(
|
|
||||||
floors: Map<string, FloorData>,
|
|
||||||
pairs: number[],
|
|
||||||
floorIds: string[]
|
|
||||||
): Record<string, MinamoTrainData> {
|
|
||||||
const data: Record<string, MinamoTrainData> = {};
|
|
||||||
|
|
||||||
const progress = new SingleBar({}, Presets.shades_classic);
|
|
||||||
|
|
||||||
progress.start(pairs.length, 0);
|
|
||||||
|
|
||||||
pairs.forEach((v, i) => {
|
|
||||||
const num1 = Math.floor(v / floorIds.length);
|
|
||||||
const num2 = v % floorIds.length;
|
|
||||||
const id1 = floorIds[num1];
|
|
||||||
const id2 = floorIds[num2];
|
|
||||||
const map1 = floors.get(id1)?.map;
|
|
||||||
const map2 = floors.get(id2)?.map;
|
|
||||||
if (!map1 || !map2) return;
|
|
||||||
const [w1, h1] = [map1[0].length, map1.length];
|
|
||||||
const [w2, h2] = [map2[0].length, map2.length];
|
|
||||||
if (w1 !== w2 || h1 !== h2) return;
|
|
||||||
generatePair(data, id1, id2, map1, map2, [w1, h1]);
|
|
||||||
progress.update(i + 1);
|
|
||||||
});
|
|
||||||
|
|
||||||
progress.stop();
|
|
||||||
|
|
||||||
return data;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function parseMinamo(data: Map<string, FloorData>): MinamoDataset {
|
|
||||||
const length = data.size;
|
|
||||||
const totalCount = Math.round((length * (length - 1)) / 2);
|
|
||||||
|
|
||||||
const pairs = choosePair(length, 10000);
|
|
||||||
|
|
||||||
console.log(
|
|
||||||
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${pairs.length} 个组合`
|
|
||||||
);
|
|
||||||
|
|
||||||
const trainData = generateDataset(data, pairs, [...data.keys()]);
|
|
||||||
|
|
||||||
const dataset: MinamoDataset = {
|
|
||||||
datasetId: Math.floor(Math.random() * 1e12),
|
|
||||||
data: trainData
|
|
||||||
};
|
|
||||||
|
|
||||||
return dataset;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function generateAssignedData(
|
|
||||||
data1: Map<string, FloorData>,
|
|
||||||
data2: Map<string, FloorData>,
|
|
||||||
count: [number, number]
|
|
||||||
): MinamoDataset {
|
|
||||||
const length = data1.size + data2.size;
|
|
||||||
const totalCount = data1.size * data2.size;
|
|
||||||
const count1 = Math.min(count[0], data1.size);
|
|
||||||
const count2 = Math.min(count[1], data2.size);
|
|
||||||
const keys1 = [...data1.keys()];
|
|
||||||
const keys2 = [...data2.keys()];
|
|
||||||
const choose1 = chooseFrom(keys1, count1);
|
|
||||||
|
|
||||||
const trainData: Record<string, MinamoTrainData> = {};
|
|
||||||
|
|
||||||
console.log(
|
|
||||||
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${
|
|
||||||
count1 * count2
|
|
||||||
} 个组合`
|
|
||||||
);
|
|
||||||
|
|
||||||
const progress = new SingleBar({}, Presets.shades_classic);
|
|
||||||
progress.start(count1 * count2, 0);
|
|
||||||
let n = 0;
|
|
||||||
|
|
||||||
for (const key1 of choose1) {
|
|
||||||
const choose2 = chooseFrom(keys2, count2);
|
|
||||||
for (const key2 of choose2) {
|
|
||||||
const { map: map1 } = data1.get(key1)!;
|
|
||||||
const { map: map2 } = data2.get(key2)!;
|
|
||||||
if (!map1 || !map2) continue;
|
|
||||||
const [w1, h1] = [map1[0].length, map1.length];
|
|
||||||
const [w2, h2] = [map2[0].length, map2.length];
|
|
||||||
if (w1 !== w2 || h1 !== h2) continue;
|
|
||||||
generatePair(trainData, key1, key2, map1, map2, [w1, h1]);
|
|
||||||
n++;
|
|
||||||
progress.update(n);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
progress.stop();
|
|
||||||
|
|
||||||
const dataset: MinamoDataset = {
|
|
||||||
datasetId: Math.floor(Math.random() * 1e12),
|
|
||||||
data: trainData
|
|
||||||
};
|
|
||||||
|
|
||||||
return dataset;
|
|
||||||
}
|
|
||||||
@ -1,38 +0,0 @@
|
|||||||
import { readFile, writeFile } from 'fs-extra';
|
|
||||||
import { chooseFrom, DatasetMergable, mergeDataset } from './utils';
|
|
||||||
|
|
||||||
const [target, ...review] = process.argv.slice(2);
|
|
||||||
const n = getNum();
|
|
||||||
|
|
||||||
function getNum() {
|
|
||||||
const last = review.at(-1);
|
|
||||||
if (!last) return 1000;
|
|
||||||
else {
|
|
||||||
const n = parseInt(last);
|
|
||||||
if (!n) return 1000;
|
|
||||||
else {
|
|
||||||
review.pop();
|
|
||||||
return n;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
(async () => {
|
|
||||||
const datas = await Promise.all(
|
|
||||||
review.map(async v => {
|
|
||||||
const file = await readFile(v, 'utf-8');
|
|
||||||
return JSON.parse(file) as DatasetMergable<any>;
|
|
||||||
})
|
|
||||||
);
|
|
||||||
const targetFile = await readFile(target, 'utf-8');
|
|
||||||
const targetData = JSON.parse(targetFile) as DatasetMergable<any>;
|
|
||||||
const merged = mergeDataset(true, ...datas);
|
|
||||||
const keys = Object.keys(merged.data);
|
|
||||||
const toReview = chooseFrom(keys, n);
|
|
||||||
const reviewData: DatasetMergable<any> = {
|
|
||||||
datasetId: Math.floor(Math.random() * 1e12),
|
|
||||||
data: Object.fromEntries(toReview.map(v => [v, merged.data[v]]))
|
|
||||||
};
|
|
||||||
const reviewed = mergeDataset(false, targetData, reviewData);
|
|
||||||
await writeFile(target, JSON.stringify(reviewed), 'utf-8');
|
|
||||||
})();
|
|
||||||
@ -1,95 +0,0 @@
|
|||||||
import { buildTopologicalGraph } from './graph';
|
|
||||||
import { mirrorMapX, mirrorMapY, rotateMap } from './transform';
|
|
||||||
import { overallSimilarity } from './similarity';
|
|
||||||
|
|
||||||
(() => {
|
|
||||||
// MT3
|
|
||||||
const map1 = [
|
|
||||||
[1, 1, 1, 1, 1, 1, 1],
|
|
||||||
[1, 3, 1, 10, 7, 3, 1],
|
|
||||||
[1, 6, 1, 5, 1, 7, 1],
|
|
||||||
[1, 9, 1, 8, 1, 6, 1],
|
|
||||||
[1, 5, 8, 0, 1, 7, 1],
|
|
||||||
[1, 2, 1, 5, 7, 10, 1],
|
|
||||||
[1, 1, 1, 1, 1, 1, 1]
|
|
||||||
];
|
|
||||||
// MT6
|
|
||||||
const map2 = [
|
|
||||||
[1, 1, 1, 1, 1, 1, 1],
|
|
||||||
[1, 5, 6, 4, 1, 10, 1],
|
|
||||||
[1, 6, 1, 9, 1, 5, 1],
|
|
||||||
[1, 8, 0, 6, 0, 8, 1],
|
|
||||||
[1, 5, 1, 10, 1, 2, 1],
|
|
||||||
[1, 9, 3, 1, 4, 9, 1],
|
|
||||||
[1, 1, 1, 1, 1, 1, 1]
|
|
||||||
];
|
|
||||||
// MT8
|
|
||||||
// const map2 = [
|
|
||||||
// [1, 1, 1, 1, 1, 1, 1],
|
|
||||||
// [1, 5, 8, 10, 7, 2, 1],
|
|
||||||
// [1, 2, 1, 5, 1, 7, 1],
|
|
||||||
// [1, 3, 1, 3, 6, 4, 1],
|
|
||||||
// [1, 6, 1, 6, 1, 8, 1],
|
|
||||||
// [1, 10, 7, 5, 1, 5, 1],
|
|
||||||
// [1, 1, 1, 1, 1, 1, 1]
|
|
||||||
// ];
|
|
||||||
// MT3 微调
|
|
||||||
// const map2 = [
|
|
||||||
// [1, 1, 1, 1, 1, 1, 1],
|
|
||||||
// [1, 3, 1, 10, 7, 3, 1],
|
|
||||||
// [1, 6, 1, 5, 1, 7, 1],
|
|
||||||
// [1, 9, 1, 8, 1, 6, 1],
|
|
||||||
// [1, 5, 8, 0, 1, 7, 1],
|
|
||||||
// [1, 2, 1, 4, 7, 10, 1],
|
|
||||||
// [1, 1, 1, 1, 1, 1, 1]
|
|
||||||
// ];
|
|
||||||
|
|
||||||
// 1. 两张图与自身对比
|
|
||||||
const graph1 = buildTopologicalGraph(map1);
|
|
||||||
const graph2 = buildTopologicalGraph(map2);
|
|
||||||
|
|
||||||
console.log(`map1 vs map1: ${overallSimilarity(graph1, graph1)}`);
|
|
||||||
console.log(`map2 vs map2: ${overallSimilarity(graph2, graph2)}`);
|
|
||||||
|
|
||||||
// 2. 两张图相互对比
|
|
||||||
console.log(`map1 vs map2: ${overallSimilarity(graph1, graph2)}`);
|
|
||||||
console.log(`map2 vs map1: ${overallSimilarity(graph2, graph1)}`);
|
|
||||||
|
|
||||||
// 3. x镜像对比
|
|
||||||
const xFlipped1 = mirrorMapX(map1);
|
|
||||||
const xFlipped2 = mirrorMapX(map2);
|
|
||||||
const graphX1 = buildTopologicalGraph(xFlipped1);
|
|
||||||
const graphX2 = buildTopologicalGraph(xFlipped2);
|
|
||||||
console.log(`map1:x vs map1: ${overallSimilarity(graphX1, graph1)}`);
|
|
||||||
console.log(`map1:x vs map2: ${overallSimilarity(graphX1, graph2)}`);
|
|
||||||
console.log(`map1 vs map2:x: ${overallSimilarity(graph1, graphX2)}`);
|
|
||||||
console.log(`map2:x vs map2: ${overallSimilarity(graphX2, graph2)}`);
|
|
||||||
console.log(`map2:x vs map1: ${overallSimilarity(graphX2, graph2)}`);
|
|
||||||
console.log(`map2 vs map1:x: ${overallSimilarity(graph2, graphX1)}`);
|
|
||||||
|
|
||||||
// 4. y镜像对比
|
|
||||||
const yFlipped1 = mirrorMapY(map1);
|
|
||||||
const yFlipped2 = mirrorMapY(map2);
|
|
||||||
const graphY1 = buildTopologicalGraph(yFlipped1);
|
|
||||||
const graphY2 = buildTopologicalGraph(yFlipped2);
|
|
||||||
console.log(`map1:y vs map1: ${overallSimilarity(graphY1, graph1)}`);
|
|
||||||
console.log(`map1:y vs map2: ${overallSimilarity(graphY1, graph2)}`);
|
|
||||||
console.log(`map1 vs map2:y: ${overallSimilarity(graph1, graphY2)}`);
|
|
||||||
console.log(`map2:y vs map2: ${overallSimilarity(graphY2, graph2)}`);
|
|
||||||
console.log(`map2:y vs map1: ${overallSimilarity(graphY2, graph1)}`);
|
|
||||||
console.log(`map2 vs map1:y: ${overallSimilarity(graph2, graphY1)}`);
|
|
||||||
|
|
||||||
// 5. xy 镜像混合对比
|
|
||||||
console.log(`map1:x vs map1:y: ${overallSimilarity(graphX1, graphY1)}`);
|
|
||||||
console.log(`map1:y vs map2:x: ${overallSimilarity(graphY1, graphX2)}`);
|
|
||||||
console.log(`map1:x vs map2:x: ${overallSimilarity(graphX1, graphX2)}`);
|
|
||||||
console.log(`map1:x vs map2:y: ${overallSimilarity(graphX1, graphY2)}`);
|
|
||||||
|
|
||||||
// 6. 旋转对比
|
|
||||||
const rot901 = rotateMap(map1);
|
|
||||||
const rot902 = rotateMap(map2);
|
|
||||||
const graph901 = buildTopologicalGraph(rot901);
|
|
||||||
const graph902 = buildTopologicalGraph(rot902);
|
|
||||||
console.log(`map1:90 vs map1: ${overallSimilarity(graph1, graph901)}`);
|
|
||||||
console.log(`map2:90 vs map2: ${overallSimilarity(graph2, graph902)}`);
|
|
||||||
})();
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
export function mirrorMapX(map: number[][]) {
|
|
||||||
return map.map(v => [...v].reverse());
|
|
||||||
}
|
|
||||||
|
|
||||||
export function mirrorMapY(map: number[][]) {
|
|
||||||
return [...map].reverse();
|
|
||||||
}
|
|
||||||
|
|
||||||
export function rotateMap(map: number[][]) {
|
|
||||||
return [
|
|
||||||
...map[0].map((_, colIndex) => map.map(row => row[colIndex]))
|
|
||||||
].reverse();
|
|
||||||
}
|
|
||||||
@ -1,7 +1,6 @@
|
|||||||
import { readFile } from 'fs-extra';
|
import { readFile } from 'fs-extra';
|
||||||
import { join } from 'path';
|
import { join } from 'path';
|
||||||
import { BaseConfig, GinkaConfig, TowerInfo } from './types';
|
import { BaseConfig, GinkaConfig, TowerInfo } from './types';
|
||||||
import { convertFloor } from './floor';
|
|
||||||
|
|
||||||
export interface DatasetMergable<T> {
|
export interface DatasetMergable<T> {
|
||||||
datasetId: number;
|
datasetId: number;
|
||||||
@ -77,81 +76,6 @@ export async function parseTowerInfo(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function getAllFloors(...info: TowerInfo[]) {
|
|
||||||
const floorData = await Promise.all(
|
|
||||||
info.map(async tower => {
|
|
||||||
// 获取必要信息
|
|
||||||
const enemyFile = await readFile(
|
|
||||||
join(tower.path, 'enemys.js'),
|
|
||||||
'utf-8'
|
|
||||||
);
|
|
||||||
const mapFile = await readFile(
|
|
||||||
join(tower.path, 'maps.js'),
|
|
||||||
'utf-8'
|
|
||||||
);
|
|
||||||
const enemyMap = JSON.parse(
|
|
||||||
enemyFile.split('\n').slice(1).join('\n')
|
|
||||||
) as Record<string, any>;
|
|
||||||
const mapData = JSON.parse(
|
|
||||||
mapFile.split('\n').slice(1).join('\n')
|
|
||||||
) as Record<number, any>;
|
|
||||||
const enemyNumMap: Record<number, any> = {};
|
|
||||||
// 将怪物转化为数字映射
|
|
||||||
for (const [key, value] of Object.entries(mapData)) {
|
|
||||||
if (value.cls === 'enemys') {
|
|
||||||
enemyNumMap[parseInt(key)] = enemyMap[value.id];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Promise.all(
|
|
||||||
tower.floorIds.map(async id => {
|
|
||||||
const floorFile = await readFile(
|
|
||||||
join(tower.path, 'floors', `${id}.js`),
|
|
||||||
'utf-8'
|
|
||||||
);
|
|
||||||
try {
|
|
||||||
const data = JSON.parse(
|
|
||||||
floorFile
|
|
||||||
// .replaceAll("'", '"')
|
|
||||||
.slice(floorFile.indexOf('=') + 1)
|
|
||||||
);
|
|
||||||
|
|
||||||
const map = data.map as number[][];
|
|
||||||
// 裁剪地图
|
|
||||||
const { clip } = tower.config;
|
|
||||||
const area = clip.special[id] ?? clip.defaults;
|
|
||||||
|
|
||||||
return convertFloor(
|
|
||||||
map,
|
|
||||||
area,
|
|
||||||
tower.config as GinkaConfig,
|
|
||||||
enemyNumMap
|
|
||||||
);
|
|
||||||
} catch (e) {
|
|
||||||
console.log(
|
|
||||||
`Error when processing '${tower.name}' '${id}'`
|
|
||||||
);
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
})
|
|
||||||
);
|
|
||||||
})
|
|
||||||
);
|
|
||||||
const maps: Map<string, FloorData> = new Map();
|
|
||||||
floorData.forEach((tower, tid) => {
|
|
||||||
const name = info[tid].name;
|
|
||||||
tower.forEach((map, mid) => {
|
|
||||||
const floorId = info[tid].floorIds[mid];
|
|
||||||
maps.set(`${name}::${floorId}`, {
|
|
||||||
map,
|
|
||||||
id: floorId,
|
|
||||||
config: info[tid].config
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
return maps;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function mergeFloorIds(...info: TowerInfo[]) {
|
export function mergeFloorIds(...info: TowerInfo[]) {
|
||||||
const ids: string[] = [];
|
const ids: string[] = [];
|
||||||
info.forEach(v => {
|
info.forEach(v => {
|
||||||
@ -160,14 +84,6 @@ export function mergeFloorIds(...info: TowerInfo[]) {
|
|||||||
return ids;
|
return ids;
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function readOne(path: string) {
|
|
||||||
if (path.endsWith('.json')) {
|
|
||||||
return fromJSON(path);
|
|
||||||
} else {
|
|
||||||
return getAllFloors(await parseTowerInfo(path, 'minamo-config.json'));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function fromJSON(path: string) {
|
export async function fromJSON(path: string) {
|
||||||
const file = await readFile(path, 'utf-8');
|
const file = await readFile(path, 'utf-8');
|
||||||
const data = JSON.parse(file) as Record<string, number[][]>;
|
const data = JSON.parse(file) as Record<string, number[][]>;
|
||||||
|
|||||||
@ -1,145 +0,0 @@
|
|||||||
interface VisualSimilarityConfig {
|
|
||||||
// 类型重要性权重(需根据游戏设定调整)
|
|
||||||
typeWeights: { [key: number]: number };
|
|
||||||
// 是否启用视觉焦点增强
|
|
||||||
enableVisualFocus: boolean;
|
|
||||||
// 是否启用密度感知
|
|
||||||
enableDensityAwareness: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
const DEFAULT_CONFIG: VisualSimilarityConfig = {
|
|
||||||
typeWeights: {
|
|
||||||
0: 0.2, // 空地
|
|
||||||
1: 0.3, // 墙壁
|
|
||||||
2: 0.6, // 钥匙
|
|
||||||
3: 0.7, // 红宝石
|
|
||||||
4: 0.7, // 蓝宝石
|
|
||||||
5: 0.5, // 血瓶
|
|
||||||
6: 0.4, // 门
|
|
||||||
7: 0.5, // 弱怪
|
|
||||||
8: 0.6, // 中怪
|
|
||||||
9: 0.6, // 强怪
|
|
||||||
10: 0.4, // 楼梯
|
|
||||||
11: 0.4, // 箭头
|
|
||||||
12: 0.7 // 道具
|
|
||||||
},
|
|
||||||
enableVisualFocus: true,
|
|
||||||
enableDensityAwareness: true
|
|
||||||
};
|
|
||||||
|
|
||||||
export function calculateVisualSimilarity(
|
|
||||||
map1: number[][],
|
|
||||||
map2: number[][],
|
|
||||||
config = DEFAULT_CONFIG
|
|
||||||
): number {
|
|
||||||
// 尺寸校验
|
|
||||||
if (map1.length !== map2.length || map1[0]?.length !== map2[0]?.length) {
|
|
||||||
return 0; // 或抛出异常
|
|
||||||
}
|
|
||||||
|
|
||||||
const rows = map1.length;
|
|
||||||
const cols = map1[0].length;
|
|
||||||
let totalScore = 0;
|
|
||||||
let maxPossibleScore = 0;
|
|
||||||
|
|
||||||
// 视觉焦点权重图
|
|
||||||
const focusWeights = config.enableVisualFocus
|
|
||||||
? generateFocusWeights(rows, cols)
|
|
||||||
: Array(rows)
|
|
||||||
.fill(1)
|
|
||||||
.map(() => Array(cols).fill(1));
|
|
||||||
|
|
||||||
// 类型密度分布计算
|
|
||||||
const densityMap = config.enableDensityAwareness
|
|
||||||
? calculateDensityImpact(map1, map2, config.typeWeights)
|
|
||||||
: Array(rows)
|
|
||||||
.fill(1)
|
|
||||||
.map(() => Array(cols).fill(1));
|
|
||||||
|
|
||||||
for (let i = 0; i < rows; i++) {
|
|
||||||
for (let j = 0; j < cols; j++) {
|
|
||||||
const type1 = map1[i][j];
|
|
||||||
const type2 = map2[i][j];
|
|
||||||
|
|
||||||
// 基础类型权重
|
|
||||||
const baseWeight = Math.max(
|
|
||||||
config.typeWeights[type1] || 0.5,
|
|
||||||
config.typeWeights[type2] || 0.5
|
|
||||||
);
|
|
||||||
|
|
||||||
// 空间权重组合
|
|
||||||
const spatialWeight = focusWeights[i][j] * densityMap[i][j];
|
|
||||||
|
|
||||||
// 类型匹配得分
|
|
||||||
const typeScore = type1 === type2 ? 1 : 0;
|
|
||||||
|
|
||||||
totalScore += typeScore * baseWeight * spatialWeight;
|
|
||||||
maxPossibleScore += baseWeight * spatialWeight;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return maxPossibleScore > 0 ? totalScore / maxPossibleScore : 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 生成视觉焦点权重图(基于人类视觉注意力分布)
|
|
||||||
*/
|
|
||||||
function generateFocusWeights(rows: number, cols: number): number[][] {
|
|
||||||
const weights = [];
|
|
||||||
const centerX = cols / 2;
|
|
||||||
const centerY = rows / 2;
|
|
||||||
const maxDist = Math.sqrt(centerX ** 2 + centerY ** 2) * 0.7;
|
|
||||||
|
|
||||||
for (let i = 0; i < rows; i++) {
|
|
||||||
const rowWeights = [];
|
|
||||||
for (let j = 0; j < cols; j++) {
|
|
||||||
// 使用高斯分布模拟视觉焦点
|
|
||||||
const dx = (j - centerX) / cols;
|
|
||||||
const dy = (i - centerY) / rows;
|
|
||||||
const distance = Math.sqrt(dx ** 2 + dy ** 2);
|
|
||||||
const gaussian = Math.exp(-(distance ** 2) / (2 * 0.3 ** 2));
|
|
||||||
rowWeights.push(1.0 + 0.6 * gaussian); // 中心区域最高1.6倍权重
|
|
||||||
}
|
|
||||||
weights.push(rowWeights);
|
|
||||||
}
|
|
||||||
return weights;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 计算类型密度影响权重
|
|
||||||
*/
|
|
||||||
function calculateDensityImpact(
|
|
||||||
map1: number[][],
|
|
||||||
map2: number[][],
|
|
||||||
typeWeights: { [key: number]: number }
|
|
||||||
): number[][] {
|
|
||||||
const rows = map1.length;
|
|
||||||
const cols = map1[0].length;
|
|
||||||
const densityMap = Array(rows)
|
|
||||||
.fill(0)
|
|
||||||
.map(() => Array(cols).fill(0));
|
|
||||||
|
|
||||||
// 滑动窗口分析局部密度
|
|
||||||
const windowSize = 3;
|
|
||||||
const halfWindow = Math.floor(windowSize / 2);
|
|
||||||
|
|
||||||
for (let i = 0; i < rows; i++) {
|
|
||||||
for (let j = 0; j < cols; j++) {
|
|
||||||
let density = 0;
|
|
||||||
for (let di = -halfWindow; di <= halfWindow; di++) {
|
|
||||||
for (let dj = -halfWindow; dj <= halfWindow; dj++) {
|
|
||||||
const ni = i + di;
|
|
||||||
const nj = j + dj;
|
|
||||||
if (ni >= 0 && ni < rows && nj >= 0 && nj < cols) {
|
|
||||||
const weight1 = typeWeights[map1[ni][nj]] || 0.5;
|
|
||||||
const weight2 = typeWeights[map2[ni][nj]] || 0.5;
|
|
||||||
density += (weight1 + weight2) / 2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 密度权重:高密度区域增强对比度
|
|
||||||
densityMap[i][j] = 1.0 + 0.4 * (density / windowSize ** 2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return densityMap;
|
|
||||||
}
|
|
||||||
@ -1,74 +0,0 @@
|
|||||||
import { calculateVisualSimilarity } from './similarity';
|
|
||||||
|
|
||||||
(() => {
|
|
||||||
// MT3
|
|
||||||
const map1 = [
|
|
||||||
[1, 1, 1, 1, 1, 1, 1],
|
|
||||||
[1, 3, 1, 10, 7, 3, 1],
|
|
||||||
[1, 6, 1, 5, 1, 7, 1],
|
|
||||||
[1, 9, 1, 8, 1, 6, 1],
|
|
||||||
[1, 5, 8, 0, 1, 7, 1],
|
|
||||||
[1, 2, 1, 5, 7, 10, 1],
|
|
||||||
[1, 1, 1, 1, 1, 1, 1]
|
|
||||||
];
|
|
||||||
// MT6
|
|
||||||
const map2 = [
|
|
||||||
[1, 1, 1, 1, 1, 1, 1],
|
|
||||||
[1, 5, 6, 4, 1, 10, 1],
|
|
||||||
[1, 6, 1, 9, 1, 5, 1],
|
|
||||||
[1, 8, 0, 6, 0, 8, 1],
|
|
||||||
[1, 5, 1, 10, 1, 2, 1],
|
|
||||||
[1, 9, 3, 1, 4, 9, 1],
|
|
||||||
[1, 1, 1, 1, 1, 1, 1]
|
|
||||||
];
|
|
||||||
// MT8
|
|
||||||
const map3 = [
|
|
||||||
[1, 1, 1, 1, 1, 1, 1],
|
|
||||||
[1, 5, 8, 10, 7, 2, 1],
|
|
||||||
[1, 2, 1, 5, 1, 7, 1],
|
|
||||||
[1, 3, 1, 3, 6, 4, 1],
|
|
||||||
[1, 6, 1, 6, 1, 8, 1],
|
|
||||||
[1, 10, 7, 5, 1, 5, 1],
|
|
||||||
[1, 1, 1, 1, 1, 1, 1]
|
|
||||||
];
|
|
||||||
// MT3 微调
|
|
||||||
const map4 = [
|
|
||||||
[1, 1, 1, 1, 1, 1, 1],
|
|
||||||
[1, 3, 1, 10, 7, 3, 1],
|
|
||||||
[1, 6, 1, 5, 1, 7, 1],
|
|
||||||
[1, 9, 1, 8, 1, 6, 1],
|
|
||||||
[1, 5, 8, 0, 1, 7, 1],
|
|
||||||
[1, 2, 1, 4, 7, 10, 1],
|
|
||||||
[1, 1, 1, 1, 1, 1, 1]
|
|
||||||
];
|
|
||||||
// MT10
|
|
||||||
const map5 = [
|
|
||||||
[1, 1, 1, 1, 1, 1, 1],
|
|
||||||
[1, 5, 1, 10, 1, 5, 1],
|
|
||||||
[1, 6, 7, 7, 7, 6, 1],
|
|
||||||
[1, 1, 6, 5, 6, 1, 1],
|
|
||||||
[1, 4, 5, 9, 5, 4, 1],
|
|
||||||
[1, 3, 1, 1, 1, 3, 1],
|
|
||||||
[1, 1, 1, 1, 1, 1, 1]
|
|
||||||
];
|
|
||||||
|
|
||||||
// 测试自我对比
|
|
||||||
console.log(`map1 vs map1: ${calculateVisualSimilarity(map1, map1)}`);
|
|
||||||
console.log(`map2 vs map2: ${calculateVisualSimilarity(map2, map2)}`);
|
|
||||||
console.log(`map3 vs map3: ${calculateVisualSimilarity(map3, map3)}`);
|
|
||||||
console.log(`map4 vs map4: ${calculateVisualSimilarity(map4, map4)}`);
|
|
||||||
// 两两测试
|
|
||||||
console.log(`map1 vs map2: ${calculateVisualSimilarity(map1, map2)}`);
|
|
||||||
console.log(`map1 vs map3: ${calculateVisualSimilarity(map1, map3)}`);
|
|
||||||
console.log(`map1 vs map4: ${calculateVisualSimilarity(map1, map4)}`);
|
|
||||||
console.log(`map1 vs map5: ${calculateVisualSimilarity(map1, map5)}`);
|
|
||||||
console.log(`map2 vs map3: ${calculateVisualSimilarity(map2, map3)}`);
|
|
||||||
console.log(`map2 vs map4: ${calculateVisualSimilarity(map2, map4)}`);
|
|
||||||
console.log(`map2 vs map5: ${calculateVisualSimilarity(map2, map5)}`);
|
|
||||||
console.log(`map3 vs map4: ${calculateVisualSimilarity(map3, map4)}`);
|
|
||||||
console.log(`map3 vs map5: ${calculateVisualSimilarity(map3, map5)}`);
|
|
||||||
console.log(`map4 vs map5: ${calculateVisualSimilarity(map4, map5)}`);
|
|
||||||
// 测试交换性
|
|
||||||
console.log(`map2 vs map1: ${calculateVisualSimilarity(map2, map1)}`);
|
|
||||||
console.log(`map4 vs map2: ${calculateVisualSimilarity(map4, map2)}`);
|
|
||||||
})();
|
|
||||||
@ -1,133 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch_geometric.nn import GCNConv, TransformerConv
|
|
||||||
from torch_geometric.utils import grid
|
|
||||||
|
|
||||||
def batch_edge_index(B, edge_index, num_nodes_per_batch):
|
|
||||||
# 批次偏移 edge_index
|
|
||||||
edge_index = edge_index.clone() # [2, E]
|
|
||||||
batch_edge_index = []
|
|
||||||
for i in range(B):
|
|
||||||
offset = i * num_nodes_per_batch
|
|
||||||
batch_edge_index.append(edge_index + offset)
|
|
||||||
return torch.cat(batch_edge_index, dim=1)
|
|
||||||
|
|
||||||
class DoubleConvBlock(nn.Module):
|
|
||||||
def __init__(self, feats: tuple[int, int, int]):
|
|
||||||
super().__init__()
|
|
||||||
self.cnn = nn.Sequential(
|
|
||||||
nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate'),
|
|
||||||
nn.InstanceNorm2d(feats[1]),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate'),
|
|
||||||
nn.InstanceNorm2d(feats[2]),
|
|
||||||
nn.GELU(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.cnn(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class GCNBlock(nn.Module):
|
|
||||||
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
|
|
||||||
super().__init__()
|
|
||||||
self.conv1 = GCNConv(in_ch, hidden_ch)
|
|
||||||
self.conv2 = GCNConv(hidden_ch, hidden_ch)
|
|
||||||
self.conv3 = GCNConv(hidden_ch, out_ch)
|
|
||||||
self.norm1 = nn.LayerNorm(hidden_ch)
|
|
||||||
self.norm2 = nn.LayerNorm(hidden_ch)
|
|
||||||
self.norm3 = nn.LayerNorm(out_ch)
|
|
||||||
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# x: [B, C, H, W]
|
|
||||||
B, C, H, W = x.shape
|
|
||||||
|
|
||||||
# Reshape to [B * H * W, C]
|
|
||||||
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
|
|
||||||
|
|
||||||
# Construct batched edge index
|
|
||||||
device = x.device
|
|
||||||
edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W)
|
|
||||||
|
|
||||||
# Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling)
|
|
||||||
# batch = torch.arange(B, device=device).repeat_interleave(H * W)
|
|
||||||
|
|
||||||
# GCN forward
|
|
||||||
x = self.conv1(x, edge_index)
|
|
||||||
x = F.gelu(self.norm1(x))
|
|
||||||
x = self.conv2(x, edge_index)
|
|
||||||
x = F.gelu(self.norm2(x))
|
|
||||||
x = self.conv3(x, edge_index)
|
|
||||||
x = F.gelu(self.norm3(x))
|
|
||||||
|
|
||||||
# Reshape back to [B, C, H, W]
|
|
||||||
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class TransformerGCNBlock(nn.Module):
|
|
||||||
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
|
|
||||||
super().__init__()
|
|
||||||
self.conv1 = TransformerConv(in_ch, hidden_ch // 8, heads=8, concat=True)
|
|
||||||
self.conv2 = TransformerConv(hidden_ch, out_ch, heads=1)
|
|
||||||
self.norm1 = nn.LayerNorm(hidden_ch)
|
|
||||||
self.norm2 = nn.LayerNorm(out_ch)
|
|
||||||
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# x: [B, C, H, W]
|
|
||||||
B, C, H, W = x.shape
|
|
||||||
|
|
||||||
# Reshape to [B * H * W, C]
|
|
||||||
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
|
|
||||||
|
|
||||||
# Construct batched edge index
|
|
||||||
device = x.device
|
|
||||||
edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W)
|
|
||||||
|
|
||||||
# Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling)
|
|
||||||
# batch = torch.arange(B, device=device).repeat_interleave(H * W)
|
|
||||||
|
|
||||||
# GCN forward
|
|
||||||
x = self.conv1(x, edge_index)
|
|
||||||
x = F.gelu(self.norm1(x))
|
|
||||||
x = self.conv2(x, edge_index)
|
|
||||||
x = F.gelu(self.norm2(x))
|
|
||||||
|
|
||||||
# Reshape back to [B, C, H, W]
|
|
||||||
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class ConvFusionModule(nn.Module):
|
|
||||||
def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int):
|
|
||||||
super().__init__()
|
|
||||||
self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch])
|
|
||||||
self.gcn = TransformerGCNBlock(in_ch, hidden_ch, in_ch, w, h)
|
|
||||||
self.fusion = DoubleConvBlock([in_ch*2, hidden_ch, out_ch])
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x1 = self.cnn(x)
|
|
||||||
x2 = self.gcn(x)
|
|
||||||
x = torch.cat([x1, x2], dim=1)
|
|
||||||
x = self.fusion(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class DoubleFCModule(nn.Module):
|
|
||||||
def __init__(self, in_dim, hidden_dim, out_dim):
|
|
||||||
super().__init__()
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(in_dim, hidden_dim),
|
|
||||||
nn.LayerNorm(hidden_dim),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
nn.Linear(hidden_dim, out_dim),
|
|
||||||
nn.LayerNorm(out_dim),
|
|
||||||
nn.GELU()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.fc(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
@ -1,50 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from .common import DoubleFCModule
|
|
||||||
|
|
||||||
class ConditionEncoder(nn.Module):
|
|
||||||
def __init__(self, tag_dim=64, val_dim=16, hidden_dim=256, out_dim=256):
|
|
||||||
super().__init__()
|
|
||||||
self.tag_embed = DoubleFCModule(tag_dim, hidden_dim, hidden_dim)
|
|
||||||
self.val_embed = DoubleFCModule(val_dim, hidden_dim, hidden_dim)
|
|
||||||
self.stage_embed = DoubleFCModule(1, hidden_dim, hidden_dim)
|
|
||||||
self.encoder = nn.TransformerEncoder(
|
|
||||||
nn.TransformerEncoderLayer(
|
|
||||||
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
|
|
||||||
batch_first=True
|
|
||||||
),
|
|
||||||
num_layers=4
|
|
||||||
)
|
|
||||||
self.fusion = nn.Sequential(
|
|
||||||
nn.Linear(hidden_dim, hidden_dim),
|
|
||||||
nn.LayerNorm(hidden_dim),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
nn.Linear(hidden_dim, out_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, tag, val, stage):
|
|
||||||
# tag = self.tag_embed(tag)
|
|
||||||
val = self.val_embed(val)
|
|
||||||
stage = self.stage_embed(stage)
|
|
||||||
feat = torch.stack([val, stage], dim=1)
|
|
||||||
feat = self.encoder(feat)
|
|
||||||
feat = torch.mean(feat, dim=1)
|
|
||||||
feat = self.fusion(feat)
|
|
||||||
return feat
|
|
||||||
|
|
||||||
class ConditionInjector(nn.Module):
|
|
||||||
def __init__(self, cond_dim, out_dim):
|
|
||||||
super().__init__()
|
|
||||||
self.gamma_layer = nn.Sequential(
|
|
||||||
nn.Linear(cond_dim, out_dim)
|
|
||||||
)
|
|
||||||
self.beta_layer = nn.Sequential(
|
|
||||||
nn.Linear(cond_dim, out_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, cond):
|
|
||||||
gamma = self.gamma_layer(cond).unsqueeze(2).unsqueeze(3)
|
|
||||||
beta = self.beta_layer(cond).unsqueeze(2).unsqueeze(3)
|
|
||||||
return x * gamma + beta
|
|
||||||
@ -1,292 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.nn.utils import spectral_norm
|
|
||||||
from torch_geometric.nn import global_max_pool, GCNConv, TransformerConv
|
|
||||||
from torch_geometric.utils import grid
|
|
||||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
|
||||||
from .vision import MinamoVisionModel
|
|
||||||
from .topo import MinamoTopoModel
|
|
||||||
|
|
||||||
def print_memory(tag=""):
|
|
||||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
|
||||||
|
|
||||||
def batch_edge_index(B, edge_index, num_nodes_per_batch):
|
|
||||||
# 批次偏移 edge_index
|
|
||||||
edge_index = edge_index.clone() # [2, E]
|
|
||||||
batch_edge_index = []
|
|
||||||
for i in range(B):
|
|
||||||
offset = i * num_nodes_per_batch
|
|
||||||
batch_edge_index.append(edge_index + offset)
|
|
||||||
return torch.cat(batch_edge_index, dim=1)
|
|
||||||
|
|
||||||
class DoubleConvBlock(nn.Module):
|
|
||||||
def __init__(self, feats: tuple[int, int, int]):
|
|
||||||
super().__init__()
|
|
||||||
self.cnn = nn.Sequential(
|
|
||||||
spectral_norm(nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate')),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
spectral_norm(nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate')),
|
|
||||||
nn.GELU(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.cnn(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class TransformerGCNBlock(nn.Module):
|
|
||||||
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
|
|
||||||
super().__init__()
|
|
||||||
self.conv1 = TransformerConv(in_ch, hidden_ch // 8, heads=8, concat=True)
|
|
||||||
self.conv2 = TransformerConv(hidden_ch, out_ch, heads=1)
|
|
||||||
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, C, H, W = x.shape
|
|
||||||
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
|
|
||||||
device = x.device
|
|
||||||
edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W)
|
|
||||||
x = self.conv1(x, edge_index)
|
|
||||||
x = F.gelu(x)
|
|
||||||
x = self.conv2(x, edge_index)
|
|
||||||
x = F.gelu(x)
|
|
||||||
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class ConvFusionModule(nn.Module):
|
|
||||||
def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int):
|
|
||||||
super().__init__()
|
|
||||||
self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch])
|
|
||||||
self.gcn = TransformerGCNBlock(in_ch, hidden_ch, in_ch, w, h)
|
|
||||||
self.fusion = DoubleConvBlock([in_ch*2, hidden_ch, out_ch])
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x1 = self.cnn(x)
|
|
||||||
x2 = self.gcn(x)
|
|
||||||
x = torch.cat([x1, x2], dim=1)
|
|
||||||
x = self.fusion(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class DoubleFCModule(nn.Module):
|
|
||||||
def __init__(self, in_dim, hidden_dim, out_dim):
|
|
||||||
super().__init__()
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
spectral_norm(nn.Linear(in_dim, hidden_dim)),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
spectral_norm(nn.Linear(hidden_dim, out_dim)),
|
|
||||||
nn.GELU()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.fc(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class ConditionEncoder(nn.Module):
|
|
||||||
def __init__(self, tag_dim, val_dim, hidden_dim, out_dim):
|
|
||||||
super().__init__()
|
|
||||||
self.tag_embed = DoubleFCModule(tag_dim, hidden_dim, hidden_dim)
|
|
||||||
self.val_embed = DoubleFCModule(val_dim, hidden_dim, hidden_dim)
|
|
||||||
self.stage_embed = DoubleFCModule(1, hidden_dim, hidden_dim)
|
|
||||||
self.encoder = nn.TransformerEncoder(
|
|
||||||
nn.TransformerEncoderLayer(
|
|
||||||
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
|
|
||||||
batch_first=True
|
|
||||||
),
|
|
||||||
num_layers=4
|
|
||||||
)
|
|
||||||
self.fusion = nn.Sequential(
|
|
||||||
spectral_norm(nn.Linear(hidden_dim, hidden_dim)),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
spectral_norm(nn.Linear(hidden_dim, out_dim))
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, tag, val, stage):
|
|
||||||
tag = self.tag_embed(tag)
|
|
||||||
val = self.val_embed(val)
|
|
||||||
stage = self.stage_embed(stage)
|
|
||||||
feat = torch.stack([tag, val, stage], dim=1)
|
|
||||||
feat = self.encoder(feat)
|
|
||||||
feat = torch.mean(feat, dim=1)
|
|
||||||
feat = self.fusion(feat)
|
|
||||||
return feat
|
|
||||||
|
|
||||||
class ConditionInjector(nn.Module):
|
|
||||||
def __init__(self, cond_dim, out_dim):
|
|
||||||
super().__init__()
|
|
||||||
self.gamma_layer = nn.Sequential(
|
|
||||||
spectral_norm(nn.Linear(cond_dim, out_dim))
|
|
||||||
)
|
|
||||||
self.beta_layer = nn.Sequential(
|
|
||||||
spectral_norm(nn.Linear(cond_dim, out_dim))
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, cond):
|
|
||||||
gamma = self.gamma_layer(cond).unsqueeze(2).unsqueeze(3)
|
|
||||||
beta = self.beta_layer(cond).unsqueeze(2).unsqueeze(3)
|
|
||||||
return x * gamma + beta
|
|
||||||
|
|
||||||
class CNNHead(nn.Module):
|
|
||||||
def __init__(self, in_ch):
|
|
||||||
super().__init__()
|
|
||||||
self.cnn = nn.Sequential(
|
|
||||||
spectral_norm(nn.Conv2d(in_ch, in_ch, 3)),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
nn.AdaptiveMaxPool2d((2, 2))
|
|
||||||
)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
spectral_norm(nn.Linear(in_ch*2*2, 1))
|
|
||||||
)
|
|
||||||
self.proj = spectral_norm(nn.Linear(256, in_ch*2*2))
|
|
||||||
|
|
||||||
def forward(self, x, cond):
|
|
||||||
x = self.cnn(x)
|
|
||||||
B, C, H, W = x.shape
|
|
||||||
x = x.view(B, -1)
|
|
||||||
cond = self.proj(cond)
|
|
||||||
proj = torch.sum(x * cond, dim=1, keepdim=True)
|
|
||||||
x = self.fc(x) + proj
|
|
||||||
return x
|
|
||||||
|
|
||||||
class GCNHead(nn.Module):
|
|
||||||
def __init__(self, in_dim):
|
|
||||||
super().__init__()
|
|
||||||
self.gcn = GCNConv(in_dim, in_dim)
|
|
||||||
self.proj = spectral_norm(nn.Linear(256, in_dim))
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
spectral_norm(nn.Linear(in_dim, 1))
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, graph, cond):
|
|
||||||
x = self.gcn(x, graph.edge_index)
|
|
||||||
x = F.gelu(x)
|
|
||||||
x = global_max_pool(x, graph.batch)
|
|
||||||
cond = self.proj(cond)
|
|
||||||
proj = torch.sum(x * cond, dim=1, keepdim=True)
|
|
||||||
x = self.fc(x) + proj
|
|
||||||
return x
|
|
||||||
|
|
||||||
class MinamoScoreHead(nn.Module):
|
|
||||||
def __init__(self, vision_dim, topo_dim):
|
|
||||||
super().__init__()
|
|
||||||
self.vision_head = CNNHead(vision_dim)
|
|
||||||
self.topo_head = GCNHead(topo_dim)
|
|
||||||
|
|
||||||
def forward(self, vis, topo, graph, cond):
|
|
||||||
vis_score = self.vision_head(vis, cond)
|
|
||||||
topo_score = self.topo_head(topo, graph, cond)
|
|
||||||
return vis_score, topo_score
|
|
||||||
|
|
||||||
class MinamoModel(nn.Module):
|
|
||||||
def __init__(self, tile_types=32):
|
|
||||||
super().__init__()
|
|
||||||
self.topo_model = MinamoTopoModel(tile_types)
|
|
||||||
self.vision_model = MinamoVisionModel(tile_types)
|
|
||||||
self.cond = ConditionEncoder(64, 16, 256, 256)
|
|
||||||
# 输出层
|
|
||||||
self.head1 = MinamoScoreHead(512, 512)
|
|
||||||
self.head2 = MinamoScoreHead(512, 512)
|
|
||||||
self.head3 = MinamoScoreHead(512, 512)
|
|
||||||
|
|
||||||
def forward(self, map, graph, stage, tag_cond, val_cond):
|
|
||||||
B, D = tag_cond.shape
|
|
||||||
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(map.device)
|
|
||||||
vision = self.vision_model(map)
|
|
||||||
topo = self.topo_model(graph)
|
|
||||||
cond = self.cond(tag_cond, val_cond, stage_tensor)
|
|
||||||
if stage == 1:
|
|
||||||
vision_score, topo_score = self.head1(vision, topo, graph, cond)
|
|
||||||
elif stage == 2:
|
|
||||||
vision_score, topo_score = self.head2(vision, topo, graph, cond)
|
|
||||||
elif stage == 3:
|
|
||||||
vision_score, topo_score = self.head3(vision, topo, graph, cond)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Unknown critic stage.")
|
|
||||||
score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score
|
|
||||||
return score, vision_score, topo_score
|
|
||||||
|
|
||||||
class MinamoHead2(nn.Module):
|
|
||||||
def __init__(self, in_ch, hidden_ch):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13)
|
|
||||||
self.pool = nn.AdaptiveMaxPool2d(1)
|
|
||||||
self.proj = spectral_norm(nn.Linear(256, hidden_ch))
|
|
||||||
self.fc = spectral_norm(nn.Linear(hidden_ch, 1))
|
|
||||||
|
|
||||||
def forward(self, x, cond):
|
|
||||||
x = self.conv(x)
|
|
||||||
x = self.pool(x)
|
|
||||||
x = x.squeeze(3).squeeze(2)
|
|
||||||
cond = self.proj(cond)
|
|
||||||
proj = torch.sum(x * cond, dim=1, keepdim=True)
|
|
||||||
x = self.fc(x) + proj
|
|
||||||
return x
|
|
||||||
|
|
||||||
class MinamoModel2(nn.Module):
|
|
||||||
def __init__(self, tile_types=32):
|
|
||||||
super().__init__()
|
|
||||||
self.cond = ConditionEncoder(64, 16, 256, 256)
|
|
||||||
|
|
||||||
self.conv1 = ConvFusionModule(tile_types, 256, 256, 13, 13)
|
|
||||||
self.conv2 = ConvFusionModule(256, 512, 256, 13, 13)
|
|
||||||
self.conv3 = ConvFusionModule(256, 512, 256, 13, 13)
|
|
||||||
|
|
||||||
self.head0 = MinamoHead2(256, 256) # 随机头的判别头
|
|
||||||
self.head1 = MinamoHead2(256, 256)
|
|
||||||
self.head2 = MinamoHead2(256, 256)
|
|
||||||
self.head3 = MinamoHead2(256, 256)
|
|
||||||
|
|
||||||
# self.inject1 = ConditionInjector(256, 256)
|
|
||||||
# self.inject2 = ConditionInjector(256, 256)
|
|
||||||
self.inject3 = ConditionInjector(256, 256)
|
|
||||||
|
|
||||||
def forward(self, x, stage, tag_cond, val_cond):
|
|
||||||
B, D = tag_cond.shape
|
|
||||||
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device)
|
|
||||||
cond = self.cond(tag_cond, val_cond, stage_tensor)
|
|
||||||
x = self.conv1(x)
|
|
||||||
# x = self.inject1(x, cond)
|
|
||||||
x = self.conv2(x)
|
|
||||||
# x = self.inject2(x, cond)
|
|
||||||
x = self.conv3(x)
|
|
||||||
x = self.inject3(x, cond)
|
|
||||||
|
|
||||||
if stage == 0:
|
|
||||||
score = self.head0(x, cond)
|
|
||||||
elif stage == 1:
|
|
||||||
score = self.head1(x, cond)
|
|
||||||
elif stage == 2:
|
|
||||||
score = self.head2(x, cond)
|
|
||||||
elif stage == 3:
|
|
||||||
score = self.head3(x, cond)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Unknown critic stage.")
|
|
||||||
|
|
||||||
return score
|
|
||||||
|
|
||||||
# 检查显存占用
|
|
||||||
if __name__ == "__main__":
|
|
||||||
input = torch.randn((1, 32, 13, 13)).cuda()
|
|
||||||
tag = torch.rand(1, 64).cuda()
|
|
||||||
val = torch.rand(1, 16).cuda()
|
|
||||||
|
|
||||||
# 初始化模型
|
|
||||||
model = MinamoModel2().cuda()
|
|
||||||
|
|
||||||
print_memory("初始化后")
|
|
||||||
|
|
||||||
# 前向传播
|
|
||||||
output = model(input, 1, tag, val)
|
|
||||||
|
|
||||||
print_memory("前向传播后")
|
|
||||||
|
|
||||||
print(f"输入形状: feat={input.shape}")
|
|
||||||
print(f"输出形状: output={output.shape}")
|
|
||||||
# print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}")
|
|
||||||
# print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}")
|
|
||||||
print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}")
|
|
||||||
print(f"Head parameters: {sum(p.numel() for p in model.head1.parameters())}")
|
|
||||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
|
||||||
@ -1,36 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.nn.utils import spectral_norm
|
|
||||||
from torch_geometric.nn import GATConv, TransformerConv
|
|
||||||
from torch_geometric.data import Data
|
|
||||||
|
|
||||||
class MinamoTopoModel(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, tile_types=32, emb_dim=128, hidden_dim=128, out_dim=512
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
# 传入 softmax 概率值,直接映射
|
|
||||||
self.input_proj = nn.Sequential(
|
|
||||||
spectral_norm(nn.Linear(tile_types, emb_dim)),
|
|
||||||
nn.LeakyReLU(0.2)
|
|
||||||
)
|
|
||||||
# 图卷积层
|
|
||||||
self.conv1 = TransformerConv(emb_dim, hidden_dim, heads=8)
|
|
||||||
self.conv2 = TransformerConv(hidden_dim*8, hidden_dim, heads=8)
|
|
||||||
self.conv3 = TransformerConv(hidden_dim*8, out_dim, heads=1)
|
|
||||||
|
|
||||||
def forward(self, graph: Data):
|
|
||||||
x = self.input_proj(graph.x)
|
|
||||||
|
|
||||||
x = self.conv1(x, graph.edge_index)
|
|
||||||
x = F.leaky_relu(x, 0.2)
|
|
||||||
|
|
||||||
x = self.conv2(x, graph.edge_index)
|
|
||||||
x = F.leaky_relu(x, 0.2)
|
|
||||||
|
|
||||||
x = self.conv3(x, graph.edge_index)
|
|
||||||
x = F.leaky_relu(x, 0.2)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
@ -1,22 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.nn.utils import spectral_norm
|
|
||||||
|
|
||||||
class MinamoVisionModel(nn.Module):
|
|
||||||
def __init__(self, in_ch=32, out_ch=512):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Sequential(
|
|
||||||
spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3)), # 11*11
|
|
||||||
nn.LeakyReLU(0.2),
|
|
||||||
|
|
||||||
spectral_norm(nn.Conv2d(in_ch*2, in_ch*8, 3)), #9*9
|
|
||||||
nn.LeakyReLU(0.2),
|
|
||||||
|
|
||||||
spectral_norm(nn.Conv2d(in_ch*8, out_ch, 3)), # 7*7
|
|
||||||
nn.LeakyReLU(0.2),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv(x)
|
|
||||||
return x
|
|
||||||
221
ginka/dataset.py
@ -8,13 +8,6 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
STAGE1_MASK = [0, 1, 2, 29, 30]
|
|
||||||
STAGE1_REMOVE = list(range(3, 29))
|
|
||||||
STAGE2_MASK = [3, 4, 5, 6, 26, 27, 28]
|
|
||||||
STAGE2_REMOVE = list(range(7, 26))
|
|
||||||
STAGE3_MASK = list(range(7, 26))
|
|
||||||
STAGE3_REMOVE = []
|
|
||||||
|
|
||||||
def load_data(path: str):
|
def load_data(path: str):
|
||||||
with open(path, 'r', encoding="utf-8") as f:
|
with open(path, 'r', encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
@ -25,220 +18,6 @@ def load_data(path: str):
|
|||||||
|
|
||||||
return data_list
|
return data_list
|
||||||
|
|
||||||
def load_minamo_gan_data(data: list):
|
|
||||||
res = list()
|
|
||||||
for one in data:
|
|
||||||
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
|
|
||||||
return res
|
|
||||||
|
|
||||||
def apply_curriculum_remove(
|
|
||||||
maps: torch.Tensor,
|
|
||||||
remove_classes: List[int], # 要移除的类别索引
|
|
||||||
):
|
|
||||||
C, H, W = maps.shape
|
|
||||||
device = maps.device
|
|
||||||
removed_maps = maps.clone()
|
|
||||||
|
|
||||||
remove_mask = removed_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0
|
|
||||||
removed_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0
|
|
||||||
removed_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地”
|
|
||||||
|
|
||||||
return removed_maps.to(device)
|
|
||||||
|
|
||||||
def apply_curriculum_mask(
|
|
||||||
maps: torch.Tensor, # [C, H, W]
|
|
||||||
mask_classes: List[int], # 要遮挡的类别索引
|
|
||||||
remove_classes: List[int], # 要移除的类别索引
|
|
||||||
mask_ratio: float # 遮挡比例 0~1
|
|
||||||
) -> torch.Tensor:
|
|
||||||
C, H, W = maps.shape
|
|
||||||
masked_maps = maps.clone()
|
|
||||||
|
|
||||||
# Step 1: 移除不需要的类别(全设为 0 类)
|
|
||||||
if remove_classes:
|
|
||||||
remove_mask = masked_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0
|
|
||||||
masked_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0
|
|
||||||
masked_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地”
|
|
||||||
|
|
||||||
removed_maps = masked_maps.clone()
|
|
||||||
|
|
||||||
# Step 2: 对指定类别随机遮挡
|
|
||||||
for cls in mask_classes:
|
|
||||||
cls_mask = masked_maps[cls] > 0 # 目标类别的像素布尔掩码 [H, W]
|
|
||||||
indices = cls_mask.nonzero(as_tuple=False) # 所有该类像素坐标
|
|
||||||
num_mask = int(len(indices) * mask_ratio)
|
|
||||||
if num_mask > 0:
|
|
||||||
selected = indices[torch.randperm(len(indices))[:num_mask]]
|
|
||||||
masked_maps[cls, selected[:, 0], selected[:, 1]] = 0
|
|
||||||
masked_maps[0, selected[:, 0], selected[:, 1]] = 1 # 置为“空地”
|
|
||||||
|
|
||||||
return removed_maps, masked_maps
|
|
||||||
|
|
||||||
def apply_curriculum_wall_mask(
|
|
||||||
maps: torch.Tensor, # [C, H, W]
|
|
||||||
mask_classes: List[int], # 要遮挡的类别索引
|
|
||||||
remove_classes: List[int], # 要移除的类别索引
|
|
||||||
mask_ratio: float # 遮挡比例 0~1
|
|
||||||
) -> torch.Tensor:
|
|
||||||
C, H, W = maps.shape
|
|
||||||
masked_maps = maps.clone()
|
|
||||||
|
|
||||||
# Step 1: 移除不需要的类别(全设为 0 类)
|
|
||||||
if remove_classes:
|
|
||||||
remove_mask = masked_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0
|
|
||||||
masked_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0
|
|
||||||
masked_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地”
|
|
||||||
|
|
||||||
removed_maps = masked_maps.clone()
|
|
||||||
|
|
||||||
area = H * W * mask_ratio
|
|
||||||
l = math.floor(math.sqrt(area))
|
|
||||||
nx = random.randint(0, W - l)
|
|
||||||
ny = random.randint(0, H - l)
|
|
||||||
masked_maps[mask_classes, nx:nx+l, ny:ny+l] = 0
|
|
||||||
masked_maps[0, nx:nx+l, ny:ny+l] = 1
|
|
||||||
|
|
||||||
return removed_maps, masked_maps
|
|
||||||
|
|
||||||
class GinkaWGANDataset(Dataset):
|
|
||||||
def __init__(self, data_path: str, device):
|
|
||||||
self.data = load_data(data_path) # 自定义数据加载函数
|
|
||||||
self.device = device
|
|
||||||
self.train_stage = 1
|
|
||||||
self.mask_ratio1 = 0.1
|
|
||||||
self.mask_ratio2 = 0.1
|
|
||||||
self.mask_ratio3 = 0.1
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def handle_stage1(self, target, tag_cond, val_cond):
|
|
||||||
# 课程学习第一阶段,蒙版填充
|
|
||||||
removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
|
|
||||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
|
|
||||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
|
|
||||||
rand = torch.rand(32, 32, 32, device=target.device)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"rand": rand,
|
|
||||||
"real0": removed1,
|
|
||||||
"real1": removed1,
|
|
||||||
"masked1": masked1,
|
|
||||||
"real2": removed2,
|
|
||||||
"masked2": masked2,
|
|
||||||
"real3": removed3,
|
|
||||||
"masked3": masked3,
|
|
||||||
"tag_cond": tag_cond,
|
|
||||||
"val_cond": val_cond
|
|
||||||
}
|
|
||||||
|
|
||||||
def handle_stage2(self, target, tag_cond, val_cond):
|
|
||||||
# 课程学习第二阶段,完全随机蒙版
|
|
||||||
removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
|
||||||
# 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可
|
|
||||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1))
|
|
||||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1))
|
|
||||||
rand = torch.rand(32, 32, 32, device=target.device)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"rand": rand,
|
|
||||||
"real0": removed1,
|
|
||||||
"real1": removed1,
|
|
||||||
"masked1": masked1,
|
|
||||||
"real2": removed2,
|
|
||||||
"masked2": masked2,
|
|
||||||
"real3": removed3,
|
|
||||||
"masked3": masked3,
|
|
||||||
"tag_cond": tag_cond,
|
|
||||||
"val_cond": val_cond
|
|
||||||
}
|
|
||||||
|
|
||||||
def handle_stage3(self, target, tag_cond, val_cond):
|
|
||||||
# 第三阶段,联合生成,输入随机蒙版
|
|
||||||
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
|
|
||||||
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
|
||||||
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
|
||||||
rand = torch.rand(32, 32, 32, device=target.device)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"rand": rand,
|
|
||||||
"real0": removed1,
|
|
||||||
"real1": removed1,
|
|
||||||
"masked1": removed1,
|
|
||||||
"real2": removed2,
|
|
||||||
"masked2": torch.zeros_like(target),
|
|
||||||
"real3": removed3,
|
|
||||||
"masked3": torch.zeros_like(target),
|
|
||||||
"tag_cond": tag_cond,
|
|
||||||
"val_cond": val_cond
|
|
||||||
}
|
|
||||||
|
|
||||||
def handle_stage4(self, target, tag_cond, val_cond):
|
|
||||||
# 第四阶段,完全随机输入
|
|
||||||
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
|
|
||||||
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
|
||||||
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
|
||||||
rand = torch.rand(32, 32, 32, device=target.device)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"rand": rand,
|
|
||||||
"real0": removed1,
|
|
||||||
"real1": removed1,
|
|
||||||
"masked1": rand,
|
|
||||||
"real2": removed2,
|
|
||||||
"masked2": torch.zeros_like(target),
|
|
||||||
"real3": removed3,
|
|
||||||
"masked3": torch.zeros_like(target),
|
|
||||||
"tag_cond": tag_cond,
|
|
||||||
"val_cond": val_cond
|
|
||||||
}
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
item = self.data[idx]
|
|
||||||
|
|
||||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
|
||||||
C, H, W = target.shape
|
|
||||||
tag_cond = torch.FloatTensor(item['tag'])
|
|
||||||
val_cond = torch.FloatTensor(item['val'])
|
|
||||||
val_cond[9] = val_cond[9] / H / W
|
|
||||||
val_cond[10] = val_cond[10] / H / W
|
|
||||||
|
|
||||||
if self.train_stage == 1:
|
|
||||||
return self.handle_stage1(target, tag_cond, val_cond)
|
|
||||||
|
|
||||||
elif self.train_stage == 2:
|
|
||||||
return self.handle_stage2(target, tag_cond, val_cond)
|
|
||||||
|
|
||||||
elif self.train_stage == 3:
|
|
||||||
return self.handle_stage3(target, tag_cond, val_cond)
|
|
||||||
|
|
||||||
elif self.train_stage == 4:
|
|
||||||
return self.handle_stage4(target, tag_cond, val_cond)
|
|
||||||
|
|
||||||
raise RuntimeError(f"Invalid train stage: {self.train_stage}")
|
|
||||||
|
|
||||||
class GinkaRNNDataset(Dataset):
|
|
||||||
def __init__(self, data_path: str, device):
|
|
||||||
self.data = load_data(data_path) # 自定义数据加载函数
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
item = self.data[idx]
|
|
||||||
|
|
||||||
target = torch.LongTensor(item['map']) # [H, W]
|
|
||||||
H, W = target.shape
|
|
||||||
tag_cond = torch.FloatTensor(item['tag'])
|
|
||||||
val_cond = torch.FloatTensor(item['val'])
|
|
||||||
|
|
||||||
return {
|
|
||||||
"tag_cond": tag_cond,
|
|
||||||
"val_cond": val_cond,
|
|
||||||
"target_map": target
|
|
||||||
}
|
|
||||||
|
|
||||||
class GinkaMaskGITDataset(Dataset):
|
class GinkaMaskGITDataset(Dataset):
|
||||||
def __init__(self, data_path: str, device):
|
def __init__(self, data_path: str, device):
|
||||||
self.data = load_data(data_path) # 自定义数据加载函数
|
self.data = load_data(data_path) # 自定义数据加载函数
|
||||||
|
|||||||
@ -1,37 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch_geometric.nn import GCNConv, TransformerConv
|
|
||||||
from torch_geometric.utils import grid
|
|
||||||
from ..common.cond import ConditionInjector
|
|
||||||
|
|
||||||
# 考虑使用 GCN 作为生成器主路径,暂时先留着
|
|
||||||
|
|
||||||
class GCNBlock(nn.Module):
|
|
||||||
def __init__(self, feats: tuple[int, int, int]):
|
|
||||||
super().__init__()
|
|
||||||
self.conv1 = GCNConv(feats[0], feats[1])
|
|
||||||
self.conv2 = GCNConv(feats[1], feats[2])
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(feats[1])
|
|
||||||
self.norm2 = nn.LayerNorm(feats[2])
|
|
||||||
|
|
||||||
def forward(self, x, edge_index):
|
|
||||||
x = self.conv1(x, edge_index)
|
|
||||||
x = F.elu(self.norm1(x))
|
|
||||||
|
|
||||||
x = self.conv2(x, edge_index)
|
|
||||||
x = F.elu(self.norm2(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
class GinkaGCNEncoder(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
class GinkaGCNDecoder(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
class GinkaGCNModel(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
@ -1,58 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from ..common.common import ConvFusionModule
|
|
||||||
from ..common.cond import ConditionInjector
|
|
||||||
from .unet import GinkaEncoderPath, GinkaDecoderPath
|
|
||||||
|
|
||||||
class RandomInputHead(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.enc = GinkaEncoderPath(32, 32)
|
|
||||||
self.dec = GinkaDecoderPath(32)
|
|
||||||
self.out_conv = nn.Sequential(
|
|
||||||
nn.AdaptiveMaxPool2d((15, 15)),
|
|
||||||
nn.Conv2d(32, 64, 3, padding=0),
|
|
||||||
nn.InstanceNorm2d(64),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
nn.Conv2d(64, 32, 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, cond):
|
|
||||||
x1, x2, x3, x4 = self.enc(x, cond)
|
|
||||||
x = self.dec(x1, x2, x3, x4, cond)
|
|
||||||
x = self.out_conv(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class InputUpsample(nn.Module):
|
|
||||||
def __init__(self, in_ch, hidden_ch=64, out_ch=64):
|
|
||||||
super().__init__()
|
|
||||||
self.net = nn.Sequential(
|
|
||||||
ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13),
|
|
||||||
|
|
||||||
nn.Upsample(scale_factor=2, mode='nearest'), # 13x13 → 26x26
|
|
||||||
ConvFusionModule(hidden_ch, hidden_ch, hidden_ch, 26, 26),
|
|
||||||
|
|
||||||
nn.Upsample(size=(32, 32), mode='nearest'), # 26x26 → 32x32
|
|
||||||
ConvFusionModule(hidden_ch, hidden_ch, out_ch, 32, 32),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x): # [B, C, 13, 13]
|
|
||||||
x = self.net(x) # [B, C, 32, 32]
|
|
||||||
return x
|
|
||||||
|
|
||||||
class GinkaInput(nn.Module):
|
|
||||||
def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)):
|
|
||||||
super().__init__()
|
|
||||||
self.out_size = out_size
|
|
||||||
self.upsample = InputUpsample(in_ch, in_ch*2, out_ch)
|
|
||||||
self.enc = ConvFusionModule(out_ch, out_ch*2, out_ch, out_size[0], out_size[1])
|
|
||||||
self.inject1 = ConditionInjector(256, out_ch)
|
|
||||||
self.inject2 = ConditionInjector(256, out_ch)
|
|
||||||
|
|
||||||
def forward(self, x, cond):
|
|
||||||
x = self.upsample(x)
|
|
||||||
x = self.inject1(x, cond)
|
|
||||||
x = self.enc(x)
|
|
||||||
x = self.inject2(x, cond)
|
|
||||||
return x
|
|
||||||
@ -1,420 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch_geometric.data import Data
|
|
||||||
|
|
||||||
CLASS_NUM = 32
|
|
||||||
ILLEGAL_MAX_NUM = 30
|
|
||||||
|
|
||||||
STAGE_CHANGEABLE = [
|
|
||||||
[],
|
|
||||||
[0, 1, 2, 29, 30],
|
|
||||||
[3, 4, 5, 6, 26, 27, 28],
|
|
||||||
list(range(7, 26))
|
|
||||||
]
|
|
||||||
|
|
||||||
STAGE_ALLOWED = [
|
|
||||||
[],
|
|
||||||
STAGE_CHANGEABLE[1],
|
|
||||||
[*STAGE_CHANGEABLE[1], *STAGE_CHANGEABLE[2]],
|
|
||||||
[*STAGE_CHANGEABLE[1], *STAGE_CHANGEABLE[2], *STAGE_CHANGEABLE[3]]
|
|
||||||
]
|
|
||||||
|
|
||||||
DENSITY_MAP = [
|
|
||||||
[1, *list(range(3, 30))],
|
|
||||||
[1],
|
|
||||||
[2],
|
|
||||||
[3, 4, 5, 6],
|
|
||||||
[26, 27, 28],
|
|
||||||
list(range(7, 26)),
|
|
||||||
list(range(10, 19)),
|
|
||||||
[19, 20, 21, 22],
|
|
||||||
[7, 8, 9],
|
|
||||||
[23, 24, 25],
|
|
||||||
[29, 30]
|
|
||||||
]
|
|
||||||
|
|
||||||
DENSITY_WEIGHTS = [
|
|
||||||
1,
|
|
||||||
1.5,
|
|
||||||
0.5,
|
|
||||||
5,
|
|
||||||
4,
|
|
||||||
3,
|
|
||||||
3,
|
|
||||||
3,
|
|
||||||
5,
|
|
||||||
10,
|
|
||||||
20
|
|
||||||
]
|
|
||||||
|
|
||||||
DENSITY_STAGE = [
|
|
||||||
[],
|
|
||||||
[1, 2],
|
|
||||||
[1, 2, 3, 4],
|
|
||||||
list(range(0, 10))
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_not_allowed(classes: list[int], include_illegal=False):
|
|
||||||
res = list()
|
|
||||||
for num in range(0, CLASS_NUM):
|
|
||||||
if not num in classes:
|
|
||||||
if num > ILLEGAL_MAX_NUM:
|
|
||||||
if include_illegal:
|
|
||||||
res.append(num)
|
|
||||||
else:
|
|
||||||
res.append(num)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def inner_constraint_loss(pred: torch.Tensor, allowed=list(range(0, 30))):
|
|
||||||
"""限定内部允许出现的图块种类
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pred (torch.Tensor): 模型输出的概率分布 [B, C, H, W]
|
|
||||||
allowed (list, optional): 在地图中部(除最外圈)允许出现的图块种类
|
|
||||||
"""
|
|
||||||
B, C, H, W = pred.shape
|
|
||||||
|
|
||||||
# 创建内部 mask [H, W]
|
|
||||||
mask = torch.ones((H, W), dtype=torch.bool, device=pred.device)
|
|
||||||
mask[0, :] = False # 第一行
|
|
||||||
mask[-1, :] = False # 最后一行
|
|
||||||
mask[:, 0] = False # 第一列
|
|
||||||
mask[:, -1] = False # 最后一列
|
|
||||||
|
|
||||||
# 提取所有允许和不允许类别的概率和 [B, H, W]
|
|
||||||
unallowed_probs = pred[:, get_not_allowed(allowed, include_illegal=True), :, :].sum(dim=1)
|
|
||||||
|
|
||||||
# 获取外圈区域允许类别的概率 [B, N_pixels]
|
|
||||||
inner_unallowed = unallowed_probs[:, mask]
|
|
||||||
|
|
||||||
target = torch.zeros_like(inner_unallowed)
|
|
||||||
loss_unallowed = F.mse_loss(inner_unallowed, target)
|
|
||||||
|
|
||||||
return loss_unallowed
|
|
||||||
|
|
||||||
def _create_distance_kernel(size):
|
|
||||||
"""生成一个环状衰减核"""
|
|
||||||
y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij')
|
|
||||||
center = size // 2
|
|
||||||
dist = torch.sqrt((x - center)**2 + (y - center)**2)
|
|
||||||
kernel = 1 / (dist + 1)
|
|
||||||
kernel /= kernel.sum() # 归一化
|
|
||||||
return kernel.unsqueeze(0).unsqueeze(0) # [1,1,H,W]
|
|
||||||
|
|
||||||
def entrance_constraint_loss(
|
|
||||||
pred: torch.Tensor,
|
|
||||||
entrance_classes=[29, 30],
|
|
||||||
min_distance=9,
|
|
||||||
presence_threshold=0.8,
|
|
||||||
lambda_presence=1.0,
|
|
||||||
lambda_spacing=0.5
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
入口约束损失函数
|
|
||||||
|
|
||||||
参数:
|
|
||||||
pred: 模型输出的概率分布 [B, C, H, W]
|
|
||||||
entrance_classes: 入口类别列表
|
|
||||||
min_distance: 最小间隔距离(对应卷积核尺寸)
|
|
||||||
presence_threshold: 存在性概率阈值
|
|
||||||
lambda_presence: 存在性损失权重
|
|
||||||
lambda_spacing: 间距约束权重
|
|
||||||
|
|
||||||
返回:
|
|
||||||
total_loss: 综合损失值
|
|
||||||
"""
|
|
||||||
B, C, H, W = pred.shape
|
|
||||||
entrance_probs = pred[:, entrance_classes, :, :].sum(dim=1) # [B, H, W]
|
|
||||||
|
|
||||||
# 计算存在性损失:鼓励至少有一个高置信度入口
|
|
||||||
max_per_sample = entrance_probs.view(B, -1).max(dim=1)[0] # [B, H*W] -> [B, 1]
|
|
||||||
presence_loss = F.relu(presence_threshold - max_per_sample).mean()
|
|
||||||
|
|
||||||
# 生成空间权重掩码(中心衰减)
|
|
||||||
y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
|
|
||||||
center_weight = 1 - torch.sqrt(((x-W//2)/W*2)**2 + ((y-H//2)/H*2)**2)
|
|
||||||
center_weight = center_weight.clamp(0,1).to(pred.device) # [H,W]
|
|
||||||
|
|
||||||
# 概率密度感知的间距计算
|
|
||||||
kernel = _create_distance_kernel(min_distance) # 自定义函数生成权重核
|
|
||||||
kernel = kernel.to(pred.device)
|
|
||||||
density_map = F.conv2d(entrance_probs.unsqueeze(1), kernel, padding=min_distance-1)
|
|
||||||
|
|
||||||
spacing_loss = density_map.mean()
|
|
||||||
|
|
||||||
# 区域加权综合损失
|
|
||||||
total_loss = (
|
|
||||||
lambda_presence * presence_loss +
|
|
||||||
lambda_spacing * (spacing_loss * center_weight).mean()
|
|
||||||
)
|
|
||||||
return total_loss
|
|
||||||
|
|
||||||
def input_head_illegal_loss(input_map, allowed_classes=[0, 1, 2]):
|
|
||||||
C = input_map.shape[1]
|
|
||||||
unallowed = get_not_allowed(allowed_classes, include_illegal=True)
|
|
||||||
illegal = input_map[:, unallowed, :, :]
|
|
||||||
penalty = F.l1_loss(illegal, torch.zeros_like(illegal, device=illegal.device))
|
|
||||||
|
|
||||||
return penalty
|
|
||||||
|
|
||||||
def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=[1, 2]):
|
|
||||||
wall_prob = input_map[:, wall_class] # [B, H, W]
|
|
||||||
wall_ratio = wall_prob.mean() # 计算平均墙体占比
|
|
||||||
wall_penalty = torch.clamp(wall_ratio - max_wall_ratio, min=0.0) # 超过则惩罚
|
|
||||||
|
|
||||||
return wall_penalty
|
|
||||||
|
|
||||||
def compute_multi_density_loss(probs, target_densities, tile_list):
|
|
||||||
"""
|
|
||||||
pred: [B, C, H, W]
|
|
||||||
target_densities: [B, N] - N 个目标类别密度
|
|
||||||
class_indices: [N] - 对应类别通道索引
|
|
||||||
"""
|
|
||||||
losses = []
|
|
||||||
for i, classes in enumerate(DENSITY_MAP):
|
|
||||||
class_map = probs[:, classes, :, :]
|
|
||||||
pred_density = torch.mean(class_map, dim=(1, 2, 3))
|
|
||||||
if i in tile_list:
|
|
||||||
loss = F.mse_loss(pred_density, target_densities[:, i])
|
|
||||||
losses.append(loss * DENSITY_WEIGHTS[i])
|
|
||||||
return sum(losses)
|
|
||||||
|
|
||||||
# 对图像数据进行插值
|
|
||||||
def interpolate_data(real_data, fake_data, epsilon):
|
|
||||||
return epsilon * real_data + (1 - epsilon) * fake_data
|
|
||||||
|
|
||||||
# 对节点特征进行插值,但保持边连接关系不变
|
|
||||||
def interpolate_graph_features(real_graph, fake_graph, epsilon=0.5):
|
|
||||||
# 插值节点特征
|
|
||||||
x_real, x_fake = real_graph.x, fake_graph.x
|
|
||||||
x_interp = epsilon * x_real + (1 - epsilon) * x_fake
|
|
||||||
|
|
||||||
# 保持边连接关系和边特征不变
|
|
||||||
edge_index_interp = real_graph.edge_index # 保持边连接关系
|
|
||||||
edge_attr_interp = real_graph.edge_attr # 如果有边特征,保持不变
|
|
||||||
|
|
||||||
return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp)
|
|
||||||
|
|
||||||
def js_divergence(p, q, eps=1e-6, softmax=False):
|
|
||||||
if softmax:
|
|
||||||
p = F.softmax(p, dim=1)
|
|
||||||
q = F.softmax(q, dim=1)
|
|
||||||
# softmax 后变成概率分布
|
|
||||||
m = 0.5 * (p + q)
|
|
||||||
|
|
||||||
# log_softmax 以供 kl_div 使用
|
|
||||||
log_p = torch.log(p + eps)
|
|
||||||
log_q = torch.log(q + eps)
|
|
||||||
log_m = torch.log(m + eps)
|
|
||||||
|
|
||||||
kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m)
|
|
||||||
kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m)
|
|
||||||
|
|
||||||
return torch.log1p(0.5 * (kl_pm + kl_qm))
|
|
||||||
|
|
||||||
def immutable_penalty_loss(
|
|
||||||
pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
惩罚模型修改不可更改区域的损失。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input: 模型输出 [B, C, H, W],概率分布 (softmax 前)
|
|
||||||
target: 原始输入图 [B, C, H, W],概率分布 (softmax 前)
|
|
||||||
modifiable_classes: 允许被修改的类别列表
|
|
||||||
"""
|
|
||||||
not_allowed = get_not_allowed(modifiable_classes, include_illegal=True)
|
|
||||||
input_mask = pred[:, not_allowed, :, :]
|
|
||||||
with torch.no_grad():
|
|
||||||
target_mask = torch.argmax(input[:, not_allowed, :, :], dim=1)
|
|
||||||
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
|
|
||||||
|
|
||||||
# 差异区域(模型试图改变的地方)
|
|
||||||
penalty = torch.clamp(F.cross_entropy(input_mask, target_mask) - 0.2, min=0)
|
|
||||||
|
|
||||||
return penalty
|
|
||||||
|
|
||||||
def modifiable_penalty_loss(
|
|
||||||
probs: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
target_modifiable = input[:, modifiable_classes, :, :]
|
|
||||||
pred_modifiable = probs[:, modifiable_classes, :, :]
|
|
||||||
existed = torch.clamp(target_modifiable - pred_modifiable, min=0.0, max=1.0)
|
|
||||||
penalty = F.mse_loss(existed, torch.zeros_like(existed, device=existed.device))
|
|
||||||
|
|
||||||
return penalty
|
|
||||||
|
|
||||||
def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]):
|
|
||||||
not_allowed = get_not_allowed(legal_classes, include_illegal=True)
|
|
||||||
input_mask = pred[:, not_allowed, :, :]
|
|
||||||
target = torch.zeros_like(input_mask)
|
|
||||||
penalty = F.cross_entropy(input_mask, target)
|
|
||||||
return penalty
|
|
||||||
|
|
||||||
class WGANGinkaLoss:
|
|
||||||
def __init__(self, lambda_gp=100, weight=[1, 0.4, 20, 0.2, 0.2, 0.05, 0.4]):
|
|
||||||
# weight:
|
|
||||||
# 1. 判别器损失及图块维持损失(可修改部分的已有内容不可修改)
|
|
||||||
# 2. CE 损失
|
|
||||||
# 3. 不可修改类型损失和非法图块损失
|
|
||||||
# 4. 图块类型损失
|
|
||||||
# 5. 入口存在性损失
|
|
||||||
# 6. 多样性损失
|
|
||||||
# 7. 密度损失
|
|
||||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
|
||||||
self.weight = weight
|
|
||||||
|
|
||||||
def compute_gradient_penalty(self, critic, stage, real_data, fake_data, tag_cond, val_cond):
|
|
||||||
# 进行插值
|
|
||||||
batch_size = real_data.size(0)
|
|
||||||
epsilon_data = torch.rand(batch_size, 1, 1, 1, device=real_data.device)
|
|
||||||
interp_data = interpolate_data(real_data, fake_data, epsilon_data).to(real_data.device)
|
|
||||||
|
|
||||||
# 对图像进行反向传播并计算梯度
|
|
||||||
interp_data.requires_grad_()
|
|
||||||
|
|
||||||
d_score = critic(interp_data, stage, tag_cond, val_cond)
|
|
||||||
|
|
||||||
# 计算梯度
|
|
||||||
grad = torch.autograd.grad(
|
|
||||||
outputs=d_score, inputs=interp_data,
|
|
||||||
grad_outputs=torch.ones_like(d_score),
|
|
||||||
create_graph=True, retain_graph=True, only_inputs=True
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
# 计算梯度的 L2 范数
|
|
||||||
grad_norm = grad.reshape(batch_size, -1).norm(2, dim=1)
|
|
||||||
# 计算梯度惩罚项
|
|
||||||
gp_loss = ((grad_norm - 1.0) ** 2).mean()
|
|
||||||
# print(grad_norm_topo.mean().item(), grad_norm_vis.mean().item())
|
|
||||||
|
|
||||||
return gp_loss
|
|
||||||
|
|
||||||
def discriminator_loss(
|
|
||||||
self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor,
|
|
||||||
tag_cond: torch.Tensor, val_cond: torch.Tensor
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
""" 判别器损失函数 """
|
|
||||||
fake_data = F.softmax(fake_data, dim=1)
|
|
||||||
real_scores = critic(real_data, stage, tag_cond, val_cond)
|
|
||||||
fake_scores = critic(fake_data, stage, tag_cond, val_cond)
|
|
||||||
|
|
||||||
# Wasserstein 距离
|
|
||||||
d_loss = fake_scores.mean() - real_scores.mean()
|
|
||||||
grad_loss = self.compute_gradient_penalty(critic, stage, real_data, fake_data, tag_cond, val_cond)
|
|
||||||
|
|
||||||
total_loss = d_loss + self.lambda_gp * grad_loss
|
|
||||||
|
|
||||||
return total_loss, d_loss
|
|
||||||
|
|
||||||
def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input, tag_cond, val_cond) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
""" 生成器损失函数 """
|
|
||||||
probs_fake = F.softmax(fake, dim=1)
|
|
||||||
|
|
||||||
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
|
||||||
minamo_loss = -torch.mean(fake_scores)
|
|
||||||
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
|
|
||||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
|
|
||||||
constraint_loss = inner_constraint_loss(probs_fake)
|
|
||||||
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
|
|
||||||
|
|
||||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
|
||||||
|
|
||||||
losses = [
|
|
||||||
minamo_loss * self.weight[0],
|
|
||||||
ce_loss * self.weight[1],
|
|
||||||
immutable_loss * self.weight[2],
|
|
||||||
constraint_loss * self.weight[3],
|
|
||||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
|
||||||
density_loss * self.weight[6],
|
|
||||||
]
|
|
||||||
|
|
||||||
if stage == 1:
|
|
||||||
# 第一个阶段检查入口存在性
|
|
||||||
entrance_loss = entrance_constraint_loss(probs_fake)
|
|
||||||
losses.append(entrance_loss * self.weight[4])
|
|
||||||
|
|
||||||
return sum(losses), ce_loss
|
|
||||||
|
|
||||||
def generator_loss_total(self, critic, stage, fake, tag_cond, val_cond) -> torch.Tensor:
|
|
||||||
probs_fake = F.softmax(fake, dim=1)
|
|
||||||
|
|
||||||
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
|
||||||
minamo_loss = -torch.mean(fake_scores)
|
|
||||||
illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage])
|
|
||||||
constraint_loss = inner_constraint_loss(probs_fake)
|
|
||||||
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
|
|
||||||
|
|
||||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
|
||||||
|
|
||||||
losses = [
|
|
||||||
minamo_loss * self.weight[0],
|
|
||||||
illegal_loss * self.weight[2],
|
|
||||||
constraint_loss * self.weight[3],
|
|
||||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
|
||||||
density_loss * self.weight[6],
|
|
||||||
]
|
|
||||||
|
|
||||||
if stage == 1:
|
|
||||||
# 第一个阶段检查入口存在性
|
|
||||||
entrance_loss = entrance_constraint_loss(probs_fake)
|
|
||||||
losses.append(entrance_loss * self.weight[4])
|
|
||||||
|
|
||||||
return sum(losses)
|
|
||||||
|
|
||||||
def generator_loss_total_with_input(self, critic, stage, fake, input, tag_cond, val_cond) -> torch.Tensor:
|
|
||||||
probs_fake = F.softmax(fake, dim=1)
|
|
||||||
|
|
||||||
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
|
||||||
minamo_loss = -torch.mean(fake_scores)
|
|
||||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
|
|
||||||
constraint_loss = inner_constraint_loss(probs_fake)
|
|
||||||
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
|
|
||||||
|
|
||||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
|
||||||
|
|
||||||
losses = [
|
|
||||||
minamo_loss * self.weight[0],
|
|
||||||
immutable_loss * self.weight[2],
|
|
||||||
constraint_loss * self.weight[3],
|
|
||||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
|
||||||
density_loss * self.weight[6],
|
|
||||||
]
|
|
||||||
|
|
||||||
if stage == 1:
|
|
||||||
# 第一个阶段检查入口存在性
|
|
||||||
entrance_loss = entrance_constraint_loss(probs_fake)
|
|
||||||
losses.append(entrance_loss * self.weight[4])
|
|
||||||
|
|
||||||
return sum(losses)
|
|
||||||
|
|
||||||
def generator_input_head_loss(self, critic, probs: torch.Tensor, tag_cond, val_cond) -> torch.Tensor:
|
|
||||||
head_scores = -torch.mean(critic(probs, 0, tag_cond, val_cond))
|
|
||||||
probs_a, probs_b = probs.chunk(2, dim=0)
|
|
||||||
|
|
||||||
losses = [
|
|
||||||
head_scores,
|
|
||||||
input_head_illegal_loss(probs) * 50,
|
|
||||||
-js_divergence(probs_a, probs_b, softmax=False) * 0.5
|
|
||||||
]
|
|
||||||
|
|
||||||
return sum(losses)
|
|
||||||
|
|
||||||
class RNNGinkaLoss:
|
|
||||||
def __init__(self, num_classes, device):
|
|
||||||
self.num_classes = num_classes
|
|
||||||
weight = torch.ones(self.num_classes)
|
|
||||||
weight[0] = 0.3
|
|
||||||
weight[1] = 0.5
|
|
||||||
self.weight = weight.to(device)
|
|
||||||
pass
|
|
||||||
|
|
||||||
def rnn_loss(self, fake, target):
|
|
||||||
"""
|
|
||||||
fake: [B, C, H, W]
|
|
||||||
target: [B, H, W]
|
|
||||||
"""
|
|
||||||
target = F.one_hot(target, num_classes=self.num_classes).float().permute(0, 3, 1, 2)
|
|
||||||
return F.cross_entropy(fake, target, label_smoothing=0.1)
|
|
||||||
@ -1,66 +0,0 @@
|
|||||||
import time
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from .unet import GinkaUNet
|
|
||||||
from .output import GinkaOutput
|
|
||||||
from .input import GinkaInput, RandomInputHead
|
|
||||||
from ..common.cond import ConditionEncoder
|
|
||||||
|
|
||||||
def print_memory(tag=""):
|
|
||||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
|
||||||
|
|
||||||
class GinkaModel(nn.Module):
|
|
||||||
def __init__(self, base_ch=64, out_ch=32):
|
|
||||||
"""Ginka Model 模型定义部分
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.head = RandomInputHead()
|
|
||||||
self.cond = ConditionEncoder(64, 16, 256, 256)
|
|
||||||
self.input = GinkaInput(32, 64, (13, 13), (32, 32))
|
|
||||||
self.unet = GinkaUNet(64, base_ch, base_ch)
|
|
||||||
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
|
||||||
|
|
||||||
def forward(self, x, stage, tag_cond, val_cond):
|
|
||||||
B, D = tag_cond.shape
|
|
||||||
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device)
|
|
||||||
cond = self.cond(tag_cond, val_cond, stage_tensor)
|
|
||||||
if stage == 0:
|
|
||||||
x = self.head(x, cond)
|
|
||||||
else:
|
|
||||||
x = self.input(x, cond)
|
|
||||||
x = self.unet(x, cond)
|
|
||||||
x = self.output(x, stage, cond)
|
|
||||||
return x
|
|
||||||
|
|
||||||
# 检查显存占用
|
|
||||||
if __name__ == "__main__":
|
|
||||||
input = torch.rand(1, 32, 32, 32).cuda()
|
|
||||||
tag = torch.rand(1, 64).cuda()
|
|
||||||
val = torch.rand(1, 16).cuda()
|
|
||||||
|
|
||||||
# 初始化模型
|
|
||||||
model = GinkaModel().cuda()
|
|
||||||
|
|
||||||
print_memory("初始化后")
|
|
||||||
|
|
||||||
# 前向传播
|
|
||||||
start = time.perf_counter()
|
|
||||||
fake0 = model(input, 0, tag, val)
|
|
||||||
fake1 = model(F.softmax(fake0, dim=1), 1, tag, val)
|
|
||||||
fake2 = model(F.softmax(fake1, dim=1), 1, tag, val)
|
|
||||||
fake3 = model(F.softmax(fake2, dim=1), 1, tag, val)
|
|
||||||
end = time.perf_counter()
|
|
||||||
|
|
||||||
print_memory("前向传播后")
|
|
||||||
|
|
||||||
print(f"推理耗时: {end - start}")
|
|
||||||
print(f"输入形状: feat={input.shape}")
|
|
||||||
print(f"输出形状: output={fake3.shape}")
|
|
||||||
print(f"Random parameters: {sum(p.numel() for p in model.head.parameters())}")
|
|
||||||
print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}")
|
|
||||||
print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
|
|
||||||
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
|
|
||||||
print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}")
|
|
||||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
|
||||||
|
|
||||||
@ -1,45 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from ..common.common import ConvFusionModule
|
|
||||||
from ..common.cond import ConditionInjector
|
|
||||||
|
|
||||||
class StageHead(nn.Module):
|
|
||||||
def __init__(self, in_ch, out_ch, out_size=(13, 13)):
|
|
||||||
super().__init__()
|
|
||||||
self.dec1 = ConvFusionModule(in_ch, in_ch*2, in_ch*2, 32, 32)
|
|
||||||
self.dec2 = ConvFusionModule(in_ch*2, in_ch*2, in_ch*2, 32, 32)
|
|
||||||
self.pool = nn.Sequential(
|
|
||||||
ConvFusionModule(in_ch*2, in_ch*2, in_ch*2, 32, 32),
|
|
||||||
ConvFusionModule(in_ch*2, in_ch*2, in_ch, 32, 32),
|
|
||||||
|
|
||||||
nn.AdaptiveMaxPool2d(out_size),
|
|
||||||
nn.Conv2d(in_ch, out_ch, 1)
|
|
||||||
)
|
|
||||||
self.inject1 = ConditionInjector(256, in_ch*2)
|
|
||||||
self.inject2 = ConditionInjector(256, in_ch*2)
|
|
||||||
|
|
||||||
def forward(self, x, cond):
|
|
||||||
x = self.dec1(x)
|
|
||||||
x = self.inject1(x, cond)
|
|
||||||
x = self.dec2(x)
|
|
||||||
x = self.inject2(x, cond)
|
|
||||||
x = self.pool(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class GinkaOutput(nn.Module):
|
|
||||||
def __init__(self, in_ch=64, out_ch=32, out_size=(13, 13)):
|
|
||||||
super().__init__()
|
|
||||||
self.head1 = StageHead(in_ch, out_ch, out_size)
|
|
||||||
self.head2 = StageHead(in_ch, out_ch, out_size)
|
|
||||||
self.head3 = StageHead(in_ch, out_ch, out_size)
|
|
||||||
|
|
||||||
def forward(self, x, stage, cond):
|
|
||||||
if stage == 1:
|
|
||||||
x = self.head1(x, cond)
|
|
||||||
elif stage == 2:
|
|
||||||
x = self.head2(x, cond)
|
|
||||||
elif stage == 3:
|
|
||||||
x = self.head3(x, cond)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Unknown generate stage.")
|
|
||||||
return x
|
|
||||||
@ -1,258 +0,0 @@
|
|||||||
import time
|
|
||||||
import random
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
class RNNConditionEncoder(nn.Module):
|
|
||||||
def __init__(self, val_dim=16, output_dim=256, width=13, height=13):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# 条件编码
|
|
||||||
self.val_fc = nn.Sequential(
|
|
||||||
nn.Linear(val_dim, output_dim * 4),
|
|
||||||
nn.LayerNorm(output_dim * 4),
|
|
||||||
nn.GELU(),
|
|
||||||
)
|
|
||||||
self.fusion = nn.Sequential(
|
|
||||||
nn.Linear(output_dim * 4, output_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, val_cond: torch.Tensor):
|
|
||||||
val_hidden = self.val_fc(val_cond)
|
|
||||||
return self.fusion(val_hidden)
|
|
||||||
|
|
||||||
class GinkaMapPatch(nn.Module):
|
|
||||||
def __init__(self, tile_classes=32, width=13, height=13):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# 地图局部卷积,用于捕获局部结构信息
|
|
||||||
|
|
||||||
self.width = width
|
|
||||||
self.height = height
|
|
||||||
self.tile_classes = 32
|
|
||||||
|
|
||||||
self.patch_cnn = nn.Sequential(
|
|
||||||
nn.Conv2d(tile_classes + 1, 256, 3, padding=1),
|
|
||||||
nn.Dropout(0.2),
|
|
||||||
nn.BatchNorm2d(256),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
nn.Conv2d(256, 512, 3),
|
|
||||||
nn.Dropout(0.2),
|
|
||||||
nn.BatchNorm2d(512),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
nn.Flatten()
|
|
||||||
)
|
|
||||||
self.fc = nn.Linear(512 * 3 * 3, 256)
|
|
||||||
|
|
||||||
def forward(self, map: torch.Tensor, x: int, y: int):
|
|
||||||
"""
|
|
||||||
map: [B, H, W]
|
|
||||||
"""
|
|
||||||
B, H, W = map.shape
|
|
||||||
mask = torch.zeros([B, 5, 5]).to(map.device)
|
|
||||||
result = torch.zeros([B, 5, 5], dtype=torch.long).to(map.device)
|
|
||||||
left = x - 2 if x >= 2 else 0
|
|
||||||
right = x + 3 if x < self.width - 2 else self.width
|
|
||||||
top = y - 4 if y >= 4 else 0
|
|
||||||
bottom = y + 1
|
|
||||||
|
|
||||||
res_left = left - (x - 2)
|
|
||||||
res_right = right - (x + 3) + 5
|
|
||||||
res_top = top - (y - 4)
|
|
||||||
res_bottom = 5
|
|
||||||
|
|
||||||
result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right]
|
|
||||||
# 没画到的地方要置为 0
|
|
||||||
result[:, 4, 2] = 0
|
|
||||||
result[:, 4, 3] = 0
|
|
||||||
result[:, 4, 4] = 0
|
|
||||||
mask[:, res_top:res_bottom, res_left:res_right] = 1
|
|
||||||
mask[:, 4, 2] = 0
|
|
||||||
mask[:, 4, 3] = 0
|
|
||||||
mask[:, 4, 4] = 0
|
|
||||||
masked_result = torch.zeros([B, self.tile_classes + 1, 5, 5]).to(map.device)
|
|
||||||
masked_result[:, 0:32] = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
|
|
||||||
masked_result[:, 32] = mask
|
|
||||||
|
|
||||||
feat = self.patch_cnn(masked_result)
|
|
||||||
feat = self.fc(feat)
|
|
||||||
return feat
|
|
||||||
|
|
||||||
class GinkaTileEmbedding(nn.Module):
|
|
||||||
def __init__(self, tile_classes=32, embed_dim=256):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# 图块编码,上一次画的图块
|
|
||||||
|
|
||||||
self.embedding = nn.Embedding(tile_classes, embed_dim)
|
|
||||||
|
|
||||||
def forward(self, tile: torch.Tensor):
|
|
||||||
return self.embedding(tile)
|
|
||||||
|
|
||||||
class GinkaPosEmbedding(nn.Module):
|
|
||||||
def __init__(self, width=13, height=13, embed_dim=256):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# 位置编码
|
|
||||||
|
|
||||||
self.width = width
|
|
||||||
self.height = height
|
|
||||||
|
|
||||||
self.row_embedding = nn.Embedding(width, embed_dim)
|
|
||||||
self.col_embedding = nn.Embedding(height, embed_dim)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
||||||
row = self.row_embedding(x).squeeze(1)
|
|
||||||
col = self.col_embedding(y).squeeze(1)
|
|
||||||
|
|
||||||
return row, col
|
|
||||||
|
|
||||||
class GinkaInputFusion(nn.Module):
|
|
||||||
def __init__(self, d_model=256):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# 使用 Transformer 进行信息整合
|
|
||||||
|
|
||||||
self.transformer = nn.TransformerEncoder(
|
|
||||||
nn.TransformerEncoderLayer(
|
|
||||||
d_model=d_model, nhead=2, dim_feedforward=d_model*2, batch_first=True,
|
|
||||||
dropout=0.2
|
|
||||||
),
|
|
||||||
num_layers=4
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,
|
|
||||||
row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
tile_embed: [B, 256]
|
|
||||||
cond_vec: [B, 256]
|
|
||||||
row_embed: [B, 256]
|
|
||||||
col_embed: [B, 256]
|
|
||||||
patch_vec: [B, 256]
|
|
||||||
"""
|
|
||||||
vec = torch.stack([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1)
|
|
||||||
feat = self.transformer(vec)
|
|
||||||
return feat[:, 0]
|
|
||||||
|
|
||||||
class GinkaRNN(nn.Module):
|
|
||||||
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# GRU
|
|
||||||
self.gru = nn.GRUCell(input_dim, hidden_dim)
|
|
||||||
self.drop = nn.Dropout(0.2)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(hidden_dim, hidden_dim),
|
|
||||||
nn.LayerNorm(hidden_dim),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
nn.Linear(hidden_dim, tile_classes)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, feat_fusion: torch.Tensor, hidden: torch.Tensor):
|
|
||||||
"""
|
|
||||||
feat_fusion: [B, input_dim]
|
|
||||||
hidden: [B, hidden_dim]
|
|
||||||
"""
|
|
||||||
hidden = self.drop(self.gru(feat_fusion, hidden))
|
|
||||||
logits = self.fc(hidden)
|
|
||||||
return logits, hidden
|
|
||||||
|
|
||||||
class GinkaRNNModel(nn.Module):
|
|
||||||
def __init__(self, device: torch.device, start_tile=31, width=13, height=13):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.device = device
|
|
||||||
self.width = width
|
|
||||||
self.height = height
|
|
||||||
self.start_tile = start_tile
|
|
||||||
|
|
||||||
self.rnn_hidden = 512
|
|
||||||
self.tile_classes = 32
|
|
||||||
|
|
||||||
# 模型结构
|
|
||||||
self.cond = RNNConditionEncoder()
|
|
||||||
self.tile_embedding = GinkaTileEmbedding(tile_classes=self.tile_classes)
|
|
||||||
self.pos_embedding = GinkaPosEmbedding()
|
|
||||||
self.map_patch = GinkaMapPatch(tile_classes=self.tile_classes)
|
|
||||||
self.feat_fusion = GinkaInputFusion()
|
|
||||||
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
|
|
||||||
|
|
||||||
def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self_probility=0):
|
|
||||||
"""
|
|
||||||
val_cond: [B, val_dim]
|
|
||||||
target_map: [B, H, W]
|
|
||||||
use_self: 是否使用自己生成的上一步结果执行下一步
|
|
||||||
"""
|
|
||||||
B, C = val_cond.shape
|
|
||||||
|
|
||||||
# 张量声明
|
|
||||||
now_tile = torch.LongTensor([self.start_tile]).to(self.device).expand(B)
|
|
||||||
|
|
||||||
map = torch.zeros([B, self.height, self.width], dtype=torch.int32).to(self.device)
|
|
||||||
output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device)
|
|
||||||
hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device)
|
|
||||||
|
|
||||||
# 条件编码,全局,所以只用一次
|
|
||||||
cond = self.cond(val_cond)
|
|
||||||
|
|
||||||
for y in range(0, self.height):
|
|
||||||
for x in range(0, self.width):
|
|
||||||
x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1)
|
|
||||||
y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1)
|
|
||||||
# 位置编码、图块编码、地图局部编码
|
|
||||||
tile_embed = self.tile_embedding(now_tile)
|
|
||||||
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
|
||||||
use_self = random.random() < use_self_probility
|
|
||||||
map_patch = self.map_patch(map if use_self else target_map, x, y)
|
|
||||||
# 编码特征融合
|
|
||||||
feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch)
|
|
||||||
# RNN 输出
|
|
||||||
logits, h = self.rnn(feat, hidden)
|
|
||||||
# 处理输出
|
|
||||||
output_logits[:, y, x] = logits[:]
|
|
||||||
hidden = h
|
|
||||||
tile_id = torch.argmax(logits, dim=1).detach()
|
|
||||||
map[:, y, x] = tile_id[:]
|
|
||||||
now_tile = tile_id if use_self else target_map[:, y, x].detach()
|
|
||||||
|
|
||||||
return output_logits.permute(0, 3, 1, 2), map
|
|
||||||
|
|
||||||
def print_memory(device, tag=""):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated(device) / 1024**2:.2f} MB")
|
|
||||||
else:
|
|
||||||
print("当前设备不支持 cuda.")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
device = torch.device("cpu")
|
|
||||||
|
|
||||||
input = torch.randint(0, 32, [1, 13, 13]).to(device)
|
|
||||||
cond = torch.rand(1, 16).to(device)
|
|
||||||
|
|
||||||
# 初始化模型
|
|
||||||
model = GinkaRNNModel("cpu").to(device)
|
|
||||||
|
|
||||||
print_memory("初始化后")
|
|
||||||
|
|
||||||
# 前向传播
|
|
||||||
start = time.perf_counter()
|
|
||||||
fake_logits, fake_map = model(cond, input, False)
|
|
||||||
end = time.perf_counter()
|
|
||||||
|
|
||||||
print_memory("前向传播后")
|
|
||||||
|
|
||||||
print(f"推理耗时: {end - start}")
|
|
||||||
print(f"输出形状: fake_logits={fake_logits.shape}, fake_map={fake_map.shape}")
|
|
||||||
print(f"Condition Encoder parameters: {sum(p.numel() for p in model.cond.parameters())}")
|
|
||||||
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
|
|
||||||
print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}")
|
|
||||||
print(f"Map Patch parameters: {sum(p.numel() for p in model.map_patch.parameters())}")
|
|
||||||
print(f"Feature Fusion parameters: {sum(p.numel() for p in model.feat_fusion.parameters())}")
|
|
||||||
print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}")
|
|
||||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
|
||||||
@ -1,188 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from shared.attention import ChannelAttention
|
|
||||||
from ..common.common import GCNBlock, TransformerGCNBlock, DoubleConvBlock, ConvFusionModule
|
|
||||||
from ..common.cond import ConditionInjector
|
|
||||||
|
|
||||||
class GinkaTransformerEncoder(nn.Module):
|
|
||||||
def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6):
|
|
||||||
super().__init__()
|
|
||||||
in_dim = in_dim // token_size
|
|
||||||
hidden_dim = hidden_dim // token_size
|
|
||||||
out_dim = out_dim // token_size
|
|
||||||
self.embedding = nn.Sequential(
|
|
||||||
nn.Linear(in_dim, hidden_dim),
|
|
||||||
nn.LayerNorm(hidden_dim)
|
|
||||||
)
|
|
||||||
self.pos_embedding = nn.Parameter(torch.randn(1, token_size, hidden_dim))
|
|
||||||
self.transformer = nn.TransformerEncoder(
|
|
||||||
nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=ff_dim, batch_first=True),
|
|
||||||
num_layers=num_layers
|
|
||||||
)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(hidden_dim, out_dim),
|
|
||||||
nn.LayerNorm(out_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# 输入 [B, L, in_dim]
|
|
||||||
# 输出 [B, L, out_dim]
|
|
||||||
x = self.embedding(x) # [B, L, hidden_dim]
|
|
||||||
x = x + self.pos_embedding # [B, L, hidden_dim]
|
|
||||||
x = self.transformer(x) # [B, L, hidden_dim]
|
|
||||||
x = self.fc(x) # [B, L, out_dim]
|
|
||||||
return x
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
|
||||||
def __init__(self, in_ch, out_ch, attn=True):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = DoubleConvBlock([in_ch, out_ch, out_ch])
|
|
||||||
# self.conv = nn.Sequential(
|
|
||||||
# nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
|
||||||
# nn.InstanceNorm2d(out_ch),
|
|
||||||
# nn.ELU(),
|
|
||||||
# nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
|
||||||
# nn.InstanceNorm2d(out_ch),
|
|
||||||
# )
|
|
||||||
# if attn:
|
|
||||||
# self.conv.append(ChannelAttention(out_ch))
|
|
||||||
# self.conv.append(nn.ELU())
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.conv(x)
|
|
||||||
|
|
||||||
class FusionModule(nn.Module):
|
|
||||||
def __init__(self, in_ch, out_ch):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate')
|
|
||||||
|
|
||||||
def forward(self, x1, x2):
|
|
||||||
x = torch.cat([x1, x2], dim=1)
|
|
||||||
x = self.conv(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class GinkaUNetInput(nn.Module):
|
|
||||||
def __init__(self, in_ch, out_ch, w, h):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
|
|
||||||
self.inject = ConditionInjector(256, out_ch)
|
|
||||||
|
|
||||||
def forward(self, x, cond):
|
|
||||||
x = self.conv(x)
|
|
||||||
x = self.inject(x, cond)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class GinkaEncoder(nn.Module):
|
|
||||||
def __init__(self, in_ch, out_ch, w, h):
|
|
||||||
super().__init__()
|
|
||||||
self.pool = nn.MaxPool2d(2)
|
|
||||||
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
|
|
||||||
self.inject = ConditionInjector(256, out_ch)
|
|
||||||
|
|
||||||
def forward(self, x, cond):
|
|
||||||
x = self.pool(x)
|
|
||||||
x = self.conv(x)
|
|
||||||
x = self.inject(x, cond)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class GinkaUpSample(nn.Module):
|
|
||||||
def __init__(self, in_ch, out_ch):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Sequential(
|
|
||||||
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
|
|
||||||
nn.InstanceNorm2d(out_ch),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
|
||||||
nn.InstanceNorm2d(out_ch),
|
|
||||||
nn.GELU()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.conv(x)
|
|
||||||
|
|
||||||
class GinkaDecoder(nn.Module):
|
|
||||||
def __init__(self, in_ch, out_ch, w, h):
|
|
||||||
super().__init__()
|
|
||||||
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
|
|
||||||
self.fusion = nn.Conv2d(in_ch, in_ch, 1)
|
|
||||||
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
|
|
||||||
self.inject = ConditionInjector(256, out_ch)
|
|
||||||
|
|
||||||
def forward(self, x, feat, cond):
|
|
||||||
x = self.upsample(x)
|
|
||||||
x = torch.cat([x, feat], dim=1)
|
|
||||||
x = self.fusion(x)
|
|
||||||
x = self.conv(x)
|
|
||||||
x = self.inject(x, cond)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class GinkaBottleneck(nn.Module):
|
|
||||||
def __init__(self, module_ch, w, h):
|
|
||||||
super().__init__()
|
|
||||||
# self.transformer = GinkaTransformerEncoder(
|
|
||||||
# in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
|
|
||||||
# token_size=16, ff_dim=1024, num_layers=4
|
|
||||||
# )
|
|
||||||
# self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
|
|
||||||
# self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
|
|
||||||
self.conv = ConvFusionModule(module_ch, module_ch, module_ch, w, h)
|
|
||||||
self.inject = ConditionInjector(256, module_ch)
|
|
||||||
|
|
||||||
def forward(self, x, cond):
|
|
||||||
# x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
|
|
||||||
# x1 = self.transformer(x1)
|
|
||||||
# x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
|
|
||||||
x = self.conv(x)
|
|
||||||
x = self.inject(x, cond)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class GinkaEncoderPath(nn.Module):
|
|
||||||
def __init__(self, in_ch, base_ch):
|
|
||||||
super().__init__()
|
|
||||||
self.down1 = GinkaUNetInput(in_ch, base_ch, 32, 32)
|
|
||||||
self.down2 = GinkaEncoder(base_ch, base_ch*2, 16, 16)
|
|
||||||
self.down3 = GinkaEncoder(base_ch*2, base_ch*4, 8, 8)
|
|
||||||
self.down4 = GinkaEncoder(base_ch*4, base_ch*8, 4, 4)
|
|
||||||
|
|
||||||
def forward(self, x, cond):
|
|
||||||
x1 = self.down1(x, cond) # [B, 64, 32, 32]
|
|
||||||
x2 = self.down2(x1, cond) # [B, 128, 16, 16]
|
|
||||||
x3 = self.down3(x2, cond) # [B, 256, 8, 8]
|
|
||||||
x4 = self.down4(x3, cond) # [B, 512, 4, 4]
|
|
||||||
|
|
||||||
return x1, x2, x3, x4
|
|
||||||
|
|
||||||
class GinkaDecoderPath(nn.Module):
|
|
||||||
def __init__(self, base_ch):
|
|
||||||
super().__init__()
|
|
||||||
self.up1 = GinkaDecoder(base_ch*8, base_ch*4, 8, 8)
|
|
||||||
self.up2 = GinkaDecoder(base_ch*4, base_ch*2, 16, 16)
|
|
||||||
self.up3 = GinkaDecoder(base_ch*2, base_ch, 32, 32)
|
|
||||||
|
|
||||||
def forward(self, x1, x2, x3, x4, cond):
|
|
||||||
x = self.up1(x4, x3, cond) # [B, 256, 8, 8]
|
|
||||||
x = self.up2(x, x2, cond) # [B, 128, 16, 16]
|
|
||||||
x = self.up3(x, x1, cond) # [B, 64, 32, 32]
|
|
||||||
return x
|
|
||||||
|
|
||||||
class GinkaUNet(nn.Module):
|
|
||||||
def __init__(self, in_ch=32, base_ch=32, out_ch=32):
|
|
||||||
"""Ginka Model UNet 部分
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.enc = GinkaEncoderPath(in_ch, base_ch)
|
|
||||||
self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
|
|
||||||
self.dec = GinkaDecoderPath(base_ch)
|
|
||||||
|
|
||||||
self.final = ConvFusionModule(base_ch, base_ch, out_ch, 32, 32)
|
|
||||||
|
|
||||||
def forward(self, x, cond):
|
|
||||||
x1, x2, x3, x4 = self.enc(x, cond)
|
|
||||||
x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4]
|
|
||||||
x = self.dec(x1, x2, x3, x4, cond)
|
|
||||||
|
|
||||||
x = self.final(x) # [B, 32, 32, 32]
|
|
||||||
|
|
||||||
return x
|
|
||||||
@ -1,16 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import torch
|
|
||||||
|
|
||||||
def to_deployment(path: str, output: str):
|
|
||||||
state = torch.load(path)
|
|
||||||
torch.save({
|
|
||||||
"model_state": state["model_state"]
|
|
||||||
}, output)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--input", type=str, default="result/ginka.pth")
|
|
||||||
parser.add_argument("--output", type=str, default="result/ginka_deploy.pth")
|
|
||||||
args = parser.parse_args()
|
|
||||||
to_deployment(args.input, args.output)
|
|
||||||
|
|
||||||
@ -89,9 +89,9 @@ def train():
|
|||||||
|
|
||||||
# 用于生成图片
|
# 用于生成图片
|
||||||
tile_dict = dict()
|
tile_dict = dict()
|
||||||
for file in os.listdir('tiles2'):
|
for file in os.listdir('tiles'):
|
||||||
name = os.path.splitext(file)[0]
|
name = os.path.splitext(file)[0]
|
||||||
tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED)
|
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
# 接续训练
|
# 接续训练
|
||||||
if args.resume:
|
if args.resume:
|
||||||
|
|||||||
@ -1,188 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from datetime import datetime
|
|
||||||
import torch
|
|
||||||
import torch.optim as optim
|
|
||||||
import cv2
|
|
||||||
from torch_geometric.loader import DataLoader
|
|
||||||
from tqdm import tqdm
|
|
||||||
from .generator.rnn import GinkaRNNModel
|
|
||||||
from .dataset import GinkaRNNDataset
|
|
||||||
from .generator.loss import RNNGinkaLoss
|
|
||||||
from shared.image import matrix_to_image_cv
|
|
||||||
|
|
||||||
# 手工标注标签定义(暂时不用):
|
|
||||||
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
|
|
||||||
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风
|
|
||||||
# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门
|
|
||||||
# 24. 中心对称, 25. 部分对称, 26. 鱼骨
|
|
||||||
|
|
||||||
# 自动标注标签定义(暂时不用):
|
|
||||||
# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊
|
|
||||||
# 32. 平面塔, 33. 转换塔, 34. 道具塔
|
|
||||||
|
|
||||||
# 标量值定义:
|
|
||||||
# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块
|
|
||||||
# 1. 墙体密度,墙壁/地图面积
|
|
||||||
# 2. 装饰密度,装饰数量/地图面积
|
|
||||||
# 3. 门密度,门数量/地图面积
|
|
||||||
# 4. 怪物密度,怪物数量/地图面积
|
|
||||||
# 5. 资源密度,资源数量/地图面积
|
|
||||||
# 6. 宝石密度,宝石数量/地图面积
|
|
||||||
# 7. 血瓶密度,血瓶数量/地图面积
|
|
||||||
# 8. 钥匙密度,钥匙数量/地图面积
|
|
||||||
# 9. 道具密度,道具数量/地图面积
|
|
||||||
# 10. 入口数量
|
|
||||||
# 11. 机关门数量
|
|
||||||
# 12. 咸鱼门数量(多层咸鱼门只算一个)
|
|
||||||
|
|
||||||
# 图块定义:
|
|
||||||
# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地),
|
|
||||||
# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门
|
|
||||||
# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启
|
|
||||||
# 10-12. 三种等级的红宝石
|
|
||||||
# 13-15. 三种等级的蓝宝石
|
|
||||||
# 16-18. 三种等级的绿宝石
|
|
||||||
# 19-22. 四种等级的血瓶
|
|
||||||
# 23-25. 三种等级的道具
|
|
||||||
# 26-28. 三种等级的怪物
|
|
||||||
# 29. 入口,不区分楼梯和箭头
|
|
||||||
|
|
||||||
BATCH_SIZE = 96
|
|
||||||
|
|
||||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
|
||||||
os.makedirs("result", exist_ok=True)
|
|
||||||
os.makedirs("result/rnn", exist_ok=True)
|
|
||||||
os.makedirs("result/ginka_rnn_img", exist_ok=True)
|
|
||||||
|
|
||||||
disable_tqdm = not sys.stdout.isatty()
|
|
||||||
|
|
||||||
def gt_prob(epoch: int, max_epoch: int) -> float:
|
|
||||||
progress = epoch / max_epoch
|
|
||||||
return 0.1 + 0.9 * progress
|
|
||||||
|
|
||||||
def parse_arguments():
|
|
||||||
parser = argparse.ArgumentParser(description="training codes")
|
|
||||||
parser.add_argument("--resume", type=bool, default=False)
|
|
||||||
parser.add_argument("--state_ginka", type=str, default="result/rnn/ginka-100.pth")
|
|
||||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
|
||||||
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
|
||||||
parser.add_argument("--checkpoint", type=int, default=5)
|
|
||||||
parser.add_argument("--load_optim", type=bool, default=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
||||||
|
|
||||||
def train():
|
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
|
||||||
|
|
||||||
args = parse_arguments()
|
|
||||||
|
|
||||||
ginka_rnn = GinkaRNNModel(device).to(device)
|
|
||||||
|
|
||||||
dataset = GinkaRNNDataset(args.train, device)
|
|
||||||
dataset_val = GinkaRNNDataset(args.validate, device)
|
|
||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 8)
|
|
||||||
|
|
||||||
optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4)
|
|
||||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
|
||||||
|
|
||||||
criterion = RNNGinkaLoss(32, device)
|
|
||||||
|
|
||||||
# 用于生成图片
|
|
||||||
tile_dict = dict()
|
|
||||||
for file in os.listdir('tiles'):
|
|
||||||
name = os.path.splitext(file)[0]
|
|
||||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
|
||||||
|
|
||||||
if args.resume:
|
|
||||||
data_ginka = torch.load(args.state_ginka, map_location=device)
|
|
||||||
|
|
||||||
ginka_rnn.load_state_dict(data_ginka["model_state"], strict=False)
|
|
||||||
|
|
||||||
if args.load_optim:
|
|
||||||
if data_ginka.get("optim_state") is not None:
|
|
||||||
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
|
|
||||||
|
|
||||||
print("Train from loaded state.")
|
|
||||||
|
|
||||||
for epoch in tqdm(range(args.epochs), desc="RNN Training", disable=disable_tqdm):
|
|
||||||
loss_total_ginka = torch.Tensor([0]).to(device)
|
|
||||||
|
|
||||||
iters = 0
|
|
||||||
|
|
||||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
|
||||||
val_cond = batch["val_cond"].to(device)
|
|
||||||
target_map = batch["target_map"].to(device)
|
|
||||||
|
|
||||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epochs))
|
|
||||||
|
|
||||||
loss = criterion.rnn_loss(fake_logits, target_map)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_norm_(ginka_rnn.parameters(), max_norm=1.0)
|
|
||||||
optimizer_ginka.step()
|
|
||||||
loss_total_ginka += loss.detach()
|
|
||||||
|
|
||||||
iters += 1
|
|
||||||
|
|
||||||
# if iters % 50 == 0:
|
|
||||||
# avg_loss_ginka = loss_total_ginka.item() / iters
|
|
||||||
|
|
||||||
# tqdm.write(
|
|
||||||
# f"[Iters {iters} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
|
||||||
# f"E: {epoch + 1} | Loss: {avg_loss_ginka:.6f} | " +
|
|
||||||
# f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
|
||||||
# )
|
|
||||||
|
|
||||||
avg_loss_ginka = loss_total_ginka.item() / len(dataloader)
|
|
||||||
tqdm.write(
|
|
||||||
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
|
||||||
f"E: {epoch + 1} | Loss: {avg_loss_ginka:.6f} | " +
|
|
||||||
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler_ginka.step()
|
|
||||||
|
|
||||||
# 每若干轮输出一次图片,并保存检查点
|
|
||||||
if (epoch + 1) % args.checkpoint == 0:
|
|
||||||
# 保存检查点
|
|
||||||
torch.save({
|
|
||||||
"model_state": ginka_rnn.state_dict(),
|
|
||||||
"optim_state": optimizer_ginka.state_dict(),
|
|
||||||
}, f"result/rnn/ginka-{epoch + 1}.pth")
|
|
||||||
|
|
||||||
val_loss_total = torch.Tensor([0]).to(device)
|
|
||||||
with torch.no_grad():
|
|
||||||
idx = 0
|
|
||||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
|
||||||
val_cond = batch["val_cond"].to(device)
|
|
||||||
target_map = batch["target_map"].to(device)
|
|
||||||
|
|
||||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epochs))
|
|
||||||
|
|
||||||
val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()
|
|
||||||
|
|
||||||
fake_map = fake_map.cpu().numpy()
|
|
||||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
|
||||||
cv2.imwrite(f"result/ginka_rnn_img/{idx}.png", fake_img)
|
|
||||||
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
|
||||||
tqdm.write(
|
|
||||||
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " +
|
|
||||||
f"Loss: {avg_loss_val:.6f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Train ended.")
|
|
||||||
torch.save({
|
|
||||||
"model_state": ginka_rnn.state_dict(),
|
|
||||||
}, f"result/ginka_rnn.pth")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
torch.set_num_threads(4)
|
|
||||||
train()
|
|
||||||
@ -1,271 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import random
|
|
||||||
from datetime import datetime
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.optim as optim
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from torch_geometric.loader import DataLoader
|
|
||||||
from tqdm import tqdm
|
|
||||||
from .transformer.vae import GinkaTransformerVAE
|
|
||||||
from .vae_rnn.loss import VAELoss
|
|
||||||
from .vae_rnn.scheduler import VAEScheduler
|
|
||||||
from .dataset import GinkaRNNDataset
|
|
||||||
from shared.image import matrix_to_image_cv
|
|
||||||
|
|
||||||
# 手工标注标签定义(暂时不用):
|
|
||||||
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
|
|
||||||
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风
|
|
||||||
# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门
|
|
||||||
# 24. 中心对称, 25. 部分对称, 26. 鱼骨
|
|
||||||
|
|
||||||
# 自动标注标签定义(暂时不用):
|
|
||||||
# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊
|
|
||||||
# 32. 平面塔, 33. 转换塔, 34. 道具塔
|
|
||||||
|
|
||||||
# 标量值定义:
|
|
||||||
# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块
|
|
||||||
# 1. 墙体密度,墙壁/地图面积
|
|
||||||
# 2. 装饰密度,装饰数量/地图面积
|
|
||||||
# 3. 门密度,门数量/地图面积
|
|
||||||
# 4. 怪物密度,怪物数量/地图面积
|
|
||||||
# 5. 资源密度,资源数量/地图面积
|
|
||||||
# 6. 宝石密度,宝石数量/地图面积
|
|
||||||
# 7. 血瓶密度,血瓶数量/地图面积
|
|
||||||
# 8. 钥匙密度,钥匙数量/地图面积
|
|
||||||
# 9. 道具密度,道具数量/地图面积
|
|
||||||
# 10. 入口数量
|
|
||||||
# 11. 机关门数量
|
|
||||||
# 12. 咸鱼门数量(多层咸鱼门只算一个)
|
|
||||||
|
|
||||||
# 图块定义:
|
|
||||||
# 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶
|
|
||||||
# 8. 道具, 9. 怪物, 10. 入口, 14. 起始 token, 15. 终止 token
|
|
||||||
|
|
||||||
BATCH_SIZE = 128
|
|
||||||
LATENT_DIM = 32
|
|
||||||
KL_BETA = 0.01
|
|
||||||
SELF_GATE = 0.5
|
|
||||||
GATE_EPOCH = 5
|
|
||||||
VAL_BATCH_DIVIDER = 128
|
|
||||||
PROB_STEP = 0.05
|
|
||||||
NUM_CLASSES = 16
|
|
||||||
|
|
||||||
device = torch.device(
|
|
||||||
"cuda:1" if torch.cuda.is_available()
|
|
||||||
else "mps" if torch.mps.is_available()
|
|
||||||
else "cpu"
|
|
||||||
)
|
|
||||||
os.makedirs("result", exist_ok=True)
|
|
||||||
os.makedirs("result/vae", exist_ok=True)
|
|
||||||
os.makedirs("result/ginka_vae_img", exist_ok=True)
|
|
||||||
|
|
||||||
disable_tqdm = not sys.stdout.isatty()
|
|
||||||
|
|
||||||
def parse_arguments():
|
|
||||||
parser = argparse.ArgumentParser(description="training codes")
|
|
||||||
parser.add_argument("--resume", type=bool, default=False)
|
|
||||||
parser.add_argument("--state_ginka", type=str, default="result/vae/ginka-100.pth")
|
|
||||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
|
||||||
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
|
||||||
parser.add_argument("--checkpoint", type=int, default=5)
|
|
||||||
parser.add_argument("--load_optim", type=bool, default=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
||||||
|
|
||||||
def train():
|
|
||||||
print(f"Using {device.type} to train model.")
|
|
||||||
|
|
||||||
args = parse_arguments()
|
|
||||||
|
|
||||||
vae = GinkaTransformerVAE(num_classes=NUM_CLASSES, latent_dim=LATENT_DIM).to(device)
|
|
||||||
|
|
||||||
dataset = GinkaRNNDataset(args.train, device)
|
|
||||||
dataset_val = GinkaRNNDataset(args.validate, device)
|
|
||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
|
|
||||||
|
|
||||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-2, betas=(0.9, 0.95))
|
|
||||||
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
|
|
||||||
scheduler_ginka = VAEScheduler(
|
|
||||||
optimizer_ginka, factor=0.9, increase_factor=2, patience=10, max_lr=1e-4, min_lr=1e-6
|
|
||||||
)
|
|
||||||
|
|
||||||
criterion = VAELoss()
|
|
||||||
|
|
||||||
self_prob = 0
|
|
||||||
prob_epochs = 0
|
|
||||||
|
|
||||||
# 用于生成图片
|
|
||||||
tile_dict = dict()
|
|
||||||
for file in os.listdir('tiles2'):
|
|
||||||
name = os.path.splitext(file)[0]
|
|
||||||
tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED)
|
|
||||||
|
|
||||||
if args.resume:
|
|
||||||
data_ginka = torch.load(args.state_ginka, map_location=device)
|
|
||||||
|
|
||||||
vae.load_state_dict(data_ginka["model_state"], strict=False)
|
|
||||||
|
|
||||||
if args.load_optim:
|
|
||||||
if data_ginka.get("optim_state") is not None:
|
|
||||||
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
|
|
||||||
|
|
||||||
print("Train from loaded state.")
|
|
||||||
|
|
||||||
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
|
|
||||||
loss_total = torch.Tensor([0]).to(device)
|
|
||||||
reco_loss_total = torch.Tensor([0]).to(device)
|
|
||||||
kl_loss_total = torch.Tensor([0]).to(device)
|
|
||||||
|
|
||||||
vae.teacher_forcing()
|
|
||||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
|
||||||
target_map = batch["target_map"].to(device)
|
|
||||||
B, H, W = target_map.shape
|
|
||||||
input = target_map.view(B, H * W)
|
|
||||||
|
|
||||||
optimizer_ginka.zero_grad()
|
|
||||||
fake_logits, mu, logvar = vae(input, self_prob)
|
|
||||||
|
|
||||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, input, mu, logvar, KL_BETA)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
|
|
||||||
optimizer_ginka.step()
|
|
||||||
loss_total += loss.detach()
|
|
||||||
reco_loss_total += reco_loss.detach()
|
|
||||||
kl_loss_total += kl_loss.detach()
|
|
||||||
|
|
||||||
avg_loss = loss_total.item() / len(dataloader)
|
|
||||||
avg_reco_loss = reco_loss_total.item() / len(dataloader)
|
|
||||||
avg_kl_loss = kl_loss_total.item() / len(dataloader)
|
|
||||||
tqdm.write(
|
|
||||||
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
|
||||||
f"E: {epoch + 1} | Loss: {avg_loss:.6f} | Reco: {avg_reco_loss:.6f} | " +
|
|
||||||
f"KL: {avg_kl_loss:.6f} | Prob: {self_prob:.2f} | LR: {scheduler_ginka.get_last_lr()[0]:.6f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 验证集
|
|
||||||
# with torch.no_grad():
|
|
||||||
# val_loss_total = torch.Tensor([0]).to(device)
|
|
||||||
# for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
|
||||||
# target_map = batch["target_map"].to(device)
|
|
||||||
|
|
||||||
# fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
|
|
||||||
|
|
||||||
# loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
|
||||||
# val_loss_total += loss.detach()
|
|
||||||
|
|
||||||
# avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
|
||||||
|
|
||||||
# 先使用训练集的损失值,因为过拟合比较严重,后续再想办法
|
|
||||||
if avg_loss < SELF_GATE:
|
|
||||||
prob_epochs += 1
|
|
||||||
else:
|
|
||||||
prob_epochs = 0
|
|
||||||
|
|
||||||
if prob_epochs >= GATE_EPOCH and self_prob < 1:
|
|
||||||
self_prob += PROB_STEP
|
|
||||||
prob_epochs = 0
|
|
||||||
if self_prob > 1:
|
|
||||||
self_prob = 1
|
|
||||||
|
|
||||||
self_prob = 1
|
|
||||||
|
|
||||||
scheduler_ginka.step(avg_loss, self_prob)
|
|
||||||
|
|
||||||
# 每若干轮输出一次图片,并保存检查点
|
|
||||||
if (epoch + 1) % args.checkpoint == 0:
|
|
||||||
# 保存检查点
|
|
||||||
torch.save({
|
|
||||||
"model_state": vae.state_dict(),
|
|
||||||
"optim_state": optimizer_ginka.state_dict(),
|
|
||||||
}, f"result/rnn/ginka-{epoch + 1}.pth")
|
|
||||||
|
|
||||||
val_loss_total = torch.Tensor([0]).to(device)
|
|
||||||
val_reco_loss_total = torch.Tensor([0]).to(device)
|
|
||||||
val_kl_loss_total = torch.Tensor([0]).to(device)
|
|
||||||
vae.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
idx = 0
|
|
||||||
gap = 5
|
|
||||||
color = (255, 255, 255) # 白色
|
|
||||||
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
|
|
||||||
# 地图重建展示
|
|
||||||
vae.teacher_forcing()
|
|
||||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
|
||||||
target_map = batch["target_map"].to(device)
|
|
||||||
B, H, W = target_map.shape
|
|
||||||
input = target_map.view(B, H * W)
|
|
||||||
|
|
||||||
fake_logits, mu, logvar = vae(input, self_prob)
|
|
||||||
|
|
||||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, input, mu, logvar, KL_BETA)
|
|
||||||
val_loss_total += loss.detach()
|
|
||||||
val_reco_loss_total += reco_loss.detach()
|
|
||||||
val_kl_loss_total += kl_loss.detach()
|
|
||||||
|
|
||||||
fake_map = torch.argmax(fake_logits, dim=2)[:,0:169].view(B, H, W).cpu().numpy()
|
|
||||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
|
||||||
real_map = target_map.cpu().numpy()
|
|
||||||
real_img = matrix_to_image_cv(real_map[0], tile_dict)
|
|
||||||
img = np.block([[real_img], [vline], [fake_img]])
|
|
||||||
cv2.imwrite(f"result/ginka_vae_img/{idx}.png", img)
|
|
||||||
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
# 随机采样
|
|
||||||
vae.autoregressive()
|
|
||||||
for i in range(0, 8):
|
|
||||||
z = torch.randn(1, LATENT_DIM).to(device)
|
|
||||||
|
|
||||||
fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device))
|
|
||||||
fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy()
|
|
||||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
|
||||||
|
|
||||||
cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img)
|
|
||||||
|
|
||||||
# 插值
|
|
||||||
val_length = len(dataset_val.data)
|
|
||||||
index1 = random.randint(0, val_length - 1)
|
|
||||||
index2 = random.randint(0, val_length - 1)
|
|
||||||
map1 = torch.LongTensor(dataset_val.data[index1]["map"]).to(device).view(1, 169)
|
|
||||||
map2 = torch.LongTensor(dataset_val.data[index2]["map"]).to(device).view(1, 169)
|
|
||||||
mu1, logvar1 = vae.encoder(map1)
|
|
||||||
mu2, logvar2 = vae.encoder(map2)
|
|
||||||
z1 = vae.reparameterize(mu1, logvar1)
|
|
||||||
z2 = vae.reparameterize(mu2, logvar2)
|
|
||||||
real_img1 = matrix_to_image_cv(map1[0].view(13, 13).cpu().numpy(), tile_dict)
|
|
||||||
real_img2 = matrix_to_image_cv(map2[0].view(13, 13).cpu().numpy(), tile_dict)
|
|
||||||
i = 0
|
|
||||||
for t in torch.linspace(0, 1, 8):
|
|
||||||
z = z1 * (1 - t / 8) + z2 * t / 8
|
|
||||||
fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device))
|
|
||||||
fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy()
|
|
||||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
|
||||||
img = np.block([[real_img1], [vline], [fake_img], [vline], [real_img2]])
|
|
||||||
|
|
||||||
cv2.imwrite(f"result/ginka_vae_img/{i}_linspace.png", img)
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
|
||||||
avg_reco_loss_val = val_reco_loss_total.item() / len(dataloader_val)
|
|
||||||
avg_kl_loss_val = val_kl_loss_total.item() / len(dataloader_val)
|
|
||||||
tqdm.write(
|
|
||||||
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " +
|
|
||||||
f"Loss: {avg_loss_val:.6f} | Reco: {avg_reco_loss_val:.6f} | KL: {avg_kl_loss_val:.6f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Train ended.")
|
|
||||||
torch.save({
|
|
||||||
"model_state": vae.state_dict(),
|
|
||||||
}, f"result/ginka_transformer.pth")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
torch.set_num_threads(4)
|
|
||||||
train()
|
|
||||||
@ -1,270 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import random
|
|
||||||
from datetime import datetime
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.optim as optim
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from torch_geometric.loader import DataLoader
|
|
||||||
from tqdm import tqdm
|
|
||||||
from .vae_rnn.vae import GinkaVAE
|
|
||||||
from .vae_rnn.loss import VAELoss
|
|
||||||
from .vae_rnn.scheduler import VAEScheduler
|
|
||||||
from .dataset import GinkaRNNDataset
|
|
||||||
from shared.image import matrix_to_image_cv
|
|
||||||
|
|
||||||
# 手工标注标签定义(暂时不用):
|
|
||||||
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
|
|
||||||
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风
|
|
||||||
# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门
|
|
||||||
# 24. 中心对称, 25. 部分对称, 26. 鱼骨
|
|
||||||
|
|
||||||
# 自动标注标签定义(暂时不用):
|
|
||||||
# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊
|
|
||||||
# 32. 平面塔, 33. 转换塔, 34. 道具塔
|
|
||||||
|
|
||||||
# 标量值定义:
|
|
||||||
# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块
|
|
||||||
# 1. 墙体密度,墙壁/地图面积
|
|
||||||
# 2. 装饰密度,装饰数量/地图面积
|
|
||||||
# 3. 门密度,门数量/地图面积
|
|
||||||
# 4. 怪物密度,怪物数量/地图面积
|
|
||||||
# 5. 资源密度,资源数量/地图面积
|
|
||||||
# 6. 宝石密度,宝石数量/地图面积
|
|
||||||
# 7. 血瓶密度,血瓶数量/地图面积
|
|
||||||
# 8. 钥匙密度,钥匙数量/地图面积
|
|
||||||
# 9. 道具密度,道具数量/地图面积
|
|
||||||
# 10. 入口数量
|
|
||||||
# 11. 机关门数量
|
|
||||||
# 12. 咸鱼门数量(多层咸鱼门只算一个)
|
|
||||||
|
|
||||||
# 图块定义:
|
|
||||||
# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地),
|
|
||||||
# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门
|
|
||||||
# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启
|
|
||||||
# 10-12. 三种等级的红宝石
|
|
||||||
# 13-15. 三种等级的蓝宝石
|
|
||||||
# 16-18. 三种等级的绿宝石
|
|
||||||
# 19-22. 四种等级的血瓶
|
|
||||||
# 23-25. 三种等级的道具
|
|
||||||
# 26-28. 三种等级的怪物
|
|
||||||
# 29. 入口,不区分楼梯和箭头
|
|
||||||
|
|
||||||
BATCH_SIZE = 128
|
|
||||||
LATENT_DIM = 64
|
|
||||||
KL_BETA = 0.01
|
|
||||||
SELF_GATE = 0.3
|
|
||||||
GATE_EPOCH = 10
|
|
||||||
VAL_BATCH_DIVIDER = 128
|
|
||||||
PROB_STEP = 0.05
|
|
||||||
|
|
||||||
device = torch.device(
|
|
||||||
"cuda:1" if torch.cuda.is_available()
|
|
||||||
else "mps" if torch.mps.is_available()
|
|
||||||
else "cpu"
|
|
||||||
)
|
|
||||||
os.makedirs("result", exist_ok=True)
|
|
||||||
os.makedirs("result/vae", exist_ok=True)
|
|
||||||
os.makedirs("result/ginka_vae_img", exist_ok=True)
|
|
||||||
|
|
||||||
disable_tqdm = not sys.stdout.isatty()
|
|
||||||
|
|
||||||
def parse_arguments():
|
|
||||||
parser = argparse.ArgumentParser(description="training codes")
|
|
||||||
parser.add_argument("--resume", type=bool, default=False)
|
|
||||||
parser.add_argument("--state_ginka", type=str, default="result/vae/ginka-100.pth")
|
|
||||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
|
||||||
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
|
||||||
parser.add_argument("--checkpoint", type=int, default=5)
|
|
||||||
parser.add_argument("--load_optim", type=bool, default=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
||||||
|
|
||||||
def train():
|
|
||||||
print(f"Using {device.type} to train model.")
|
|
||||||
|
|
||||||
args = parse_arguments()
|
|
||||||
|
|
||||||
vae = GinkaVAE(device, latent_dim=LATENT_DIM).to(device)
|
|
||||||
|
|
||||||
dataset = GinkaRNNDataset(args.train, device)
|
|
||||||
dataset_val = GinkaRNNDataset(args.validate, device)
|
|
||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
|
|
||||||
|
|
||||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-4)
|
|
||||||
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
|
|
||||||
scheduler_ginka = VAEScheduler(
|
|
||||||
optimizer_ginka, factor=0.9, increase_factor=1.5, patience=20, max_lr=3e-4, min_lr=1e-6
|
|
||||||
)
|
|
||||||
|
|
||||||
criterion = VAELoss()
|
|
||||||
|
|
||||||
self_prob = 0
|
|
||||||
prob_epochs = 0
|
|
||||||
|
|
||||||
# 用于生成图片
|
|
||||||
tile_dict = dict()
|
|
||||||
for file in os.listdir('tiles'):
|
|
||||||
name = os.path.splitext(file)[0]
|
|
||||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
|
||||||
|
|
||||||
if args.resume:
|
|
||||||
data_ginka = torch.load(args.state_ginka, map_location=device)
|
|
||||||
|
|
||||||
vae.load_state_dict(data_ginka["model_state"], strict=False)
|
|
||||||
|
|
||||||
if args.load_optim:
|
|
||||||
if data_ginka.get("optim_state") is not None:
|
|
||||||
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
|
|
||||||
|
|
||||||
print("Train from loaded state.")
|
|
||||||
|
|
||||||
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
|
|
||||||
loss_total = torch.Tensor([0]).to(device)
|
|
||||||
reco_loss_total = torch.Tensor([0]).to(device)
|
|
||||||
kl_loss_total = torch.Tensor([0]).to(device)
|
|
||||||
|
|
||||||
vae.train()
|
|
||||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
|
||||||
target_map = batch["target_map"].to(device)
|
|
||||||
|
|
||||||
optimizer_ginka.zero_grad()
|
|
||||||
fake_logits, mu, logvar = vae(target_map, self_prob)
|
|
||||||
|
|
||||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
|
|
||||||
optimizer_ginka.step()
|
|
||||||
loss_total += loss.detach()
|
|
||||||
reco_loss_total += reco_loss.detach()
|
|
||||||
kl_loss_total += kl_loss.detach()
|
|
||||||
|
|
||||||
avg_loss = loss_total.item() / len(dataloader)
|
|
||||||
avg_reco_loss = reco_loss_total.item() / len(dataloader)
|
|
||||||
avg_kl_loss = kl_loss_total.item() / len(dataloader)
|
|
||||||
tqdm.write(
|
|
||||||
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
|
||||||
f"E: {epoch + 1} | Loss: {avg_loss:.6f} | Reco: {avg_reco_loss:.6f} | " +
|
|
||||||
f"KL: {avg_kl_loss:.6f} | Prob: {self_prob:.2f} | LR: {scheduler_ginka.get_last_lr()[0]:.6f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 验证集
|
|
||||||
# with torch.no_grad():
|
|
||||||
# val_loss_total = torch.Tensor([0]).to(device)
|
|
||||||
# for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
|
||||||
# target_map = batch["target_map"].to(device)
|
|
||||||
|
|
||||||
# fake_logits, mu, logvar = vae(target_map, 1 - gt_prob)
|
|
||||||
|
|
||||||
# loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
|
||||||
# val_loss_total += loss.detach()
|
|
||||||
|
|
||||||
# avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
|
||||||
|
|
||||||
# 先使用训练集的损失值,因为过拟合比较严重,后续再想办法
|
|
||||||
if avg_loss < SELF_GATE:
|
|
||||||
prob_epochs += 1
|
|
||||||
else:
|
|
||||||
prob_epochs = 0
|
|
||||||
|
|
||||||
if prob_epochs >= GATE_EPOCH and self_prob < 1:
|
|
||||||
self_prob += PROB_STEP
|
|
||||||
prob_epochs = 0
|
|
||||||
if self_prob > 1:
|
|
||||||
self_prob = 1
|
|
||||||
|
|
||||||
scheduler_ginka.step(avg_loss, self_prob)
|
|
||||||
|
|
||||||
# 每若干轮输出一次图片,并保存检查点
|
|
||||||
if (epoch + 1) % args.checkpoint == 0:
|
|
||||||
vae.eval()
|
|
||||||
# 保存检查点
|
|
||||||
torch.save({
|
|
||||||
"model_state": vae.state_dict(),
|
|
||||||
"optim_state": optimizer_ginka.state_dict(),
|
|
||||||
}, f"result/rnn/ginka-{epoch + 1}.pth")
|
|
||||||
|
|
||||||
val_loss_total = torch.Tensor([0]).to(device)
|
|
||||||
val_reco_loss_total = torch.Tensor([0]).to(device)
|
|
||||||
val_kl_loss_total = torch.Tensor([0]).to(device)
|
|
||||||
with torch.no_grad():
|
|
||||||
idx = 0
|
|
||||||
gap = 5
|
|
||||||
color = (255, 255, 255) # 白色
|
|
||||||
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
|
|
||||||
# 地图重建展示
|
|
||||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
|
||||||
target_map = batch["target_map"].to(device)
|
|
||||||
|
|
||||||
fake_logits, mu, logvar = vae(target_map, self_prob)
|
|
||||||
|
|
||||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
|
||||||
val_loss_total += loss.detach()
|
|
||||||
val_reco_loss_total += reco_loss.detach()
|
|
||||||
val_kl_loss_total += kl_loss.detach()
|
|
||||||
|
|
||||||
fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy()
|
|
||||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
|
||||||
real_map = target_map.cpu().numpy()
|
|
||||||
real_img = matrix_to_image_cv(real_map[0], tile_dict)
|
|
||||||
img = np.block([[real_img], [vline], [fake_img]])
|
|
||||||
cv2.imwrite(f"result/ginka_vae_img/{idx}.png", img)
|
|
||||||
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
# 随机采样
|
|
||||||
for i in range(0, 8):
|
|
||||||
z = torch.randn(1, LATENT_DIM).to(device)
|
|
||||||
|
|
||||||
fake_logits = vae.decoder(z, torch.zeros(1, 13, 13).to(device), 1)
|
|
||||||
fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy()
|
|
||||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
|
||||||
|
|
||||||
cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img)
|
|
||||||
|
|
||||||
# 插值
|
|
||||||
val_length = len(dataset_val.data)
|
|
||||||
index1 = random.randint(0, val_length - 1)
|
|
||||||
index2 = random.randint(0, val_length - 1)
|
|
||||||
map1 = torch.LongTensor(dataset_val.data[index1]["map"]).to(device).reshape(1, 13, 13)
|
|
||||||
map2 = torch.LongTensor(dataset_val.data[index2]["map"]).to(device).reshape(1, 13, 13)
|
|
||||||
mu1, logvar1 = vae.encoder(map1)
|
|
||||||
mu2, logvar2 = vae.encoder(map2)
|
|
||||||
z1 = vae.reparameterize(mu1, logvar1)
|
|
||||||
z2 = vae.reparameterize(mu2, logvar2)
|
|
||||||
real_img1 = matrix_to_image_cv(map1[0].cpu().numpy(), tile_dict)
|
|
||||||
real_img2 = matrix_to_image_cv(map2[0].cpu().numpy(), tile_dict)
|
|
||||||
i = 0
|
|
||||||
for t in torch.linspace(0, 1, 8):
|
|
||||||
z = z1 * (1 - t / 8) + z2 * t / 8
|
|
||||||
fake_logits = vae.decoder(z, torch.zeros(1, 13, 13).to(device), 1)
|
|
||||||
fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy()
|
|
||||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
|
||||||
img = np.block([[real_img1], [vline], [fake_img], [vline], [real_img2]])
|
|
||||||
|
|
||||||
cv2.imwrite(f"result/ginka_vae_img/{i}_linspace.png", img)
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
avg_loss_val = val_loss_total.item() / len(dataloader_val)
|
|
||||||
avg_reco_loss_val = val_reco_loss_total.item() / len(dataloader_val)
|
|
||||||
avg_kl_loss_val = val_kl_loss_total.item() / len(dataloader_val)
|
|
||||||
tqdm.write(
|
|
||||||
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " +
|
|
||||||
f"Loss: {avg_loss_val:.6f} | Reco: {avg_reco_loss_val:.6f} | KL: {avg_kl_loss_val:.6f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Train ended.")
|
|
||||||
torch.save({
|
|
||||||
"model_state": vae.state_dict(),
|
|
||||||
}, f"result/ginka_rnn.pth")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
torch.set_num_threads(4)
|
|
||||||
train()
|
|
||||||
@ -1,428 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from datetime import datetime
|
|
||||||
import torch
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from torch_geometric.loader import DataLoader
|
|
||||||
from tqdm import tqdm
|
|
||||||
from .generator.model import GinkaModel
|
|
||||||
from .dataset import GinkaWGANDataset
|
|
||||||
from .generator.loss import WGANGinkaLoss
|
|
||||||
from .critic.model import MinamoModel2
|
|
||||||
from shared.image import matrix_to_image_cv
|
|
||||||
|
|
||||||
# 手工标注标签定义:
|
|
||||||
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
|
|
||||||
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风
|
|
||||||
# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门
|
|
||||||
# 24. 中心对称, 25. 部分对称, 26. 鱼骨
|
|
||||||
|
|
||||||
# 自动标注标签定义:
|
|
||||||
# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊
|
|
||||||
# 32. 平面塔, 33. 转换塔, 34. 道具塔
|
|
||||||
|
|
||||||
# 标量值定义:
|
|
||||||
# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块
|
|
||||||
# 1. 墙体密度,墙壁/地图面积
|
|
||||||
# 2. 装饰密度,装饰数量/地图面积
|
|
||||||
# 3. 门密度,门数量/地图面积
|
|
||||||
# 4. 怪物密度,怪物数量/地图面积
|
|
||||||
# 5. 资源密度,资源数量/地图面积
|
|
||||||
# 6. 宝石密度,宝石数量/地图面积
|
|
||||||
# 7. 血瓶密度,血瓶数量/地图面积
|
|
||||||
# 8. 钥匙密度,钥匙数量/地图面积
|
|
||||||
# 9. 道具密度,道具数量/地图面积
|
|
||||||
# 10. 入口数量
|
|
||||||
# 11. 机关门数量
|
|
||||||
# 12. 咸鱼门数量(多层咸鱼门只算一个)
|
|
||||||
|
|
||||||
# 图块定义:
|
|
||||||
# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地),
|
|
||||||
# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门
|
|
||||||
# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启
|
|
||||||
# 10-12. 三种等级的红宝石
|
|
||||||
# 13-15. 三种等级的蓝宝石
|
|
||||||
# 16-18. 三种等级的绿宝石
|
|
||||||
# 19-22. 四种等级的血瓶
|
|
||||||
# 23-25. 三种等级的道具
|
|
||||||
# 26-28. 三种等级的怪物
|
|
||||||
# 29. 入口,不区分楼梯和箭头
|
|
||||||
|
|
||||||
BATCH_SIZE = 6
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
os.makedirs("result", exist_ok=True)
|
|
||||||
os.makedirs("result/wgan", exist_ok=True)
|
|
||||||
|
|
||||||
disable_tqdm = not sys.stdout.isatty()
|
|
||||||
|
|
||||||
def parse_arguments():
|
|
||||||
parser = argparse.ArgumentParser(description="training codes")
|
|
||||||
parser.add_argument("--resume", type=bool, default=False)
|
|
||||||
parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth")
|
|
||||||
parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth")
|
|
||||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
|
||||||
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
|
||||||
parser.add_argument("--checkpoint", type=int, default=5)
|
|
||||||
parser.add_argument("--load_optim", type=bool, default=True)
|
|
||||||
parser.add_argument("--curr_epoch", type=int, default=20) # 课程学习至少多少 epoch
|
|
||||||
parser.add_argument("--tuning", type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
||||||
|
|
||||||
def gen_curriculum(gen, masked1, masked2, masked3, tag, val, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
fake1 = gen(masked1, 1, tag, val)
|
|
||||||
fake2 = gen(masked2, 2, tag, val)
|
|
||||||
fake3 = gen(masked3, 3, tag, val)
|
|
||||||
if detach:
|
|
||||||
return fake1.detach(), fake2.detach(), fake3.detach()
|
|
||||||
else:
|
|
||||||
return fake1, fake2, fake3
|
|
||||||
|
|
||||||
def gen_total(gen, input, tag, val, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
|
|
||||||
if random:
|
|
||||||
fake0 = gen(input, 0, tag, val)
|
|
||||||
x_in = F.softmax(fake0, dim=1)
|
|
||||||
else:
|
|
||||||
fake0 = input
|
|
||||||
x_in = input
|
|
||||||
if progress_detach:
|
|
||||||
fake1 = gen(x_in.detach(), 1, tag, val)
|
|
||||||
fake2 = gen(F.softmax(fake1.detach(), dim=1), 2, tag, val)
|
|
||||||
fake3 = gen(F.softmax(fake2.detach(), dim=1), 3, tag, val)
|
|
||||||
else:
|
|
||||||
fake1 = gen(x_in, 1, tag, val)
|
|
||||||
fake2 = gen(F.softmax(fake1, dim=1), 2, tag, val)
|
|
||||||
fake3 = gen(F.softmax(fake2, dim=1), 3, tag, val)
|
|
||||||
if result_detach:
|
|
||||||
return fake1.detach(), fake2.detach(), fake3.detach(), fake0.detach()
|
|
||||||
else:
|
|
||||||
return fake1, fake2, fake3, fake0
|
|
||||||
|
|
||||||
def train():
|
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
|
||||||
|
|
||||||
args = parse_arguments()
|
|
||||||
|
|
||||||
c_steps = 2
|
|
||||||
g_steps = 1
|
|
||||||
# 训练阶段
|
|
||||||
train_stage = 1
|
|
||||||
mask_ratio = 0.2 # 蒙版区域大小
|
|
||||||
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
|
|
||||||
total_epoch = 0
|
|
||||||
|
|
||||||
ginka = GinkaModel().to(device)
|
|
||||||
minamo = MinamoModel2().to(device)
|
|
||||||
|
|
||||||
dataset = GinkaWGANDataset(args.train, device)
|
|
||||||
dataset_val = GinkaWGANDataset(args.validate, device)
|
|
||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
|
|
||||||
|
|
||||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
|
||||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
|
||||||
|
|
||||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2)
|
|
||||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2)
|
|
||||||
|
|
||||||
criterion = WGANGinkaLoss()
|
|
||||||
|
|
||||||
# 用于生成图片
|
|
||||||
tile_dict = dict()
|
|
||||||
for file in os.listdir('tiles'):
|
|
||||||
name = os.path.splitext(file)[0]
|
|
||||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
|
||||||
|
|
||||||
if args.resume:
|
|
||||||
data_ginka = torch.load(args.state_ginka, map_location=device)
|
|
||||||
data_minamo = torch.load(args.state_minamo, map_location=device)
|
|
||||||
|
|
||||||
ginka.load_state_dict(data_ginka["model_state"], strict=False)
|
|
||||||
minamo.load_state_dict(data_minamo["model_state"], strict=False)
|
|
||||||
|
|
||||||
# if data_ginka.get("c_steps") is not None and data_ginka.get("g_steps") is not None:
|
|
||||||
# c_steps = data_ginka["c_steps"]
|
|
||||||
# g_steps = data_ginka["g_steps"]
|
|
||||||
|
|
||||||
if data_ginka.get("mask_ratio") is not None:
|
|
||||||
mask_ratio = data_ginka["mask_ratio"]
|
|
||||||
|
|
||||||
if data_ginka.get("stage_epoch") is not None:
|
|
||||||
stage_epoch = data_ginka["stage_epoch"]
|
|
||||||
|
|
||||||
if data_ginka.get("stage") is not None:
|
|
||||||
train_stage = data_ginka["stage"]
|
|
||||||
|
|
||||||
if data_ginka.get("total_epoch") is not None:
|
|
||||||
total_epoch = data_ginka["data_ginka"]
|
|
||||||
|
|
||||||
if args.load_optim:
|
|
||||||
if data_ginka.get("optim_state") is not None:
|
|
||||||
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
|
|
||||||
if data_minamo.get("optim_state") is not None:
|
|
||||||
optimizer_minamo.load_state_dict(data_minamo["optim_state"])
|
|
||||||
|
|
||||||
print("Train from loaded state.")
|
|
||||||
|
|
||||||
curr_epoch = args.curr_epoch
|
|
||||||
first_curr = curr_epoch * 3
|
|
||||||
|
|
||||||
if args.tuning:
|
|
||||||
train_stage = 1
|
|
||||||
curr_epoch = curr_epoch // 4
|
|
||||||
first_curr = first_curr // 4
|
|
||||||
stage_epoch = 0
|
|
||||||
mask_ratio = 0.2
|
|
||||||
|
|
||||||
dataset.train_stage = train_stage
|
|
||||||
dataset.mask_ratio1 = mask_ratio
|
|
||||||
dataset.mask_ratio2 = mask_ratio
|
|
||||||
dataset.mask_ratio3 = mask_ratio
|
|
||||||
|
|
||||||
dataset_val.train_stage = train_stage
|
|
||||||
dataset_val.mask_ratio1 = mask_ratio
|
|
||||||
dataset_val.mask_ratio2 = mask_ratio
|
|
||||||
dataset_val.mask_ratio3 = mask_ratio
|
|
||||||
|
|
||||||
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
|
|
||||||
loss_total_minamo = torch.Tensor([0]).to(device)
|
|
||||||
loss_total_ginka = torch.Tensor([0]).to(device)
|
|
||||||
dis_total = torch.Tensor([0]).to(device)
|
|
||||||
loss_ce_total = torch.Tensor([0]).to(device)
|
|
||||||
|
|
||||||
iters = 0
|
|
||||||
|
|
||||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
|
||||||
rand = batch["rand"].to(device)
|
|
||||||
real0 = batch["real0"].to(device)
|
|
||||||
real1 = batch["real1"].to(device)
|
|
||||||
masked1 = batch["masked1"].to(device)
|
|
||||||
real2 = batch["real2"].to(device)
|
|
||||||
masked2 = batch["masked2"].to(device)
|
|
||||||
real3 = batch["real3"].to(device)
|
|
||||||
masked3 = batch["masked3"].to(device)
|
|
||||||
tag_cond = batch["tag_cond"].to(device)
|
|
||||||
val_cond = batch["val_cond"].to(device)
|
|
||||||
|
|
||||||
# ---------- 训练判别器
|
|
||||||
for _ in range(c_steps):
|
|
||||||
# 生成假样本
|
|
||||||
optimizer_minamo.zero_grad()
|
|
||||||
optimizer_ginka.zero_grad()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
if train_stage == 1 or train_stage == 2:
|
|
||||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
|
|
||||||
elif train_stage == 3 or train_stage == 4:
|
|
||||||
fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
|
|
||||||
|
|
||||||
if train_stage < 4:
|
|
||||||
fake0 = ginka(rand, 0, tag_cond, val_cond)
|
|
||||||
|
|
||||||
loss_d0, dis0 = criterion.discriminator_loss(minamo, 0, real0, fake0, tag_cond, val_cond)
|
|
||||||
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1, tag_cond, val_cond)
|
|
||||||
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2, tag_cond, val_cond)
|
|
||||||
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3, tag_cond, val_cond)
|
|
||||||
|
|
||||||
dis = [dis0, dis1, dis2, dis3]
|
|
||||||
loss_d = [loss_d0, loss_d1, loss_d2, loss_d3]
|
|
||||||
|
|
||||||
dis_avg = sum(dis) / len(dis)
|
|
||||||
loss_d_avg = sum(loss_d) / len(loss_d)
|
|
||||||
|
|
||||||
# 反向传播
|
|
||||||
loss_d_avg.backward()
|
|
||||||
|
|
||||||
optimizer_minamo.step()
|
|
||||||
|
|
||||||
loss_total_minamo += loss_d_avg.detach()
|
|
||||||
dis_total += dis_avg.detach()
|
|
||||||
|
|
||||||
# ---------- 训练生成器
|
|
||||||
|
|
||||||
for _ in range(g_steps):
|
|
||||||
optimizer_minamo.zero_grad()
|
|
||||||
optimizer_ginka.zero_grad()
|
|
||||||
if train_stage == 1 or train_stage == 2:
|
|
||||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, False)
|
|
||||||
|
|
||||||
loss_g1, loss_ce_g1 = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1, tag_cond, val_cond)
|
|
||||||
loss_g2, loss_ce_g2 = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond)
|
|
||||||
loss_g3, loss_ce_g3 = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond)
|
|
||||||
|
|
||||||
loss_g = (loss_g1 * 3.0 + loss_g2 + loss_g3) / 5.0
|
|
||||||
loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3)
|
|
||||||
|
|
||||||
loss_ce_total += loss_ce.detach()
|
|
||||||
|
|
||||||
elif train_stage == 3 or train_stage == 4:
|
|
||||||
fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4)
|
|
||||||
if train_stage == 4:
|
|
||||||
fake0 = F.softmax(fake0, dim=1)
|
|
||||||
|
|
||||||
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, fake0, tag_cond, val_cond)
|
|
||||||
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1, tag_cond, val_cond)
|
|
||||||
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond)
|
|
||||||
|
|
||||||
loss_g = (loss_g1 * 3.0 + loss_g2 + loss_g3) / 5.0
|
|
||||||
|
|
||||||
if train_stage < 4:
|
|
||||||
fake0 = F.softmax(ginka(rand, 0, tag_cond, val_cond), dim=1)
|
|
||||||
|
|
||||||
loss_g0 = criterion.generator_input_head_loss(minamo, fake0, tag_cond, val_cond)
|
|
||||||
loss_g += loss_g0
|
|
||||||
|
|
||||||
loss_g.backward()
|
|
||||||
optimizer_ginka.step()
|
|
||||||
loss_total_ginka += loss_g.detach()
|
|
||||||
|
|
||||||
iters += 1
|
|
||||||
|
|
||||||
if iters % 50 == 0:
|
|
||||||
avg_loss_ginka = loss_total_ginka.item() / iters / g_steps
|
|
||||||
avg_loss_minamo = loss_total_minamo.item() / iters / c_steps
|
|
||||||
avg_loss_ce = loss_ce_total.item() / iters / g_steps
|
|
||||||
avg_dis = dis_total.item() / iters / c_steps
|
|
||||||
tqdm.write(
|
|
||||||
f"[Iters {iters} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
|
||||||
f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
|
|
||||||
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
|
|
||||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " +
|
|
||||||
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
|
|
||||||
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
|
|
||||||
avg_loss_ce = loss_ce_total.item() / len(dataloader) / g_steps
|
|
||||||
avg_dis = dis_total.item() / len(dataloader) / c_steps
|
|
||||||
tqdm.write(
|
|
||||||
f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
|
||||||
f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
|
|
||||||
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
|
|
||||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " +
|
|
||||||
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 每若干轮输出一次图片,并保存检查点
|
|
||||||
if (epoch + 1) % args.checkpoint == 0:
|
|
||||||
# 保存检查点
|
|
||||||
torch.save({
|
|
||||||
"model_state": ginka.state_dict(),
|
|
||||||
"optim_state": optimizer_ginka.state_dict(),
|
|
||||||
"c_steps": c_steps,
|
|
||||||
"g_steps": g_steps,
|
|
||||||
"stage": train_stage,
|
|
||||||
"mask_ratio": mask_ratio,
|
|
||||||
"stage_epoch": stage_epoch,
|
|
||||||
}, f"result/wgan/ginka-{epoch + 1}.pth")
|
|
||||||
torch.save({
|
|
||||||
"model_state": minamo.state_dict(),
|
|
||||||
"optim_state": optimizer_minamo.state_dict()
|
|
||||||
}, f"result/wgan/minamo-{epoch + 1}.pth")
|
|
||||||
|
|
||||||
idx = 0
|
|
||||||
gap = 5
|
|
||||||
color = (255, 255, 255) # 白色
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
|
||||||
real1 = batch["real1"].to(device)
|
|
||||||
masked1 = batch["masked1"].to(device)
|
|
||||||
real2 = batch["real2"].to(device)
|
|
||||||
masked2 = batch["masked2"].to(device)
|
|
||||||
real3 = batch["real3"].to(device)
|
|
||||||
masked3 = batch["masked3"].to(device)
|
|
||||||
tag_cond = batch["tag_cond"].to(device)
|
|
||||||
val_cond = batch["val_cond"].to(device)
|
|
||||||
|
|
||||||
if train_stage == 1 or train_stage == 2:
|
|
||||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
|
|
||||||
|
|
||||||
elif train_stage == 3 or train_stage == 4:
|
|
||||||
fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
|
|
||||||
fake0 = torch.argmax(fake0, dim=1).cpu().numpy()
|
|
||||||
|
|
||||||
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
|
|
||||||
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
|
|
||||||
fake3 = torch.argmax(fake3, dim=1).cpu().numpy()
|
|
||||||
masked1 = torch.argmax(masked1, dim=1).cpu().numpy()
|
|
||||||
masked2 = torch.argmax(masked2, dim=1).cpu().numpy()
|
|
||||||
masked3 = torch.argmax(masked3, dim=1).cpu().numpy()
|
|
||||||
|
|
||||||
for i in range(fake1.shape[0]):
|
|
||||||
fake1_img = matrix_to_image_cv(fake1[i], tile_dict)
|
|
||||||
fake2_img = matrix_to_image_cv(fake2[i], tile_dict)
|
|
||||||
fake3_img = matrix_to_image_cv(fake3[i], tile_dict)
|
|
||||||
if train_stage == 1 or train_stage == 2:
|
|
||||||
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
|
|
||||||
hline = np.full((gap, 3 * 416 + gap * 2, 3), color, dtype=np.uint8) # 水平分割线
|
|
||||||
in1_img = matrix_to_image_cv(masked1[i], tile_dict)
|
|
||||||
in2_img = matrix_to_image_cv(masked2[i], tile_dict)
|
|
||||||
in3_img = matrix_to_image_cv(masked3[i], tile_dict)
|
|
||||||
img = np.block([
|
|
||||||
[[in1_img], [vline], [in2_img], [vline], [in3_img]],
|
|
||||||
[[hline]],
|
|
||||||
[[fake1_img], [vline], [fake2_img], [vline], [fake3_img]]
|
|
||||||
])
|
|
||||||
elif train_stage == 3 or train_stage == 4:
|
|
||||||
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
|
|
||||||
hline = np.full((gap, 2 * 416 + gap, 3), color, dtype=np.uint8) # 水平分割线
|
|
||||||
in_img = matrix_to_image_cv(fake0[i], tile_dict)
|
|
||||||
img = np.block([
|
|
||||||
[[in_img], [vline], [fake1_img]],
|
|
||||||
[[hline]],
|
|
||||||
[[fake2_img], [vline], [fake3_img]]
|
|
||||||
])
|
|
||||||
|
|
||||||
cv2.imwrite(f"result/ginka_img/{idx}.png", img)
|
|
||||||
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
# 训练流程控制
|
|
||||||
|
|
||||||
# if train_stage >= 2:
|
|
||||||
# # train_stage = 4
|
|
||||||
# if (epoch + 1) % 100 == 5:
|
|
||||||
# train_stage = 3
|
|
||||||
# elif (epoch + 1) % 100 == 20:
|
|
||||||
# train_stage = 4
|
|
||||||
# elif (epoch + 1) % 100 == 0:
|
|
||||||
# train_stage = 2
|
|
||||||
|
|
||||||
if train_stage == 1:
|
|
||||||
if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \
|
|
||||||
(mask_ratio > 0.3 and stage_epoch >= curr_epoch):
|
|
||||||
mask_ratio += 0.2
|
|
||||||
mask_ratio = min(mask_ratio, 0.8)
|
|
||||||
|
|
||||||
stage_epoch = 0
|
|
||||||
if mask_ratio >= 0.8:
|
|
||||||
train_stage = 4
|
|
||||||
|
|
||||||
stage_epoch += 1
|
|
||||||
total_epoch += 1
|
|
||||||
|
|
||||||
dataset.train_stage = train_stage
|
|
||||||
dataset_val.train_stage = train_stage
|
|
||||||
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
|
|
||||||
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
|
|
||||||
|
|
||||||
scheduler_ginka.step()
|
|
||||||
scheduler_minamo.step()
|
|
||||||
|
|
||||||
print("Train ended.")
|
|
||||||
torch.save({
|
|
||||||
"model_state": ginka.state_dict(),
|
|
||||||
}, f"result/ginka.pth")
|
|
||||||
torch.save({
|
|
||||||
"model_state": minamo.state_dict(),
|
|
||||||
}, f"result/minamo.pth")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
torch.set_num_threads(4)
|
|
||||||
train()
|
|
||||||
@ -1,106 +0,0 @@
|
|||||||
import time
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from ..utils import print_memory
|
|
||||||
|
|
||||||
class GinkaTransformerDecoder(nn.Module):
|
|
||||||
def __init__(self, num_classes=32, dim_ff=256, nhead=4, num_layers=4, map_size=13*13):
|
|
||||||
super().__init__()
|
|
||||||
self.autoregressive = False
|
|
||||||
self.dim_ff = dim_ff
|
|
||||||
self.map_size = map_size
|
|
||||||
self.embedding = nn.Embedding(num_classes, dim_ff)
|
|
||||||
self.pos_embedding = nn.Embedding(map_size + 1, dim_ff)
|
|
||||||
self.encoder = nn.TransformerEncoder(
|
|
||||||
nn.TransformerEncoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True),
|
|
||||||
num_layers=max(num_layers // 2, 1)
|
|
||||||
)
|
|
||||||
self.decoder = nn.TransformerDecoder(
|
|
||||||
nn.TransformerDecoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True),
|
|
||||||
num_layers=num_layers
|
|
||||||
)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(dim_ff, num_classes)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, z: torch.Tensor, target_map: torch.Tensor):
|
|
||||||
# z: [B, dim_ff]
|
|
||||||
# target_map: [B, H * W]
|
|
||||||
# training output: [B, H * W, dim_ff]
|
|
||||||
# evaling output: [B, H * W]
|
|
||||||
B, L = target_map.shape
|
|
||||||
|
|
||||||
memory = self.encoder(z.unsqueeze(1)) # [B, 1, dim_ff]
|
|
||||||
mask = torch.triu(torch.ones(L + 1, L + 1, dtype=torch.bool)).to(z.device) # [B, H * W, H * W]
|
|
||||||
|
|
||||||
# when training, use teacher forcing
|
|
||||||
if not self.autoregressive:
|
|
||||||
first_token = torch.tensor([31], dtype=torch.long).to(z.device).repeat(B, 1)
|
|
||||||
with_first = torch.cat([first_token, target_map], dim=1)
|
|
||||||
map = self.embedding(with_first)
|
|
||||||
pos_embed = self.pos_embedding(torch.arange(L + 1, dtype=torch.long).to(z.device))
|
|
||||||
map = map + pos_embed # [B, H * W, dim_ff]
|
|
||||||
decoded = self.decoder(map, memory, tgt_mask=mask) # [B, H * W, dim_ff]
|
|
||||||
output = self.fc(decoded)
|
|
||||||
return output
|
|
||||||
|
|
||||||
# when evaling, use autoregressive generation
|
|
||||||
else:
|
|
||||||
output = torch.zeros([B, L + 1], dtype=torch.int).to(z.device)
|
|
||||||
for idx in range(0, self.map_size):
|
|
||||||
embed = self.embedding(output)
|
|
||||||
pos_embed = self.pos_embedding(torch.IntTensor([idx]).repeat(B, 1).to(z.device))
|
|
||||||
map = embed + pos_embed # [B, H * W, dim_ff]
|
|
||||||
decoded = self.decoder(map, memory, tgt_mask=mask)
|
|
||||||
decoded = self.fc(decoded) # [B, H * W, dim_ff]
|
|
||||||
output[:, idx] = torch.argmax(decoded[:, idx, :], dim=1)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
class GinkaTransformerVAEDecoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, latent_dim=32, num_classes=32, dim_ff=256, nhead=4, num_layers=4,
|
|
||||||
map_size=13*13
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.map_size = map_size
|
|
||||||
self.input = nn.Sequential(
|
|
||||||
nn.Linear(latent_dim, dim_ff),
|
|
||||||
nn.Dropout(0.3),
|
|
||||||
nn.LayerNorm(dim_ff),
|
|
||||||
nn.ReLU(),
|
|
||||||
|
|
||||||
nn.Linear(dim_ff, dim_ff)
|
|
||||||
)
|
|
||||||
self.decoder = GinkaTransformerDecoder(
|
|
||||||
num_classes=num_classes, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers, map_size=map_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, z: torch.Tensor, map: torch.Tensor):
|
|
||||||
hidden = self.input(z)
|
|
||||||
output = self.decoder(hidden, map)
|
|
||||||
return output
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
device = torch.device("cpu")
|
|
||||||
|
|
||||||
input = torch.randn(1, 32).to(device)
|
|
||||||
map = torch.randint(0, 32, [1, 169]).to(device)
|
|
||||||
|
|
||||||
# 初始化模型
|
|
||||||
model = GinkaTransformerVAEDecoder().to(device)
|
|
||||||
|
|
||||||
print_memory("初始化后")
|
|
||||||
|
|
||||||
# 前向传播
|
|
||||||
start = time.perf_counter()
|
|
||||||
output = model(input, map)
|
|
||||||
end = time.perf_counter()
|
|
||||||
|
|
||||||
print_memory("前向传播后")
|
|
||||||
|
|
||||||
print(f"推理耗时: {end - start}")
|
|
||||||
print(f"输出形状: output={output.shape}")
|
|
||||||
print(f"Input Embedding parameters: {sum(p.numel() for p in model.input.parameters())}")
|
|
||||||
print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}")
|
|
||||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
|
||||||
@ -1,97 +0,0 @@
|
|||||||
import time
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from ..utils import print_memory
|
|
||||||
|
|
||||||
class GinkaTransformerEncoder(nn.Module):
|
|
||||||
def __init__(self, dim_ff=256, nhead=4, num_layers=4):
|
|
||||||
super().__init__()
|
|
||||||
self.dim_ff = dim_ff
|
|
||||||
self.encoder = nn.TransformerEncoder(
|
|
||||||
nn.TransformerEncoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True, activation=F.gelu),
|
|
||||||
num_layers=num_layers
|
|
||||||
)
|
|
||||||
self.decoder = nn.TransformerDecoder(
|
|
||||||
nn.TransformerDecoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True, activation=F.gelu),
|
|
||||||
num_layers=max(num_layers // 2, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
# x: [B, H * W, S]
|
|
||||||
B, L, S = x.shape
|
|
||||||
first_token = torch.randn(B, 1, self.dim_ff).to(x.device)
|
|
||||||
x = self.encoder(x)
|
|
||||||
x = self.decoder(first_token, x)
|
|
||||||
return x.squeeze(1)
|
|
||||||
|
|
||||||
class GinkaTransformerBottleneck(nn.Module):
|
|
||||||
def __init__(self, dim_ff=256, hidden_dim=512, latent_dim=32):
|
|
||||||
super().__init__()
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(dim_ff, hidden_dim),
|
|
||||||
nn.Dropout(0.3),
|
|
||||||
nn.LayerNorm(hidden_dim),
|
|
||||||
nn.GELU(),
|
|
||||||
)
|
|
||||||
self.fc_mu = nn.Sequential(
|
|
||||||
nn.Linear(hidden_dim, latent_dim)
|
|
||||||
)
|
|
||||||
self.fc_logvar = nn.Sequential(
|
|
||||||
nn.Linear(hidden_dim, latent_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# x: [B, dim_ff]
|
|
||||||
hidden = self.fc(x)
|
|
||||||
mu = self.fc_mu(hidden)
|
|
||||||
logvar = self.fc_logvar(hidden)
|
|
||||||
return mu, logvar
|
|
||||||
|
|
||||||
class GinkaTransformerVAEEncoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, num_classes=32, latent_dim=32, bottleneck_dim=512, dim_ff=256,
|
|
||||||
nhead=4, num_layers=4, map_size=13*13
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.map_size = map_size
|
|
||||||
self.embedding = nn.Embedding(num_classes, dim_ff)
|
|
||||||
self.pos_embedding = nn.Embedding(map_size, dim_ff)
|
|
||||||
self.encoder = GinkaTransformerEncoder(dim_ff=dim_ff, nhead=nhead, num_layers=num_layers)
|
|
||||||
self.bottleneck = GinkaTransformerBottleneck(
|
|
||||||
dim_ff=dim_ff, hidden_dim=bottleneck_dim, latent_dim=latent_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
# x: [B, map_size]
|
|
||||||
pos = self.pos_embedding(torch.arange(self.map_size, dtype=torch.long).to(x.device))
|
|
||||||
x = self.embedding(x) + pos
|
|
||||||
x = self.encoder(x)
|
|
||||||
mu, logvar = self.bottleneck(x)
|
|
||||||
return mu, logvar
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
device = torch.device("cpu")
|
|
||||||
|
|
||||||
input = torch.randint(0, 32, [1, 169]).to(device)
|
|
||||||
|
|
||||||
# 初始化模型
|
|
||||||
model = GinkaTransformerVAEEncoder().to(device)
|
|
||||||
|
|
||||||
print_memory("初始化后")
|
|
||||||
|
|
||||||
# 前向传播
|
|
||||||
start = time.perf_counter()
|
|
||||||
mu, logvar = model(input)
|
|
||||||
end = time.perf_counter()
|
|
||||||
|
|
||||||
print_memory("前向传播后")
|
|
||||||
|
|
||||||
print(f"推理耗时: {end - start}")
|
|
||||||
print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}")
|
|
||||||
print(f"Embedding parameters: {sum(p.numel() for p in model.embedding.parameters())}")
|
|
||||||
print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}")
|
|
||||||
print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}")
|
|
||||||
print(f"bottleneck parameters: {sum(p.numel() for p in model.bottleneck.parameters())}")
|
|
||||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
|
||||||
|
|
||||||
@ -1,22 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
class FSQ(nn.Module):
|
|
||||||
def __init__(self, levels=7):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.levels = levels
|
|
||||||
self.scale = (levels - 1) / 2
|
|
||||||
|
|
||||||
def forward(self, z):
|
|
||||||
|
|
||||||
# 限制范围
|
|
||||||
z = torch.tanh(z)
|
|
||||||
|
|
||||||
# 量化
|
|
||||||
z_q = torch.round(z * self.scale) / self.scale
|
|
||||||
|
|
||||||
# Straight-through estimator
|
|
||||||
z_q = z + (z_q - z).detach()
|
|
||||||
|
|
||||||
return z_q
|
|
||||||
@ -1,54 +0,0 @@
|
|||||||
import time
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from .encoder import GinkaTransformerVAEEncoder
|
|
||||||
from .decoder import GinkaTransformerVAEDecoder
|
|
||||||
from ..utils import print_memory
|
|
||||||
|
|
||||||
class GinkaTransformerVAE(nn.Module):
|
|
||||||
def __init__(self, num_classes=32, latent_dim=32):
|
|
||||||
super().__init__()
|
|
||||||
self.encoder = GinkaTransformerVAEEncoder(num_classes=num_classes, latent_dim=latent_dim)
|
|
||||||
self.decoder = GinkaTransformerVAEDecoder(latent_dim=latent_dim)
|
|
||||||
|
|
||||||
def reparameterize(self, mu, logvar):
|
|
||||||
std = torch.exp(0.5 * logvar)
|
|
||||||
eps = torch.randn_like(std)
|
|
||||||
return mu + eps * std
|
|
||||||
|
|
||||||
def autoregressive(self):
|
|
||||||
self.decoder.decoder.autoregressive = True
|
|
||||||
|
|
||||||
def teacher_forcing(self):
|
|
||||||
self.decoder.decoder.autoregressive = False
|
|
||||||
|
|
||||||
def forward(self, target_map: torch.Tensor, use_self_probility=0):
|
|
||||||
# target_map: [B, H * W]
|
|
||||||
mu, logvar = self.encoder(target_map) # [B, latent_dim]
|
|
||||||
z = self.reparameterize(mu, logvar)
|
|
||||||
logits = self.decoder(z, target_map) # [B, H * W, num_classes] | [B, H * W]
|
|
||||||
return logits, mu, logvar
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
device = torch.device("cpu")
|
|
||||||
|
|
||||||
input = torch.randint(0, 32, [1, 169]).to(device)
|
|
||||||
|
|
||||||
# 初始化模型
|
|
||||||
model = GinkaTransformerVAE().to(device)
|
|
||||||
|
|
||||||
print_memory("初始化后")
|
|
||||||
|
|
||||||
# 前向传播
|
|
||||||
start = time.perf_counter()
|
|
||||||
logits, mu, logvar = model(input)
|
|
||||||
end = time.perf_counter()
|
|
||||||
|
|
||||||
print_memory("前向传播后")
|
|
||||||
|
|
||||||
print(f"推理耗时: {end - start}")
|
|
||||||
print(f"输出形状: logits= {logits.shape}, mu={mu.shape}, logvar={logvar.shape}")
|
|
||||||
print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}")
|
|
||||||
print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}")
|
|
||||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
|
||||||
@ -1,264 +0,0 @@
|
|||||||
import time
|
|
||||||
import random
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from ..utils import print_memory
|
|
||||||
|
|
||||||
class DecoderMapPatch(nn.Module):
|
|
||||||
def __init__(self, tile_classes=32, width=13, height=13):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# 地图局部卷积,用于捕获局部结构信息
|
|
||||||
|
|
||||||
self.width = width
|
|
||||||
self.height = height
|
|
||||||
self.tile_classes = 32
|
|
||||||
|
|
||||||
self.patch_cnn = nn.Sequential(
|
|
||||||
nn.Conv2d(tile_classes + 1, 64, 3, padding=1),
|
|
||||||
nn.Dropout(0.2),
|
|
||||||
nn.BatchNorm2d(64),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
nn.Conv2d(64, 128, 3),
|
|
||||||
nn.Dropout(0.2),
|
|
||||||
nn.BatchNorm2d(128),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
nn.Flatten()
|
|
||||||
)
|
|
||||||
self.fc = nn.Linear(128 * 3 * 3, 256)
|
|
||||||
|
|
||||||
def forward(self, map: torch.Tensor, x: int, y: int):
|
|
||||||
"""
|
|
||||||
map: [B, H, W]
|
|
||||||
"""
|
|
||||||
B, H, W = map.shape
|
|
||||||
mask = torch.zeros([B, 5, 5]).to(map.device)
|
|
||||||
result = torch.zeros([B, 5, 5], dtype=torch.long).to(map.device)
|
|
||||||
left = x - 2 if x >= 2 else 0
|
|
||||||
right = x + 3 if x < self.width - 2 else self.width
|
|
||||||
top = y - 4 if y >= 4 else 0
|
|
||||||
bottom = y + 1
|
|
||||||
|
|
||||||
res_left = left - (x - 2)
|
|
||||||
res_right = right - (x + 3) + 5
|
|
||||||
res_top = top - (y - 4)
|
|
||||||
res_bottom = 5
|
|
||||||
|
|
||||||
result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right]
|
|
||||||
# 没画到的地方要置为 0
|
|
||||||
result[:, 4, 2] = 0
|
|
||||||
result[:, 4, 3] = 0
|
|
||||||
result[:, 4, 4] = 0
|
|
||||||
mask[:, res_top:res_bottom, res_left:res_right] = 1
|
|
||||||
mask[:, 4, 2] = 0
|
|
||||||
mask[:, 4, 3] = 0
|
|
||||||
mask[:, 4, 4] = 0
|
|
||||||
masked_result = torch.zeros([B, self.tile_classes + 1, 5, 5]).to(map.device)
|
|
||||||
masked_result[:, 0:32] = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
|
|
||||||
masked_result[:, 32] = mask
|
|
||||||
|
|
||||||
feat = self.patch_cnn(masked_result)
|
|
||||||
feat = self.fc(feat)
|
|
||||||
return feat
|
|
||||||
|
|
||||||
class DecoderTileEmbedding(nn.Module):
|
|
||||||
def __init__(self, tile_classes=32, embed_dim=256):
|
|
||||||
super().__init__()
|
|
||||||
# 图块编码,上一次画的图块
|
|
||||||
self.embedding = nn.Embedding(tile_classes, embed_dim)
|
|
||||||
|
|
||||||
def forward(self, tile: torch.Tensor):
|
|
||||||
return self.embedding(tile)
|
|
||||||
|
|
||||||
class DecoderPosEmbedding(nn.Module):
|
|
||||||
def __init__(self, width=13, height=13, embed_dim=256):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# 位置编码
|
|
||||||
|
|
||||||
self.width = width
|
|
||||||
self.height = height
|
|
||||||
|
|
||||||
self.row_embedding = nn.Embedding(height, embed_dim)
|
|
||||||
self.col_embedding = nn.Embedding(width, embed_dim)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
||||||
row = self.row_embedding(y)
|
|
||||||
col = self.col_embedding(x)
|
|
||||||
return row, col
|
|
||||||
|
|
||||||
class DecoderInputFusion(nn.Module):
|
|
||||||
def __init__(self, d_model=256):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# 使用 Transformer 进行信息整合
|
|
||||||
|
|
||||||
self.transformer = nn.TransformerEncoder(
|
|
||||||
nn.TransformerEncoderLayer(
|
|
||||||
d_model=d_model, nhead=2, dim_feedforward=d_model, batch_first=True,
|
|
||||||
dropout=0.2
|
|
||||||
),
|
|
||||||
num_layers=2
|
|
||||||
)
|
|
||||||
self.norm = nn.LayerNorm(d_model)
|
|
||||||
self.fusion = nn.Sequential(
|
|
||||||
nn.Linear(d_model * 2, d_model * 2),
|
|
||||||
nn.Dropout(0.2),
|
|
||||||
nn.LayerNorm(d_model * 2),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
nn.Linear(d_model * 2, d_model),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.LayerNorm(d_model),
|
|
||||||
nn.GELU()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,
|
|
||||||
col_embed: torch.Tensor, row_embed: torch.Tensor, patch_vec: torch.Tensor
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
tile_embed: [B, 256]
|
|
||||||
cond_vec: [B, 256]
|
|
||||||
col_embed: [B, 256]
|
|
||||||
row_embed: [B, 256]
|
|
||||||
patch_vec: [B, 256]
|
|
||||||
"""
|
|
||||||
vec = torch.stack([tile_embed, cond_vec, col_embed, row_embed, patch_vec], dim=1)
|
|
||||||
feat = self.norm(self.transformer(vec))
|
|
||||||
mean = torch.mean(feat, dim=1)
|
|
||||||
max = torch.max(feat, dim=1).values
|
|
||||||
hidden = torch.cat([mean, max], dim=1)
|
|
||||||
fused = self.fusion(hidden)
|
|
||||||
return fused
|
|
||||||
|
|
||||||
class DecoderRNN(nn.Module):
|
|
||||||
def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# GRU
|
|
||||||
self.gru = nn.GRUCell(input_dim, hidden_dim)
|
|
||||||
self.drop = nn.Dropout(0.2)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(hidden_dim, hidden_dim),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.LayerNorm(hidden_dim),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
nn.Linear(hidden_dim, tile_classes)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, feat_fusion: torch.Tensor, hidden: torch.Tensor):
|
|
||||||
"""
|
|
||||||
feat_fusion: [B, input_dim]
|
|
||||||
hidden: [B, hidden_dim]
|
|
||||||
"""
|
|
||||||
hidden = self.drop(self.gru(feat_fusion, hidden))
|
|
||||||
logits = self.fc(hidden)
|
|
||||||
return logits, hidden
|
|
||||||
|
|
||||||
class VAEDecoder(nn.Module):
|
|
||||||
def __init__(self, device: torch.device, start_tile=31, map_vec_dim=32, width=13, height=13):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.device = device
|
|
||||||
self.width = width
|
|
||||||
self.height = height
|
|
||||||
self.start_tile = start_tile
|
|
||||||
|
|
||||||
self.rnn_hidden = 512
|
|
||||||
self.tile_classes = 32
|
|
||||||
|
|
||||||
# 模型结构
|
|
||||||
self.map_vec_fc = nn.Sequential(
|
|
||||||
nn.Linear(map_vec_dim, 128),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.LayerNorm(128),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
nn.Linear(128, 256)
|
|
||||||
)
|
|
||||||
self.tile_embedding = DecoderTileEmbedding(tile_classes=self.tile_classes)
|
|
||||||
self.pos_embedding = DecoderPosEmbedding()
|
|
||||||
self.map_patch = DecoderMapPatch(tile_classes=self.tile_classes)
|
|
||||||
self.feat_fusion = DecoderInputFusion()
|
|
||||||
self.rnn = DecoderRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
|
|
||||||
|
|
||||||
self.col_list = []
|
|
||||||
self.row_list = []
|
|
||||||
for y in range(0, height):
|
|
||||||
for x in range(0, width):
|
|
||||||
self.col_list.append(x)
|
|
||||||
self.row_list.append(y)
|
|
||||||
|
|
||||||
def forward(self, map_vec: torch.Tensor, target_map: torch.Tensor, use_self_probility=0):
|
|
||||||
"""
|
|
||||||
map_vec: [B, vec_dim]
|
|
||||||
target_map: [B, H, W]
|
|
||||||
use_self_probility: 使用自己生成的上一步结果执行下一步的概率
|
|
||||||
"""
|
|
||||||
B, C = map_vec.shape
|
|
||||||
|
|
||||||
# 张量声明
|
|
||||||
now_tile = torch.LongTensor([self.start_tile]).to(self.device).expand(B)
|
|
||||||
|
|
||||||
map = torch.zeros([B, self.height, self.width], dtype=torch.int32).to(self.device)
|
|
||||||
output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device)
|
|
||||||
hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device)
|
|
||||||
|
|
||||||
col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1)
|
|
||||||
row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1)
|
|
||||||
col_embed, row_embed = self.pos_embedding(col_list, row_list)
|
|
||||||
|
|
||||||
map_vec = self.map_vec_fc(map_vec)
|
|
||||||
|
|
||||||
for y in range(0, self.height):
|
|
||||||
for x in range(0, self.width):
|
|
||||||
idx = y * self.width + x
|
|
||||||
# 图块编码、地图局部编码
|
|
||||||
tile_embed = self.tile_embedding(now_tile)
|
|
||||||
use_self = random.random() < use_self_probility
|
|
||||||
map_patch = self.map_patch(map if use_self else target_map, x, y)
|
|
||||||
# 编码特征融合
|
|
||||||
feat = self.feat_fusion(tile_embed, map_vec, col_embed[:, idx], row_embed[:, idx], map_patch)
|
|
||||||
# RNN 输出
|
|
||||||
logits, h = self.rnn(feat, hidden)
|
|
||||||
# 处理输出
|
|
||||||
output_logits[:, y, x] = logits[:]
|
|
||||||
hidden = h
|
|
||||||
tile_id = torch.argmax(logits, dim=1).detach()
|
|
||||||
map[:, y, x] = tile_id[:]
|
|
||||||
now_tile = tile_id if use_self else target_map[:, y, x].detach()
|
|
||||||
|
|
||||||
return output_logits.permute(0, 3, 1, 2)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
input = torch.randint(0, 32, [1, 13, 13]).to(device)
|
|
||||||
map_vec = torch.rand(1, 32).to(device)
|
|
||||||
|
|
||||||
# 初始化模型
|
|
||||||
model = VAEDecoder(device).to(device)
|
|
||||||
|
|
||||||
print_memory(device, "初始化后")
|
|
||||||
|
|
||||||
# 前向传播
|
|
||||||
start = time.perf_counter()
|
|
||||||
fake_logits = model(map_vec, input, 0)
|
|
||||||
end = time.perf_counter()
|
|
||||||
|
|
||||||
print_memory(device, "前向传播后")
|
|
||||||
|
|
||||||
print(f"推理耗时: {end - start}")
|
|
||||||
print(f"输出形状: fake_logits={fake_logits.shape}")
|
|
||||||
print(f"Map Vector FC parameters: {sum(p.numel() for p in model.map_vec_fc.parameters())}")
|
|
||||||
print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}")
|
|
||||||
print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}")
|
|
||||||
print(f"Map Patch parameters: {sum(p.numel() for p in model.map_patch.parameters())}")
|
|
||||||
print(f"Feature Fusion parameters: {sum(p.numel() for p in model.feat_fusion.parameters())}")
|
|
||||||
print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}")
|
|
||||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
|
||||||
@ -1,161 +0,0 @@
|
|||||||
import time
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from ..utils import print_memory
|
|
||||||
|
|
||||||
class EncoderEmbedding(nn.Module):
|
|
||||||
def __init__(self, tile_classes=32, width=13, height=13, hidden_dim=128, output_dim=256):
|
|
||||||
super().__init__()
|
|
||||||
self.tile_embedding = nn.Sequential(
|
|
||||||
nn.Embedding(tile_classes, hidden_dim),
|
|
||||||
nn.LayerNorm(hidden_dim),
|
|
||||||
nn.GELU()
|
|
||||||
)
|
|
||||||
self.col_embedding = nn.Sequential(
|
|
||||||
nn.Embedding(width, hidden_dim),
|
|
||||||
nn.LayerNorm(hidden_dim),
|
|
||||||
nn.GELU()
|
|
||||||
)
|
|
||||||
self.row_embedding = nn.Sequential(
|
|
||||||
nn.Embedding(height, hidden_dim),
|
|
||||||
nn.LayerNorm(hidden_dim),
|
|
||||||
nn.GELU()
|
|
||||||
)
|
|
||||||
self.fusion = nn.Linear(hidden_dim * 3, output_dim)
|
|
||||||
|
|
||||||
def forward(self, tile, x, y):
|
|
||||||
tile_embed = self.tile_embedding(tile)
|
|
||||||
col_embed = self.col_embedding(x)
|
|
||||||
row_embed = self.row_embedding(y)
|
|
||||||
embed = torch.cat([tile_embed, col_embed, row_embed], dim=2)
|
|
||||||
fused = self.fusion(embed)
|
|
||||||
return fused
|
|
||||||
|
|
||||||
class EncoderGRU(nn.Module):
|
|
||||||
def __init__(self, input_dim=256, hidden_dim=512, output_dim=256):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# GRU
|
|
||||||
self.gru = nn.GRUCell(input_dim, hidden_dim)
|
|
||||||
self.drop = nn.Dropout(0.2)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(hidden_dim, hidden_dim),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.LayerNorm(hidden_dim),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
nn.Linear(hidden_dim, output_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, feat: torch.Tensor, hidden: torch.Tensor):
|
|
||||||
"""
|
|
||||||
feat: [B, input_dim]
|
|
||||||
hidden: [B, hidden_dim]
|
|
||||||
"""
|
|
||||||
hidden = self.drop(self.gru(feat, hidden))
|
|
||||||
logits = self.fc(hidden)
|
|
||||||
return logits, hidden
|
|
||||||
|
|
||||||
class EncoderFusion(nn.Module):
|
|
||||||
def __init__(self, d_model=256):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.transformer = nn.TransformerEncoder(
|
|
||||||
nn.TransformerEncoderLayer(
|
|
||||||
d_model=d_model, dim_feedforward=d_model*2, nhead=2, batch_first=True
|
|
||||||
),
|
|
||||||
num_layers=3
|
|
||||||
)
|
|
||||||
self.norm = nn.LayerNorm(d_model)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(d_model * 2, d_model * 2),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.LayerNorm(d_model * 2),
|
|
||||||
nn.GELU()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, logits):
|
|
||||||
x = self.norm(self.transformer(logits))
|
|
||||||
h_mean = torch.mean(x, dim=1)
|
|
||||||
h_max = torch.max(x, dim=1).values
|
|
||||||
h = torch.cat([h_mean, h_max], dim=1)
|
|
||||||
return self.fc(h)
|
|
||||||
|
|
||||||
class VAEEncoder(nn.Module):
|
|
||||||
def __init__(self, device, tile_classes=32, latent_dim=32, width=13, height=13):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
self.rnn_hidden = 512
|
|
||||||
self.logits_dim = 256
|
|
||||||
|
|
||||||
self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256)
|
|
||||||
self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim)
|
|
||||||
self.fusion = EncoderFusion(256)
|
|
||||||
self.fc_mu = nn.Sequential(
|
|
||||||
nn.Linear(512, 512),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.LayerNorm(512),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(512, latent_dim)
|
|
||||||
)
|
|
||||||
self.fc_logvar = nn.Sequential(
|
|
||||||
nn.Linear(512, 512),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.LayerNorm(512),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(512, latent_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.col_list = []
|
|
||||||
self.row_list = []
|
|
||||||
for y in range(0, height):
|
|
||||||
for x in range(0, width):
|
|
||||||
self.col_list.append(x)
|
|
||||||
self.row_list.append(y)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
B, H, W = x.shape
|
|
||||||
|
|
||||||
map = torch.flatten(x, start_dim=1)
|
|
||||||
hidden = torch.zeros(B, self.rnn_hidden).to(self.device)
|
|
||||||
output = torch.zeros(B, H * W, self.logits_dim).to(self.device)
|
|
||||||
|
|
||||||
col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1)
|
|
||||||
row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1)
|
|
||||||
embed = self.embedding(map, col_list, row_list)
|
|
||||||
|
|
||||||
for idx in range(0, len(self.col_list)):
|
|
||||||
logits, h = self.rnn(embed[:, idx], hidden)
|
|
||||||
hidden = h
|
|
||||||
output[:, idx] = logits
|
|
||||||
|
|
||||||
h = self.fusion(output)
|
|
||||||
mu = self.fc_mu(h)
|
|
||||||
logvar = self.fc_logvar(h)
|
|
||||||
return mu, logvar
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
input = torch.randint(0, 32, [1, 13, 13]).to(device)
|
|
||||||
|
|
||||||
# 初始化模型
|
|
||||||
model = VAEEncoder(device).to(device)
|
|
||||||
|
|
||||||
print_memory(device, "初始化后")
|
|
||||||
|
|
||||||
# 前向传播
|
|
||||||
start = time.perf_counter()
|
|
||||||
mu, logvar = model(input)
|
|
||||||
end = time.perf_counter()
|
|
||||||
|
|
||||||
print_memory(device, "前向传播后")
|
|
||||||
|
|
||||||
print(f"推理耗时: {end - start}")
|
|
||||||
print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}")
|
|
||||||
print(f"Embedding parameters: {sum(p.numel() for p in model.embedding.parameters())}")
|
|
||||||
print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}")
|
|
||||||
print(f"Fusion parameters: {sum(p.numel() for p in model.fusion.parameters())}")
|
|
||||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
|
||||||
@ -1,20 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
class VAELoss:
|
|
||||||
def __init__(self):
|
|
||||||
self.num_classes = 32
|
|
||||||
|
|
||||||
def vae_loss(self, logits, target, mu, logvar, beta=0.1):
|
|
||||||
# logits: [B, 169, 16]
|
|
||||||
# target: [B, 169]
|
|
||||||
B, L = target.shape
|
|
||||||
end_token = torch.tensor([15], dtype=torch.long).to(logits.device).repeat(B, 1)
|
|
||||||
target = torch.cat([target, end_token], dim=1)
|
|
||||||
recon_loss = F.cross_entropy(logits.permute(0, 2, 1), target)
|
|
||||||
|
|
||||||
kl_loss = -0.5 * torch.mean(
|
|
||||||
1 + logvar - mu.pow(2) - logvar.exp()
|
|
||||||
)
|
|
||||||
|
|
||||||
return recon_loss + beta * kl_loss, recon_loss, kl_loss
|
|
||||||
@ -1,43 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
class VAEScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
||||||
def __init__(
|
|
||||||
self, optimizer, mode="min", factor=0.1, patience=10, threshold=0.0001,
|
|
||||||
threshold_mode="rel", cooldown=0, min_lr=0, eps=1e-8, verbose="deprecated",
|
|
||||||
max_lr=1e-2, increase_factor=2, start_prob=0
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
optimizer, mode, factor, patience, threshold,
|
|
||||||
threshold_mode, cooldown, min_lr, eps, verbose
|
|
||||||
)
|
|
||||||
self.max_lr = max_lr
|
|
||||||
self.increase_factor = increase_factor
|
|
||||||
self.last_prob = start_prob
|
|
||||||
|
|
||||||
if isinstance(max_lr, (list, tuple)):
|
|
||||||
if len(max_lr) != len(optimizer.param_groups):
|
|
||||||
raise ValueError(
|
|
||||||
f"expected {len(optimizer.param_groups)} max_lrs, got {len(max_lr)}"
|
|
||||||
)
|
|
||||||
self.default_max_lr = None
|
|
||||||
self.max_lrs = list(max_lr)
|
|
||||||
else:
|
|
||||||
self.default_max_lr = max_lr
|
|
||||||
self.max_lrs = [max_lr] * len(optimizer.param_groups)
|
|
||||||
|
|
||||||
def step(self, metrics, prob: float, epoch=None):
|
|
||||||
if prob > self.last_prob:
|
|
||||||
self.best = metrics
|
|
||||||
self.num_bad_epochs = 0
|
|
||||||
self.last_prob = prob
|
|
||||||
self._increase_lr()
|
|
||||||
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
|
||||||
else:
|
|
||||||
return super().step(metrics, epoch)
|
|
||||||
|
|
||||||
def _increase_lr(self):
|
|
||||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
|
||||||
old_lr = float(param_group["lr"])
|
|
||||||
new_lr = min(old_lr * self.increase_factor, self.max_lrs[i])
|
|
||||||
if new_lr - old_lr > self.eps:
|
|
||||||
param_group["lr"] = new_lr
|
|
||||||
@ -1,47 +0,0 @@
|
|||||||
import time
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from .encoder import VAEEncoder
|
|
||||||
from .decoder import VAEDecoder
|
|
||||||
from ..utils import print_memory
|
|
||||||
|
|
||||||
class GinkaVAE(nn.Module):
|
|
||||||
def __init__(self, device, tile_classes=32, latent_dim=32):
|
|
||||||
super().__init__()
|
|
||||||
self.encoder = VAEEncoder(device, tile_classes, latent_dim)
|
|
||||||
self.decoder = VAEDecoder(device, map_vec_dim=latent_dim)
|
|
||||||
|
|
||||||
def reparameterize(self, mu, logvar):
|
|
||||||
std = torch.exp(0.5 * logvar)
|
|
||||||
eps = torch.randn_like(std)
|
|
||||||
return mu + eps * std
|
|
||||||
|
|
||||||
def forward(self, target_map: torch.Tensor, use_self_probility=0):
|
|
||||||
mu, logvar = self.encoder(target_map)
|
|
||||||
z = self.reparameterize(mu, logvar)
|
|
||||||
logits = self.decoder(z, target_map, use_self_probility)
|
|
||||||
return logits, mu, logvar
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
device = torch.device("cpu")
|
|
||||||
|
|
||||||
input = torch.randint(0, 32, [1, 13, 13]).to(device)
|
|
||||||
|
|
||||||
# 初始化模型
|
|
||||||
model = GinkaVAE(device).to(device)
|
|
||||||
|
|
||||||
print_memory("初始化后")
|
|
||||||
|
|
||||||
# 前向传播
|
|
||||||
start = time.perf_counter()
|
|
||||||
logits, mu, logvar = model(input)
|
|
||||||
end = time.perf_counter()
|
|
||||||
|
|
||||||
print_memory("前向传播后")
|
|
||||||
|
|
||||||
print(f"推理耗时: {end - start}")
|
|
||||||
print(f"输出形状: logits= {logits.shape}, mu={mu.shape}, logvar={logvar.shape}")
|
|
||||||
print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}")
|
|
||||||
print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}")
|
|
||||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
|
||||||
@ -1,85 +0,0 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch_geometric.loader import DataLoader
|
|
||||||
from tqdm import tqdm
|
|
||||||
from .critic.model import MinamoModel
|
|
||||||
from .dataset import GinkaDataset
|
|
||||||
from .generator.loss import GinkaLoss
|
|
||||||
from .generator.model import GinkaModel
|
|
||||||
from shared.image import matrix_to_image_cv
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
os.makedirs('result/ginka_img', exist_ok=True)
|
|
||||||
|
|
||||||
def validate():
|
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
|
||||||
model = GinkaModel()
|
|
||||||
state = torch.load("result/ginka.pth", map_location=device)["model_state"]
|
|
||||||
model.load_state_dict(state)
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
print(f"Layer: {name}, Params: {param.numel()}")
|
|
||||||
total_params = sum(p.numel() for p in model.parameters())
|
|
||||||
print(f"Total parameters: {total_params}")
|
|
||||||
|
|
||||||
minamo = MinamoModel(32)
|
|
||||||
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
|
||||||
minamo.to(device)
|
|
||||||
|
|
||||||
# 准备数据集
|
|
||||||
val_dataset = GinkaDataset("ginka-eval.json", device, minamo)
|
|
||||||
val_loader = DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
batch_size=32,
|
|
||||||
shuffle=True
|
|
||||||
)
|
|
||||||
|
|
||||||
criterion = GinkaLoss(minamo)
|
|
||||||
|
|
||||||
tile_dict = dict()
|
|
||||||
val_output = dict()
|
|
||||||
|
|
||||||
for file in os.listdir('tiles'):
|
|
||||||
name = os.path.splitext(file)[0]
|
|
||||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
|
||||||
|
|
||||||
minamo.eval()
|
|
||||||
model.eval()
|
|
||||||
val_loss = 0
|
|
||||||
idx = 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch in tqdm(val_loader):
|
|
||||||
# 数据迁移到设备
|
|
||||||
target = batch["target"].to(device)
|
|
||||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
|
||||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
|
||||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
|
||||||
# 前向传播
|
|
||||||
output, output_softmax = model(feat_vec)
|
|
||||||
map_matrix = torch.argmax(output, dim=1)
|
|
||||||
|
|
||||||
for matrix in map_matrix[:].cpu():
|
|
||||||
image = matrix_to_image_cv(matrix.numpy(), tile_dict)
|
|
||||||
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
|
|
||||||
val_output[f"val_{idx}"] = matrix.tolist()
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
# 计算损失
|
|
||||||
_, loss = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
|
||||||
val_loss += loss.item()
|
|
||||||
|
|
||||||
avg_val_loss = val_loss / len(val_loader)
|
|
||||||
tqdm.write(f"Validation::loss: {avg_val_loss:.6f}")
|
|
||||||
|
|
||||||
with open('result/ginka_val.json', 'w') as f:
|
|
||||||
json.dump(val_output, f)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
torch.set_num_threads(2)
|
|
||||||
validate()
|
|
||||||
|
|
||||||
@ -1,12 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
def parse_arguments(from_default: str, train_default: str, val_default: str):
|
|
||||||
parser = argparse.ArgumentParser(description="training codes")
|
|
||||||
parser.add_argument("--resume", type=bool, default=False)
|
|
||||||
parser.add_argument("--from_state", type=str, default=from_default)
|
|
||||||
parser.add_argument("--load_optim", type=bool, default=False)
|
|
||||||
parser.add_argument("--train", type=str, default=train_default)
|
|
||||||
parser.add_argument("--validate", type=str, default=val_default)
|
|
||||||
parser.add_argument("--epochs", type=int, default=150)
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
||||||
@ -1,73 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
class ChannelAttention(nn.Module):
|
|
||||||
"""通道注意力模块"""
|
|
||||||
def __init__(self, channels, reduction=8):
|
|
||||||
super().__init__()
|
|
||||||
# 通道注意力
|
|
||||||
self.channel_att = nn.Sequential(
|
|
||||||
nn.AdaptiveAvgPool2d(1),
|
|
||||||
nn.Conv2d(channels, channels//reduction, 1),
|
|
||||||
nn.ELU(),
|
|
||||||
nn.Conv2d(channels//reduction, channels, 1),
|
|
||||||
nn.Sigmoid()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# 通道注意力
|
|
||||||
c_att = self.channel_att(x)
|
|
||||||
x = x * c_att
|
|
||||||
return x
|
|
||||||
|
|
||||||
class SpatialAttention(nn.Module):
|
|
||||||
"""空间注意力模块"""
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
# 空间注意力
|
|
||||||
self.spatial_att = nn.Sequential(
|
|
||||||
nn.Conv2d(2, 1, 7, padding=3),
|
|
||||||
nn.Sigmoid()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# 空间注意力
|
|
||||||
max_pool = torch.max(x, dim=1, keepdim=True)[0]
|
|
||||||
avg_pool = torch.mean(x, dim=1, keepdim=True)
|
|
||||||
s_att = self.spatial_att(torch.cat([max_pool, avg_pool], dim=1))
|
|
||||||
return x * s_att
|
|
||||||
|
|
||||||
class CBAM(nn.Module):
|
|
||||||
"""通道与空间注意力结合"""
|
|
||||||
def __init__(self, channels, reduction=8):
|
|
||||||
super().__init__()
|
|
||||||
# 通道注意力
|
|
||||||
self.channel_att = ChannelAttention(channels, reduction)
|
|
||||||
# 空间注意力
|
|
||||||
self.spatial_att = SpatialAttention()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# 通道注意力
|
|
||||||
c_att = self.channel_att(x)
|
|
||||||
x = x * c_att
|
|
||||||
|
|
||||||
# 空间注意力
|
|
||||||
s_att = self.spatial_att(x)
|
|
||||||
return x * s_att
|
|
||||||
|
|
||||||
class SEBlock(nn.Module):
|
|
||||||
def __init__(self, channel, reduction=4):
|
|
||||||
super().__init__()
|
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(channel, channel // reduction),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(channel // reduction, channel),
|
|
||||||
nn.Sigmoid()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
b, c, _, _ = x.size()
|
|
||||||
y = self.avg_pool(x).view(b, c)
|
|
||||||
y = self.fc(y).view(b, c, 1, 1)
|
|
||||||
return x * y
|
|
||||||
@ -1,6 +0,0 @@
|
|||||||
VIS_DIM = 512
|
|
||||||
TOPO_DIM = 512
|
|
||||||
FEAT_DIM = 1024
|
|
||||||
|
|
||||||
VISION_WEIGHT = 0.2
|
|
||||||
TOPO_WEIGHT = 0.8
|
|
||||||
@ -1,66 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch_geometric.data import Data, Batch
|
|
||||||
from torch_geometric.utils import add_self_loops, grid
|
|
||||||
|
|
||||||
def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
|
|
||||||
"""
|
|
||||||
可导的图结构转换(返回 PyG Data 对象)
|
|
||||||
map_probs: [C, H, W]
|
|
||||||
返回:
|
|
||||||
Data(x=[N, C], edge_index=[2, E], edge_attr=[E, C])
|
|
||||||
"""
|
|
||||||
C, H, W = map_probs.shape
|
|
||||||
device = map_probs.device
|
|
||||||
N = H * W
|
|
||||||
|
|
||||||
# 1. 节点特征
|
|
||||||
node_features = map_probs.view(C, H * W).T # [N, C]
|
|
||||||
edge_index, _ = grid(H, W)
|
|
||||||
edge_index = edge_index.to(device)
|
|
||||||
|
|
||||||
# 2. 构建所有可能的边连接
|
|
||||||
# node_indices = torch.arange(N, device=device).view(H, W)
|
|
||||||
|
|
||||||
# # 水平连接(右邻居)
|
|
||||||
# right_src = node_indices[:, :-1].flatten()
|
|
||||||
# right_dst = node_indices[:, 1:].flatten()
|
|
||||||
|
|
||||||
# # 垂直连接(下邻居)
|
|
||||||
# down_src = node_indices[:-1, :].flatten()
|
|
||||||
# down_dst = node_indices[1:, :].flatten()
|
|
||||||
|
|
||||||
# # 合并边列表(双向)
|
|
||||||
# edge_src = torch.cat([right_src, down_src])
|
|
||||||
# edge_dst = torch.cat([right_dst, down_dst])
|
|
||||||
# edge_index = torch.cat([
|
|
||||||
# torch.stack([edge_src, edge_dst], dim=0),
|
|
||||||
# torch.stack([edge_dst, edge_src], dim=0) # 反向连接
|
|
||||||
# ], dim=1).to(device, dtype=torch.long)
|
|
||||||
|
|
||||||
# # 3. 计算边特征
|
|
||||||
# src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C]
|
|
||||||
# dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C]
|
|
||||||
# edge_attr = (src_feat + dst_feat) / 2 # [E, C]
|
|
||||||
|
|
||||||
# edge_index, edge_attr = add_self_loops(edge_index, edge_attr)
|
|
||||||
|
|
||||||
return Data(
|
|
||||||
x=node_features,
|
|
||||||
edge_index=edge_index,
|
|
||||||
# edge_attr=edge_attr,
|
|
||||||
num_nodes=N
|
|
||||||
)
|
|
||||||
|
|
||||||
def batch_convert_soft_map_to_graph(batch_map_probs):
|
|
||||||
"""
|
|
||||||
处理 batch 维度,将 [B, C, H, W] 转换为批量图结构 Batch
|
|
||||||
"""
|
|
||||||
B, C, H, W = batch_map_probs.shape # 获取 batch 维度
|
|
||||||
batch_graphs = []
|
|
||||||
|
|
||||||
for i in range(B):
|
|
||||||
graph = differentiable_convert_to_data(batch_map_probs[i]) # 处理单个样本
|
|
||||||
batch_graphs.append(graph)
|
|
||||||
|
|
||||||
# 合并所有图为批量 Batch
|
|
||||||
return Batch.from_data_list(batch_graphs)
|
|
||||||
@ -1,273 +0,0 @@
|
|||||||
# Converted Python version of the JS code
|
|
||||||
import math
|
|
||||||
from typing import Dict, Set, List, Tuple, Union
|
|
||||||
from collections import deque, defaultdict
|
|
||||||
|
|
||||||
# 拓扑相似度,由 ChatGPT-4o 从 ts 转译而来
|
|
||||||
|
|
||||||
class ResourceArea:
|
|
||||||
def __init__(self):
|
|
||||||
self.type = 'resource'
|
|
||||||
self.resources: Dict[int, int] = {}
|
|
||||||
self.members: Set[int] = set()
|
|
||||||
self.neighbor: Set[int] = set()
|
|
||||||
|
|
||||||
class BranchNode:
|
|
||||||
def __init__(self, tile: int):
|
|
||||||
self.type = 'branch'
|
|
||||||
self.tile = tile
|
|
||||||
self.neighbor: Set[int] = set()
|
|
||||||
|
|
||||||
class ResourceNode:
|
|
||||||
def __init__(self, resource_type: int, area: ResourceArea):
|
|
||||||
self.type = 'resource'
|
|
||||||
self.resourceType = resource_type
|
|
||||||
self.neighbor = area.neighbor
|
|
||||||
self.resourceArea = area
|
|
||||||
|
|
||||||
GinkaNode = Union[BranchNode, ResourceNode]
|
|
||||||
|
|
||||||
class GinkaGraph:
|
|
||||||
def __init__(self):
|
|
||||||
self.graph: Dict[int, GinkaNode] = {}
|
|
||||||
self.resourceMap: Dict[int, int] = {}
|
|
||||||
self.areaMap: List[ResourceArea] = []
|
|
||||||
self.visitedEntrance: Set[int] = set()
|
|
||||||
self.visited: Set[int] = set()
|
|
||||||
|
|
||||||
class GinkaTopologicalGraphs:
|
|
||||||
def __init__(self):
|
|
||||||
self.graphs: List[GinkaGraph] = []
|
|
||||||
self.entranceMap: Dict[int, GinkaGraph] = {}
|
|
||||||
self.unreachable: Set[int] = set()
|
|
||||||
|
|
||||||
TILE_TYPE = set(range(13))
|
|
||||||
BRANCH_TYPE = {6, 7, 8, 9}
|
|
||||||
ENTRANCE_TYPE = {10, 11}
|
|
||||||
RESOURCE_TYPE = {0, 2, 3, 4, 5, 10, 11, 12, 13}
|
|
||||||
|
|
||||||
directions: List[Tuple[int, int]] = [
|
|
||||||
(-1, 0), (1, 0), (0, -1), (0, 1)
|
|
||||||
]
|
|
||||||
|
|
||||||
def find_resource_nodes(map_: List[List[int]]):
|
|
||||||
width, height = len(map_[0]), len(map_)
|
|
||||||
visited = set()
|
|
||||||
areas = []
|
|
||||||
resource_map = {}
|
|
||||||
|
|
||||||
for ny in range(height):
|
|
||||||
for nx in range(width):
|
|
||||||
tile = map_[ny][nx]
|
|
||||||
index = ny * width + nx
|
|
||||||
if index in visited or tile not in RESOURCE_TYPE:
|
|
||||||
continue
|
|
||||||
queue = deque([(nx, ny)])
|
|
||||||
area = ResourceArea()
|
|
||||||
area.resources[tile] = 1
|
|
||||||
area.members.add(index)
|
|
||||||
while queue:
|
|
||||||
cx, cy = queue.popleft()
|
|
||||||
cindex = cy * width + cx
|
|
||||||
if cindex in visited:
|
|
||||||
continue
|
|
||||||
ctile = map_[cy][cx]
|
|
||||||
if ctile not in RESOURCE_TYPE:
|
|
||||||
continue
|
|
||||||
visited.add(cindex)
|
|
||||||
area.resources[ctile] = area.resources.get(ctile, 0) + 1
|
|
||||||
area.members.add(cindex)
|
|
||||||
resource_map[cindex] = len(areas)
|
|
||||||
for dx, dy in directions:
|
|
||||||
px, py = cx + dx, cy + dy
|
|
||||||
if 0 <= px < width and 0 <= py < height:
|
|
||||||
queue.append((px, py))
|
|
||||||
areas.append(area)
|
|
||||||
return areas, resource_map
|
|
||||||
|
|
||||||
def build_graph_from_entrance(map_: List[List[int]], entrance: int, resource_map: Dict[int, int], area_map: List[ResourceArea]) -> GinkaGraph:
|
|
||||||
width, height = len(map_[0]), len(map_)
|
|
||||||
graph = GinkaGraph()
|
|
||||||
graph.resourceMap = resource_map
|
|
||||||
graph.areaMap = area_map
|
|
||||||
|
|
||||||
visited = graph.visited
|
|
||||||
visited_entrance = graph.visitedEntrance
|
|
||||||
visited_entrance.add(entrance)
|
|
||||||
|
|
||||||
branch_nodes = set()
|
|
||||||
queue = deque([(entrance % width, entrance // width)])
|
|
||||||
|
|
||||||
while queue:
|
|
||||||
x, y = queue.popleft()
|
|
||||||
index = y * width + x
|
|
||||||
if index in visited:
|
|
||||||
continue
|
|
||||||
tile = map_[y][x]
|
|
||||||
if tile in ENTRANCE_TYPE:
|
|
||||||
visited_entrance.add(index)
|
|
||||||
if tile in BRANCH_TYPE:
|
|
||||||
branch_nodes.add(index)
|
|
||||||
visited.add(index)
|
|
||||||
for dx, dy in directions:
|
|
||||||
px, py = x + dx, y + dy
|
|
||||||
if 0 <= px < width and 0 <= py < height and map_[py][px] != 1:
|
|
||||||
queue.append((px, py))
|
|
||||||
|
|
||||||
for v in branch_nodes:
|
|
||||||
x, y = v % width, v // width
|
|
||||||
if v not in graph.graph:
|
|
||||||
graph.graph[v] = BranchNode(map_[y][x])
|
|
||||||
node = graph.graph[v]
|
|
||||||
for dx, dy in directions:
|
|
||||||
px, py = x + dx, y + dy
|
|
||||||
if 0 <= px < width and 0 <= py < height:
|
|
||||||
index = py * width + px
|
|
||||||
if index in branch_nodes:
|
|
||||||
node.neighbor.add(index)
|
|
||||||
elif index in resource_map:
|
|
||||||
area = area_map[resource_map[index]]
|
|
||||||
area.neighbor.add(v)
|
|
||||||
for m in area.members:
|
|
||||||
node.neighbor.add(m)
|
|
||||||
|
|
||||||
for area in area_map:
|
|
||||||
for index in area.members:
|
|
||||||
x, y = index % width, index // width
|
|
||||||
tile = map_[y][x]
|
|
||||||
if tile == 0:
|
|
||||||
continue
|
|
||||||
node = ResourceNode(tile, area)
|
|
||||||
graph.graph[index] = node
|
|
||||||
|
|
||||||
return graph
|
|
||||||
|
|
||||||
def build_topological_graph(map_: List[List[int]]) -> GinkaTopologicalGraphs:
|
|
||||||
width, height = len(map_[0]), len(map_)
|
|
||||||
entrances = set()
|
|
||||||
entrances = {y * width + x for y in range(height) for x in range(width) if map_[y][x] in ENTRANCE_TYPE}
|
|
||||||
area_map, resource_map = find_resource_nodes(map_)
|
|
||||||
|
|
||||||
top_graph = GinkaTopologicalGraphs()
|
|
||||||
used_entrance = set()
|
|
||||||
total_visited = set()
|
|
||||||
|
|
||||||
for entrance in entrances:
|
|
||||||
if entrance in used_entrance:
|
|
||||||
continue
|
|
||||||
graph = build_graph_from_entrance(map_, entrance, resource_map, area_map)
|
|
||||||
top_graph.graphs.append(graph)
|
|
||||||
for ent in graph.visitedEntrance:
|
|
||||||
used_entrance.add(ent)
|
|
||||||
top_graph.entranceMap[ent] = graph
|
|
||||||
total_visited.update(graph.visited)
|
|
||||||
|
|
||||||
for y in range(height):
|
|
||||||
for x in range(width):
|
|
||||||
index = y * width + x
|
|
||||||
if index not in total_visited and map_[y][x] != 1:
|
|
||||||
top_graph.unreachable.add(index)
|
|
||||||
|
|
||||||
return top_graph
|
|
||||||
|
|
||||||
class WLNode:
|
|
||||||
def __init__(self, pos: int, label: str):
|
|
||||||
self.originalPos = pos
|
|
||||||
self.originalLabel = label
|
|
||||||
self.currentLabel = label
|
|
||||||
self.neighbors: List['WLNode'] = []
|
|
||||||
|
|
||||||
def encode_node_labels(graph: GinkaGraph) -> List[WLNode]:
|
|
||||||
node_map = {}
|
|
||||||
nodes = []
|
|
||||||
for pos, node in graph.graph.items():
|
|
||||||
label = f"B:{node.tile}" if node.type == 'branch' else f"R:{node.resourceType}"
|
|
||||||
wl_node = WLNode(pos, label)
|
|
||||||
node_map[pos] = wl_node
|
|
||||||
nodes.append(wl_node)
|
|
||||||
|
|
||||||
for node in nodes:
|
|
||||||
g_node = graph.graph[node.originalPos]
|
|
||||||
for neighbor in g_node.neighbor:
|
|
||||||
if neighbor in node_map:
|
|
||||||
node.neighbors.append(node_map[neighbor])
|
|
||||||
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
def weisfeiler_lehman_iteration(nodes: List[WLNode], iterations: int, decay: float = 0.6) -> Dict[str, float]:
|
|
||||||
label_history = []
|
|
||||||
for _ in range(iterations):
|
|
||||||
new_labels = []
|
|
||||||
for node in nodes:
|
|
||||||
neighbor_labels = sorted(n.currentLabel for n in node.neighbors)
|
|
||||||
composite = f"{node.currentLabel}|{','.join(neighbor_labels)}"[:8192]
|
|
||||||
new_labels.append(composite)
|
|
||||||
for node, new_label in zip(nodes, new_labels):
|
|
||||||
node.currentLabel = new_label
|
|
||||||
label_history.append(new_labels[:])
|
|
||||||
|
|
||||||
weight = 1.0
|
|
||||||
label_counts = defaultdict(float)
|
|
||||||
for layer in label_history:
|
|
||||||
for label in layer:
|
|
||||||
label_counts[label] += weight
|
|
||||||
weight *= decay
|
|
||||||
for node in nodes:
|
|
||||||
label_counts[node.originalLabel] += weight
|
|
||||||
return dict(label_counts)
|
|
||||||
|
|
||||||
def vectorize_features(features: Dict[str, float], vocab: List[str]) -> List[float]:
|
|
||||||
vec = [0.0] * len(vocab)
|
|
||||||
for label, count in features.items():
|
|
||||||
if label in vocab:
|
|
||||||
idx = vocab.index(label)
|
|
||||||
vec[idx] += count
|
|
||||||
return vec
|
|
||||||
|
|
||||||
def cosine_similarity(a: List[float], b: List[float]) -> float:
|
|
||||||
dot = sum(x * y for x, y in zip(a, b))
|
|
||||||
norm_a = math.sqrt(sum(x * x for x in a))
|
|
||||||
norm_b = math.sqrt(sum(y * y for y in b))
|
|
||||||
if norm_a == 0 or norm_b == 0:
|
|
||||||
return 0.0
|
|
||||||
return dot / (norm_a * norm_b)
|
|
||||||
|
|
||||||
def wl_kernel(graph_a: GinkaGraph, graph_b: GinkaGraph, iterations: int = 3) -> float:
|
|
||||||
nodes_a = encode_node_labels(graph_a)
|
|
||||||
nodes_b = encode_node_labels(graph_b)
|
|
||||||
features_a = weisfeiler_lehman_iteration(nodes_a, iterations)
|
|
||||||
features_b = weisfeiler_lehman_iteration(nodes_b, iterations)
|
|
||||||
vocab = list(set(features_a.keys()) | set(features_b.keys()))
|
|
||||||
vec_a = vectorize_features(features_a, vocab)
|
|
||||||
vec_b = vectorize_features(features_b, vocab)
|
|
||||||
return cosine_similarity(vec_a, vec_b)
|
|
||||||
|
|
||||||
def overall_similarity(a: GinkaTopologicalGraphs, b: GinkaTopologicalGraphs) -> float:
|
|
||||||
graphs_a = a.graphs
|
|
||||||
graphs_b = b.graphs
|
|
||||||
|
|
||||||
total_similarity = 0.0
|
|
||||||
compared_graphs: Set[GinkaGraph] = set()
|
|
||||||
|
|
||||||
for ga in graphs_a:
|
|
||||||
max_similarity = 0.0
|
|
||||||
max_graph = None
|
|
||||||
for gb in graphs_b:
|
|
||||||
if gb in compared_graphs:
|
|
||||||
continue
|
|
||||||
min_nodes = min(len(ga.graph), len(gb.graph))
|
|
||||||
iterations = max(1, math.ceil(math.log(min_nodes)))
|
|
||||||
similarity = wl_kernel(ga, gb, iterations)
|
|
||||||
if similarity > max_similarity and not math.isnan(similarity):
|
|
||||||
max_similarity = similarity
|
|
||||||
max_graph = gb
|
|
||||||
if similarity == 1:
|
|
||||||
break
|
|
||||||
total_similarity += max_similarity
|
|
||||||
if max_graph:
|
|
||||||
compared_graphs.add(max_graph)
|
|
||||||
|
|
||||||
reduction = 1 / (1 + abs(len(a.unreachable) - len(b.unreachable)))
|
|
||||||
if not graphs_a:
|
|
||||||
return 0.0
|
|
||||||
return math.sqrt(total_similarity / len(graphs_a)) * reduction
|
|
||||||
@ -1,75 +0,0 @@
|
|||||||
from typing import List, Dict
|
|
||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# 视觉相似度,由 ChatGPT-4o 从 ts 转译而来
|
|
||||||
|
|
||||||
class VisualSimilarityConfig:
|
|
||||||
def __init__(self):
|
|
||||||
self.type_weights: Dict[int, float] = {
|
|
||||||
0: 0.2, 1: 0.3, 2: 0.6, 3: 0.7, 4: 0.7, 5: 0.5,
|
|
||||||
6: 0.4, 7: 0.5, 8: 0.6, 9: 0.6, 10: 0.4, 11: 0.4, 12: 0.7
|
|
||||||
}
|
|
||||||
self.enable_visual_focus: bool = True
|
|
||||||
self.enable_density_awareness: bool = True
|
|
||||||
|
|
||||||
def generate_focus_weights(rows: int, cols: int) -> List[List[float]]:
|
|
||||||
weights = []
|
|
||||||
center_x = cols / 2
|
|
||||||
center_y = rows / 2
|
|
||||||
for i in range(rows):
|
|
||||||
row_weights = []
|
|
||||||
for j in range(cols):
|
|
||||||
dx = (j - center_x) / cols
|
|
||||||
dy = (i - center_y) / rows
|
|
||||||
distance = math.sqrt(dx ** 2 + dy ** 2)
|
|
||||||
gaussian = math.exp(-(distance ** 2) / (2 * 0.3 ** 2))
|
|
||||||
row_weights.append(1.0 + 0.6 * gaussian)
|
|
||||||
weights.append(row_weights)
|
|
||||||
return weights
|
|
||||||
|
|
||||||
def calculate_density_impact(map1: List[List[int]], map2: List[List[int]], type_weights: Dict[int, float]) -> List[List[float]]:
|
|
||||||
rows, cols = len(map1), len(map1[0])
|
|
||||||
density_map = [[0.0 for _ in range(cols)] for _ in range(rows)]
|
|
||||||
window_size = 3
|
|
||||||
half_window = window_size // 2
|
|
||||||
|
|
||||||
for i in range(rows):
|
|
||||||
for j in range(cols):
|
|
||||||
density = 0
|
|
||||||
for di in range(-half_window, half_window + 1):
|
|
||||||
for dj in range(-half_window, half_window + 1):
|
|
||||||
ni, nj = i + di, j + dj
|
|
||||||
if 0 <= ni < rows and 0 <= nj < cols:
|
|
||||||
weight1 = type_weights.get(map1[ni][nj], 0.5)
|
|
||||||
weight2 = type_weights.get(map2[ni][nj], 0.5)
|
|
||||||
density += (weight1 + weight2) / 2
|
|
||||||
density_map[i][j] = 1.0 + 0.4 * (density / (window_size ** 2))
|
|
||||||
return density_map
|
|
||||||
|
|
||||||
def calculate_visual_similarity(map1: List[List[int]], map2: List[List[int]], config: VisualSimilarityConfig = None) -> float:
|
|
||||||
if config is None:
|
|
||||||
config = VisualSimilarityConfig()
|
|
||||||
|
|
||||||
if len(map1) != len(map2) or len(map1[0]) != len(map2[0]):
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
rows, cols = len(map1), len(map1[0])
|
|
||||||
total_score = 0.0
|
|
||||||
max_possible_score = 0.0
|
|
||||||
|
|
||||||
focus_weights = generate_focus_weights(rows, cols) if config.enable_visual_focus else [[1.0 for _ in range(cols)] for _ in range(rows)]
|
|
||||||
density_map = calculate_density_impact(map1, map2, config.type_weights) if config.enable_density_awareness else [[1.0 for _ in range(cols)] for _ in range(rows)]
|
|
||||||
|
|
||||||
for i in range(rows):
|
|
||||||
for j in range(cols):
|
|
||||||
type1 = map1[i][j]
|
|
||||||
type2 = map2[i][j]
|
|
||||||
base_weight = max(config.type_weights.get(type1, 0.5), config.type_weights.get(type2, 0.5))
|
|
||||||
spatial_weight = focus_weights[i][j] * density_map[i][j]
|
|
||||||
type_score = 1.0 if type1 == type2 else 0.0
|
|
||||||
|
|
||||||
total_score += type_score * base_weight * spatial_weight
|
|
||||||
max_possible_score += base_weight * spatial_weight
|
|
||||||
|
|
||||||
return total_score / max_possible_score if max_possible_score > 0 else 0.0
|
|
||||||
@ -1,17 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
def random_smooth_onehot(onehot_map, min_main=0.75, max_main=1.0, epsilon=0.25):
|
|
||||||
"""
|
|
||||||
生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动
|
|
||||||
"""
|
|
||||||
C, H, W = onehot_map.shape
|
|
||||||
# 生成主类别的随机概率 (min_main, max_main)
|
|
||||||
main_prob = torch.rand(H, W) * (max_main - min_main) + min_main
|
|
||||||
|
|
||||||
# 计算剩余概率并随机分配到其他类别
|
|
||||||
noise = torch.rand(C, H, W) * epsilon # 随机噪声
|
|
||||||
noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon
|
|
||||||
|
|
||||||
# 计算最终平滑 one-hot 结果
|
|
||||||
smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise
|
|
||||||
return smooth_onehot
|
|
||||||
@ -63,7 +63,7 @@ def convert_dataset_to_images(
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
convert_dataset_to_images(
|
convert_dataset_to_images(
|
||||||
json_path="data/result.json", # 数据集文件
|
json_path="data/result.json", # 数据集文件
|
||||||
tile_folder="tiles2", # 贴图文件夹
|
tile_folder="tiles", # 贴图文件夹
|
||||||
output_folder="map_images", # 输出文件夹
|
output_folder="map_images", # 输出文件夹
|
||||||
tile_size=32 # tile 尺寸
|
tile_size=32 # tile 尺寸
|
||||||
)
|
)
|
||||||
BIN
tiles/10.png
|
Before Width: | Height: | Size: 406 B After Width: | Height: | Size: 699 B |
BIN
tiles/11.png
|
Before Width: | Height: | Size: 291 B |
BIN
tiles/12.png
|
Before Width: | Height: | Size: 341 B |
BIN
tiles/13.png
|
Before Width: | Height: | Size: 396 B |
BIN
tiles/14.png
|
Before Width: | Height: | Size: 289 B |
BIN
tiles/15.png
|
Before Width: | Height: | Size: 344 B After Width: | Height: | Size: 414 B |
BIN
tiles/16.png
|
Before Width: | Height: | Size: 419 B |
BIN
tiles/17.png
|
Before Width: | Height: | Size: 305 B |
BIN
tiles/18.png
|
Before Width: | Height: | Size: 358 B |
BIN
tiles/19.png
|
Before Width: | Height: | Size: 441 B |
BIN
tiles/2.png
|
Before Width: | Height: | Size: 847 B After Width: | Height: | Size: 426 B |
BIN
tiles/20.png
|
Before Width: | Height: | Size: 442 B |
BIN
tiles/21.png
|
Before Width: | Height: | Size: 448 B |
BIN
tiles/22.png
|
Before Width: | Height: | Size: 436 B |
BIN
tiles/23.png
|
Before Width: | Height: | Size: 448 B |
BIN
tiles/24.png
|
Before Width: | Height: | Size: 643 B |
BIN
tiles/25.png
|
Before Width: | Height: | Size: 389 B |
BIN
tiles/26.png
|
Before Width: | Height: | Size: 353 B |
BIN
tiles/27.png
|
Before Width: | Height: | Size: 382 B |
BIN
tiles/28.png
|
Before Width: | Height: | Size: 453 B |
BIN
tiles/29.png
|
Before Width: | Height: | Size: 699 B |
BIN
tiles/3.png
|
Before Width: | Height: | Size: 426 B After Width: | Height: | Size: 368 B |
BIN
tiles/30_1.png
|
Before Width: | Height: | Size: 323 B |
BIN
tiles/30_2.png
|
Before Width: | Height: | Size: 311 B |
BIN
tiles/30_3.png
|
Before Width: | Height: | Size: 312 B |
BIN
tiles/30_4.png
|
Before Width: | Height: | Size: 308 B |
BIN
tiles/4.png
|
Before Width: | Height: | Size: 420 B After Width: | Height: | Size: 406 B |
BIN
tiles/5.png
|
Before Width: | Height: | Size: 422 B After Width: | Height: | Size: 396 B |
BIN
tiles/6.png
|
Before Width: | Height: | Size: 678 B After Width: | Height: | Size: 419 B |
BIN
tiles/7.png
|
Before Width: | Height: | Size: 368 B After Width: | Height: | Size: 441 B |
BIN
tiles/8.png
|
Before Width: | Height: | Size: 365 B After Width: | Height: | Size: 448 B |
BIN
tiles/9.png
|
Before Width: | Height: | Size: 377 B After Width: | Height: | Size: 353 B |
BIN
tiles/999.png
|
Before Width: | Height: | Size: 414 B |
BIN
tiles2/0.png
|
Before Width: | Height: | Size: 1.4 KiB |
BIN
tiles2/1.png
|
Before Width: | Height: | Size: 576 B |
BIN
tiles2/10.png
|
Before Width: | Height: | Size: 699 B |
BIN
tiles2/15.png
|
Before Width: | Height: | Size: 414 B |
BIN
tiles2/2.png
|
Before Width: | Height: | Size: 426 B |
BIN
tiles2/3.png
|
Before Width: | Height: | Size: 368 B |
BIN
tiles2/4.png
|
Before Width: | Height: | Size: 406 B |
BIN
tiles2/5.png
|
Before Width: | Height: | Size: 396 B |
BIN
tiles2/6.png
|
Before Width: | Height: | Size: 419 B |
BIN
tiles2/7.png
|
Before Width: | Height: | Size: 441 B |
BIN
tiles2/8.png
|
Before Width: | Height: | Size: 448 B |
BIN
tiles2/9.png
|
Before Width: | Height: | Size: 353 B |
11
train.sh
@ -1,8 +1,3 @@
|
|||||||
# 从头训练
|
# MaskGIT
|
||||||
python3 -u -m ginka.train_wgan --epochs 20 --curr_epoch 1 --checkpoint 1 >> output.log
|
python3 -u -m ginka.train_maskGIT --epochs 150 --checkpoint 10 >> output_maskGIT.log
|
||||||
# 接续训练
|
python3 -u -m ginka.train_maskGIT --resume true --epochs 150 --checkpoint 10 --state_ginka "result/transformer/ginka-100.pth" >> output_maskGIT.log
|
||||||
python3 -u -m ginka.train_wgan --resume true --epochs 300 --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log
|
|
||||||
|
|
||||||
# rnn
|
|
||||||
python3 -u -m ginka.train_rnn --epochs 150 --checkpoint 10 >> output_rnn.log
|
|
||||||
python3 -u -m ginka.train_rnn --resume true --epochs 150 --checkpoint 10 --state_ginka "result/rnn/ginka-100.pth" >> output_rnn.log
|
|
||||||
|
|||||||