diff --git a/README.md b/README.md index a7766fb..d9fcd9b 100644 --- a/README.md +++ b/README.md @@ -1,73 +1,3 @@ # GINKA 地图生成器 -GINKA Model 是一个用于生成网格状魔塔地图的模型,采用 UNet 网络。 - -GINKA Model 内部集成了 Minamo Model 用做判别器,与 Ginka Model 对抗训练,训练使用 Wasserstein GAN 训练方式。 - -## 贡献 GINKA Model 数据集 - -对于 HTML5 魔塔,如果你想要贡献数据集,需要对你的魔塔进行手动数据处理,流程如下: - -1. 在 `project` 文件夹下创建 `ginka-config.json` 文件,双击进入编辑,粘贴如下模板: - -```json -{ - "clip": { - "defaults": [0, 0, 13, 13], - "special": {} - }, - "mapping": { - "redGem": { - "27": 1 - }, - "blueGem": { - "28": 1 - }, - "greenGem": { - "29": 1 - }, - "yellowGem": { - "30": 1 - }, - "item": { - "47": 1, - "49": 1, - "50": 0, - "51": 1, - "52": 1, - "53": 2 - }, - "potion": { - "31": 100, - "32": 200, - "33": 400, - "34": 800 - }, - "key": { - "21": 0, - "22": 1, - "23": 2, - "24": 2, - "25": 2 - }, - "door": { - "81": 0, - "82": 1, - "83": 2, - "84": 2, - "85": 3, - "86": 2 - }, - "wall": [1, 17], - "decoration": [], - "floor": [87, 88], - "arrow": [91, 92, 93, 94] - }, - "data": {} -} -``` - -其中,`clip` 属性表示你的每张地图的那一部分会被当成数据集,例如填写 `[0, 0, 13, 13]` 就会让坐标为 `(0, 0)`,长宽为 `(13, 13)` 的矩形内容作为数据集。`special` 不用管。注意装饰所使用的贴图是白墙,如果白墙是墙壁的话,需要将白墙设置为墙壁。注意不要忘记保存 - -2. 使用 [在线工具](https://unanmed.github.io/ginka-process) 处理数据,需要给每个地图添加标签,为每个图块分配种类,有一些图块包含多种等级,需要填写正确。 -3. 将 `project` 文件夹打包发给我 +Ginka 地图生成器是专门训练用来生成魔塔地图的工具,采用 `MaskGIT` 模型及生成方法。 diff --git a/data/src/floor.ts b/data/src/floor.ts deleted file mode 100644 index 4f74d81..0000000 --- a/data/src/floor.ts +++ /dev/null @@ -1,295 +0,0 @@ -import { GinkaConfig } from './types'; - -const numMap: Record = { - 0: 0, - 1: 1, - 2: 2, - 91: 30, - 92: 30, - 93: 30, - 94: 30, - 87: 29, - 88: 29 -}; - -export interface Enemy { - num: number; - hp: number; - atk: number; - def: number; -} - -function convert( - map: number[][], - [x, y, w, h]: [number, number, number, number], - config: GinkaConfig, - enemyMap: Record -) { - const clipped: number[][] = []; - - // 1. 裁剪 - for (let ny = y; ny < y + w; ny++) { - const row: number[] = []; - for (let nx = y; nx < x + h; nx++) { - row.push(map[ny][nx]); - } - clipped.push(row); - } - - const res: number[][] = Array.from({ length: clipped.length }, () => - Array.from({ length: clipped[0].length }, () => 0) - ); - - // 2. 初步映射 - for (let nx = 0; nx < w; nx++) { - for (let ny = 0; ny < h; ny++) { - const tile = clipped[ny][nx]; - if (numMap[tile] !== void 0) { - res[ny][nx] = numMap[tile]; - } - } - } - - // 3. 转换一般图块 - const mapping: Record = {}; - const dict = config.mapping; - dict.wall.forEach(v => (mapping[v] = 1)); - dict.decoration.forEach(v => (mapping[v] = 2)); - dict.floor.forEach(v => (mapping[v] = 29)); - dict.arrow.forEach(v => (mapping[v] = 30)); - for (let nx = 0; nx < w; nx++) { - for (let ny = 0; ny < h; ny++) { - const tile = clipped[ny][nx]; - if (mapping[tile] !== void 0) res[ny][nx] = mapping[tile]; - } - } - - // 4. 转换含等级图块 - const redGemSet = new Set(); - const blueGemSet = new Set(); - const greenGemSet = new Set(); - const potionSet = new Set(); - for (let nx = 0; nx < w; nx++) { - for (let ny = 0; ny < h; ny++) { - const tile = clipped[ny][nx]; - if (dict.redGem[tile] !== void 0) { - redGemSet.add(dict.redGem[tile]); - } else if (dict.blueGem[tile] !== void 0) { - blueGemSet.add(dict.blueGem[tile]); - } else if (dict.greenGem[tile] !== void 0) { - greenGemSet.add(dict.greenGem[tile]); - } else if (dict.yellowGem[tile] !== void 0) { - redGemSet.add(dict.yellowGem[tile]); - blueGemSet.add(dict.yellowGem[tile]); - greenGemSet.add(dict.yellowGem[tile]); - } else if (dict.potion[tile] !== void 0) { - potionSet.add(dict.potion[tile]); - } - } - } - const minRedGem = Math.min(...redGemSet); - const maxRedGem = Math.max(...redGemSet); - const minBlueGem = Math.min(...blueGemSet); - const maxBlueGem = Math.max(...blueGemSet); - const minGreenGem = Math.min(...greenGemSet); - const maxGreenGem = Math.max(...greenGemSet); - const minPotion = Math.min(...potionSet); - const maxPotion = Math.max(...potionSet); - - for (let nx = 0; nx < w; nx++) { - for (let ny = 0; ny < h; ny++) { - const tile = clipped[ny][nx]; - if (dict.redGem[tile] !== void 0) { - const value = dict.redGem[tile]; - if (maxRedGem - minRedGem < 1e-8) { - res[ny][nx] = 10; - } else { - const level = Math.min( - Math.floor( - ((value - minRedGem) / (maxRedGem - minRedGem)) * 3 - ), - 2 - ); - res[ny][nx] = 10 + level; - } - } else if (dict.blueGem[tile] !== void 0) { - const value = dict.blueGem[tile]; - if (maxBlueGem - minBlueGem < 1e-8) { - res[ny][nx] = 13; - } else { - const level = Math.min( - Math.floor( - ((value - minBlueGem) / (maxBlueGem - minBlueGem)) * - 3 - ), - 2 - ); - res[ny][nx] = 13 + level; - } - } else if (dict.greenGem[tile] !== void 0) { - const value = dict.greenGem[tile]; - if (maxGreenGem - minGreenGem < 1e-8) { - res[ny][nx] = 16; - } else { - const level = Math.min( - Math.floor( - ((value - minGreenGem) / - (maxGreenGem - minGreenGem)) * - 3 - ), - 2 - ); - res[ny][nx] = 16 + level; - } - } else if (dict.yellowGem[tile] !== void 0) { - const rand = Math.random(); - const value = dict.yellowGem[tile]; - if (rand < 2 / 5) { - if (maxRedGem - minRedGem < 1e-8) { - res[ny][nx] = 10; - } else { - const level = Math.min( - Math.floor( - ((value - minRedGem) / - (maxRedGem - minRedGem)) * - 3 - ), - 2 - ); - res[ny][nx] = 10 + level; - } - } else if (rand < 4 / 5) { - if (maxBlueGem - minBlueGem < 1e-8) { - res[ny][nx] = 13; - } else { - const level = Math.min( - Math.floor( - ((value - minBlueGem) / - (maxBlueGem - minBlueGem)) * - 3 - ), - 2 - ); - res[ny][nx] = 13 + level; - } - } else { - if (maxGreenGem - minGreenGem < 1e-8) { - res[ny][nx] = 16; - } else { - const level = Math.min( - Math.floor( - ((value - minGreenGem) / - (maxGreenGem - minGreenGem)) * - 3 - ), - 2 - ); - res[ny][nx] = 16 + level; - } - } - } else if (dict.potion[tile] !== void 0) { - const value = dict.potion[tile]; - if (maxPotion - minPotion < 1e-8) { - res[ny][nx] = 19; - } else { - const level = Math.min( - Math.floor( - ((value - minPotion) / (maxPotion - minPotion)) * 4 - ), - 3 - ); - res[ny][nx] = 19 + level; - } - } else if (dict.door[tile] !== void 0) { - const level = dict.door[tile]; - res[ny][nx] = 3 + level; - } else if (dict.key[tile] !== void 0) { - const level = dict.key[tile]; - res[ny][nx] = 7 + level; - } else if (dict.item[tile] !== void 0) { - const level = dict.item[tile]; - res[ny][nx] = 22 + level; - } - } - } - - // 5. 转换怪物 - const enemySet = new Set(); - for (let nx = 0; nx < w; nx++) { - for (let ny = 0; ny < h; ny++) { - const tile = clipped[ny][nx]; - const enemy = enemyMap[tile]; - if (!enemy) continue; - enemySet.add({ ...enemy, num: tile }); - } - } - const enemyArr = [...enemySet]; - enemyArr.sort((a, b) => a.num - b.num); - - const attrs = [...enemySet].map(v => (v.atk + v.def) * v.hp); - const maxAttr = Math.max(...attrs); - const minAttr = Math.min(...attrs); - const delta = maxAttr - minAttr; - for (let ny = 0; ny < w; ny++) { - for (let nx = 0; nx < h; nx++) { - const tile = clipped[ny][nx]; - const enemy = enemyMap[tile]; - if (!enemy) continue; - // 替换为弱怪/中怪/强怪 - const attr = (enemy.atk + enemy.def) * enemy.hp; - const ad = attr - minAttr; - if (ad < delta / 3 || delta === 0) { - res[ny][nx] = 26; - } else if (ad < (delta * 2) / 3) { - res[ny][nx] = 27; - } else { - res[ny][nx] = 28; - } - } - } - - return res; -} - -export function convertFloor( - map: number[][], - clip: [number, number, number, number], - config: GinkaConfig, - enemyMap: Record -) { - return convert(map, clip, config, enemyMap); -} - -export function getCount(map: number[][], tiles: number[]) { - let n = 0; - map.flat().forEach(v => { - if (tiles.includes(v)) n++; - }); - return n; -} - -export function getRatio(map: number[][], tiles: number[]) { - const area = map.length * map[0].length; - return getCount(map, tiles) / area; -} - -function range(from: number, to: number) { - const length = to - from; - return Array.from({ length }, (_, i) => i + from); -} - -export function getGinkaRatio(map: number[][]): number[] { - const arr: number[] = Array(16).fill(0); - arr[0] = getRatio(map, [1, ...range(3, 32)]); - arr[1] = getRatio(map, [1]); - arr[2] = getRatio(map, [2]); - arr[3] = getRatio(map, [3, 4, 5, 6]); - arr[4] = getRatio(map, [26, 27, 28]); - arr[5] = getRatio(map, range(7, 26)); - arr[6] = getRatio(map, range(10, 19)); - arr[7] = getRatio(map, range(19, 23)); - arr[8] = getRatio(map, [7, 8, 9]); - arr[9] = getCount(map, [23, 24, 25]); - arr[10] = getCount(map, [29, 30]); - return arr; -} diff --git a/data/src/gan.ts b/data/src/gan.ts deleted file mode 100644 index 24ec44e..0000000 --- a/data/src/gan.ts +++ /dev/null @@ -1,177 +0,0 @@ -import { createConnection, Socket } from 'net'; -import { chooseFrom, FloorData, readOne } from './utils'; -import { MinamoTrainData } from './types'; -import { generateTrainData } from './process/minamo'; - -const SOCKET_FILE = '../tmp/ginka_uds'; -const [refer, replayPath = '../datasets/replay.bin'] = process.argv.slice(2); - -let id = 0; - -function readMap(count: number, arr: number[], h: number, w: number) { - const area = w * h; - - const maps: number[][][] = Array.from({ - length: count - }).map(() => { - return Array.from({ length: h }).map(() => { - return Array.from({ length: w }).fill(0); - }); - }); - - arr.forEach((v, i) => { - const n = Math.floor(i / area); - const y = Math.floor((i % area) / w); - const x = i % w; - maps[n][y][x] = v; - }); - - return maps; -} - -function generateGANData( - keys: string[], - refer: Map, - map: number[][] -) { - const id2 = `$${id++}`; - const toTrain = chooseFrom(keys, 4); - const data = toTrain.map(v => { - const floor = refer.get(v); - if (!floor) return []; - const size1: [number, number] = [floor.map[0].length, floor.map.length]; - const size2: [number, number] = [map[0].length, map.length]; - if (size1[0] !== size2[0] || size1[1] !== size2[1]) return []; - - return generateTrainData(v, id2, floor.map, map, size1, false, false, false); - }); - return data.flat(); -} - -const enum ReceiverStatus { - Header, - Content -} - -class DataReceiver { - static active?: DataReceiver - /** 接收状态 */ - private status: ReceiverStatus = ReceiverStatus.Header; - - private received: number[] = [] - private count: number = 0; - private h: number = 0; - private w: number = 0; - - receive(buf: Buffer): [number[][][], number, number, number] | null { - // 数据通讯 node 输入协议,单位字节: - // 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type. - switch (this.status) { - case ReceiverStatus.Header: - this.count = buf.readInt16BE(); - this.h = buf.readInt8(2); - this.w = buf.readInt8(3); - this.received.push(...buf.subarray(4)); - this.status = ReceiverStatus.Content; - break; - case ReceiverStatus.Content: - this.received.push(...buf); - break - } - if (this.received.length === this.count * this.h * this.w) { - delete DataReceiver.active; - return [readMap(this.count, this.received, this.h, this.w), this.count, this.h, this.w]; - } else { - return null; - } - } - - static check(buf: Buffer) { - if (this.active) { - return this.active.receive(buf); - } else { - this.active = new DataReceiver(); - return this.active.receive(buf); - } - } -} - -(async () => { - const referTower = await readOne(refer); - const keys = [...referTower.keys()]; - - const client = createConnection(SOCKET_FILE, () => { - console.log(`UDS IPC connected successfully.`); - }); - - client.on('data', async buffer => { - const data = DataReceiver.check(buffer); - if (!data) return; - - const [map, count, h, w] = data; - const simData = map.map(v => generateGANData(keys, referTower, v)); - const rc = 0; - const compareData = simData.flat(); - - // 数据通讯 node 输出协议,单位字节: - // 2 - Tensor count; 2 - Replay count. Replay is right behind train data; - // 1*tc - Compare count for every map tensor delivered. - // 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo; - // N*1*H*W - Compare map for every map tensor. rc*2*H*W - Replay map tensor. - const toSend = Buffer.alloc( - 2 + // Tensor count - 2 + // Replay count - 1 * count + // Compare count - 2 * 4 * (compareData.length + rc) + // Similarity data - compareData.length * 1 * h * w + // Compare map - rc * 2 * h * w, // Replay map - 0 - ); - console.log( - 2, - 2, - count, - 2 * 4 * (compareData.length + rc), - compareData.length * 1 * h * w, - rc * 2 * h * w, - compareData.length, - rc - ); - - let offset = 0; - toSend.writeInt16BE(count); // Tensor count - toSend.writeInt16BE(0, 2); // Replay count - offset += 2 + 2; - // Compare count - toSend.set( - simData.map(v => v.length), - offset - ); - offset += 1 * count; - // Similarity data - compareData.forEach(v => { - // console.log(v.visionSimilarity, v.topoSimilarity); - - toSend.writeFloatBE(v.visionSimilarity, offset); - offset += 4; - toSend.writeFloatBE(v.topoSimilarity, offset); - offset += 4; - }); - // Compare map - toSend.set( - new Uint8Array(compareData.map(v => v.map1).flat(3)), - offset // Set from Compare map - ); - offset += compareData.length * 1 * h * w; - - client.write(toSend); - }); - - client.on('end', () => { - console.log(`Connection lose.`); - }); - - client.on('error', () => { - client.end(); - }); -})(); diff --git a/data/src/ginka.ts b/data/src/ginka.ts deleted file mode 100644 index 39723c2..0000000 --- a/data/src/ginka.ts +++ /dev/null @@ -1,16 +0,0 @@ -import { writeFile } from 'fs-extra'; -import { getAllFloors, parseTowerInfo } from './utils'; -import { parseGinka } from './process/ginka'; - -const [output, ...list] = process.argv.slice(2); - -(async () => { - const towers = await Promise.all( - list.map(v => parseTowerInfo(v, 'ginka-config.json')) - ); - const floors = await getAllFloors(...towers); - const results = parseGinka(floors); - await writeFile(output, JSON.stringify(results, void 0), 'utf-8'); - const size = Object.keys(results.data).length; - console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个地图`); -})(); diff --git a/data/src/minamo.ts b/data/src/minamo.ts deleted file mode 100644 index b0cbe56..0000000 --- a/data/src/minamo.ts +++ /dev/null @@ -1,40 +0,0 @@ -import { writeFile } from 'fs-extra'; -import { readOne, getAllFloors, parseTowerInfo } from './utils'; -import { generateAssignedData, parseMinamo } from './process/minamo'; - -const [output, ...list] = process.argv.slice(2); -// 判断 assigned 模式,此模式下只会对前两个塔处理,会在这两个塔之间对比,而单个塔的地图不会对比 -const assigned = list.at(-1)?.startsWith('assigned'); -const assignedCount = parseAssigned(list.at(-1)!); -if (assigned) list.pop(); - -function parseAssigned(arg: string): [number, number] { - const p = arg.slice(9); - const [a, b] = p.split(':'); - return [parseInt(a) || 100, parseInt(b) || 100]; -} - -(async () => { - if (!assigned) { - const towers = await Promise.all( - list.map(v => parseTowerInfo(v, 'minamo-config.json')) - ); - const floors = await getAllFloors(...towers); - const results = parseMinamo(floors); - await writeFile(output, JSON.stringify(results, void 0), 'utf-8'); - const size = Object.keys(results.data).length; - console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`); - } else { - const [tower1, tower2] = list; - if (!tower1 || !tower2) { - console.log(`⚠️ assigned 模式下必须传入两个塔!`); - return; - } - const data1 = await readOne(tower1); - const data2 = await readOne(tower2); - const results = generateAssignedData(data1, data2, assignedCount); - await writeFile(output, JSON.stringify(results, void 0), 'utf-8'); - const size = Object.keys(results.data).length; - console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`); - } -})(); diff --git a/data/src/process/ginka.ts b/data/src/process/ginka.ts deleted file mode 100644 index 89431d8..0000000 --- a/data/src/process/ginka.ts +++ /dev/null @@ -1,36 +0,0 @@ -import { SingleBar, Presets } from 'cli-progress'; -import { getGinkaRatio } from 'src/floor'; -import { GinkaTrainData, GinkaConfig, GinkaDataset } from 'src/types'; -import { FloorData } from 'src/utils'; - -export function parseGinka(data: Map) { - const resolved: Record = {}; - - const progress = new SingleBar({}, Presets.shades_classic); - progress.start(data.size, 0); - let i = 0; - - data.forEach((floor, key) => { - const config = floor.config as GinkaConfig; - const data = config.data[floor.id] ?? { - tag: Array(64).fill(0) - }; - resolved[key] = { - map: floor.map, - size: [floor.map[0].length, floor.map.length], - tag: data.tag, - val: getGinkaRatio(floor.map) - }; - i++; - progress.update(i); - }); - - const dataset: GinkaDataset = { - datasetId: Math.floor(Math.random() * 1e12), - data: resolved - }; - - progress.stop(); - - return dataset; -} diff --git a/data/src/process/minamo.ts b/data/src/process/minamo.ts deleted file mode 100644 index 085b009..0000000 --- a/data/src/process/minamo.ts +++ /dev/null @@ -1,406 +0,0 @@ -import { SingleBar, Presets } from 'cli-progress'; -import { compareMap } from 'src/topology/compare'; -import { directions, tileType } from 'src/topology/graph'; -import { rotateMap, mirrorMapX, mirrorMapY } from 'src/topology/transform'; -import { MinamoTrainData, MinamoDataset } from 'src/types'; -import { chooseFrom, FloorData } from 'src/utils'; -import { calculateVisualSimilarity } from 'src/vision/similarity'; - -function chooseN(maxCount: number, n: number) { - return chooseFrom( - Array(maxCount) - .fill(0) - .map((_, i) => i), - n - ); -} - -function choosePair(n: number, max: number = 1000) { - const totalCount = Math.round((n * (n - 1)) / 2); - const count = Math.min(totalCount, max); - const pairs: number[] = []; - for (let i = 0; i < n; i++) { - for (let j = i + 1; j < n; j++) { - pairs.push(i * n + j); - } - } - // 直接打乱后取前 count 个 - for (let i = pairs.length - 1; i > 0; i--) { - let randIndex = Math.floor(Math.random() * (i + 1)); - [pairs[i], pairs[randIndex]] = [pairs[randIndex], pairs[i]]; - } - - return pairs.slice(0, count); -} - -function transform(map: number[][], rot: number, flip: number) { - let res = map; - for (let i = 0; i < rot; i++) { - res = rotateMap(res); - } - if (flip & 0b01) { - res = mirrorMapX(res); - } - if (flip & 0b10) { - res = mirrorMapY(res); - } - return res; -} - -function generateTransformData( - id1: string, - id2: string, - map1: number[][], - map2: number[][], - simi: number -) { - const types: [rot: number, flip: number][] = []; - for (const rot of [0, 1, 2, 3]) { - for (const flip of [0b00, 0b01, 0b10, 0b11]) { - if (rot === 0 && flip === 0) continue; - types.push([rot, flip]); - } - } - // 随机抽取最多一个 - const trans = chooseFrom(types, Math.floor(Math.random() * 1)); - return trans - .map(([rot, flip]) => { - const com1 = `${id1}.${rot}.${flip}:${id1}`; - const com2 = `${id1}.${rot}.${flip}:${id2}`; - const com3 = `${id2}.${rot}.${flip}:${id1}`; - const com4 = `${id2}.${rot}.${flip}:${id2}`; - const choose = chooseFrom( - [com1, com2, com3, com4], - Math.floor(Math.random() * 2) - ); - const res: [id: string, data: MinamoTrainData][] = []; - if (choose.includes(com1)) { - const t = transform(map1, rot, flip); - res.push([ - com1, - { - map1: t, - map2: map1, - topoSimilarity: 1, - visionSimilarity: calculateVisualSimilarity(map1, t), - size: [map1[0].length, map1.length] - } - ]); - } - if (choose.includes(com2)) { - const t = transform(map1, rot, flip); - res.push([ - com2, - { - map1: t, - map2: map2, - topoSimilarity: simi, - visionSimilarity: calculateVisualSimilarity(t, map2), - size: [map1[0].length, map1.length] - } - ]); - } - if (choose.includes(com3)) { - const t = transform(map2, rot, flip); - res.push([ - com3, - { - map1: t, - map2: map1, - topoSimilarity: simi, - visionSimilarity: calculateVisualSimilarity(t, map1), - size: [map1[0].length, map1.length] - } - ]); - } - if (choose.includes(com4)) { - const t = transform(map2, rot, flip); - res.push([ - com4, - { - map1: t, - map2: map2, - topoSimilarity: 1, - visionSimilarity: calculateVisualSimilarity(t, map2), - size: [map1[0].length, map1.length] - } - ]); - } - - return res; - }) - .flat(); -} - -function generateSimilarData(id: string, map: number[][]) { - // 生成最多两个微调地图 - const width = map[0].length; - const height = map.length; - const num = Math.floor(Math.random() * 2); - const res: [id: string, data: MinamoTrainData][] = []; - - for (let i = 0; i < num; i++) { - const clone = map.map(v => v.slice()); - const prob = Math.random() * 0.3; - for (let ny = 0; ny < height; ny++) { - for (let nx = 0; nx < width; nx++) { - if (Math.random() > prob) { - // 有一定的概率进行微调 - continue; - } - if (Math.random() < 0.2) { - // 20% 概率与旁边图块互换位置 - const [dx, dy] = - directions[ - Math.floor(Math.random() * directions.length) - ]; - const px = nx + dx; - const py = ny + dy; - if (px < 0 || px >= width || py < 0 || py >= height) { - continue; - } - [clone[ny][nx], clone[py][px]] = [ - clone[py][px], - clone[ny][nx] - ]; - } else { - // 80% 概率替换当前图块 - clone[ny][nx] = Math.floor(Math.random() * tileType.size); - } - } - } - const id2 = `${id}.S${i}`; - const sid = `${id}:${id2}`; - const simi = compareMap(id, id2, map, clone); - - res.push([ - sid, - { - map1: map, - map2: clone, - size: [width, height], - topoSimilarity: simi, - visionSimilarity: calculateVisualSimilarity(map, clone) - } - ]); - } - return res; -} - -export function generateTrainData( - id1: string, - id2: string, - map1: number[][], - map2: number[][], - size: [number, number], - hasSelf: boolean = true, - hasTransform: boolean = true, - hasSimilar: boolean = true -) { - const topoSimilarity = compareMap(id1, id2, map1, map2); - const visionSimilarity = calculateVisualSimilarity(map1, map2); - const train: MinamoTrainData = { - map1, - map2, - topoSimilarity, - visionSimilarity, - size: size - }; - const data: MinamoTrainData[] = []; - data.push(train); - if (hasSelf) { - // 自身与自身对比的训练集,保证模型对相同地图输出 1 - const self1 = `${id1}:${id1}`; - const self2 = `${id2}:${id2}`; - const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1)); - if (selfTrain.includes(self1)) { - const selfTrain1: MinamoTrainData = { - map1: map1, - map2: map1, - topoSimilarity: 1, - visionSimilarity: 1, - size: size - }; - data.push(selfTrain1); - } - if (selfTrain.includes(self2)) { - const selfTrain2: MinamoTrainData = { - map1: map2, - map2: map2, - topoSimilarity: 1, - visionSimilarity: 1, - size: size - }; - data.push(selfTrain2); - } - } - if (hasTransform) { - const transform = generateTransformData( - id1, - id2, - map1, - map2, - topoSimilarity - ); - data.push(...transform.map(v => v[1])) - } - if (hasSimilar) { - const similar = generateSimilarData(id1, map1); - data.push(...similar.map(v => v[1])) - } - return data; -} - -export function generatePair( - data: Record, - id1: string, - id2: string, - map1: number[][], - map2: number[][], - size: [number, number] -) { - const topoSimilarity = compareMap(id1, id2, map1, map2); - const visionSimilarity = calculateVisualSimilarity(map1, map2); - const train: MinamoTrainData = { - map1, - map2, - topoSimilarity, - visionSimilarity, - size: size - }; - data[`${id1}:${id2}`] = train; - // 自身与自身对比的训练集,保证模型对相同地图输出 1 - const self1 = `${id1}:${id1}`; - const self2 = `${id2}:${id2}`; - const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1)); - if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) { - const selfTrain1: MinamoTrainData = { - map1: map1, - map2: map1, - topoSimilarity: 1, - visionSimilarity: 1, - size: size - }; - data[`${id1}:${id1}`] = selfTrain1; - } - if (selfTrain.includes(self2) && !data[`${id2}:${id2}`]) { - const selfTrain2: MinamoTrainData = { - map1: map2, - map2: map2, - topoSimilarity: 1, - visionSimilarity: 1, - size: size - }; - data[`${id2}:${id2}`] = selfTrain2; - } - // 翻转、旋转训练集 - Object.assign( - data, - Object.fromEntries( - generateTransformData(id1, id2, map1, map2, topoSimilarity) - ) - ); - // 地图微调训练集 - Object.assign(data, Object.fromEntries(generateSimilarData(id1, map1))); -} - -function generateDataset( - floors: Map, - pairs: number[], - floorIds: string[] -): Record { - const data: Record = {}; - - const progress = new SingleBar({}, Presets.shades_classic); - - progress.start(pairs.length, 0); - - pairs.forEach((v, i) => { - const num1 = Math.floor(v / floorIds.length); - const num2 = v % floorIds.length; - const id1 = floorIds[num1]; - const id2 = floorIds[num2]; - const map1 = floors.get(id1)?.map; - const map2 = floors.get(id2)?.map; - if (!map1 || !map2) return; - const [w1, h1] = [map1[0].length, map1.length]; - const [w2, h2] = [map2[0].length, map2.length]; - if (w1 !== w2 || h1 !== h2) return; - generatePair(data, id1, id2, map1, map2, [w1, h1]); - progress.update(i + 1); - }); - - progress.stop(); - - return data; -} - -export function parseMinamo(data: Map): MinamoDataset { - const length = data.size; - const totalCount = Math.round((length * (length - 1)) / 2); - - const pairs = choosePair(length, 10000); - - console.log( - `✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${pairs.length} 个组合` - ); - - const trainData = generateDataset(data, pairs, [...data.keys()]); - - const dataset: MinamoDataset = { - datasetId: Math.floor(Math.random() * 1e12), - data: trainData - }; - - return dataset; -} - -export function generateAssignedData( - data1: Map, - data2: Map, - count: [number, number] -): MinamoDataset { - const length = data1.size + data2.size; - const totalCount = data1.size * data2.size; - const count1 = Math.min(count[0], data1.size); - const count2 = Math.min(count[1], data2.size); - const keys1 = [...data1.keys()]; - const keys2 = [...data2.keys()]; - const choose1 = chooseFrom(keys1, count1); - - const trainData: Record = {}; - - console.log( - `✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${ - count1 * count2 - } 个组合` - ); - - const progress = new SingleBar({}, Presets.shades_classic); - progress.start(count1 * count2, 0); - let n = 0; - - for (const key1 of choose1) { - const choose2 = chooseFrom(keys2, count2); - for (const key2 of choose2) { - const { map: map1 } = data1.get(key1)!; - const { map: map2 } = data2.get(key2)!; - if (!map1 || !map2) continue; - const [w1, h1] = [map1[0].length, map1.length]; - const [w2, h2] = [map2[0].length, map2.length]; - if (w1 !== w2 || h1 !== h2) continue; - generatePair(trainData, key1, key2, map1, map2, [w1, h1]); - n++; - progress.update(n); - } - } - - progress.stop(); - - const dataset: MinamoDataset = { - datasetId: Math.floor(Math.random() * 1e12), - data: trainData - }; - - return dataset; -} diff --git a/data/src/review.ts b/data/src/review.ts deleted file mode 100644 index 989da4a..0000000 --- a/data/src/review.ts +++ /dev/null @@ -1,38 +0,0 @@ -import { readFile, writeFile } from 'fs-extra'; -import { chooseFrom, DatasetMergable, mergeDataset } from './utils'; - -const [target, ...review] = process.argv.slice(2); -const n = getNum(); - -function getNum() { - const last = review.at(-1); - if (!last) return 1000; - else { - const n = parseInt(last); - if (!n) return 1000; - else { - review.pop(); - return n; - } - } -} - -(async () => { - const datas = await Promise.all( - review.map(async v => { - const file = await readFile(v, 'utf-8'); - return JSON.parse(file) as DatasetMergable; - }) - ); - const targetFile = await readFile(target, 'utf-8'); - const targetData = JSON.parse(targetFile) as DatasetMergable; - const merged = mergeDataset(true, ...datas); - const keys = Object.keys(merged.data); - const toReview = chooseFrom(keys, n); - const reviewData: DatasetMergable = { - datasetId: Math.floor(Math.random() * 1e12), - data: Object.fromEntries(toReview.map(v => [v, merged.data[v]])) - }; - const reviewed = mergeDataset(false, targetData, reviewData); - await writeFile(target, JSON.stringify(reviewed), 'utf-8'); -})(); diff --git a/data/src/topology/test.ts b/data/src/topology/test.ts deleted file mode 100644 index 24d5ffb..0000000 --- a/data/src/topology/test.ts +++ /dev/null @@ -1,95 +0,0 @@ -import { buildTopologicalGraph } from './graph'; -import { mirrorMapX, mirrorMapY, rotateMap } from './transform'; -import { overallSimilarity } from './similarity'; - -(() => { - // MT3 - const map1 = [ - [1, 1, 1, 1, 1, 1, 1], - [1, 3, 1, 10, 7, 3, 1], - [1, 6, 1, 5, 1, 7, 1], - [1, 9, 1, 8, 1, 6, 1], - [1, 5, 8, 0, 1, 7, 1], - [1, 2, 1, 5, 7, 10, 1], - [1, 1, 1, 1, 1, 1, 1] - ]; - // MT6 - const map2 = [ - [1, 1, 1, 1, 1, 1, 1], - [1, 5, 6, 4, 1, 10, 1], - [1, 6, 1, 9, 1, 5, 1], - [1, 8, 0, 6, 0, 8, 1], - [1, 5, 1, 10, 1, 2, 1], - [1, 9, 3, 1, 4, 9, 1], - [1, 1, 1, 1, 1, 1, 1] - ]; - // MT8 - // const map2 = [ - // [1, 1, 1, 1, 1, 1, 1], - // [1, 5, 8, 10, 7, 2, 1], - // [1, 2, 1, 5, 1, 7, 1], - // [1, 3, 1, 3, 6, 4, 1], - // [1, 6, 1, 6, 1, 8, 1], - // [1, 10, 7, 5, 1, 5, 1], - // [1, 1, 1, 1, 1, 1, 1] - // ]; - // MT3 微调 - // const map2 = [ - // [1, 1, 1, 1, 1, 1, 1], - // [1, 3, 1, 10, 7, 3, 1], - // [1, 6, 1, 5, 1, 7, 1], - // [1, 9, 1, 8, 1, 6, 1], - // [1, 5, 8, 0, 1, 7, 1], - // [1, 2, 1, 4, 7, 10, 1], - // [1, 1, 1, 1, 1, 1, 1] - // ]; - - // 1. 两张图与自身对比 - const graph1 = buildTopologicalGraph(map1); - const graph2 = buildTopologicalGraph(map2); - - console.log(`map1 vs map1: ${overallSimilarity(graph1, graph1)}`); - console.log(`map2 vs map2: ${overallSimilarity(graph2, graph2)}`); - - // 2. 两张图相互对比 - console.log(`map1 vs map2: ${overallSimilarity(graph1, graph2)}`); - console.log(`map2 vs map1: ${overallSimilarity(graph2, graph1)}`); - - // 3. x镜像对比 - const xFlipped1 = mirrorMapX(map1); - const xFlipped2 = mirrorMapX(map2); - const graphX1 = buildTopologicalGraph(xFlipped1); - const graphX2 = buildTopologicalGraph(xFlipped2); - console.log(`map1:x vs map1: ${overallSimilarity(graphX1, graph1)}`); - console.log(`map1:x vs map2: ${overallSimilarity(graphX1, graph2)}`); - console.log(`map1 vs map2:x: ${overallSimilarity(graph1, graphX2)}`); - console.log(`map2:x vs map2: ${overallSimilarity(graphX2, graph2)}`); - console.log(`map2:x vs map1: ${overallSimilarity(graphX2, graph2)}`); - console.log(`map2 vs map1:x: ${overallSimilarity(graph2, graphX1)}`); - - // 4. y镜像对比 - const yFlipped1 = mirrorMapY(map1); - const yFlipped2 = mirrorMapY(map2); - const graphY1 = buildTopologicalGraph(yFlipped1); - const graphY2 = buildTopologicalGraph(yFlipped2); - console.log(`map1:y vs map1: ${overallSimilarity(graphY1, graph1)}`); - console.log(`map1:y vs map2: ${overallSimilarity(graphY1, graph2)}`); - console.log(`map1 vs map2:y: ${overallSimilarity(graph1, graphY2)}`); - console.log(`map2:y vs map2: ${overallSimilarity(graphY2, graph2)}`); - console.log(`map2:y vs map1: ${overallSimilarity(graphY2, graph1)}`); - console.log(`map2 vs map1:y: ${overallSimilarity(graph2, graphY1)}`); - - // 5. xy 镜像混合对比 - console.log(`map1:x vs map1:y: ${overallSimilarity(graphX1, graphY1)}`); - console.log(`map1:y vs map2:x: ${overallSimilarity(graphY1, graphX2)}`); - console.log(`map1:x vs map2:x: ${overallSimilarity(graphX1, graphX2)}`); - console.log(`map1:x vs map2:y: ${overallSimilarity(graphX1, graphY2)}`); - - // 6. 旋转对比 - const rot901 = rotateMap(map1); - const rot902 = rotateMap(map2); - const graph901 = buildTopologicalGraph(rot901); - const graph902 = buildTopologicalGraph(rot902); - console.log(`map1:90 vs map1: ${overallSimilarity(graph1, graph901)}`); - console.log(`map2:90 vs map2: ${overallSimilarity(graph2, graph902)}`); -})(); diff --git a/data/src/topology/transform.ts b/data/src/topology/transform.ts deleted file mode 100644 index 9afef83..0000000 --- a/data/src/topology/transform.ts +++ /dev/null @@ -1,13 +0,0 @@ -export function mirrorMapX(map: number[][]) { - return map.map(v => [...v].reverse()); -} - -export function mirrorMapY(map: number[][]) { - return [...map].reverse(); -} - -export function rotateMap(map: number[][]) { - return [ - ...map[0].map((_, colIndex) => map.map(row => row[colIndex])) - ].reverse(); -} diff --git a/data/src/utils.ts b/data/src/utils.ts index d11a879..a6ef14f 100644 --- a/data/src/utils.ts +++ b/data/src/utils.ts @@ -1,7 +1,6 @@ import { readFile } from 'fs-extra'; import { join } from 'path'; import { BaseConfig, GinkaConfig, TowerInfo } from './types'; -import { convertFloor } from './floor'; export interface DatasetMergable { datasetId: number; @@ -77,81 +76,6 @@ export async function parseTowerInfo( }; } -export async function getAllFloors(...info: TowerInfo[]) { - const floorData = await Promise.all( - info.map(async tower => { - // 获取必要信息 - const enemyFile = await readFile( - join(tower.path, 'enemys.js'), - 'utf-8' - ); - const mapFile = await readFile( - join(tower.path, 'maps.js'), - 'utf-8' - ); - const enemyMap = JSON.parse( - enemyFile.split('\n').slice(1).join('\n') - ) as Record; - const mapData = JSON.parse( - mapFile.split('\n').slice(1).join('\n') - ) as Record; - const enemyNumMap: Record = {}; - // 将怪物转化为数字映射 - for (const [key, value] of Object.entries(mapData)) { - if (value.cls === 'enemys') { - enemyNumMap[parseInt(key)] = enemyMap[value.id]; - } - } - - return Promise.all( - tower.floorIds.map(async id => { - const floorFile = await readFile( - join(tower.path, 'floors', `${id}.js`), - 'utf-8' - ); - try { - const data = JSON.parse( - floorFile - // .replaceAll("'", '"') - .slice(floorFile.indexOf('=') + 1) - ); - - const map = data.map as number[][]; - // 裁剪地图 - const { clip } = tower.config; - const area = clip.special[id] ?? clip.defaults; - - return convertFloor( - map, - area, - tower.config as GinkaConfig, - enemyNumMap - ); - } catch (e) { - console.log( - `Error when processing '${tower.name}' '${id}'` - ); - throw e; - } - }) - ); - }) - ); - const maps: Map = new Map(); - floorData.forEach((tower, tid) => { - const name = info[tid].name; - tower.forEach((map, mid) => { - const floorId = info[tid].floorIds[mid]; - maps.set(`${name}::${floorId}`, { - map, - id: floorId, - config: info[tid].config - }); - }); - }); - return maps; -} - export function mergeFloorIds(...info: TowerInfo[]) { const ids: string[] = []; info.forEach(v => { @@ -160,14 +84,6 @@ export function mergeFloorIds(...info: TowerInfo[]) { return ids; } -export async function readOne(path: string) { - if (path.endsWith('.json')) { - return fromJSON(path); - } else { - return getAllFloors(await parseTowerInfo(path, 'minamo-config.json')); - } -} - export async function fromJSON(path: string) { const file = await readFile(path, 'utf-8'); const data = JSON.parse(file) as Record; diff --git a/data/src/vision/similarity.ts b/data/src/vision/similarity.ts deleted file mode 100644 index 5caab99..0000000 --- a/data/src/vision/similarity.ts +++ /dev/null @@ -1,145 +0,0 @@ -interface VisualSimilarityConfig { - // 类型重要性权重(需根据游戏设定调整) - typeWeights: { [key: number]: number }; - // 是否启用视觉焦点增强 - enableVisualFocus: boolean; - // 是否启用密度感知 - enableDensityAwareness: boolean; -} - -const DEFAULT_CONFIG: VisualSimilarityConfig = { - typeWeights: { - 0: 0.2, // 空地 - 1: 0.3, // 墙壁 - 2: 0.6, // 钥匙 - 3: 0.7, // 红宝石 - 4: 0.7, // 蓝宝石 - 5: 0.5, // 血瓶 - 6: 0.4, // 门 - 7: 0.5, // 弱怪 - 8: 0.6, // 中怪 - 9: 0.6, // 强怪 - 10: 0.4, // 楼梯 - 11: 0.4, // 箭头 - 12: 0.7 // 道具 - }, - enableVisualFocus: true, - enableDensityAwareness: true -}; - -export function calculateVisualSimilarity( - map1: number[][], - map2: number[][], - config = DEFAULT_CONFIG -): number { - // 尺寸校验 - if (map1.length !== map2.length || map1[0]?.length !== map2[0]?.length) { - return 0; // 或抛出异常 - } - - const rows = map1.length; - const cols = map1[0].length; - let totalScore = 0; - let maxPossibleScore = 0; - - // 视觉焦点权重图 - const focusWeights = config.enableVisualFocus - ? generateFocusWeights(rows, cols) - : Array(rows) - .fill(1) - .map(() => Array(cols).fill(1)); - - // 类型密度分布计算 - const densityMap = config.enableDensityAwareness - ? calculateDensityImpact(map1, map2, config.typeWeights) - : Array(rows) - .fill(1) - .map(() => Array(cols).fill(1)); - - for (let i = 0; i < rows; i++) { - for (let j = 0; j < cols; j++) { - const type1 = map1[i][j]; - const type2 = map2[i][j]; - - // 基础类型权重 - const baseWeight = Math.max( - config.typeWeights[type1] || 0.5, - config.typeWeights[type2] || 0.5 - ); - - // 空间权重组合 - const spatialWeight = focusWeights[i][j] * densityMap[i][j]; - - // 类型匹配得分 - const typeScore = type1 === type2 ? 1 : 0; - - totalScore += typeScore * baseWeight * spatialWeight; - maxPossibleScore += baseWeight * spatialWeight; - } - } - - return maxPossibleScore > 0 ? totalScore / maxPossibleScore : 0; -} - -/** - * 生成视觉焦点权重图(基于人类视觉注意力分布) - */ -function generateFocusWeights(rows: number, cols: number): number[][] { - const weights = []; - const centerX = cols / 2; - const centerY = rows / 2; - const maxDist = Math.sqrt(centerX ** 2 + centerY ** 2) * 0.7; - - for (let i = 0; i < rows; i++) { - const rowWeights = []; - for (let j = 0; j < cols; j++) { - // 使用高斯分布模拟视觉焦点 - const dx = (j - centerX) / cols; - const dy = (i - centerY) / rows; - const distance = Math.sqrt(dx ** 2 + dy ** 2); - const gaussian = Math.exp(-(distance ** 2) / (2 * 0.3 ** 2)); - rowWeights.push(1.0 + 0.6 * gaussian); // 中心区域最高1.6倍权重 - } - weights.push(rowWeights); - } - return weights; -} - -/** - * 计算类型密度影响权重 - */ -function calculateDensityImpact( - map1: number[][], - map2: number[][], - typeWeights: { [key: number]: number } -): number[][] { - const rows = map1.length; - const cols = map1[0].length; - const densityMap = Array(rows) - .fill(0) - .map(() => Array(cols).fill(0)); - - // 滑动窗口分析局部密度 - const windowSize = 3; - const halfWindow = Math.floor(windowSize / 2); - - for (let i = 0; i < rows; i++) { - for (let j = 0; j < cols; j++) { - let density = 0; - for (let di = -halfWindow; di <= halfWindow; di++) { - for (let dj = -halfWindow; dj <= halfWindow; dj++) { - const ni = i + di; - const nj = j + dj; - if (ni >= 0 && ni < rows && nj >= 0 && nj < cols) { - const weight1 = typeWeights[map1[ni][nj]] || 0.5; - const weight2 = typeWeights[map2[ni][nj]] || 0.5; - density += (weight1 + weight2) / 2; - } - } - } - // 密度权重:高密度区域增强对比度 - densityMap[i][j] = 1.0 + 0.4 * (density / windowSize ** 2); - } - } - return densityMap; -} diff --git a/data/src/vision/test.ts b/data/src/vision/test.ts deleted file mode 100644 index 388d355..0000000 --- a/data/src/vision/test.ts +++ /dev/null @@ -1,74 +0,0 @@ -import { calculateVisualSimilarity } from './similarity'; - -(() => { - // MT3 - const map1 = [ - [1, 1, 1, 1, 1, 1, 1], - [1, 3, 1, 10, 7, 3, 1], - [1, 6, 1, 5, 1, 7, 1], - [1, 9, 1, 8, 1, 6, 1], - [1, 5, 8, 0, 1, 7, 1], - [1, 2, 1, 5, 7, 10, 1], - [1, 1, 1, 1, 1, 1, 1] - ]; - // MT6 - const map2 = [ - [1, 1, 1, 1, 1, 1, 1], - [1, 5, 6, 4, 1, 10, 1], - [1, 6, 1, 9, 1, 5, 1], - [1, 8, 0, 6, 0, 8, 1], - [1, 5, 1, 10, 1, 2, 1], - [1, 9, 3, 1, 4, 9, 1], - [1, 1, 1, 1, 1, 1, 1] - ]; - // MT8 - const map3 = [ - [1, 1, 1, 1, 1, 1, 1], - [1, 5, 8, 10, 7, 2, 1], - [1, 2, 1, 5, 1, 7, 1], - [1, 3, 1, 3, 6, 4, 1], - [1, 6, 1, 6, 1, 8, 1], - [1, 10, 7, 5, 1, 5, 1], - [1, 1, 1, 1, 1, 1, 1] - ]; - // MT3 微调 - const map4 = [ - [1, 1, 1, 1, 1, 1, 1], - [1, 3, 1, 10, 7, 3, 1], - [1, 6, 1, 5, 1, 7, 1], - [1, 9, 1, 8, 1, 6, 1], - [1, 5, 8, 0, 1, 7, 1], - [1, 2, 1, 4, 7, 10, 1], - [1, 1, 1, 1, 1, 1, 1] - ]; - // MT10 - const map5 = [ - [1, 1, 1, 1, 1, 1, 1], - [1, 5, 1, 10, 1, 5, 1], - [1, 6, 7, 7, 7, 6, 1], - [1, 1, 6, 5, 6, 1, 1], - [1, 4, 5, 9, 5, 4, 1], - [1, 3, 1, 1, 1, 3, 1], - [1, 1, 1, 1, 1, 1, 1] - ]; - - // 测试自我对比 - console.log(`map1 vs map1: ${calculateVisualSimilarity(map1, map1)}`); - console.log(`map2 vs map2: ${calculateVisualSimilarity(map2, map2)}`); - console.log(`map3 vs map3: ${calculateVisualSimilarity(map3, map3)}`); - console.log(`map4 vs map4: ${calculateVisualSimilarity(map4, map4)}`); - // 两两测试 - console.log(`map1 vs map2: ${calculateVisualSimilarity(map1, map2)}`); - console.log(`map1 vs map3: ${calculateVisualSimilarity(map1, map3)}`); - console.log(`map1 vs map4: ${calculateVisualSimilarity(map1, map4)}`); - console.log(`map1 vs map5: ${calculateVisualSimilarity(map1, map5)}`); - console.log(`map2 vs map3: ${calculateVisualSimilarity(map2, map3)}`); - console.log(`map2 vs map4: ${calculateVisualSimilarity(map2, map4)}`); - console.log(`map2 vs map5: ${calculateVisualSimilarity(map2, map5)}`); - console.log(`map3 vs map4: ${calculateVisualSimilarity(map3, map4)}`); - console.log(`map3 vs map5: ${calculateVisualSimilarity(map3, map5)}`); - console.log(`map4 vs map5: ${calculateVisualSimilarity(map4, map5)}`); - // 测试交换性 - console.log(`map2 vs map1: ${calculateVisualSimilarity(map2, map1)}`); - console.log(`map4 vs map2: ${calculateVisualSimilarity(map4, map2)}`); -})(); diff --git a/ginka/common/common.py b/ginka/common/common.py deleted file mode 100644 index 2ac3c7c..0000000 --- a/ginka/common/common.py +++ /dev/null @@ -1,133 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch_geometric.nn import GCNConv, TransformerConv -from torch_geometric.utils import grid - -def batch_edge_index(B, edge_index, num_nodes_per_batch): - # 批次偏移 edge_index - edge_index = edge_index.clone() # [2, E] - batch_edge_index = [] - for i in range(B): - offset = i * num_nodes_per_batch - batch_edge_index.append(edge_index + offset) - return torch.cat(batch_edge_index, dim=1) - -class DoubleConvBlock(nn.Module): - def __init__(self, feats: tuple[int, int, int]): - super().__init__() - self.cnn = nn.Sequential( - nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate'), - nn.InstanceNorm2d(feats[1]), - nn.GELU(), - - nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate'), - nn.InstanceNorm2d(feats[2]), - nn.GELU(), - ) - - def forward(self, x): - x = self.cnn(x) - return x - -class GCNBlock(nn.Module): - def __init__(self, in_ch, hidden_ch, out_ch, w, h): - super().__init__() - self.conv1 = GCNConv(in_ch, hidden_ch) - self.conv2 = GCNConv(hidden_ch, hidden_ch) - self.conv3 = GCNConv(hidden_ch, out_ch) - self.norm1 = nn.LayerNorm(hidden_ch) - self.norm2 = nn.LayerNorm(hidden_ch) - self.norm3 = nn.LayerNorm(out_ch) - self.single_edge_index, _ = grid(h, w) # [2, E] for a single map - - def forward(self, x): - # x: [B, C, H, W] - B, C, H, W = x.shape - - # Reshape to [B * H * W, C] - x = x.permute(0, 2, 3, 1).reshape(B * H * W, C) - - # Construct batched edge index - device = x.device - edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W) - - # Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling) - # batch = torch.arange(B, device=device).repeat_interleave(H * W) - - # GCN forward - x = self.conv1(x, edge_index) - x = F.gelu(self.norm1(x)) - x = self.conv2(x, edge_index) - x = F.gelu(self.norm2(x)) - x = self.conv3(x, edge_index) - x = F.gelu(self.norm3(x)) - - # Reshape back to [B, C, H, W] - x = x.view(B, H, W, -1).permute(0, 3, 1, 2) - return x - -class TransformerGCNBlock(nn.Module): - def __init__(self, in_ch, hidden_ch, out_ch, w, h): - super().__init__() - self.conv1 = TransformerConv(in_ch, hidden_ch // 8, heads=8, concat=True) - self.conv2 = TransformerConv(hidden_ch, out_ch, heads=1) - self.norm1 = nn.LayerNorm(hidden_ch) - self.norm2 = nn.LayerNorm(out_ch) - self.single_edge_index, _ = grid(h, w) # [2, E] for a single map - - def forward(self, x): - # x: [B, C, H, W] - B, C, H, W = x.shape - - # Reshape to [B * H * W, C] - x = x.permute(0, 2, 3, 1).reshape(B * H * W, C) - - # Construct batched edge index - device = x.device - edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W) - - # Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling) - # batch = torch.arange(B, device=device).repeat_interleave(H * W) - - # GCN forward - x = self.conv1(x, edge_index) - x = F.gelu(self.norm1(x)) - x = self.conv2(x, edge_index) - x = F.gelu(self.norm2(x)) - - # Reshape back to [B, C, H, W] - x = x.view(B, H, W, -1).permute(0, 3, 1, 2) - return x - -class ConvFusionModule(nn.Module): - def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int): - super().__init__() - self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch]) - self.gcn = TransformerGCNBlock(in_ch, hidden_ch, in_ch, w, h) - self.fusion = DoubleConvBlock([in_ch*2, hidden_ch, out_ch]) - - def forward(self, x): - x1 = self.cnn(x) - x2 = self.gcn(x) - x = torch.cat([x1, x2], dim=1) - x = self.fusion(x) - return x - -class DoubleFCModule(nn.Module): - def __init__(self, in_dim, hidden_dim, out_dim): - super().__init__() - self.fc = nn.Sequential( - nn.Linear(in_dim, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.GELU(), - - nn.Linear(hidden_dim, out_dim), - nn.LayerNorm(out_dim), - nn.GELU() - ) - - def forward(self, x): - x = self.fc(x) - return x - \ No newline at end of file diff --git a/ginka/common/cond.py b/ginka/common/cond.py deleted file mode 100644 index ddd7b44..0000000 --- a/ginka/common/cond.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from .common import DoubleFCModule - -class ConditionEncoder(nn.Module): - def __init__(self, tag_dim=64, val_dim=16, hidden_dim=256, out_dim=256): - super().__init__() - self.tag_embed = DoubleFCModule(tag_dim, hidden_dim, hidden_dim) - self.val_embed = DoubleFCModule(val_dim, hidden_dim, hidden_dim) - self.stage_embed = DoubleFCModule(1, hidden_dim, hidden_dim) - self.encoder = nn.TransformerEncoder( - nn.TransformerEncoderLayer( - d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4, - batch_first=True - ), - num_layers=4 - ) - self.fusion = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.GELU(), - - nn.Linear(hidden_dim, out_dim) - ) - - def forward(self, tag, val, stage): - # tag = self.tag_embed(tag) - val = self.val_embed(val) - stage = self.stage_embed(stage) - feat = torch.stack([val, stage], dim=1) - feat = self.encoder(feat) - feat = torch.mean(feat, dim=1) - feat = self.fusion(feat) - return feat - -class ConditionInjector(nn.Module): - def __init__(self, cond_dim, out_dim): - super().__init__() - self.gamma_layer = nn.Sequential( - nn.Linear(cond_dim, out_dim) - ) - self.beta_layer = nn.Sequential( - nn.Linear(cond_dim, out_dim) - ) - - def forward(self, x, cond): - gamma = self.gamma_layer(cond).unsqueeze(2).unsqueeze(3) - beta = self.beta_layer(cond).unsqueeze(2).unsqueeze(3) - return x * gamma + beta diff --git a/ginka/critic/model.py b/ginka/critic/model.py deleted file mode 100644 index 04c2cf2..0000000 --- a/ginka/critic/model.py +++ /dev/null @@ -1,292 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.utils import spectral_norm -from torch_geometric.nn import global_max_pool, GCNConv, TransformerConv -from torch_geometric.utils import grid -from shared.constant import VISION_WEIGHT, TOPO_WEIGHT -from .vision import MinamoVisionModel -from .topo import MinamoTopoModel - -def print_memory(tag=""): - print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") - -def batch_edge_index(B, edge_index, num_nodes_per_batch): - # 批次偏移 edge_index - edge_index = edge_index.clone() # [2, E] - batch_edge_index = [] - for i in range(B): - offset = i * num_nodes_per_batch - batch_edge_index.append(edge_index + offset) - return torch.cat(batch_edge_index, dim=1) - -class DoubleConvBlock(nn.Module): - def __init__(self, feats: tuple[int, int, int]): - super().__init__() - self.cnn = nn.Sequential( - spectral_norm(nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate')), - nn.GELU(), - - spectral_norm(nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate')), - nn.GELU(), - ) - - def forward(self, x): - x = self.cnn(x) - return x - -class TransformerGCNBlock(nn.Module): - def __init__(self, in_ch, hidden_ch, out_ch, w, h): - super().__init__() - self.conv1 = TransformerConv(in_ch, hidden_ch // 8, heads=8, concat=True) - self.conv2 = TransformerConv(hidden_ch, out_ch, heads=1) - self.single_edge_index, _ = grid(h, w) # [2, E] for a single map - - def forward(self, x): - B, C, H, W = x.shape - x = x.permute(0, 2, 3, 1).reshape(B * H * W, C) - device = x.device - edge_index = batch_edge_index(B, self.single_edge_index.to(device), H * W) - x = self.conv1(x, edge_index) - x = F.gelu(x) - x = self.conv2(x, edge_index) - x = F.gelu(x) - x = x.view(B, H, W, -1).permute(0, 3, 1, 2) - return x - -class ConvFusionModule(nn.Module): - def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int): - super().__init__() - self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch]) - self.gcn = TransformerGCNBlock(in_ch, hidden_ch, in_ch, w, h) - self.fusion = DoubleConvBlock([in_ch*2, hidden_ch, out_ch]) - - def forward(self, x): - x1 = self.cnn(x) - x2 = self.gcn(x) - x = torch.cat([x1, x2], dim=1) - x = self.fusion(x) - return x - -class DoubleFCModule(nn.Module): - def __init__(self, in_dim, hidden_dim, out_dim): - super().__init__() - self.fc = nn.Sequential( - spectral_norm(nn.Linear(in_dim, hidden_dim)), - nn.GELU(), - - spectral_norm(nn.Linear(hidden_dim, out_dim)), - nn.GELU() - ) - - def forward(self, x): - x = self.fc(x) - return x - -class ConditionEncoder(nn.Module): - def __init__(self, tag_dim, val_dim, hidden_dim, out_dim): - super().__init__() - self.tag_embed = DoubleFCModule(tag_dim, hidden_dim, hidden_dim) - self.val_embed = DoubleFCModule(val_dim, hidden_dim, hidden_dim) - self.stage_embed = DoubleFCModule(1, hidden_dim, hidden_dim) - self.encoder = nn.TransformerEncoder( - nn.TransformerEncoderLayer( - d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4, - batch_first=True - ), - num_layers=4 - ) - self.fusion = nn.Sequential( - spectral_norm(nn.Linear(hidden_dim, hidden_dim)), - nn.GELU(), - - spectral_norm(nn.Linear(hidden_dim, out_dim)) - ) - - def forward(self, tag, val, stage): - tag = self.tag_embed(tag) - val = self.val_embed(val) - stage = self.stage_embed(stage) - feat = torch.stack([tag, val, stage], dim=1) - feat = self.encoder(feat) - feat = torch.mean(feat, dim=1) - feat = self.fusion(feat) - return feat - -class ConditionInjector(nn.Module): - def __init__(self, cond_dim, out_dim): - super().__init__() - self.gamma_layer = nn.Sequential( - spectral_norm(nn.Linear(cond_dim, out_dim)) - ) - self.beta_layer = nn.Sequential( - spectral_norm(nn.Linear(cond_dim, out_dim)) - ) - - def forward(self, x, cond): - gamma = self.gamma_layer(cond).unsqueeze(2).unsqueeze(3) - beta = self.beta_layer(cond).unsqueeze(2).unsqueeze(3) - return x * gamma + beta - -class CNNHead(nn.Module): - def __init__(self, in_ch): - super().__init__() - self.cnn = nn.Sequential( - spectral_norm(nn.Conv2d(in_ch, in_ch, 3)), - nn.GELU(), - - nn.AdaptiveMaxPool2d((2, 2)) - ) - self.fc = nn.Sequential( - spectral_norm(nn.Linear(in_ch*2*2, 1)) - ) - self.proj = spectral_norm(nn.Linear(256, in_ch*2*2)) - - def forward(self, x, cond): - x = self.cnn(x) - B, C, H, W = x.shape - x = x.view(B, -1) - cond = self.proj(cond) - proj = torch.sum(x * cond, dim=1, keepdim=True) - x = self.fc(x) + proj - return x - -class GCNHead(nn.Module): - def __init__(self, in_dim): - super().__init__() - self.gcn = GCNConv(in_dim, in_dim) - self.proj = spectral_norm(nn.Linear(256, in_dim)) - self.fc = nn.Sequential( - spectral_norm(nn.Linear(in_dim, 1)) - ) - - def forward(self, x, graph, cond): - x = self.gcn(x, graph.edge_index) - x = F.gelu(x) - x = global_max_pool(x, graph.batch) - cond = self.proj(cond) - proj = torch.sum(x * cond, dim=1, keepdim=True) - x = self.fc(x) + proj - return x - -class MinamoScoreHead(nn.Module): - def __init__(self, vision_dim, topo_dim): - super().__init__() - self.vision_head = CNNHead(vision_dim) - self.topo_head = GCNHead(topo_dim) - - def forward(self, vis, topo, graph, cond): - vis_score = self.vision_head(vis, cond) - topo_score = self.topo_head(topo, graph, cond) - return vis_score, topo_score - -class MinamoModel(nn.Module): - def __init__(self, tile_types=32): - super().__init__() - self.topo_model = MinamoTopoModel(tile_types) - self.vision_model = MinamoVisionModel(tile_types) - self.cond = ConditionEncoder(64, 16, 256, 256) - # 输出层 - self.head1 = MinamoScoreHead(512, 512) - self.head2 = MinamoScoreHead(512, 512) - self.head3 = MinamoScoreHead(512, 512) - - def forward(self, map, graph, stage, tag_cond, val_cond): - B, D = tag_cond.shape - stage_tensor = torch.Tensor([stage]).expand(B, 1).to(map.device) - vision = self.vision_model(map) - topo = self.topo_model(graph) - cond = self.cond(tag_cond, val_cond, stage_tensor) - if stage == 1: - vision_score, topo_score = self.head1(vision, topo, graph, cond) - elif stage == 2: - vision_score, topo_score = self.head2(vision, topo, graph, cond) - elif stage == 3: - vision_score, topo_score = self.head3(vision, topo, graph, cond) - else: - raise RuntimeError("Unknown critic stage.") - score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score - return score, vision_score, topo_score - -class MinamoHead2(nn.Module): - def __init__(self, in_ch, hidden_ch): - super().__init__() - self.conv = ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13) - self.pool = nn.AdaptiveMaxPool2d(1) - self.proj = spectral_norm(nn.Linear(256, hidden_ch)) - self.fc = spectral_norm(nn.Linear(hidden_ch, 1)) - - def forward(self, x, cond): - x = self.conv(x) - x = self.pool(x) - x = x.squeeze(3).squeeze(2) - cond = self.proj(cond) - proj = torch.sum(x * cond, dim=1, keepdim=True) - x = self.fc(x) + proj - return x - -class MinamoModel2(nn.Module): - def __init__(self, tile_types=32): - super().__init__() - self.cond = ConditionEncoder(64, 16, 256, 256) - - self.conv1 = ConvFusionModule(tile_types, 256, 256, 13, 13) - self.conv2 = ConvFusionModule(256, 512, 256, 13, 13) - self.conv3 = ConvFusionModule(256, 512, 256, 13, 13) - - self.head0 = MinamoHead2(256, 256) # 随机头的判别头 - self.head1 = MinamoHead2(256, 256) - self.head2 = MinamoHead2(256, 256) - self.head3 = MinamoHead2(256, 256) - - # self.inject1 = ConditionInjector(256, 256) - # self.inject2 = ConditionInjector(256, 256) - self.inject3 = ConditionInjector(256, 256) - - def forward(self, x, stage, tag_cond, val_cond): - B, D = tag_cond.shape - stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device) - cond = self.cond(tag_cond, val_cond, stage_tensor) - x = self.conv1(x) - # x = self.inject1(x, cond) - x = self.conv2(x) - # x = self.inject2(x, cond) - x = self.conv3(x) - x = self.inject3(x, cond) - - if stage == 0: - score = self.head0(x, cond) - elif stage == 1: - score = self.head1(x, cond) - elif stage == 2: - score = self.head2(x, cond) - elif stage == 3: - score = self.head3(x, cond) - else: - raise RuntimeError("Unknown critic stage.") - - return score - -# 检查显存占用 -if __name__ == "__main__": - input = torch.randn((1, 32, 13, 13)).cuda() - tag = torch.rand(1, 64).cuda() - val = torch.rand(1, 16).cuda() - - # 初始化模型 - model = MinamoModel2().cuda() - - print_memory("初始化后") - - # 前向传播 - output = model(input, 1, tag, val) - - print_memory("前向传播后") - - print(f"输入形状: feat={input.shape}") - print(f"输出形状: output={output.shape}") - # print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}") - # print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}") - print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}") - print(f"Head parameters: {sum(p.numel() for p in model.head1.parameters())}") - print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/critic/topo.py b/ginka/critic/topo.py deleted file mode 100644 index 43b7eaf..0000000 --- a/ginka/critic/topo.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.utils import spectral_norm -from torch_geometric.nn import GATConv, TransformerConv -from torch_geometric.data import Data - -class MinamoTopoModel(nn.Module): - def __init__( - self, tile_types=32, emb_dim=128, hidden_dim=128, out_dim=512 - ): - super().__init__() - # 传入 softmax 概率值,直接映射 - self.input_proj = nn.Sequential( - spectral_norm(nn.Linear(tile_types, emb_dim)), - nn.LeakyReLU(0.2) - ) - # 图卷积层 - self.conv1 = TransformerConv(emb_dim, hidden_dim, heads=8) - self.conv2 = TransformerConv(hidden_dim*8, hidden_dim, heads=8) - self.conv3 = TransformerConv(hidden_dim*8, out_dim, heads=1) - - def forward(self, graph: Data): - x = self.input_proj(graph.x) - - x = self.conv1(x, graph.edge_index) - x = F.leaky_relu(x, 0.2) - - x = self.conv2(x, graph.edge_index) - x = F.leaky_relu(x, 0.2) - - x = self.conv3(x, graph.edge_index) - x = F.leaky_relu(x, 0.2) - - return x - \ No newline at end of file diff --git a/ginka/critic/vision.py b/ginka/critic/vision.py deleted file mode 100644 index 6e7b847..0000000 --- a/ginka/critic/vision.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.utils import spectral_norm - -class MinamoVisionModel(nn.Module): - def __init__(self, in_ch=32, out_ch=512): - super().__init__() - self.conv = nn.Sequential( - spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3)), # 11*11 - nn.LeakyReLU(0.2), - - spectral_norm(nn.Conv2d(in_ch*2, in_ch*8, 3)), #9*9 - nn.LeakyReLU(0.2), - - spectral_norm(nn.Conv2d(in_ch*8, out_ch, 3)), # 7*7 - nn.LeakyReLU(0.2), - ) - - def forward(self, x): - x = self.conv(x) - return x diff --git a/ginka/dataset.py b/ginka/dataset.py index b4a6873..512d0a2 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -8,13 +8,6 @@ import torch import torch.nn.functional as F from typing import List -STAGE1_MASK = [0, 1, 2, 29, 30] -STAGE1_REMOVE = list(range(3, 29)) -STAGE2_MASK = [3, 4, 5, 6, 26, 27, 28] -STAGE2_REMOVE = list(range(7, 26)) -STAGE3_MASK = list(range(7, 26)) -STAGE3_REMOVE = [] - def load_data(path: str): with open(path, 'r', encoding="utf-8") as f: data = json.load(f) @@ -24,220 +17,6 @@ def load_data(path: str): data_list.append(value) return data_list - -def load_minamo_gan_data(data: list): - res = list() - for one in data: - res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True)) - return res - -def apply_curriculum_remove( - maps: torch.Tensor, - remove_classes: List[int], # 要移除的类别索引 -): - C, H, W = maps.shape - device = maps.device - removed_maps = maps.clone() - - remove_mask = removed_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0 - removed_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0 - removed_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地” - - return removed_maps.to(device) - -def apply_curriculum_mask( - maps: torch.Tensor, # [C, H, W] - mask_classes: List[int], # 要遮挡的类别索引 - remove_classes: List[int], # 要移除的类别索引 - mask_ratio: float # 遮挡比例 0~1 -) -> torch.Tensor: - C, H, W = maps.shape - masked_maps = maps.clone() - - # Step 1: 移除不需要的类别(全设为 0 类) - if remove_classes: - remove_mask = masked_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0 - masked_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0 - masked_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地” - - removed_maps = masked_maps.clone() - - # Step 2: 对指定类别随机遮挡 - for cls in mask_classes: - cls_mask = masked_maps[cls] > 0 # 目标类别的像素布尔掩码 [H, W] - indices = cls_mask.nonzero(as_tuple=False) # 所有该类像素坐标 - num_mask = int(len(indices) * mask_ratio) - if num_mask > 0: - selected = indices[torch.randperm(len(indices))[:num_mask]] - masked_maps[cls, selected[:, 0], selected[:, 1]] = 0 - masked_maps[0, selected[:, 0], selected[:, 1]] = 1 # 置为“空地” - - return removed_maps, masked_maps - -def apply_curriculum_wall_mask( - maps: torch.Tensor, # [C, H, W] - mask_classes: List[int], # 要遮挡的类别索引 - remove_classes: List[int], # 要移除的类别索引 - mask_ratio: float # 遮挡比例 0~1 -) -> torch.Tensor: - C, H, W = maps.shape - masked_maps = maps.clone() - - # Step 1: 移除不需要的类别(全设为 0 类) - if remove_classes: - remove_mask = masked_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0 - masked_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0 - masked_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地” - - removed_maps = masked_maps.clone() - - area = H * W * mask_ratio - l = math.floor(math.sqrt(area)) - nx = random.randint(0, W - l) - ny = random.randint(0, H - l) - masked_maps[mask_classes, nx:nx+l, ny:ny+l] = 0 - masked_maps[0, nx:nx+l, ny:ny+l] = 1 - - return removed_maps, masked_maps - -class GinkaWGANDataset(Dataset): - def __init__(self, data_path: str, device): - self.data = load_data(data_path) # 自定义数据加载函数 - self.device = device - self.train_stage = 1 - self.mask_ratio1 = 0.1 - self.mask_ratio2 = 0.1 - self.mask_ratio3 = 0.1 - - def __len__(self): - return len(self.data) - - def handle_stage1(self, target, tag_cond, val_cond): - # 课程学习第一阶段,蒙版填充 - removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1) - removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2) - removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3) - rand = torch.rand(32, 32, 32, device=target.device) - - return { - "rand": rand, - "real0": removed1, - "real1": removed1, - "masked1": masked1, - "real2": removed2, - "masked2": masked2, - "real3": removed3, - "masked3": masked3, - "tag_cond": tag_cond, - "val_cond": val_cond - } - - def handle_stage2(self, target, tag_cond, val_cond): - # 课程学习第二阶段,完全随机蒙版 - removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) - # 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可 - removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1)) - removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1)) - rand = torch.rand(32, 32, 32, device=target.device) - - return { - "rand": rand, - "real0": removed1, - "real1": removed1, - "masked1": masked1, - "real2": removed2, - "masked2": masked2, - "real3": removed3, - "masked3": masked3, - "tag_cond": tag_cond, - "val_cond": val_cond - } - - def handle_stage3(self, target, tag_cond, val_cond): - # 第三阶段,联合生成,输入随机蒙版 - removed1 = apply_curriculum_remove(target, STAGE1_REMOVE) - removed2 = apply_curriculum_remove(target, STAGE2_REMOVE) - removed3 = apply_curriculum_remove(target, STAGE3_REMOVE) - rand = torch.rand(32, 32, 32, device=target.device) - - return { - "rand": rand, - "real0": removed1, - "real1": removed1, - "masked1": removed1, - "real2": removed2, - "masked2": torch.zeros_like(target), - "real3": removed3, - "masked3": torch.zeros_like(target), - "tag_cond": tag_cond, - "val_cond": val_cond - } - - def handle_stage4(self, target, tag_cond, val_cond): - # 第四阶段,完全随机输入 - removed1 = apply_curriculum_remove(target, STAGE1_REMOVE) - removed2 = apply_curriculum_remove(target, STAGE2_REMOVE) - removed3 = apply_curriculum_remove(target, STAGE3_REMOVE) - rand = torch.rand(32, 32, 32, device=target.device) - - return { - "rand": rand, - "real0": removed1, - "real1": removed1, - "masked1": rand, - "real2": removed2, - "masked2": torch.zeros_like(target), - "real3": removed3, - "masked3": torch.zeros_like(target), - "tag_cond": tag_cond, - "val_cond": val_cond - } - - def __getitem__(self, idx): - item = self.data[idx] - - target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] - C, H, W = target.shape - tag_cond = torch.FloatTensor(item['tag']) - val_cond = torch.FloatTensor(item['val']) - val_cond[9] = val_cond[9] / H / W - val_cond[10] = val_cond[10] / H / W - - if self.train_stage == 1: - return self.handle_stage1(target, tag_cond, val_cond) - - elif self.train_stage == 2: - return self.handle_stage2(target, tag_cond, val_cond) - - elif self.train_stage == 3: - return self.handle_stage3(target, tag_cond, val_cond) - - elif self.train_stage == 4: - return self.handle_stage4(target, tag_cond, val_cond) - - raise RuntimeError(f"Invalid train stage: {self.train_stage}") - -class GinkaRNNDataset(Dataset): - def __init__(self, data_path: str, device): - self.data = load_data(data_path) # 自定义数据加载函数 - self.device = device - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - item = self.data[idx] - - target = torch.LongTensor(item['map']) # [H, W] - H, W = target.shape - tag_cond = torch.FloatTensor(item['tag']) - val_cond = torch.FloatTensor(item['val']) - - return { - "tag_cond": tag_cond, - "val_cond": val_cond, - "target_map": target - } class GinkaMaskGITDataset(Dataset): def __init__(self, data_path: str, device): diff --git a/ginka/generator/gcn.py b/ginka/generator/gcn.py deleted file mode 100644 index cb4d24a..0000000 --- a/ginka/generator/gcn.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch_geometric.nn import GCNConv, TransformerConv -from torch_geometric.utils import grid -from ..common.cond import ConditionInjector - -# 考虑使用 GCN 作为生成器主路径,暂时先留着 - -class GCNBlock(nn.Module): - def __init__(self, feats: tuple[int, int, int]): - super().__init__() - self.conv1 = GCNConv(feats[0], feats[1]) - self.conv2 = GCNConv(feats[1], feats[2]) - - self.norm1 = nn.LayerNorm(feats[1]) - self.norm2 = nn.LayerNorm(feats[2]) - - def forward(self, x, edge_index): - x = self.conv1(x, edge_index) - x = F.elu(self.norm1(x)) - - x = self.conv2(x, edge_index) - x = F.elu(self.norm2(x)) - return x - -class GinkaGCNEncoder(nn.Module): - def __init__(self): - super().__init__() - -class GinkaGCNDecoder(nn.Module): - def __init__(self): - super().__init__() - -class GinkaGCNModel(nn.Module): - def __init__(self): - super().__init__() \ No newline at end of file diff --git a/ginka/generator/input.py b/ginka/generator/input.py deleted file mode 100644 index c6fcef6..0000000 --- a/ginka/generator/input.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -import torch.nn as nn -from ..common.common import ConvFusionModule -from ..common.cond import ConditionInjector -from .unet import GinkaEncoderPath, GinkaDecoderPath - -class RandomInputHead(nn.Module): - def __init__(self): - super().__init__() - self.enc = GinkaEncoderPath(32, 32) - self.dec = GinkaDecoderPath(32) - self.out_conv = nn.Sequential( - nn.AdaptiveMaxPool2d((15, 15)), - nn.Conv2d(32, 64, 3, padding=0), - nn.InstanceNorm2d(64), - nn.GELU(), - - nn.Conv2d(64, 32, 1), - ) - - def forward(self, x, cond): - x1, x2, x3, x4 = self.enc(x, cond) - x = self.dec(x1, x2, x3, x4, cond) - x = self.out_conv(x) - return x - -class InputUpsample(nn.Module): - def __init__(self, in_ch, hidden_ch=64, out_ch=64): - super().__init__() - self.net = nn.Sequential( - ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13), - - nn.Upsample(scale_factor=2, mode='nearest'), # 13x13 → 26x26 - ConvFusionModule(hidden_ch, hidden_ch, hidden_ch, 26, 26), - - nn.Upsample(size=(32, 32), mode='nearest'), # 26x26 → 32x32 - ConvFusionModule(hidden_ch, hidden_ch, out_ch, 32, 32), - ) - - def forward(self, x): # [B, C, 13, 13] - x = self.net(x) # [B, C, 32, 32] - return x - -class GinkaInput(nn.Module): - def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)): - super().__init__() - self.out_size = out_size - self.upsample = InputUpsample(in_ch, in_ch*2, out_ch) - self.enc = ConvFusionModule(out_ch, out_ch*2, out_ch, out_size[0], out_size[1]) - self.inject1 = ConditionInjector(256, out_ch) - self.inject2 = ConditionInjector(256, out_ch) - - def forward(self, x, cond): - x = self.upsample(x) - x = self.inject1(x, cond) - x = self.enc(x) - x = self.inject2(x, cond) - return x diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py deleted file mode 100644 index 66a68ee..0000000 --- a/ginka/generator/loss.py +++ /dev/null @@ -1,420 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch_geometric.data import Data - -CLASS_NUM = 32 -ILLEGAL_MAX_NUM = 30 - -STAGE_CHANGEABLE = [ - [], - [0, 1, 2, 29, 30], - [3, 4, 5, 6, 26, 27, 28], - list(range(7, 26)) -] - -STAGE_ALLOWED = [ - [], - STAGE_CHANGEABLE[1], - [*STAGE_CHANGEABLE[1], *STAGE_CHANGEABLE[2]], - [*STAGE_CHANGEABLE[1], *STAGE_CHANGEABLE[2], *STAGE_CHANGEABLE[3]] -] - -DENSITY_MAP = [ - [1, *list(range(3, 30))], - [1], - [2], - [3, 4, 5, 6], - [26, 27, 28], - list(range(7, 26)), - list(range(10, 19)), - [19, 20, 21, 22], - [7, 8, 9], - [23, 24, 25], - [29, 30] -] - -DENSITY_WEIGHTS = [ - 1, - 1.5, - 0.5, - 5, - 4, - 3, - 3, - 3, - 5, - 10, - 20 -] - -DENSITY_STAGE = [ - [], - [1, 2], - [1, 2, 3, 4], - list(range(0, 10)) -] - -def get_not_allowed(classes: list[int], include_illegal=False): - res = list() - for num in range(0, CLASS_NUM): - if not num in classes: - if num > ILLEGAL_MAX_NUM: - if include_illegal: - res.append(num) - else: - res.append(num) - - return res - -def inner_constraint_loss(pred: torch.Tensor, allowed=list(range(0, 30))): - """限定内部允许出现的图块种类 - - Args: - pred (torch.Tensor): 模型输出的概率分布 [B, C, H, W] - allowed (list, optional): 在地图中部(除最外圈)允许出现的图块种类 - """ - B, C, H, W = pred.shape - - # 创建内部 mask [H, W] - mask = torch.ones((H, W), dtype=torch.bool, device=pred.device) - mask[0, :] = False # 第一行 - mask[-1, :] = False # 最后一行 - mask[:, 0] = False # 第一列 - mask[:, -1] = False # 最后一列 - - # 提取所有允许和不允许类别的概率和 [B, H, W] - unallowed_probs = pred[:, get_not_allowed(allowed, include_illegal=True), :, :].sum(dim=1) - - # 获取外圈区域允许类别的概率 [B, N_pixels] - inner_unallowed = unallowed_probs[:, mask] - - target = torch.zeros_like(inner_unallowed) - loss_unallowed = F.mse_loss(inner_unallowed, target) - - return loss_unallowed - -def _create_distance_kernel(size): - """生成一个环状衰减核""" - y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij') - center = size // 2 - dist = torch.sqrt((x - center)**2 + (y - center)**2) - kernel = 1 / (dist + 1) - kernel /= kernel.sum() # 归一化 - return kernel.unsqueeze(0).unsqueeze(0) # [1,1,H,W] - -def entrance_constraint_loss( - pred: torch.Tensor, - entrance_classes=[29, 30], - min_distance=9, - presence_threshold=0.8, - lambda_presence=1.0, - lambda_spacing=0.5 -): - """ - 入口约束损失函数 - - 参数: - pred: 模型输出的概率分布 [B, C, H, W] - entrance_classes: 入口类别列表 - min_distance: 最小间隔距离(对应卷积核尺寸) - presence_threshold: 存在性概率阈值 - lambda_presence: 存在性损失权重 - lambda_spacing: 间距约束权重 - - 返回: - total_loss: 综合损失值 - """ - B, C, H, W = pred.shape - entrance_probs = pred[:, entrance_classes, :, :].sum(dim=1) # [B, H, W] - - # 计算存在性损失:鼓励至少有一个高置信度入口 - max_per_sample = entrance_probs.view(B, -1).max(dim=1)[0] # [B, H*W] -> [B, 1] - presence_loss = F.relu(presence_threshold - max_per_sample).mean() - - # 生成空间权重掩码(中心衰减) - y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij') - center_weight = 1 - torch.sqrt(((x-W//2)/W*2)**2 + ((y-H//2)/H*2)**2) - center_weight = center_weight.clamp(0,1).to(pred.device) # [H,W] - - # 概率密度感知的间距计算 - kernel = _create_distance_kernel(min_distance) # 自定义函数生成权重核 - kernel = kernel.to(pred.device) - density_map = F.conv2d(entrance_probs.unsqueeze(1), kernel, padding=min_distance-1) - - spacing_loss = density_map.mean() - - # 区域加权综合损失 - total_loss = ( - lambda_presence * presence_loss + - lambda_spacing * (spacing_loss * center_weight).mean() - ) - return total_loss - -def input_head_illegal_loss(input_map, allowed_classes=[0, 1, 2]): - C = input_map.shape[1] - unallowed = get_not_allowed(allowed_classes, include_illegal=True) - illegal = input_map[:, unallowed, :, :] - penalty = F.l1_loss(illegal, torch.zeros_like(illegal, device=illegal.device)) - - return penalty - -def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=[1, 2]): - wall_prob = input_map[:, wall_class] # [B, H, W] - wall_ratio = wall_prob.mean() # 计算平均墙体占比 - wall_penalty = torch.clamp(wall_ratio - max_wall_ratio, min=0.0) # 超过则惩罚 - - return wall_penalty - -def compute_multi_density_loss(probs, target_densities, tile_list): - """ - pred: [B, C, H, W] - target_densities: [B, N] - N 个目标类别密度 - class_indices: [N] - 对应类别通道索引 - """ - losses = [] - for i, classes in enumerate(DENSITY_MAP): - class_map = probs[:, classes, :, :] - pred_density = torch.mean(class_map, dim=(1, 2, 3)) - if i in tile_list: - loss = F.mse_loss(pred_density, target_densities[:, i]) - losses.append(loss * DENSITY_WEIGHTS[i]) - return sum(losses) - -# 对图像数据进行插值 -def interpolate_data(real_data, fake_data, epsilon): - return epsilon * real_data + (1 - epsilon) * fake_data - -# 对节点特征进行插值,但保持边连接关系不变 -def interpolate_graph_features(real_graph, fake_graph, epsilon=0.5): - # 插值节点特征 - x_real, x_fake = real_graph.x, fake_graph.x - x_interp = epsilon * x_real + (1 - epsilon) * x_fake - - # 保持边连接关系和边特征不变 - edge_index_interp = real_graph.edge_index # 保持边连接关系 - edge_attr_interp = real_graph.edge_attr # 如果有边特征,保持不变 - - return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp) - -def js_divergence(p, q, eps=1e-6, softmax=False): - if softmax: - p = F.softmax(p, dim=1) - q = F.softmax(q, dim=1) - # softmax 后变成概率分布 - m = 0.5 * (p + q) - - # log_softmax 以供 kl_div 使用 - log_p = torch.log(p + eps) - log_q = torch.log(q + eps) - log_m = torch.log(m + eps) - - kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m) - kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m) - - return torch.log1p(0.5 * (kl_pm + kl_qm)) - -def immutable_penalty_loss( - pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int] -) -> torch.Tensor: - """ - 惩罚模型修改不可更改区域的损失。 - - Args: - input: 模型输出 [B, C, H, W],概率分布 (softmax 前) - target: 原始输入图 [B, C, H, W],概率分布 (softmax 前) - modifiable_classes: 允许被修改的类别列表 - """ - not_allowed = get_not_allowed(modifiable_classes, include_illegal=True) - input_mask = pred[:, not_allowed, :, :] - with torch.no_grad(): - target_mask = torch.argmax(input[:, not_allowed, :, :], dim=1) - target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float() - - # 差异区域(模型试图改变的地方) - penalty = torch.clamp(F.cross_entropy(input_mask, target_mask) - 0.2, min=0) - - return penalty - -def modifiable_penalty_loss( - probs: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int] -) -> torch.Tensor: - target_modifiable = input[:, modifiable_classes, :, :] - pred_modifiable = probs[:, modifiable_classes, :, :] - existed = torch.clamp(target_modifiable - pred_modifiable, min=0.0, max=1.0) - penalty = F.mse_loss(existed, torch.zeros_like(existed, device=existed.device)) - - return penalty - -def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]): - not_allowed = get_not_allowed(legal_classes, include_illegal=True) - input_mask = pred[:, not_allowed, :, :] - target = torch.zeros_like(input_mask) - penalty = F.cross_entropy(input_mask, target) - return penalty - -class WGANGinkaLoss: - def __init__(self, lambda_gp=100, weight=[1, 0.4, 20, 0.2, 0.2, 0.05, 0.4]): - # weight: - # 1. 判别器损失及图块维持损失(可修改部分的已有内容不可修改) - # 2. CE 损失 - # 3. 不可修改类型损失和非法图块损失 - # 4. 图块类型损失 - # 5. 入口存在性损失 - # 6. 多样性损失 - # 7. 密度损失 - self.lambda_gp = lambda_gp # 梯度惩罚系数 - self.weight = weight - - def compute_gradient_penalty(self, critic, stage, real_data, fake_data, tag_cond, val_cond): - # 进行插值 - batch_size = real_data.size(0) - epsilon_data = torch.rand(batch_size, 1, 1, 1, device=real_data.device) - interp_data = interpolate_data(real_data, fake_data, epsilon_data).to(real_data.device) - - # 对图像进行反向传播并计算梯度 - interp_data.requires_grad_() - - d_score = critic(interp_data, stage, tag_cond, val_cond) - - # 计算梯度 - grad = torch.autograd.grad( - outputs=d_score, inputs=interp_data, - grad_outputs=torch.ones_like(d_score), - create_graph=True, retain_graph=True, only_inputs=True - )[0] - - # 计算梯度的 L2 范数 - grad_norm = grad.reshape(batch_size, -1).norm(2, dim=1) - # 计算梯度惩罚项 - gp_loss = ((grad_norm - 1.0) ** 2).mean() - # print(grad_norm_topo.mean().item(), grad_norm_vis.mean().item()) - - return gp_loss - - def discriminator_loss( - self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor, - tag_cond: torch.Tensor, val_cond: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - """ 判别器损失函数 """ - fake_data = F.softmax(fake_data, dim=1) - real_scores = critic(real_data, stage, tag_cond, val_cond) - fake_scores = critic(fake_data, stage, tag_cond, val_cond) - - # Wasserstein 距离 - d_loss = fake_scores.mean() - real_scores.mean() - grad_loss = self.compute_gradient_penalty(critic, stage, real_data, fake_data, tag_cond, val_cond) - - total_loss = d_loss + self.lambda_gp * grad_loss - - return total_loss, d_loss - - def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input, tag_cond, val_cond) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ 生成器损失函数 """ - probs_fake = F.softmax(fake, dim=1) - - fake_scores = critic(probs_fake, stage, tag_cond, val_cond) - minamo_loss = -torch.mean(fake_scores) - ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小 - immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage]) - constraint_loss = inner_constraint_loss(probs_fake) - density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage]) - - fake_a, fake_b = fake.chunk(2, dim=0) - - losses = [ - minamo_loss * self.weight[0], - ce_loss * self.weight[1], - immutable_loss * self.weight[2], - constraint_loss * self.weight[3], - -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], - density_loss * self.weight[6], - ] - - if stage == 1: - # 第一个阶段检查入口存在性 - entrance_loss = entrance_constraint_loss(probs_fake) - losses.append(entrance_loss * self.weight[4]) - - return sum(losses), ce_loss - - def generator_loss_total(self, critic, stage, fake, tag_cond, val_cond) -> torch.Tensor: - probs_fake = F.softmax(fake, dim=1) - - fake_scores = critic(probs_fake, stage, tag_cond, val_cond) - minamo_loss = -torch.mean(fake_scores) - illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage]) - constraint_loss = inner_constraint_loss(probs_fake) - density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage]) - - fake_a, fake_b = fake.chunk(2, dim=0) - - losses = [ - minamo_loss * self.weight[0], - illegal_loss * self.weight[2], - constraint_loss * self.weight[3], - -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], - density_loss * self.weight[6], - ] - - if stage == 1: - # 第一个阶段检查入口存在性 - entrance_loss = entrance_constraint_loss(probs_fake) - losses.append(entrance_loss * self.weight[4]) - - return sum(losses) - - def generator_loss_total_with_input(self, critic, stage, fake, input, tag_cond, val_cond) -> torch.Tensor: - probs_fake = F.softmax(fake, dim=1) - - fake_scores = critic(probs_fake, stage, tag_cond, val_cond) - minamo_loss = -torch.mean(fake_scores) - immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage]) - constraint_loss = inner_constraint_loss(probs_fake) - density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage]) - - fake_a, fake_b = fake.chunk(2, dim=0) - - losses = [ - minamo_loss * self.weight[0], - immutable_loss * self.weight[2], - constraint_loss * self.weight[3], - -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], - density_loss * self.weight[6], - ] - - if stage == 1: - # 第一个阶段检查入口存在性 - entrance_loss = entrance_constraint_loss(probs_fake) - losses.append(entrance_loss * self.weight[4]) - - return sum(losses) - - def generator_input_head_loss(self, critic, probs: torch.Tensor, tag_cond, val_cond) -> torch.Tensor: - head_scores = -torch.mean(critic(probs, 0, tag_cond, val_cond)) - probs_a, probs_b = probs.chunk(2, dim=0) - - losses = [ - head_scores, - input_head_illegal_loss(probs) * 50, - -js_divergence(probs_a, probs_b, softmax=False) * 0.5 - ] - - return sum(losses) - -class RNNGinkaLoss: - def __init__(self, num_classes, device): - self.num_classes = num_classes - weight = torch.ones(self.num_classes) - weight[0] = 0.3 - weight[1] = 0.5 - self.weight = weight.to(device) - pass - - def rnn_loss(self, fake, target): - """ - fake: [B, C, H, W] - target: [B, H, W] - """ - target = F.one_hot(target, num_classes=self.num_classes).float().permute(0, 3, 1, 2) - return F.cross_entropy(fake, target, label_smoothing=0.1) diff --git a/ginka/generator/model.py b/ginka/generator/model.py deleted file mode 100644 index 8244bd7..0000000 --- a/ginka/generator/model.py +++ /dev/null @@ -1,66 +0,0 @@ -import time -import torch -import torch.nn as nn -import torch.nn.functional as F -from .unet import GinkaUNet -from .output import GinkaOutput -from .input import GinkaInput, RandomInputHead -from ..common.cond import ConditionEncoder - -def print_memory(tag=""): - print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") - -class GinkaModel(nn.Module): - def __init__(self, base_ch=64, out_ch=32): - """Ginka Model 模型定义部分 - """ - super().__init__() - self.head = RandomInputHead() - self.cond = ConditionEncoder(64, 16, 256, 256) - self.input = GinkaInput(32, 64, (13, 13), (32, 32)) - self.unet = GinkaUNet(64, base_ch, base_ch) - self.output = GinkaOutput(base_ch, out_ch, (13, 13)) - - def forward(self, x, stage, tag_cond, val_cond): - B, D = tag_cond.shape - stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device) - cond = self.cond(tag_cond, val_cond, stage_tensor) - if stage == 0: - x = self.head(x, cond) - else: - x = self.input(x, cond) - x = self.unet(x, cond) - x = self.output(x, stage, cond) - return x - -# 检查显存占用 -if __name__ == "__main__": - input = torch.rand(1, 32, 32, 32).cuda() - tag = torch.rand(1, 64).cuda() - val = torch.rand(1, 16).cuda() - - # 初始化模型 - model = GinkaModel().cuda() - - print_memory("初始化后") - - # 前向传播 - start = time.perf_counter() - fake0 = model(input, 0, tag, val) - fake1 = model(F.softmax(fake0, dim=1), 1, tag, val) - fake2 = model(F.softmax(fake1, dim=1), 1, tag, val) - fake3 = model(F.softmax(fake2, dim=1), 1, tag, val) - end = time.perf_counter() - - print_memory("前向传播后") - - print(f"推理耗时: {end - start}") - print(f"输入形状: feat={input.shape}") - print(f"输出形状: output={fake3.shape}") - print(f"Random parameters: {sum(p.numel() for p in model.head.parameters())}") - print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}") - print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}") - print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}") - print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}") - print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") - \ No newline at end of file diff --git a/ginka/generator/output.py b/ginka/generator/output.py deleted file mode 100644 index 25e038c..0000000 --- a/ginka/generator/output.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch -import torch.nn as nn -from ..common.common import ConvFusionModule -from ..common.cond import ConditionInjector - -class StageHead(nn.Module): - def __init__(self, in_ch, out_ch, out_size=(13, 13)): - super().__init__() - self.dec1 = ConvFusionModule(in_ch, in_ch*2, in_ch*2, 32, 32) - self.dec2 = ConvFusionModule(in_ch*2, in_ch*2, in_ch*2, 32, 32) - self.pool = nn.Sequential( - ConvFusionModule(in_ch*2, in_ch*2, in_ch*2, 32, 32), - ConvFusionModule(in_ch*2, in_ch*2, in_ch, 32, 32), - - nn.AdaptiveMaxPool2d(out_size), - nn.Conv2d(in_ch, out_ch, 1) - ) - self.inject1 = ConditionInjector(256, in_ch*2) - self.inject2 = ConditionInjector(256, in_ch*2) - - def forward(self, x, cond): - x = self.dec1(x) - x = self.inject1(x, cond) - x = self.dec2(x) - x = self.inject2(x, cond) - x = self.pool(x) - return x - -class GinkaOutput(nn.Module): - def __init__(self, in_ch=64, out_ch=32, out_size=(13, 13)): - super().__init__() - self.head1 = StageHead(in_ch, out_ch, out_size) - self.head2 = StageHead(in_ch, out_ch, out_size) - self.head3 = StageHead(in_ch, out_ch, out_size) - - def forward(self, x, stage, cond): - if stage == 1: - x = self.head1(x, cond) - elif stage == 2: - x = self.head2(x, cond) - elif stage == 3: - x = self.head3(x, cond) - else: - raise RuntimeError("Unknown generate stage.") - return x diff --git a/ginka/generator/rnn.py b/ginka/generator/rnn.py deleted file mode 100644 index 7a5a409..0000000 --- a/ginka/generator/rnn.py +++ /dev/null @@ -1,258 +0,0 @@ -import time -import random -import torch -import torch.nn as nn -import torch.nn.functional as F - -class RNNConditionEncoder(nn.Module): - def __init__(self, val_dim=16, output_dim=256, width=13, height=13): - super().__init__() - - # 条件编码 - self.val_fc = nn.Sequential( - nn.Linear(val_dim, output_dim * 4), - nn.LayerNorm(output_dim * 4), - nn.GELU(), - ) - self.fusion = nn.Sequential( - nn.Linear(output_dim * 4, output_dim) - ) - - def forward(self, val_cond: torch.Tensor): - val_hidden = self.val_fc(val_cond) - return self.fusion(val_hidden) - -class GinkaMapPatch(nn.Module): - def __init__(self, tile_classes=32, width=13, height=13): - super().__init__() - - # 地图局部卷积,用于捕获局部结构信息 - - self.width = width - self.height = height - self.tile_classes = 32 - - self.patch_cnn = nn.Sequential( - nn.Conv2d(tile_classes + 1, 256, 3, padding=1), - nn.Dropout(0.2), - nn.BatchNorm2d(256), - nn.GELU(), - - nn.Conv2d(256, 512, 3), - nn.Dropout(0.2), - nn.BatchNorm2d(512), - nn.GELU(), - - nn.Flatten() - ) - self.fc = nn.Linear(512 * 3 * 3, 256) - - def forward(self, map: torch.Tensor, x: int, y: int): - """ - map: [B, H, W] - """ - B, H, W = map.shape - mask = torch.zeros([B, 5, 5]).to(map.device) - result = torch.zeros([B, 5, 5], dtype=torch.long).to(map.device) - left = x - 2 if x >= 2 else 0 - right = x + 3 if x < self.width - 2 else self.width - top = y - 4 if y >= 4 else 0 - bottom = y + 1 - - res_left = left - (x - 2) - res_right = right - (x + 3) + 5 - res_top = top - (y - 4) - res_bottom = 5 - - result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right] - # 没画到的地方要置为 0 - result[:, 4, 2] = 0 - result[:, 4, 3] = 0 - result[:, 4, 4] = 0 - mask[:, res_top:res_bottom, res_left:res_right] = 1 - mask[:, 4, 2] = 0 - mask[:, 4, 3] = 0 - mask[:, 4, 4] = 0 - masked_result = torch.zeros([B, self.tile_classes + 1, 5, 5]).to(map.device) - masked_result[:, 0:32] = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float() - masked_result[:, 32] = mask - - feat = self.patch_cnn(masked_result) - feat = self.fc(feat) - return feat - -class GinkaTileEmbedding(nn.Module): - def __init__(self, tile_classes=32, embed_dim=256): - super().__init__() - - # 图块编码,上一次画的图块 - - self.embedding = nn.Embedding(tile_classes, embed_dim) - - def forward(self, tile: torch.Tensor): - return self.embedding(tile) - -class GinkaPosEmbedding(nn.Module): - def __init__(self, width=13, height=13, embed_dim=256): - super().__init__() - - # 位置编码 - - self.width = width - self.height = height - - self.row_embedding = nn.Embedding(width, embed_dim) - self.col_embedding = nn.Embedding(height, embed_dim) - - def forward(self, x: torch.Tensor, y: torch.Tensor): - row = self.row_embedding(x).squeeze(1) - col = self.col_embedding(y).squeeze(1) - - return row, col - -class GinkaInputFusion(nn.Module): - def __init__(self, d_model=256): - super().__init__() - - # 使用 Transformer 进行信息整合 - - self.transformer = nn.TransformerEncoder( - nn.TransformerEncoderLayer( - d_model=d_model, nhead=2, dim_feedforward=d_model*2, batch_first=True, - dropout=0.2 - ), - num_layers=4 - ) - - def forward( - self, tile_embed: torch.Tensor, cond_vec: torch.Tensor, - row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor - ): - """ - tile_embed: [B, 256] - cond_vec: [B, 256] - row_embed: [B, 256] - col_embed: [B, 256] - patch_vec: [B, 256] - """ - vec = torch.stack([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1) - feat = self.transformer(vec) - return feat[:, 0] - -class GinkaRNN(nn.Module): - def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512): - super().__init__() - - # GRU - self.gru = nn.GRUCell(input_dim, hidden_dim) - self.drop = nn.Dropout(0.2) - self.fc = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.GELU(), - - nn.Linear(hidden_dim, tile_classes) - ) - - def forward(self, feat_fusion: torch.Tensor, hidden: torch.Tensor): - """ - feat_fusion: [B, input_dim] - hidden: [B, hidden_dim] - """ - hidden = self.drop(self.gru(feat_fusion, hidden)) - logits = self.fc(hidden) - return logits, hidden - -class GinkaRNNModel(nn.Module): - def __init__(self, device: torch.device, start_tile=31, width=13, height=13): - super().__init__() - - self.device = device - self.width = width - self.height = height - self.start_tile = start_tile - - self.rnn_hidden = 512 - self.tile_classes = 32 - - # 模型结构 - self.cond = RNNConditionEncoder() - self.tile_embedding = GinkaTileEmbedding(tile_classes=self.tile_classes) - self.pos_embedding = GinkaPosEmbedding() - self.map_patch = GinkaMapPatch(tile_classes=self.tile_classes) - self.feat_fusion = GinkaInputFusion() - self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden) - - def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self_probility=0): - """ - val_cond: [B, val_dim] - target_map: [B, H, W] - use_self: 是否使用自己生成的上一步结果执行下一步 - """ - B, C = val_cond.shape - - # 张量声明 - now_tile = torch.LongTensor([self.start_tile]).to(self.device).expand(B) - - map = torch.zeros([B, self.height, self.width], dtype=torch.int32).to(self.device) - output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device) - hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device) - - # 条件编码,全局,所以只用一次 - cond = self.cond(val_cond) - - for y in range(0, self.height): - for x in range(0, self.width): - x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1) - y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1) - # 位置编码、图块编码、地图局部编码 - tile_embed = self.tile_embedding(now_tile) - row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor) - use_self = random.random() < use_self_probility - map_patch = self.map_patch(map if use_self else target_map, x, y) - # 编码特征融合 - feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch) - # RNN 输出 - logits, h = self.rnn(feat, hidden) - # 处理输出 - output_logits[:, y, x] = logits[:] - hidden = h - tile_id = torch.argmax(logits, dim=1).detach() - map[:, y, x] = tile_id[:] - now_tile = tile_id if use_self else target_map[:, y, x].detach() - - return output_logits.permute(0, 3, 1, 2), map - -def print_memory(device, tag=""): - if torch.cuda.is_available(): - print(f"{tag} | 当前显存: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated(device) / 1024**2:.2f} MB") - else: - print("当前设备不支持 cuda.") - -if __name__ == "__main__": - device = torch.device("cpu") - - input = torch.randint(0, 32, [1, 13, 13]).to(device) - cond = torch.rand(1, 16).to(device) - - # 初始化模型 - model = GinkaRNNModel("cpu").to(device) - - print_memory("初始化后") - - # 前向传播 - start = time.perf_counter() - fake_logits, fake_map = model(cond, input, False) - end = time.perf_counter() - - print_memory("前向传播后") - - print(f"推理耗时: {end - start}") - print(f"输出形状: fake_logits={fake_logits.shape}, fake_map={fake_map.shape}") - print(f"Condition Encoder parameters: {sum(p.numel() for p in model.cond.parameters())}") - print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}") - print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}") - print(f"Map Patch parameters: {sum(p.numel() for p in model.map_patch.parameters())}") - print(f"Feature Fusion parameters: {sum(p.numel() for p in model.feat_fusion.parameters())}") - print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}") - print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/generator/unet.py b/ginka/generator/unet.py deleted file mode 100644 index c14db9a..0000000 --- a/ginka/generator/unet.py +++ /dev/null @@ -1,188 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from shared.attention import ChannelAttention -from ..common.common import GCNBlock, TransformerGCNBlock, DoubleConvBlock, ConvFusionModule -from ..common.cond import ConditionInjector - -class GinkaTransformerEncoder(nn.Module): - def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6): - super().__init__() - in_dim = in_dim // token_size - hidden_dim = hidden_dim // token_size - out_dim = out_dim // token_size - self.embedding = nn.Sequential( - nn.Linear(in_dim, hidden_dim), - nn.LayerNorm(hidden_dim) - ) - self.pos_embedding = nn.Parameter(torch.randn(1, token_size, hidden_dim)) - self.transformer = nn.TransformerEncoder( - nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=ff_dim, batch_first=True), - num_layers=num_layers - ) - self.fc = nn.Sequential( - nn.Linear(hidden_dim, out_dim), - nn.LayerNorm(out_dim) - ) - - def forward(self, x): - # 输入 [B, L, in_dim] - # 输出 [B, L, out_dim] - x = self.embedding(x) # [B, L, hidden_dim] - x = x + self.pos_embedding # [B, L, hidden_dim] - x = self.transformer(x) # [B, L, hidden_dim] - x = self.fc(x) # [B, L, out_dim] - return x - -class ConvBlock(nn.Module): - def __init__(self, in_ch, out_ch, attn=True): - super().__init__() - self.conv = DoubleConvBlock([in_ch, out_ch, out_ch]) - # self.conv = nn.Sequential( - # nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'), - # nn.InstanceNorm2d(out_ch), - # nn.ELU(), - # nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'), - # nn.InstanceNorm2d(out_ch), - # ) - # if attn: - # self.conv.append(ChannelAttention(out_ch)) - # self.conv.append(nn.ELU()) - - def forward(self, x): - return self.conv(x) - -class FusionModule(nn.Module): - def __init__(self, in_ch, out_ch): - super().__init__() - self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate') - - def forward(self, x1, x2): - x = torch.cat([x1, x2], dim=1) - x = self.conv(x) - return x - -class GinkaUNetInput(nn.Module): - def __init__(self, in_ch, out_ch, w, h): - super().__init__() - self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h) - self.inject = ConditionInjector(256, out_ch) - - def forward(self, x, cond): - x = self.conv(x) - x = self.inject(x, cond) - return x - -class GinkaEncoder(nn.Module): - def __init__(self, in_ch, out_ch, w, h): - super().__init__() - self.pool = nn.MaxPool2d(2) - self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h) - self.inject = ConditionInjector(256, out_ch) - - def forward(self, x, cond): - x = self.pool(x) - x = self.conv(x) - x = self.inject(x, cond) - return x - -class GinkaUpSample(nn.Module): - def __init__(self, in_ch, out_ch): - super().__init__() - self.conv = nn.Sequential( - nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2), - nn.InstanceNorm2d(out_ch), - nn.GELU(), - - nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'), - nn.InstanceNorm2d(out_ch), - nn.GELU() - ) - - def forward(self, x): - return self.conv(x) - -class GinkaDecoder(nn.Module): - def __init__(self, in_ch, out_ch, w, h): - super().__init__() - self.upsample = GinkaUpSample(in_ch, in_ch // 2) - self.fusion = nn.Conv2d(in_ch, in_ch, 1) - self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h) - self.inject = ConditionInjector(256, out_ch) - - def forward(self, x, feat, cond): - x = self.upsample(x) - x = torch.cat([x, feat], dim=1) - x = self.fusion(x) - x = self.conv(x) - x = self.inject(x, cond) - return x - -class GinkaBottleneck(nn.Module): - def __init__(self, module_ch, w, h): - super().__init__() - # self.transformer = GinkaTransformerEncoder( - # in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h, - # token_size=16, ff_dim=1024, num_layers=4 - # ) - # self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, 4, 4) - # self.fusion = nn.Conv2d(module_ch*3, module_ch, 1) - self.conv = ConvFusionModule(module_ch, module_ch, module_ch, w, h) - self.inject = ConditionInjector(256, module_ch) - - def forward(self, x, cond): - # x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch] - # x1 = self.transformer(x1) - # x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4] - x = self.conv(x) - x = self.inject(x, cond) - return x - -class GinkaEncoderPath(nn.Module): - def __init__(self, in_ch, base_ch): - super().__init__() - self.down1 = GinkaUNetInput(in_ch, base_ch, 32, 32) - self.down2 = GinkaEncoder(base_ch, base_ch*2, 16, 16) - self.down3 = GinkaEncoder(base_ch*2, base_ch*4, 8, 8) - self.down4 = GinkaEncoder(base_ch*4, base_ch*8, 4, 4) - - def forward(self, x, cond): - x1 = self.down1(x, cond) # [B, 64, 32, 32] - x2 = self.down2(x1, cond) # [B, 128, 16, 16] - x3 = self.down3(x2, cond) # [B, 256, 8, 8] - x4 = self.down4(x3, cond) # [B, 512, 4, 4] - - return x1, x2, x3, x4 - -class GinkaDecoderPath(nn.Module): - def __init__(self, base_ch): - super().__init__() - self.up1 = GinkaDecoder(base_ch*8, base_ch*4, 8, 8) - self.up2 = GinkaDecoder(base_ch*4, base_ch*2, 16, 16) - self.up3 = GinkaDecoder(base_ch*2, base_ch, 32, 32) - - def forward(self, x1, x2, x3, x4, cond): - x = self.up1(x4, x3, cond) # [B, 256, 8, 8] - x = self.up2(x, x2, cond) # [B, 128, 16, 16] - x = self.up3(x, x1, cond) # [B, 64, 32, 32] - return x - -class GinkaUNet(nn.Module): - def __init__(self, in_ch=32, base_ch=32, out_ch=32): - """Ginka Model UNet 部分 - """ - super().__init__() - self.enc = GinkaEncoderPath(in_ch, base_ch) - self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4) - self.dec = GinkaDecoderPath(base_ch) - - self.final = ConvFusionModule(base_ch, base_ch, out_ch, 32, 32) - - def forward(self, x, cond): - x1, x2, x3, x4 = self.enc(x, cond) - x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4] - x = self.dec(x1, x2, x3, x4, cond) - - x = self.final(x) # [B, 32, 32, 32] - - return x diff --git a/ginka/save.py b/ginka/save.py deleted file mode 100644 index e671f2f..0000000 --- a/ginka/save.py +++ /dev/null @@ -1,16 +0,0 @@ -import argparse -import torch - -def to_deployment(path: str, output: str): - state = torch.load(path) - torch.save({ - "model_state": state["model_state"] - }, output) - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument("--input", type=str, default="result/ginka.pth") - parser.add_argument("--output", type=str, default="result/ginka_deploy.pth") - args = parser.parse_args() - to_deployment(args.input, args.output) - \ No newline at end of file diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index 8b2a636..8722fd9 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -89,9 +89,9 @@ def train(): # 用于生成图片 tile_dict = dict() - for file in os.listdir('tiles2'): + for file in os.listdir('tiles'): name = os.path.splitext(file)[0] - tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED) + tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) # 接续训练 if args.resume: diff --git a/ginka/train_rnn.py b/ginka/train_rnn.py deleted file mode 100644 index ed921eb..0000000 --- a/ginka/train_rnn.py +++ /dev/null @@ -1,188 +0,0 @@ -import argparse -import os -import sys -from datetime import datetime -import torch -import torch.optim as optim -import cv2 -from torch_geometric.loader import DataLoader -from tqdm import tqdm -from .generator.rnn import GinkaRNNModel -from .dataset import GinkaRNNDataset -from .generator.loss import RNNGinkaLoss -from shared.image import matrix_to_image_cv - -# 手工标注标签定义(暂时不用): -# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层, -# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风 -# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门 -# 24. 中心对称, 25. 部分对称, 26. 鱼骨 - -# 自动标注标签定义(暂时不用): -# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊 -# 32. 平面塔, 33. 转换塔, 34. 道具塔 - -# 标量值定义: -# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块 -# 1. 墙体密度,墙壁/地图面积 -# 2. 装饰密度,装饰数量/地图面积 -# 3. 门密度,门数量/地图面积 -# 4. 怪物密度,怪物数量/地图面积 -# 5. 资源密度,资源数量/地图面积 -# 6. 宝石密度,宝石数量/地图面积 -# 7. 血瓶密度,血瓶数量/地图面积 -# 8. 钥匙密度,钥匙数量/地图面积 -# 9. 道具密度,道具数量/地图面积 -# 10. 入口数量 -# 11. 机关门数量 -# 12. 咸鱼门数量(多层咸鱼门只算一个) - -# 图块定义: -# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地), -# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门 -# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启 -# 10-12. 三种等级的红宝石 -# 13-15. 三种等级的蓝宝石 -# 16-18. 三种等级的绿宝石 -# 19-22. 四种等级的血瓶 -# 23-25. 三种等级的道具 -# 26-28. 三种等级的怪物 -# 29. 入口,不区分楼梯和箭头 - -BATCH_SIZE = 96 - -device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") -os.makedirs("result", exist_ok=True) -os.makedirs("result/rnn", exist_ok=True) -os.makedirs("result/ginka_rnn_img", exist_ok=True) - -disable_tqdm = not sys.stdout.isatty() - -def gt_prob(epoch: int, max_epoch: int) -> float: - progress = epoch / max_epoch - return 0.1 + 0.9 * progress - -def parse_arguments(): - parser = argparse.ArgumentParser(description="training codes") - parser.add_argument("--resume", type=bool, default=False) - parser.add_argument("--state_ginka", type=str, default="result/rnn/ginka-100.pth") - parser.add_argument("--train", type=str, default="ginka-dataset.json") - parser.add_argument("--validate", type=str, default="ginka-eval.json") - parser.add_argument("--epochs", type=int, default=100) - parser.add_argument("--checkpoint", type=int, default=5) - parser.add_argument("--load_optim", type=bool, default=True) - args = parser.parse_args() - return args - -def train(): - print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") - - args = parse_arguments() - - ginka_rnn = GinkaRNNModel(device).to(device) - - dataset = GinkaRNNDataset(args.train, device) - dataset_val = GinkaRNNDataset(args.validate, device) - dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) - dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 8) - - optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4) - scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6) - - criterion = RNNGinkaLoss(32, device) - - # 用于生成图片 - tile_dict = dict() - for file in os.listdir('tiles'): - name = os.path.splitext(file)[0] - tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) - - if args.resume: - data_ginka = torch.load(args.state_ginka, map_location=device) - - ginka_rnn.load_state_dict(data_ginka["model_state"], strict=False) - - if args.load_optim: - if data_ginka.get("optim_state") is not None: - optimizer_ginka.load_state_dict(data_ginka["optim_state"]) - - print("Train from loaded state.") - - for epoch in tqdm(range(args.epochs), desc="RNN Training", disable=disable_tqdm): - loss_total_ginka = torch.Tensor([0]).to(device) - - iters = 0 - - for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): - val_cond = batch["val_cond"].to(device) - target_map = batch["target_map"].to(device) - - fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epochs)) - - loss = criterion.rnn_loss(fake_logits, target_map) - - loss.backward() - torch.nn.utils.clip_grad_norm_(ginka_rnn.parameters(), max_norm=1.0) - optimizer_ginka.step() - loss_total_ginka += loss.detach() - - iters += 1 - - # if iters % 50 == 0: - # avg_loss_ginka = loss_total_ginka.item() / iters - - # tqdm.write( - # f"[Iters {iters} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + - # f"E: {epoch + 1} | Loss: {avg_loss_ginka:.6f} | " + - # f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" - # ) - - avg_loss_ginka = loss_total_ginka.item() / len(dataloader) - tqdm.write( - f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + - f"E: {epoch + 1} | Loss: {avg_loss_ginka:.6f} | " + - f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" - ) - - scheduler_ginka.step() - - # 每若干轮输出一次图片,并保存检查点 - if (epoch + 1) % args.checkpoint == 0: - # 保存检查点 - torch.save({ - "model_state": ginka_rnn.state_dict(), - "optim_state": optimizer_ginka.state_dict(), - }, f"result/rnn/ginka-{epoch + 1}.pth") - - val_loss_total = torch.Tensor([0]).to(device) - with torch.no_grad(): - idx = 0 - for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): - val_cond = batch["val_cond"].to(device) - target_map = batch["target_map"].to(device) - - fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epochs)) - - val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach() - - fake_map = fake_map.cpu().numpy() - fake_img = matrix_to_image_cv(fake_map[0], tile_dict) - cv2.imwrite(f"result/ginka_rnn_img/{idx}.png", fake_img) - - idx += 1 - - avg_loss_val = val_loss_total.item() / len(dataloader_val) - tqdm.write( - f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " + - f"Loss: {avg_loss_val:.6f}" - ) - - print("Train ended.") - torch.save({ - "model_state": ginka_rnn.state_dict(), - }, f"result/ginka_rnn.pth") - - -if __name__ == "__main__": - torch.set_num_threads(4) - train() diff --git a/ginka/train_transformer.py b/ginka/train_transformer.py deleted file mode 100644 index e3b5198..0000000 --- a/ginka/train_transformer.py +++ /dev/null @@ -1,271 +0,0 @@ -import argparse -import os -import sys -import random -from datetime import datetime -import torch -import torch.nn.functional as F -import torch.optim as optim -import cv2 -import numpy as np -from torch_geometric.loader import DataLoader -from tqdm import tqdm -from .transformer.vae import GinkaTransformerVAE -from .vae_rnn.loss import VAELoss -from .vae_rnn.scheduler import VAEScheduler -from .dataset import GinkaRNNDataset -from shared.image import matrix_to_image_cv - -# 手工标注标签定义(暂时不用): -# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层, -# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风 -# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门 -# 24. 中心对称, 25. 部分对称, 26. 鱼骨 - -# 自动标注标签定义(暂时不用): -# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊 -# 32. 平面塔, 33. 转换塔, 34. 道具塔 - -# 标量值定义: -# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块 -# 1. 墙体密度,墙壁/地图面积 -# 2. 装饰密度,装饰数量/地图面积 -# 3. 门密度,门数量/地图面积 -# 4. 怪物密度,怪物数量/地图面积 -# 5. 资源密度,资源数量/地图面积 -# 6. 宝石密度,宝石数量/地图面积 -# 7. 血瓶密度,血瓶数量/地图面积 -# 8. 钥匙密度,钥匙数量/地图面积 -# 9. 道具密度,道具数量/地图面积 -# 10. 入口数量 -# 11. 机关门数量 -# 12. 咸鱼门数量(多层咸鱼门只算一个) - -# 图块定义: -# 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶 -# 8. 道具, 9. 怪物, 10. 入口, 14. 起始 token, 15. 终止 token - -BATCH_SIZE = 128 -LATENT_DIM = 32 -KL_BETA = 0.01 -SELF_GATE = 0.5 -GATE_EPOCH = 5 -VAL_BATCH_DIVIDER = 128 -PROB_STEP = 0.05 -NUM_CLASSES = 16 - -device = torch.device( - "cuda:1" if torch.cuda.is_available() - else "mps" if torch.mps.is_available() - else "cpu" -) -os.makedirs("result", exist_ok=True) -os.makedirs("result/vae", exist_ok=True) -os.makedirs("result/ginka_vae_img", exist_ok=True) - -disable_tqdm = not sys.stdout.isatty() - -def parse_arguments(): - parser = argparse.ArgumentParser(description="training codes") - parser.add_argument("--resume", type=bool, default=False) - parser.add_argument("--state_ginka", type=str, default="result/vae/ginka-100.pth") - parser.add_argument("--train", type=str, default="ginka-dataset.json") - parser.add_argument("--validate", type=str, default="ginka-eval.json") - parser.add_argument("--epochs", type=int, default=100) - parser.add_argument("--checkpoint", type=int, default=5) - parser.add_argument("--load_optim", type=bool, default=True) - args = parser.parse_args() - return args - -def train(): - print(f"Using {device.type} to train model.") - - args = parse_arguments() - - vae = GinkaTransformerVAE(num_classes=NUM_CLASSES, latent_dim=LATENT_DIM).to(device) - - dataset = GinkaRNNDataset(args.train, device) - dataset_val = GinkaRNNDataset(args.validate, device) - dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) - dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True) - - optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-2, betas=(0.9, 0.95)) - # 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习 - scheduler_ginka = VAEScheduler( - optimizer_ginka, factor=0.9, increase_factor=2, patience=10, max_lr=1e-4, min_lr=1e-6 - ) - - criterion = VAELoss() - - self_prob = 0 - prob_epochs = 0 - - # 用于生成图片 - tile_dict = dict() - for file in os.listdir('tiles2'): - name = os.path.splitext(file)[0] - tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED) - - if args.resume: - data_ginka = torch.load(args.state_ginka, map_location=device) - - vae.load_state_dict(data_ginka["model_state"], strict=False) - - if args.load_optim: - if data_ginka.get("optim_state") is not None: - optimizer_ginka.load_state_dict(data_ginka["optim_state"]) - - print("Train from loaded state.") - - for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm): - loss_total = torch.Tensor([0]).to(device) - reco_loss_total = torch.Tensor([0]).to(device) - kl_loss_total = torch.Tensor([0]).to(device) - - vae.teacher_forcing() - for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): - target_map = batch["target_map"].to(device) - B, H, W = target_map.shape - input = target_map.view(B, H * W) - - optimizer_ginka.zero_grad() - fake_logits, mu, logvar = vae(input, self_prob) - - loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, input, mu, logvar, KL_BETA) - - loss.backward() - torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0) - optimizer_ginka.step() - loss_total += loss.detach() - reco_loss_total += reco_loss.detach() - kl_loss_total += kl_loss.detach() - - avg_loss = loss_total.item() / len(dataloader) - avg_reco_loss = reco_loss_total.item() / len(dataloader) - avg_kl_loss = kl_loss_total.item() / len(dataloader) - tqdm.write( - f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + - f"E: {epoch + 1} | Loss: {avg_loss:.6f} | Reco: {avg_reco_loss:.6f} | " + - f"KL: {avg_kl_loss:.6f} | Prob: {self_prob:.2f} | LR: {scheduler_ginka.get_last_lr()[0]:.6f}" - ) - - # 验证集 - # with torch.no_grad(): - # val_loss_total = torch.Tensor([0]).to(device) - # for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): - # target_map = batch["target_map"].to(device) - - # fake_logits, mu, logvar = vae(target_map, 1 - gt_prob) - - # loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA) - # val_loss_total += loss.detach() - - # avg_loss_val = val_loss_total.item() / len(dataloader_val) - - # 先使用训练集的损失值,因为过拟合比较严重,后续再想办法 - if avg_loss < SELF_GATE: - prob_epochs += 1 - else: - prob_epochs = 0 - - if prob_epochs >= GATE_EPOCH and self_prob < 1: - self_prob += PROB_STEP - prob_epochs = 0 - if self_prob > 1: - self_prob = 1 - - self_prob = 1 - - scheduler_ginka.step(avg_loss, self_prob) - - # 每若干轮输出一次图片,并保存检查点 - if (epoch + 1) % args.checkpoint == 0: - # 保存检查点 - torch.save({ - "model_state": vae.state_dict(), - "optim_state": optimizer_ginka.state_dict(), - }, f"result/rnn/ginka-{epoch + 1}.pth") - - val_loss_total = torch.Tensor([0]).to(device) - val_reco_loss_total = torch.Tensor([0]).to(device) - val_kl_loss_total = torch.Tensor([0]).to(device) - vae.eval() - with torch.no_grad(): - idx = 0 - gap = 5 - color = (255, 255, 255) # 白色 - vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线 - # 地图重建展示 - vae.teacher_forcing() - for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): - target_map = batch["target_map"].to(device) - B, H, W = target_map.shape - input = target_map.view(B, H * W) - - fake_logits, mu, logvar = vae(input, self_prob) - - loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, input, mu, logvar, KL_BETA) - val_loss_total += loss.detach() - val_reco_loss_total += reco_loss.detach() - val_kl_loss_total += kl_loss.detach() - - fake_map = torch.argmax(fake_logits, dim=2)[:,0:169].view(B, H, W).cpu().numpy() - fake_img = matrix_to_image_cv(fake_map[0], tile_dict) - real_map = target_map.cpu().numpy() - real_img = matrix_to_image_cv(real_map[0], tile_dict) - img = np.block([[real_img], [vline], [fake_img]]) - cv2.imwrite(f"result/ginka_vae_img/{idx}.png", img) - - idx += 1 - - # 随机采样 - vae.autoregressive() - for i in range(0, 8): - z = torch.randn(1, LATENT_DIM).to(device) - - fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device)) - fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy() - fake_img = matrix_to_image_cv(fake_map[0], tile_dict) - - cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img) - - # 插值 - val_length = len(dataset_val.data) - index1 = random.randint(0, val_length - 1) - index2 = random.randint(0, val_length - 1) - map1 = torch.LongTensor(dataset_val.data[index1]["map"]).to(device).view(1, 169) - map2 = torch.LongTensor(dataset_val.data[index2]["map"]).to(device).view(1, 169) - mu1, logvar1 = vae.encoder(map1) - mu2, logvar2 = vae.encoder(map2) - z1 = vae.reparameterize(mu1, logvar1) - z2 = vae.reparameterize(mu2, logvar2) - real_img1 = matrix_to_image_cv(map1[0].view(13, 13).cpu().numpy(), tile_dict) - real_img2 = matrix_to_image_cv(map2[0].view(13, 13).cpu().numpy(), tile_dict) - i = 0 - for t in torch.linspace(0, 1, 8): - z = z1 * (1 - t / 8) + z2 * t / 8 - fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device)) - fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy() - fake_img = matrix_to_image_cv(fake_map[0], tile_dict) - img = np.block([[real_img1], [vline], [fake_img], [vline], [real_img2]]) - - cv2.imwrite(f"result/ginka_vae_img/{i}_linspace.png", img) - i += 1 - - avg_loss_val = val_loss_total.item() / len(dataloader_val) - avg_reco_loss_val = val_reco_loss_total.item() / len(dataloader_val) - avg_kl_loss_val = val_kl_loss_total.item() / len(dataloader_val) - tqdm.write( - f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " + - f"Loss: {avg_loss_val:.6f} | Reco: {avg_reco_loss_val:.6f} | KL: {avg_kl_loss_val:.6f}" - ) - - print("Train ended.") - torch.save({ - "model_state": vae.state_dict(), - }, f"result/ginka_transformer.pth") - - -if __name__ == "__main__": - torch.set_num_threads(4) - train() diff --git a/ginka/train_vae.py b/ginka/train_vae.py deleted file mode 100644 index ae539d8..0000000 --- a/ginka/train_vae.py +++ /dev/null @@ -1,270 +0,0 @@ -import argparse -import os -import sys -import random -from datetime import datetime -import torch -import torch.nn.functional as F -import torch.optim as optim -import cv2 -import numpy as np -from torch_geometric.loader import DataLoader -from tqdm import tqdm -from .vae_rnn.vae import GinkaVAE -from .vae_rnn.loss import VAELoss -from .vae_rnn.scheduler import VAEScheduler -from .dataset import GinkaRNNDataset -from shared.image import matrix_to_image_cv - -# 手工标注标签定义(暂时不用): -# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层, -# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风 -# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门 -# 24. 中心对称, 25. 部分对称, 26. 鱼骨 - -# 自动标注标签定义(暂时不用): -# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊 -# 32. 平面塔, 33. 转换塔, 34. 道具塔 - -# 标量值定义: -# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块 -# 1. 墙体密度,墙壁/地图面积 -# 2. 装饰密度,装饰数量/地图面积 -# 3. 门密度,门数量/地图面积 -# 4. 怪物密度,怪物数量/地图面积 -# 5. 资源密度,资源数量/地图面积 -# 6. 宝石密度,宝石数量/地图面积 -# 7. 血瓶密度,血瓶数量/地图面积 -# 8. 钥匙密度,钥匙数量/地图面积 -# 9. 道具密度,道具数量/地图面积 -# 10. 入口数量 -# 11. 机关门数量 -# 12. 咸鱼门数量(多层咸鱼门只算一个) - -# 图块定义: -# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地), -# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门 -# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启 -# 10-12. 三种等级的红宝石 -# 13-15. 三种等级的蓝宝石 -# 16-18. 三种等级的绿宝石 -# 19-22. 四种等级的血瓶 -# 23-25. 三种等级的道具 -# 26-28. 三种等级的怪物 -# 29. 入口,不区分楼梯和箭头 - -BATCH_SIZE = 128 -LATENT_DIM = 64 -KL_BETA = 0.01 -SELF_GATE = 0.3 -GATE_EPOCH = 10 -VAL_BATCH_DIVIDER = 128 -PROB_STEP = 0.05 - -device = torch.device( - "cuda:1" if torch.cuda.is_available() - else "mps" if torch.mps.is_available() - else "cpu" -) -os.makedirs("result", exist_ok=True) -os.makedirs("result/vae", exist_ok=True) -os.makedirs("result/ginka_vae_img", exist_ok=True) - -disable_tqdm = not sys.stdout.isatty() - -def parse_arguments(): - parser = argparse.ArgumentParser(description="training codes") - parser.add_argument("--resume", type=bool, default=False) - parser.add_argument("--state_ginka", type=str, default="result/vae/ginka-100.pth") - parser.add_argument("--train", type=str, default="ginka-dataset.json") - parser.add_argument("--validate", type=str, default="ginka-eval.json") - parser.add_argument("--epochs", type=int, default=100) - parser.add_argument("--checkpoint", type=int, default=5) - parser.add_argument("--load_optim", type=bool, default=True) - args = parser.parse_args() - return args - -def train(): - print(f"Using {device.type} to train model.") - - args = parse_arguments() - - vae = GinkaVAE(device, latent_dim=LATENT_DIM).to(device) - - dataset = GinkaRNNDataset(args.train, device) - dataset_val = GinkaRNNDataset(args.validate, device) - dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) - dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True) - - optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-4) - # 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习 - scheduler_ginka = VAEScheduler( - optimizer_ginka, factor=0.9, increase_factor=1.5, patience=20, max_lr=3e-4, min_lr=1e-6 - ) - - criterion = VAELoss() - - self_prob = 0 - prob_epochs = 0 - - # 用于生成图片 - tile_dict = dict() - for file in os.listdir('tiles'): - name = os.path.splitext(file)[0] - tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) - - if args.resume: - data_ginka = torch.load(args.state_ginka, map_location=device) - - vae.load_state_dict(data_ginka["model_state"], strict=False) - - if args.load_optim: - if data_ginka.get("optim_state") is not None: - optimizer_ginka.load_state_dict(data_ginka["optim_state"]) - - print("Train from loaded state.") - - for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm): - loss_total = torch.Tensor([0]).to(device) - reco_loss_total = torch.Tensor([0]).to(device) - kl_loss_total = torch.Tensor([0]).to(device) - - vae.train() - for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): - target_map = batch["target_map"].to(device) - - optimizer_ginka.zero_grad() - fake_logits, mu, logvar = vae(target_map, self_prob) - - loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA) - - loss.backward() - torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0) - optimizer_ginka.step() - loss_total += loss.detach() - reco_loss_total += reco_loss.detach() - kl_loss_total += kl_loss.detach() - - avg_loss = loss_total.item() / len(dataloader) - avg_reco_loss = reco_loss_total.item() / len(dataloader) - avg_kl_loss = kl_loss_total.item() / len(dataloader) - tqdm.write( - f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + - f"E: {epoch + 1} | Loss: {avg_loss:.6f} | Reco: {avg_reco_loss:.6f} | " + - f"KL: {avg_kl_loss:.6f} | Prob: {self_prob:.2f} | LR: {scheduler_ginka.get_last_lr()[0]:.6f}" - ) - - # 验证集 - # with torch.no_grad(): - # val_loss_total = torch.Tensor([0]).to(device) - # for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): - # target_map = batch["target_map"].to(device) - - # fake_logits, mu, logvar = vae(target_map, 1 - gt_prob) - - # loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA) - # val_loss_total += loss.detach() - - # avg_loss_val = val_loss_total.item() / len(dataloader_val) - - # 先使用训练集的损失值,因为过拟合比较严重,后续再想办法 - if avg_loss < SELF_GATE: - prob_epochs += 1 - else: - prob_epochs = 0 - - if prob_epochs >= GATE_EPOCH and self_prob < 1: - self_prob += PROB_STEP - prob_epochs = 0 - if self_prob > 1: - self_prob = 1 - - scheduler_ginka.step(avg_loss, self_prob) - - # 每若干轮输出一次图片,并保存检查点 - if (epoch + 1) % args.checkpoint == 0: - vae.eval() - # 保存检查点 - torch.save({ - "model_state": vae.state_dict(), - "optim_state": optimizer_ginka.state_dict(), - }, f"result/rnn/ginka-{epoch + 1}.pth") - - val_loss_total = torch.Tensor([0]).to(device) - val_reco_loss_total = torch.Tensor([0]).to(device) - val_kl_loss_total = torch.Tensor([0]).to(device) - with torch.no_grad(): - idx = 0 - gap = 5 - color = (255, 255, 255) # 白色 - vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线 - # 地图重建展示 - for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): - target_map = batch["target_map"].to(device) - - fake_logits, mu, logvar = vae(target_map, self_prob) - - loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA) - val_loss_total += loss.detach() - val_reco_loss_total += reco_loss.detach() - val_kl_loss_total += kl_loss.detach() - - fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy() - fake_img = matrix_to_image_cv(fake_map[0], tile_dict) - real_map = target_map.cpu().numpy() - real_img = matrix_to_image_cv(real_map[0], tile_dict) - img = np.block([[real_img], [vline], [fake_img]]) - cv2.imwrite(f"result/ginka_vae_img/{idx}.png", img) - - idx += 1 - - # 随机采样 - for i in range(0, 8): - z = torch.randn(1, LATENT_DIM).to(device) - - fake_logits = vae.decoder(z, torch.zeros(1, 13, 13).to(device), 1) - fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy() - fake_img = matrix_to_image_cv(fake_map[0], tile_dict) - - cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img) - - # 插值 - val_length = len(dataset_val.data) - index1 = random.randint(0, val_length - 1) - index2 = random.randint(0, val_length - 1) - map1 = torch.LongTensor(dataset_val.data[index1]["map"]).to(device).reshape(1, 13, 13) - map2 = torch.LongTensor(dataset_val.data[index2]["map"]).to(device).reshape(1, 13, 13) - mu1, logvar1 = vae.encoder(map1) - mu2, logvar2 = vae.encoder(map2) - z1 = vae.reparameterize(mu1, logvar1) - z2 = vae.reparameterize(mu2, logvar2) - real_img1 = matrix_to_image_cv(map1[0].cpu().numpy(), tile_dict) - real_img2 = matrix_to_image_cv(map2[0].cpu().numpy(), tile_dict) - i = 0 - for t in torch.linspace(0, 1, 8): - z = z1 * (1 - t / 8) + z2 * t / 8 - fake_logits = vae.decoder(z, torch.zeros(1, 13, 13).to(device), 1) - fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy() - fake_img = matrix_to_image_cv(fake_map[0], tile_dict) - img = np.block([[real_img1], [vline], [fake_img], [vline], [real_img2]]) - - cv2.imwrite(f"result/ginka_vae_img/{i}_linspace.png", img) - i += 1 - - avg_loss_val = val_loss_total.item() / len(dataloader_val) - avg_reco_loss_val = val_reco_loss_total.item() / len(dataloader_val) - avg_kl_loss_val = val_kl_loss_total.item() / len(dataloader_val) - tqdm.write( - f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch + 1} | " + - f"Loss: {avg_loss_val:.6f} | Reco: {avg_reco_loss_val:.6f} | KL: {avg_kl_loss_val:.6f}" - ) - - print("Train ended.") - torch.save({ - "model_state": vae.state_dict(), - }, f"result/ginka_rnn.pth") - - -if __name__ == "__main__": - torch.set_num_threads(4) - train() diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py deleted file mode 100644 index 9f982a1..0000000 --- a/ginka/train_wgan.py +++ /dev/null @@ -1,428 +0,0 @@ -import argparse -import os -import sys -from datetime import datetime -import torch -import torch.optim as optim -import torch.nn.functional as F -import cv2 -import numpy as np -from torch_geometric.loader import DataLoader -from tqdm import tqdm -from .generator.model import GinkaModel -from .dataset import GinkaWGANDataset -from .generator.loss import WGANGinkaLoss -from .critic.model import MinamoModel2 -from shared.image import matrix_to_image_cv - -# 手工标注标签定义: -# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层, -# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风 -# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门 -# 24. 中心对称, 25. 部分对称, 26. 鱼骨 - -# 自动标注标签定义: -# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊 -# 32. 平面塔, 33. 转换塔, 34. 道具塔 - -# 标量值定义: -# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块 -# 1. 墙体密度,墙壁/地图面积 -# 2. 装饰密度,装饰数量/地图面积 -# 3. 门密度,门数量/地图面积 -# 4. 怪物密度,怪物数量/地图面积 -# 5. 资源密度,资源数量/地图面积 -# 6. 宝石密度,宝石数量/地图面积 -# 7. 血瓶密度,血瓶数量/地图面积 -# 8. 钥匙密度,钥匙数量/地图面积 -# 9. 道具密度,道具数量/地图面积 -# 10. 入口数量 -# 11. 机关门数量 -# 12. 咸鱼门数量(多层咸鱼门只算一个) - -# 图块定义: -# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地), -# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门 -# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启 -# 10-12. 三种等级的红宝石 -# 13-15. 三种等级的蓝宝石 -# 16-18. 三种等级的绿宝石 -# 19-22. 四种等级的血瓶 -# 23-25. 三种等级的道具 -# 26-28. 三种等级的怪物 -# 29. 入口,不区分楼梯和箭头 - -BATCH_SIZE = 6 - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -os.makedirs("result", exist_ok=True) -os.makedirs("result/wgan", exist_ok=True) - -disable_tqdm = not sys.stdout.isatty() - -def parse_arguments(): - parser = argparse.ArgumentParser(description="training codes") - parser.add_argument("--resume", type=bool, default=False) - parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth") - parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth") - parser.add_argument("--train", type=str, default="ginka-dataset.json") - parser.add_argument("--validate", type=str, default="ginka-eval.json") - parser.add_argument("--epochs", type=int, default=100) - parser.add_argument("--checkpoint", type=int, default=5) - parser.add_argument("--load_optim", type=bool, default=True) - parser.add_argument("--curr_epoch", type=int, default=20) # 课程学习至少多少 epoch - parser.add_argument("--tuning", type=bool, default=False) - args = parser.parse_args() - return args - -def gen_curriculum(gen, masked1, masked2, masked3, tag, val, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - fake1 = gen(masked1, 1, tag, val) - fake2 = gen(masked2, 2, tag, val) - fake3 = gen(masked3, 3, tag, val) - if detach: - return fake1.detach(), fake2.detach(), fake3.detach() - else: - return fake1, fake2, fake3 - -def gen_total(gen, input, tag, val, progress_detach=True, result_detach=False, random=False) -> torch.Tensor: - if random: - fake0 = gen(input, 0, tag, val) - x_in = F.softmax(fake0, dim=1) - else: - fake0 = input - x_in = input - if progress_detach: - fake1 = gen(x_in.detach(), 1, tag, val) - fake2 = gen(F.softmax(fake1.detach(), dim=1), 2, tag, val) - fake3 = gen(F.softmax(fake2.detach(), dim=1), 3, tag, val) - else: - fake1 = gen(x_in, 1, tag, val) - fake2 = gen(F.softmax(fake1, dim=1), 2, tag, val) - fake3 = gen(F.softmax(fake2, dim=1), 3, tag, val) - if result_detach: - return fake1.detach(), fake2.detach(), fake3.detach(), fake0.detach() - else: - return fake1, fake2, fake3, fake0 - -def train(): - print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") - - args = parse_arguments() - - c_steps = 2 - g_steps = 1 - # 训练阶段 - train_stage = 1 - mask_ratio = 0.2 # 蒙版区域大小 - stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程 - total_epoch = 0 - - ginka = GinkaModel().to(device) - minamo = MinamoModel2().to(device) - - dataset = GinkaWGANDataset(args.train, device) - dataset_val = GinkaWGANDataset(args.validate, device) - dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) - dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE) - - optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) - optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9)) - - scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2) - scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2) - - criterion = WGANGinkaLoss() - - # 用于生成图片 - tile_dict = dict() - for file in os.listdir('tiles'): - name = os.path.splitext(file)[0] - tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) - - if args.resume: - data_ginka = torch.load(args.state_ginka, map_location=device) - data_minamo = torch.load(args.state_minamo, map_location=device) - - ginka.load_state_dict(data_ginka["model_state"], strict=False) - minamo.load_state_dict(data_minamo["model_state"], strict=False) - - # if data_ginka.get("c_steps") is not None and data_ginka.get("g_steps") is not None: - # c_steps = data_ginka["c_steps"] - # g_steps = data_ginka["g_steps"] - - if data_ginka.get("mask_ratio") is not None: - mask_ratio = data_ginka["mask_ratio"] - - if data_ginka.get("stage_epoch") is not None: - stage_epoch = data_ginka["stage_epoch"] - - if data_ginka.get("stage") is not None: - train_stage = data_ginka["stage"] - - if data_ginka.get("total_epoch") is not None: - total_epoch = data_ginka["data_ginka"] - - if args.load_optim: - if data_ginka.get("optim_state") is not None: - optimizer_ginka.load_state_dict(data_ginka["optim_state"]) - if data_minamo.get("optim_state") is not None: - optimizer_minamo.load_state_dict(data_minamo["optim_state"]) - - print("Train from loaded state.") - - curr_epoch = args.curr_epoch - first_curr = curr_epoch * 3 - - if args.tuning: - train_stage = 1 - curr_epoch = curr_epoch // 4 - first_curr = first_curr // 4 - stage_epoch = 0 - mask_ratio = 0.2 - - dataset.train_stage = train_stage - dataset.mask_ratio1 = mask_ratio - dataset.mask_ratio2 = mask_ratio - dataset.mask_ratio3 = mask_ratio - - dataset_val.train_stage = train_stage - dataset_val.mask_ratio1 = mask_ratio - dataset_val.mask_ratio2 = mask_ratio - dataset_val.mask_ratio3 = mask_ratio - - for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm): - loss_total_minamo = torch.Tensor([0]).to(device) - loss_total_ginka = torch.Tensor([0]).to(device) - dis_total = torch.Tensor([0]).to(device) - loss_ce_total = torch.Tensor([0]).to(device) - - iters = 0 - - for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): - rand = batch["rand"].to(device) - real0 = batch["real0"].to(device) - real1 = batch["real1"].to(device) - masked1 = batch["masked1"].to(device) - real2 = batch["real2"].to(device) - masked2 = batch["masked2"].to(device) - real3 = batch["real3"].to(device) - masked3 = batch["masked3"].to(device) - tag_cond = batch["tag_cond"].to(device) - val_cond = batch["val_cond"].to(device) - - # ---------- 训练判别器 - for _ in range(c_steps): - # 生成假样本 - optimizer_minamo.zero_grad() - optimizer_ginka.zero_grad() - - with torch.no_grad(): - if train_stage == 1 or train_stage == 2: - fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True) - elif train_stage == 3 or train_stage == 4: - fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4) - - if train_stage < 4: - fake0 = ginka(rand, 0, tag_cond, val_cond) - - loss_d0, dis0 = criterion.discriminator_loss(minamo, 0, real0, fake0, tag_cond, val_cond) - loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1, tag_cond, val_cond) - loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2, tag_cond, val_cond) - loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3, tag_cond, val_cond) - - dis = [dis0, dis1, dis2, dis3] - loss_d = [loss_d0, loss_d1, loss_d2, loss_d3] - - dis_avg = sum(dis) / len(dis) - loss_d_avg = sum(loss_d) / len(loss_d) - - # 反向传播 - loss_d_avg.backward() - - optimizer_minamo.step() - - loss_total_minamo += loss_d_avg.detach() - dis_total += dis_avg.detach() - - # ---------- 训练生成器 - - for _ in range(g_steps): - optimizer_minamo.zero_grad() - optimizer_ginka.zero_grad() - if train_stage == 1 or train_stage == 2: - fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, False) - - loss_g1, loss_ce_g1 = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1, tag_cond, val_cond) - loss_g2, loss_ce_g2 = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond) - loss_g3, loss_ce_g3 = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond) - - loss_g = (loss_g1 * 3.0 + loss_g2 + loss_g3) / 5.0 - loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3) - - loss_ce_total += loss_ce.detach() - - elif train_stage == 3 or train_stage == 4: - fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4) - if train_stage == 4: - fake0 = F.softmax(fake0, dim=1) - - loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, fake0, tag_cond, val_cond) - loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1, tag_cond, val_cond) - loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond) - - loss_g = (loss_g1 * 3.0 + loss_g2 + loss_g3) / 5.0 - - if train_stage < 4: - fake0 = F.softmax(ginka(rand, 0, tag_cond, val_cond), dim=1) - - loss_g0 = criterion.generator_input_head_loss(minamo, fake0, tag_cond, val_cond) - loss_g += loss_g0 - - loss_g.backward() - optimizer_ginka.step() - loss_total_ginka += loss_g.detach() - - iters += 1 - - if iters % 50 == 0: - avg_loss_ginka = loss_total_ginka.item() / iters / g_steps - avg_loss_minamo = loss_total_minamo.item() / iters / c_steps - avg_loss_ce = loss_ce_total.item() / iters / g_steps - avg_dis = dis_total.item() / iters / c_steps - tqdm.write( - f"[Iters {iters} {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + - f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " + - f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " + - f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " + - f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" - ) - - avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps - avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps - avg_loss_ce = loss_ce_total.item() / len(dataloader) / g_steps - avg_dis = dis_total.item() / len(dataloader) / c_steps - tqdm.write( - f"[Epoch {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + - f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " + - f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " + - f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " + - f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" - ) - - # 每若干轮输出一次图片,并保存检查点 - if (epoch + 1) % args.checkpoint == 0: - # 保存检查点 - torch.save({ - "model_state": ginka.state_dict(), - "optim_state": optimizer_ginka.state_dict(), - "c_steps": c_steps, - "g_steps": g_steps, - "stage": train_stage, - "mask_ratio": mask_ratio, - "stage_epoch": stage_epoch, - }, f"result/wgan/ginka-{epoch + 1}.pth") - torch.save({ - "model_state": minamo.state_dict(), - "optim_state": optimizer_minamo.state_dict() - }, f"result/wgan/minamo-{epoch + 1}.pth") - - idx = 0 - gap = 5 - color = (255, 255, 255) # 白色 - with torch.no_grad(): - for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): - real1 = batch["real1"].to(device) - masked1 = batch["masked1"].to(device) - real2 = batch["real2"].to(device) - masked2 = batch["masked2"].to(device) - real3 = batch["real3"].to(device) - masked3 = batch["masked3"].to(device) - tag_cond = batch["tag_cond"].to(device) - val_cond = batch["val_cond"].to(device) - - if train_stage == 1 or train_stage == 2: - fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True) - - elif train_stage == 3 or train_stage == 4: - fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4) - fake0 = torch.argmax(fake0, dim=1).cpu().numpy() - - fake1 = torch.argmax(fake1, dim=1).cpu().numpy() - fake2 = torch.argmax(fake2, dim=1).cpu().numpy() - fake3 = torch.argmax(fake3, dim=1).cpu().numpy() - masked1 = torch.argmax(masked1, dim=1).cpu().numpy() - masked2 = torch.argmax(masked2, dim=1).cpu().numpy() - masked3 = torch.argmax(masked3, dim=1).cpu().numpy() - - for i in range(fake1.shape[0]): - fake1_img = matrix_to_image_cv(fake1[i], tile_dict) - fake2_img = matrix_to_image_cv(fake2[i], tile_dict) - fake3_img = matrix_to_image_cv(fake3[i], tile_dict) - if train_stage == 1 or train_stage == 2: - vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线 - hline = np.full((gap, 3 * 416 + gap * 2, 3), color, dtype=np.uint8) # 水平分割线 - in1_img = matrix_to_image_cv(masked1[i], tile_dict) - in2_img = matrix_to_image_cv(masked2[i], tile_dict) - in3_img = matrix_to_image_cv(masked3[i], tile_dict) - img = np.block([ - [[in1_img], [vline], [in2_img], [vline], [in3_img]], - [[hline]], - [[fake1_img], [vline], [fake2_img], [vline], [fake3_img]] - ]) - elif train_stage == 3 or train_stage == 4: - vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线 - hline = np.full((gap, 2 * 416 + gap, 3), color, dtype=np.uint8) # 水平分割线 - in_img = matrix_to_image_cv(fake0[i], tile_dict) - img = np.block([ - [[in_img], [vline], [fake1_img]], - [[hline]], - [[fake2_img], [vline], [fake3_img]] - ]) - - cv2.imwrite(f"result/ginka_img/{idx}.png", img) - - idx += 1 - - # 训练流程控制 - - # if train_stage >= 2: - # # train_stage = 4 - # if (epoch + 1) % 100 == 5: - # train_stage = 3 - # elif (epoch + 1) % 100 == 20: - # train_stage = 4 - # elif (epoch + 1) % 100 == 0: - # train_stage = 2 - - if train_stage == 1: - if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \ - (mask_ratio > 0.3 and stage_epoch >= curr_epoch): - mask_ratio += 0.2 - mask_ratio = min(mask_ratio, 0.8) - - stage_epoch = 0 - if mask_ratio >= 0.8: - train_stage = 4 - - stage_epoch += 1 - total_epoch += 1 - - dataset.train_stage = train_stage - dataset_val.train_stage = train_stage - dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio - dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio - - scheduler_ginka.step() - scheduler_minamo.step() - - print("Train ended.") - torch.save({ - "model_state": ginka.state_dict(), - }, f"result/ginka.pth") - torch.save({ - "model_state": minamo.state_dict(), - }, f"result/minamo.pth") - -if __name__ == "__main__": - torch.set_num_threads(4) - train() diff --git a/ginka/transformer/decoder.py b/ginka/transformer/decoder.py deleted file mode 100644 index cb720d0..0000000 --- a/ginka/transformer/decoder.py +++ /dev/null @@ -1,106 +0,0 @@ -import time -import torch -import torch.nn as nn -from ..utils import print_memory - -class GinkaTransformerDecoder(nn.Module): - def __init__(self, num_classes=32, dim_ff=256, nhead=4, num_layers=4, map_size=13*13): - super().__init__() - self.autoregressive = False - self.dim_ff = dim_ff - self.map_size = map_size - self.embedding = nn.Embedding(num_classes, dim_ff) - self.pos_embedding = nn.Embedding(map_size + 1, dim_ff) - self.encoder = nn.TransformerEncoder( - nn.TransformerEncoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True), - num_layers=max(num_layers // 2, 1) - ) - self.decoder = nn.TransformerDecoder( - nn.TransformerDecoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True), - num_layers=num_layers - ) - self.fc = nn.Sequential( - nn.Linear(dim_ff, num_classes) - ) - - def forward(self, z: torch.Tensor, target_map: torch.Tensor): - # z: [B, dim_ff] - # target_map: [B, H * W] - # training output: [B, H * W, dim_ff] - # evaling output: [B, H * W] - B, L = target_map.shape - - memory = self.encoder(z.unsqueeze(1)) # [B, 1, dim_ff] - mask = torch.triu(torch.ones(L + 1, L + 1, dtype=torch.bool)).to(z.device) # [B, H * W, H * W] - - # when training, use teacher forcing - if not self.autoregressive: - first_token = torch.tensor([31], dtype=torch.long).to(z.device).repeat(B, 1) - with_first = torch.cat([first_token, target_map], dim=1) - map = self.embedding(with_first) - pos_embed = self.pos_embedding(torch.arange(L + 1, dtype=torch.long).to(z.device)) - map = map + pos_embed # [B, H * W, dim_ff] - decoded = self.decoder(map, memory, tgt_mask=mask) # [B, H * W, dim_ff] - output = self.fc(decoded) - return output - - # when evaling, use autoregressive generation - else: - output = torch.zeros([B, L + 1], dtype=torch.int).to(z.device) - for idx in range(0, self.map_size): - embed = self.embedding(output) - pos_embed = self.pos_embedding(torch.IntTensor([idx]).repeat(B, 1).to(z.device)) - map = embed + pos_embed # [B, H * W, dim_ff] - decoded = self.decoder(map, memory, tgt_mask=mask) - decoded = self.fc(decoded) # [B, H * W, dim_ff] - output[:, idx] = torch.argmax(decoded[:, idx, :], dim=1) - - return output - -class GinkaTransformerVAEDecoder(nn.Module): - def __init__( - self, latent_dim=32, num_classes=32, dim_ff=256, nhead=4, num_layers=4, - map_size=13*13 - ): - super().__init__() - self.map_size = map_size - self.input = nn.Sequential( - nn.Linear(latent_dim, dim_ff), - nn.Dropout(0.3), - nn.LayerNorm(dim_ff), - nn.ReLU(), - - nn.Linear(dim_ff, dim_ff) - ) - self.decoder = GinkaTransformerDecoder( - num_classes=num_classes, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers, map_size=map_size - ) - - def forward(self, z: torch.Tensor, map: torch.Tensor): - hidden = self.input(z) - output = self.decoder(hidden, map) - return output - -if __name__ == "__main__": - device = torch.device("cpu") - - input = torch.randn(1, 32).to(device) - map = torch.randint(0, 32, [1, 169]).to(device) - - # 初始化模型 - model = GinkaTransformerVAEDecoder().to(device) - - print_memory("初始化后") - - # 前向传播 - start = time.perf_counter() - output = model(input, map) - end = time.perf_counter() - - print_memory("前向传播后") - - print(f"推理耗时: {end - start}") - print(f"输出形状: output={output.shape}") - print(f"Input Embedding parameters: {sum(p.numel() for p in model.input.parameters())}") - print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}") - print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/transformer/encoder.py b/ginka/transformer/encoder.py deleted file mode 100644 index 9912171..0000000 --- a/ginka/transformer/encoder.py +++ /dev/null @@ -1,97 +0,0 @@ -import time -import torch -import torch.nn as nn -import torch.nn.functional as F -from ..utils import print_memory - -class GinkaTransformerEncoder(nn.Module): - def __init__(self, dim_ff=256, nhead=4, num_layers=4): - super().__init__() - self.dim_ff = dim_ff - self.encoder = nn.TransformerEncoder( - nn.TransformerEncoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True, activation=F.gelu), - num_layers=num_layers - ) - self.decoder = nn.TransformerDecoder( - nn.TransformerDecoderLayer(d_model=dim_ff, dim_feedforward=dim_ff, nhead=nhead, batch_first=True, activation=F.gelu), - num_layers=max(num_layers // 2, 1) - ) - - def forward(self, x: torch.Tensor): - # x: [B, H * W, S] - B, L, S = x.shape - first_token = torch.randn(B, 1, self.dim_ff).to(x.device) - x = self.encoder(x) - x = self.decoder(first_token, x) - return x.squeeze(1) - -class GinkaTransformerBottleneck(nn.Module): - def __init__(self, dim_ff=256, hidden_dim=512, latent_dim=32): - super().__init__() - self.fc = nn.Sequential( - nn.Linear(dim_ff, hidden_dim), - nn.Dropout(0.3), - nn.LayerNorm(hidden_dim), - nn.GELU(), - ) - self.fc_mu = nn.Sequential( - nn.Linear(hidden_dim, latent_dim) - ) - self.fc_logvar = nn.Sequential( - nn.Linear(hidden_dim, latent_dim) - ) - - def forward(self, x): - # x: [B, dim_ff] - hidden = self.fc(x) - mu = self.fc_mu(hidden) - logvar = self.fc_logvar(hidden) - return mu, logvar - -class GinkaTransformerVAEEncoder(nn.Module): - def __init__( - self, num_classes=32, latent_dim=32, bottleneck_dim=512, dim_ff=256, - nhead=4, num_layers=4, map_size=13*13 - ): - super().__init__() - self.map_size = map_size - self.embedding = nn.Embedding(num_classes, dim_ff) - self.pos_embedding = nn.Embedding(map_size, dim_ff) - self.encoder = GinkaTransformerEncoder(dim_ff=dim_ff, nhead=nhead, num_layers=num_layers) - self.bottleneck = GinkaTransformerBottleneck( - dim_ff=dim_ff, hidden_dim=bottleneck_dim, latent_dim=latent_dim - ) - - def forward(self, x: torch.Tensor): - # x: [B, map_size] - pos = self.pos_embedding(torch.arange(self.map_size, dtype=torch.long).to(x.device)) - x = self.embedding(x) + pos - x = self.encoder(x) - mu, logvar = self.bottleneck(x) - return mu, logvar - -if __name__ == "__main__": - device = torch.device("cpu") - - input = torch.randint(0, 32, [1, 169]).to(device) - - # 初始化模型 - model = GinkaTransformerVAEEncoder().to(device) - - print_memory("初始化后") - - # 前向传播 - start = time.perf_counter() - mu, logvar = model(input) - end = time.perf_counter() - - print_memory("前向传播后") - - print(f"推理耗时: {end - start}") - print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}") - print(f"Embedding parameters: {sum(p.numel() for p in model.embedding.parameters())}") - print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}") - print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}") - print(f"bottleneck parameters: {sum(p.numel() for p in model.bottleneck.parameters())}") - print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") - diff --git a/ginka/transformer/fsq.py b/ginka/transformer/fsq.py deleted file mode 100644 index 5d39691..0000000 --- a/ginka/transformer/fsq.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch -import torch.nn as nn - -class FSQ(nn.Module): - def __init__(self, levels=7): - super().__init__() - - self.levels = levels - self.scale = (levels - 1) / 2 - - def forward(self, z): - - # 限制范围 - z = torch.tanh(z) - - # 量化 - z_q = torch.round(z * self.scale) / self.scale - - # Straight-through estimator - z_q = z + (z_q - z).detach() - - return z_q diff --git a/ginka/transformer/vae.py b/ginka/transformer/vae.py deleted file mode 100644 index 89abb7a..0000000 --- a/ginka/transformer/vae.py +++ /dev/null @@ -1,54 +0,0 @@ -import time -import torch -import torch.nn as nn -import torch.nn.functional as F -from .encoder import GinkaTransformerVAEEncoder -from .decoder import GinkaTransformerVAEDecoder -from ..utils import print_memory - -class GinkaTransformerVAE(nn.Module): - def __init__(self, num_classes=32, latent_dim=32): - super().__init__() - self.encoder = GinkaTransformerVAEEncoder(num_classes=num_classes, latent_dim=latent_dim) - self.decoder = GinkaTransformerVAEDecoder(latent_dim=latent_dim) - - def reparameterize(self, mu, logvar): - std = torch.exp(0.5 * logvar) - eps = torch.randn_like(std) - return mu + eps * std - - def autoregressive(self): - self.decoder.decoder.autoregressive = True - - def teacher_forcing(self): - self.decoder.decoder.autoregressive = False - - def forward(self, target_map: torch.Tensor, use_self_probility=0): - # target_map: [B, H * W] - mu, logvar = self.encoder(target_map) # [B, latent_dim] - z = self.reparameterize(mu, logvar) - logits = self.decoder(z, target_map) # [B, H * W, num_classes] | [B, H * W] - return logits, mu, logvar - -if __name__ == "__main__": - device = torch.device("cpu") - - input = torch.randint(0, 32, [1, 169]).to(device) - - # 初始化模型 - model = GinkaTransformerVAE().to(device) - - print_memory("初始化后") - - # 前向传播 - start = time.perf_counter() - logits, mu, logvar = model(input) - end = time.perf_counter() - - print_memory("前向传播后") - - print(f"推理耗时: {end - start}") - print(f"输出形状: logits= {logits.shape}, mu={mu.shape}, logvar={logvar.shape}") - print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}") - print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}") - print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/vae_rnn/decoder.py b/ginka/vae_rnn/decoder.py deleted file mode 100644 index 09d355e..0000000 --- a/ginka/vae_rnn/decoder.py +++ /dev/null @@ -1,264 +0,0 @@ -import time -import random -import torch -import torch.nn as nn -import torch.nn.functional as F -from ..utils import print_memory - -class DecoderMapPatch(nn.Module): - def __init__(self, tile_classes=32, width=13, height=13): - super().__init__() - - # 地图局部卷积,用于捕获局部结构信息 - - self.width = width - self.height = height - self.tile_classes = 32 - - self.patch_cnn = nn.Sequential( - nn.Conv2d(tile_classes + 1, 64, 3, padding=1), - nn.Dropout(0.2), - nn.BatchNorm2d(64), - nn.GELU(), - - nn.Conv2d(64, 128, 3), - nn.Dropout(0.2), - nn.BatchNorm2d(128), - nn.GELU(), - - nn.Flatten() - ) - self.fc = nn.Linear(128 * 3 * 3, 256) - - def forward(self, map: torch.Tensor, x: int, y: int): - """ - map: [B, H, W] - """ - B, H, W = map.shape - mask = torch.zeros([B, 5, 5]).to(map.device) - result = torch.zeros([B, 5, 5], dtype=torch.long).to(map.device) - left = x - 2 if x >= 2 else 0 - right = x + 3 if x < self.width - 2 else self.width - top = y - 4 if y >= 4 else 0 - bottom = y + 1 - - res_left = left - (x - 2) - res_right = right - (x + 3) + 5 - res_top = top - (y - 4) - res_bottom = 5 - - result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right] - # 没画到的地方要置为 0 - result[:, 4, 2] = 0 - result[:, 4, 3] = 0 - result[:, 4, 4] = 0 - mask[:, res_top:res_bottom, res_left:res_right] = 1 - mask[:, 4, 2] = 0 - mask[:, 4, 3] = 0 - mask[:, 4, 4] = 0 - masked_result = torch.zeros([B, self.tile_classes + 1, 5, 5]).to(map.device) - masked_result[:, 0:32] = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float() - masked_result[:, 32] = mask - - feat = self.patch_cnn(masked_result) - feat = self.fc(feat) - return feat - -class DecoderTileEmbedding(nn.Module): - def __init__(self, tile_classes=32, embed_dim=256): - super().__init__() - # 图块编码,上一次画的图块 - self.embedding = nn.Embedding(tile_classes, embed_dim) - - def forward(self, tile: torch.Tensor): - return self.embedding(tile) - -class DecoderPosEmbedding(nn.Module): - def __init__(self, width=13, height=13, embed_dim=256): - super().__init__() - - # 位置编码 - - self.width = width - self.height = height - - self.row_embedding = nn.Embedding(height, embed_dim) - self.col_embedding = nn.Embedding(width, embed_dim) - - def forward(self, x: torch.Tensor, y: torch.Tensor): - row = self.row_embedding(y) - col = self.col_embedding(x) - return row, col - -class DecoderInputFusion(nn.Module): - def __init__(self, d_model=256): - super().__init__() - - # 使用 Transformer 进行信息整合 - - self.transformer = nn.TransformerEncoder( - nn.TransformerEncoderLayer( - d_model=d_model, nhead=2, dim_feedforward=d_model, batch_first=True, - dropout=0.2 - ), - num_layers=2 - ) - self.norm = nn.LayerNorm(d_model) - self.fusion = nn.Sequential( - nn.Linear(d_model * 2, d_model * 2), - nn.Dropout(0.2), - nn.LayerNorm(d_model * 2), - nn.GELU(), - - nn.Linear(d_model * 2, d_model), - nn.Dropout(0.1), - nn.LayerNorm(d_model), - nn.GELU() - ) - - def forward( - self, tile_embed: torch.Tensor, cond_vec: torch.Tensor, - col_embed: torch.Tensor, row_embed: torch.Tensor, patch_vec: torch.Tensor - ): - """ - tile_embed: [B, 256] - cond_vec: [B, 256] - col_embed: [B, 256] - row_embed: [B, 256] - patch_vec: [B, 256] - """ - vec = torch.stack([tile_embed, cond_vec, col_embed, row_embed, patch_vec], dim=1) - feat = self.norm(self.transformer(vec)) - mean = torch.mean(feat, dim=1) - max = torch.max(feat, dim=1).values - hidden = torch.cat([mean, max], dim=1) - fused = self.fusion(hidden) - return fused - -class DecoderRNN(nn.Module): - def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512): - super().__init__() - - # GRU - self.gru = nn.GRUCell(input_dim, hidden_dim) - self.drop = nn.Dropout(0.2) - self.fc = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim), - nn.Dropout(0.1), - nn.LayerNorm(hidden_dim), - nn.GELU(), - - nn.Linear(hidden_dim, tile_classes) - ) - - def forward(self, feat_fusion: torch.Tensor, hidden: torch.Tensor): - """ - feat_fusion: [B, input_dim] - hidden: [B, hidden_dim] - """ - hidden = self.drop(self.gru(feat_fusion, hidden)) - logits = self.fc(hidden) - return logits, hidden - -class VAEDecoder(nn.Module): - def __init__(self, device: torch.device, start_tile=31, map_vec_dim=32, width=13, height=13): - super().__init__() - - self.device = device - self.width = width - self.height = height - self.start_tile = start_tile - - self.rnn_hidden = 512 - self.tile_classes = 32 - - # 模型结构 - self.map_vec_fc = nn.Sequential( - nn.Linear(map_vec_dim, 128), - nn.Dropout(0.1), - nn.LayerNorm(128), - nn.GELU(), - - nn.Linear(128, 256) - ) - self.tile_embedding = DecoderTileEmbedding(tile_classes=self.tile_classes) - self.pos_embedding = DecoderPosEmbedding() - self.map_patch = DecoderMapPatch(tile_classes=self.tile_classes) - self.feat_fusion = DecoderInputFusion() - self.rnn = DecoderRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden) - - self.col_list = [] - self.row_list = [] - for y in range(0, height): - for x in range(0, width): - self.col_list.append(x) - self.row_list.append(y) - - def forward(self, map_vec: torch.Tensor, target_map: torch.Tensor, use_self_probility=0): - """ - map_vec: [B, vec_dim] - target_map: [B, H, W] - use_self_probility: 使用自己生成的上一步结果执行下一步的概率 - """ - B, C = map_vec.shape - - # 张量声明 - now_tile = torch.LongTensor([self.start_tile]).to(self.device).expand(B) - - map = torch.zeros([B, self.height, self.width], dtype=torch.int32).to(self.device) - output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device) - hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device) - - col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1) - row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1) - col_embed, row_embed = self.pos_embedding(col_list, row_list) - - map_vec = self.map_vec_fc(map_vec) - - for y in range(0, self.height): - for x in range(0, self.width): - idx = y * self.width + x - # 图块编码、地图局部编码 - tile_embed = self.tile_embedding(now_tile) - use_self = random.random() < use_self_probility - map_patch = self.map_patch(map if use_self else target_map, x, y) - # 编码特征融合 - feat = self.feat_fusion(tile_embed, map_vec, col_embed[:, idx], row_embed[:, idx], map_patch) - # RNN 输出 - logits, h = self.rnn(feat, hidden) - # 处理输出 - output_logits[:, y, x] = logits[:] - hidden = h - tile_id = torch.argmax(logits, dim=1).detach() - map[:, y, x] = tile_id[:] - now_tile = tile_id if use_self else target_map[:, y, x].detach() - - return output_logits.permute(0, 3, 1, 2) - -if __name__ == "__main__": - device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") - - input = torch.randint(0, 32, [1, 13, 13]).to(device) - map_vec = torch.rand(1, 32).to(device) - - # 初始化模型 - model = VAEDecoder(device).to(device) - - print_memory(device, "初始化后") - - # 前向传播 - start = time.perf_counter() - fake_logits = model(map_vec, input, 0) - end = time.perf_counter() - - print_memory(device, "前向传播后") - - print(f"推理耗时: {end - start}") - print(f"输出形状: fake_logits={fake_logits.shape}") - print(f"Map Vector FC parameters: {sum(p.numel() for p in model.map_vec_fc.parameters())}") - print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}") - print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}") - print(f"Map Patch parameters: {sum(p.numel() for p in model.map_patch.parameters())}") - print(f"Feature Fusion parameters: {sum(p.numel() for p in model.feat_fusion.parameters())}") - print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}") - print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/vae_rnn/encoder.py b/ginka/vae_rnn/encoder.py deleted file mode 100644 index de8717e..0000000 --- a/ginka/vae_rnn/encoder.py +++ /dev/null @@ -1,161 +0,0 @@ -import time -import torch -import torch.nn as nn -import torch.nn.functional as F -from ..utils import print_memory - -class EncoderEmbedding(nn.Module): - def __init__(self, tile_classes=32, width=13, height=13, hidden_dim=128, output_dim=256): - super().__init__() - self.tile_embedding = nn.Sequential( - nn.Embedding(tile_classes, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.GELU() - ) - self.col_embedding = nn.Sequential( - nn.Embedding(width, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.GELU() - ) - self.row_embedding = nn.Sequential( - nn.Embedding(height, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.GELU() - ) - self.fusion = nn.Linear(hidden_dim * 3, output_dim) - - def forward(self, tile, x, y): - tile_embed = self.tile_embedding(tile) - col_embed = self.col_embedding(x) - row_embed = self.row_embedding(y) - embed = torch.cat([tile_embed, col_embed, row_embed], dim=2) - fused = self.fusion(embed) - return fused - -class EncoderGRU(nn.Module): - def __init__(self, input_dim=256, hidden_dim=512, output_dim=256): - super().__init__() - - # GRU - self.gru = nn.GRUCell(input_dim, hidden_dim) - self.drop = nn.Dropout(0.2) - self.fc = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim), - nn.Dropout(0.1), - nn.LayerNorm(hidden_dim), - nn.GELU(), - - nn.Linear(hidden_dim, output_dim) - ) - - def forward(self, feat: torch.Tensor, hidden: torch.Tensor): - """ - feat: [B, input_dim] - hidden: [B, hidden_dim] - """ - hidden = self.drop(self.gru(feat, hidden)) - logits = self.fc(hidden) - return logits, hidden - -class EncoderFusion(nn.Module): - def __init__(self, d_model=256): - super().__init__() - - self.transformer = nn.TransformerEncoder( - nn.TransformerEncoderLayer( - d_model=d_model, dim_feedforward=d_model*2, nhead=2, batch_first=True - ), - num_layers=3 - ) - self.norm = nn.LayerNorm(d_model) - self.fc = nn.Sequential( - nn.Linear(d_model * 2, d_model * 2), - nn.Dropout(0.1), - nn.LayerNorm(d_model * 2), - nn.GELU() - ) - - def forward(self, logits): - x = self.norm(self.transformer(logits)) - h_mean = torch.mean(x, dim=1) - h_max = torch.max(x, dim=1).values - h = torch.cat([h_mean, h_max], dim=1) - return self.fc(h) - -class VAEEncoder(nn.Module): - def __init__(self, device, tile_classes=32, latent_dim=32, width=13, height=13): - super().__init__() - self.device = device - - self.rnn_hidden = 512 - self.logits_dim = 256 - - self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256) - self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim) - self.fusion = EncoderFusion(256) - self.fc_mu = nn.Sequential( - nn.Linear(512, 512), - nn.Dropout(0.1), - nn.LayerNorm(512), - nn.GELU(), - nn.Linear(512, latent_dim) - ) - self.fc_logvar = nn.Sequential( - nn.Linear(512, 512), - nn.Dropout(0.1), - nn.LayerNorm(512), - nn.GELU(), - nn.Linear(512, latent_dim) - ) - - self.col_list = [] - self.row_list = [] - for y in range(0, height): - for x in range(0, width): - self.col_list.append(x) - self.row_list.append(y) - - def forward(self, x: torch.Tensor): - B, H, W = x.shape - - map = torch.flatten(x, start_dim=1) - hidden = torch.zeros(B, self.rnn_hidden).to(self.device) - output = torch.zeros(B, H * W, self.logits_dim).to(self.device) - - col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1) - row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1) - embed = self.embedding(map, col_list, row_list) - - for idx in range(0, len(self.col_list)): - logits, h = self.rnn(embed[:, idx], hidden) - hidden = h - output[:, idx] = logits - - h = self.fusion(output) - mu = self.fc_mu(h) - logvar = self.fc_logvar(h) - return mu, logvar - -if __name__ == "__main__": - device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") - - input = torch.randint(0, 32, [1, 13, 13]).to(device) - - # 初始化模型 - model = VAEEncoder(device).to(device) - - print_memory(device, "初始化后") - - # 前向传播 - start = time.perf_counter() - mu, logvar = model(input) - end = time.perf_counter() - - print_memory(device, "前向传播后") - - print(f"推理耗时: {end - start}") - print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}") - print(f"Embedding parameters: {sum(p.numel() for p in model.embedding.parameters())}") - print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}") - print(f"Fusion parameters: {sum(p.numel() for p in model.fusion.parameters())}") - print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/vae_rnn/loss.py b/ginka/vae_rnn/loss.py deleted file mode 100644 index 64ee0ad..0000000 --- a/ginka/vae_rnn/loss.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch -import torch.nn.functional as F - -class VAELoss: - def __init__(self): - self.num_classes = 32 - - def vae_loss(self, logits, target, mu, logvar, beta=0.1): - # logits: [B, 169, 16] - # target: [B, 169] - B, L = target.shape - end_token = torch.tensor([15], dtype=torch.long).to(logits.device).repeat(B, 1) - target = torch.cat([target, end_token], dim=1) - recon_loss = F.cross_entropy(logits.permute(0, 2, 1), target) - - kl_loss = -0.5 * torch.mean( - 1 + logvar - mu.pow(2) - logvar.exp() - ) - - return recon_loss + beta * kl_loss, recon_loss, kl_loss diff --git a/ginka/vae_rnn/scheduler.py b/ginka/vae_rnn/scheduler.py deleted file mode 100644 index b5c40df..0000000 --- a/ginka/vae_rnn/scheduler.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch - -class VAEScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau): - def __init__( - self, optimizer, mode="min", factor=0.1, patience=10, threshold=0.0001, - threshold_mode="rel", cooldown=0, min_lr=0, eps=1e-8, verbose="deprecated", - max_lr=1e-2, increase_factor=2, start_prob=0 - ): - super().__init__( - optimizer, mode, factor, patience, threshold, - threshold_mode, cooldown, min_lr, eps, verbose - ) - self.max_lr = max_lr - self.increase_factor = increase_factor - self.last_prob = start_prob - - if isinstance(max_lr, (list, tuple)): - if len(max_lr) != len(optimizer.param_groups): - raise ValueError( - f"expected {len(optimizer.param_groups)} max_lrs, got {len(max_lr)}" - ) - self.default_max_lr = None - self.max_lrs = list(max_lr) - else: - self.default_max_lr = max_lr - self.max_lrs = [max_lr] * len(optimizer.param_groups) - - def step(self, metrics, prob: float, epoch=None): - if prob > self.last_prob: - self.best = metrics - self.num_bad_epochs = 0 - self.last_prob = prob - self._increase_lr() - self._last_lr = [group["lr"] for group in self.optimizer.param_groups] - else: - return super().step(metrics, epoch) - - def _increase_lr(self): - for i, param_group in enumerate(self.optimizer.param_groups): - old_lr = float(param_group["lr"]) - new_lr = min(old_lr * self.increase_factor, self.max_lrs[i]) - if new_lr - old_lr > self.eps: - param_group["lr"] = new_lr diff --git a/ginka/vae_rnn/vae.py b/ginka/vae_rnn/vae.py deleted file mode 100644 index d6550b9..0000000 --- a/ginka/vae_rnn/vae.py +++ /dev/null @@ -1,47 +0,0 @@ -import time -import torch -import torch.nn as nn -import torch.nn.functional as F -from .encoder import VAEEncoder -from .decoder import VAEDecoder -from ..utils import print_memory - -class GinkaVAE(nn.Module): - def __init__(self, device, tile_classes=32, latent_dim=32): - super().__init__() - self.encoder = VAEEncoder(device, tile_classes, latent_dim) - self.decoder = VAEDecoder(device, map_vec_dim=latent_dim) - - def reparameterize(self, mu, logvar): - std = torch.exp(0.5 * logvar) - eps = torch.randn_like(std) - return mu + eps * std - - def forward(self, target_map: torch.Tensor, use_self_probility=0): - mu, logvar = self.encoder(target_map) - z = self.reparameterize(mu, logvar) - logits = self.decoder(z, target_map, use_self_probility) - return logits, mu, logvar - -if __name__ == "__main__": - device = torch.device("cpu") - - input = torch.randint(0, 32, [1, 13, 13]).to(device) - - # 初始化模型 - model = GinkaVAE(device).to(device) - - print_memory("初始化后") - - # 前向传播 - start = time.perf_counter() - logits, mu, logvar = model(input) - end = time.perf_counter() - - print_memory("前向传播后") - - print(f"推理耗时: {end - start}") - print(f"输出形状: logits= {logits.shape}, mu={mu.shape}, logvar={logvar.shape}") - print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}") - print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}") - print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/validate.py b/ginka/validate.py deleted file mode 100644 index 2e61bd6..0000000 --- a/ginka/validate.py +++ /dev/null @@ -1,85 +0,0 @@ -import os -import json -import cv2 -import numpy as np -import torch -import torch.nn.functional as F -from torch_geometric.loader import DataLoader -from tqdm import tqdm -from .critic.model import MinamoModel -from .dataset import GinkaDataset -from .generator.loss import GinkaLoss -from .generator.model import GinkaModel -from shared.image import matrix_to_image_cv - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -os.makedirs('result/ginka_img', exist_ok=True) - -def validate(): - print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.") - model = GinkaModel() - state = torch.load("result/ginka.pth", map_location=device)["model_state"] - model.load_state_dict(state) - model.to(device) - - for name, param in model.named_parameters(): - print(f"Layer: {name}, Params: {param.numel()}") - total_params = sum(p.numel() for p in model.parameters()) - print(f"Total parameters: {total_params}") - - minamo = MinamoModel(32) - minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) - minamo.to(device) - - # 准备数据集 - val_dataset = GinkaDataset("ginka-eval.json", device, minamo) - val_loader = DataLoader( - val_dataset, - batch_size=32, - shuffle=True - ) - - criterion = GinkaLoss(minamo) - - tile_dict = dict() - val_output = dict() - - for file in os.listdir('tiles'): - name = os.path.splitext(file)[0] - tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) - - minamo.eval() - model.eval() - val_loss = 0 - idx = 0 - with torch.no_grad(): - for batch in tqdm(val_loader): - # 数据迁移到设备 - target = batch["target"].to(device) - target_vision_feat = batch["target_vision_feat"].to(device) - target_topo_feat = batch["target_topo_feat"].to(device) - feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1) - # 前向传播 - output, output_softmax = model(feat_vec) - map_matrix = torch.argmax(output, dim=1) - - for matrix in map_matrix[:].cpu(): - image = matrix_to_image_cv(matrix.numpy(), tile_dict) - cv2.imwrite(f"result/ginka_img/{idx}.png", image) - val_output[f"val_{idx}"] = matrix.tolist() - idx += 1 - - # 计算损失 - _, loss = criterion(output_softmax, target, target_vision_feat, target_topo_feat) - val_loss += loss.item() - - avg_val_loss = val_loss / len(val_loader) - tqdm.write(f"Validation::loss: {avg_val_loss:.6f}") - - with open('result/ginka_val.json', 'w') as f: - json.dump(val_output, f) - -if __name__ == "__main__": - torch.set_num_threads(2) - validate() - \ No newline at end of file diff --git a/shared/args.py b/shared/args.py deleted file mode 100644 index 3fd8f20..0000000 --- a/shared/args.py +++ /dev/null @@ -1,12 +0,0 @@ -import argparse - -def parse_arguments(from_default: str, train_default: str, val_default: str): - parser = argparse.ArgumentParser(description="training codes") - parser.add_argument("--resume", type=bool, default=False) - parser.add_argument("--from_state", type=str, default=from_default) - parser.add_argument("--load_optim", type=bool, default=False) - parser.add_argument("--train", type=str, default=train_default) - parser.add_argument("--validate", type=str, default=val_default) - parser.add_argument("--epochs", type=int, default=150) - args = parser.parse_args() - return args diff --git a/shared/attention.py b/shared/attention.py deleted file mode 100644 index c4e7c14..0000000 --- a/shared/attention.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch -import torch.nn as nn - -class ChannelAttention(nn.Module): - """通道注意力模块""" - def __init__(self, channels, reduction=8): - super().__init__() - # 通道注意力 - self.channel_att = nn.Sequential( - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(channels, channels//reduction, 1), - nn.ELU(), - nn.Conv2d(channels//reduction, channels, 1), - nn.Sigmoid() - ) - - def forward(self, x): - # 通道注意力 - c_att = self.channel_att(x) - x = x * c_att - return x - -class SpatialAttention(nn.Module): - """空间注意力模块""" - def __init__(self): - super().__init__() - # 空间注意力 - self.spatial_att = nn.Sequential( - nn.Conv2d(2, 1, 7, padding=3), - nn.Sigmoid() - ) - - def forward(self, x): - # 空间注意力 - max_pool = torch.max(x, dim=1, keepdim=True)[0] - avg_pool = torch.mean(x, dim=1, keepdim=True) - s_att = self.spatial_att(torch.cat([max_pool, avg_pool], dim=1)) - return x * s_att - -class CBAM(nn.Module): - """通道与空间注意力结合""" - def __init__(self, channels, reduction=8): - super().__init__() - # 通道注意力 - self.channel_att = ChannelAttention(channels, reduction) - # 空间注意力 - self.spatial_att = SpatialAttention() - - def forward(self, x): - # 通道注意力 - c_att = self.channel_att(x) - x = x * c_att - - # 空间注意力 - s_att = self.spatial_att(x) - return x * s_att - -class SEBlock(nn.Module): - def __init__(self, channel, reduction=4): - super().__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc = nn.Sequential( - nn.Linear(channel, channel // reduction), - nn.GELU(), - nn.Linear(channel // reduction, channel), - nn.Sigmoid() - ) - - def forward(self, x): - b, c, _, _ = x.size() - y = self.avg_pool(x).view(b, c) - y = self.fc(y).view(b, c, 1, 1) - return x * y \ No newline at end of file diff --git a/shared/constant.py b/shared/constant.py deleted file mode 100644 index d194a6c..0000000 --- a/shared/constant.py +++ /dev/null @@ -1,6 +0,0 @@ -VIS_DIM = 512 -TOPO_DIM = 512 -FEAT_DIM = 1024 - -VISION_WEIGHT = 0.2 -TOPO_WEIGHT = 0.8 diff --git a/shared/graph.py b/shared/graph.py deleted file mode 100644 index 1eb0aa6..0000000 --- a/shared/graph.py +++ /dev/null @@ -1,66 +0,0 @@ -import torch -from torch_geometric.data import Data, Batch -from torch_geometric.utils import add_self_loops, grid - -def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data: - """ - 可导的图结构转换(返回 PyG Data 对象) - map_probs: [C, H, W] - 返回: - Data(x=[N, C], edge_index=[2, E], edge_attr=[E, C]) - """ - C, H, W = map_probs.shape - device = map_probs.device - N = H * W - - # 1. 节点特征 - node_features = map_probs.view(C, H * W).T # [N, C] - edge_index, _ = grid(H, W) - edge_index = edge_index.to(device) - - # 2. 构建所有可能的边连接 - # node_indices = torch.arange(N, device=device).view(H, W) - - # # 水平连接(右邻居) - # right_src = node_indices[:, :-1].flatten() - # right_dst = node_indices[:, 1:].flatten() - - # # 垂直连接(下邻居) - # down_src = node_indices[:-1, :].flatten() - # down_dst = node_indices[1:, :].flatten() - - # # 合并边列表(双向) - # edge_src = torch.cat([right_src, down_src]) - # edge_dst = torch.cat([right_dst, down_dst]) - # edge_index = torch.cat([ - # torch.stack([edge_src, edge_dst], dim=0), - # torch.stack([edge_dst, edge_src], dim=0) # 反向连接 - # ], dim=1).to(device, dtype=torch.long) - - # # 3. 计算边特征 - # src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C] - # dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C] - # edge_attr = (src_feat + dst_feat) / 2 # [E, C] - - # edge_index, edge_attr = add_self_loops(edge_index, edge_attr) - - return Data( - x=node_features, - edge_index=edge_index, - # edge_attr=edge_attr, - num_nodes=N - ) - -def batch_convert_soft_map_to_graph(batch_map_probs): - """ - 处理 batch 维度,将 [B, C, H, W] 转换为批量图结构 Batch - """ - B, C, H, W = batch_map_probs.shape # 获取 batch 维度 - batch_graphs = [] - - for i in range(B): - graph = differentiable_convert_to_data(batch_map_probs[i]) # 处理单个样本 - batch_graphs.append(graph) - - # 合并所有图为批量 Batch - return Batch.from_data_list(batch_graphs) diff --git a/shared/similarity/topo.py b/shared/similarity/topo.py deleted file mode 100644 index 03e97ef..0000000 --- a/shared/similarity/topo.py +++ /dev/null @@ -1,273 +0,0 @@ -# Converted Python version of the JS code -import math -from typing import Dict, Set, List, Tuple, Union -from collections import deque, defaultdict - -# 拓扑相似度,由 ChatGPT-4o 从 ts 转译而来 - -class ResourceArea: - def __init__(self): - self.type = 'resource' - self.resources: Dict[int, int] = {} - self.members: Set[int] = set() - self.neighbor: Set[int] = set() - -class BranchNode: - def __init__(self, tile: int): - self.type = 'branch' - self.tile = tile - self.neighbor: Set[int] = set() - -class ResourceNode: - def __init__(self, resource_type: int, area: ResourceArea): - self.type = 'resource' - self.resourceType = resource_type - self.neighbor = area.neighbor - self.resourceArea = area - -GinkaNode = Union[BranchNode, ResourceNode] - -class GinkaGraph: - def __init__(self): - self.graph: Dict[int, GinkaNode] = {} - self.resourceMap: Dict[int, int] = {} - self.areaMap: List[ResourceArea] = [] - self.visitedEntrance: Set[int] = set() - self.visited: Set[int] = set() - -class GinkaTopologicalGraphs: - def __init__(self): - self.graphs: List[GinkaGraph] = [] - self.entranceMap: Dict[int, GinkaGraph] = {} - self.unreachable: Set[int] = set() - -TILE_TYPE = set(range(13)) -BRANCH_TYPE = {6, 7, 8, 9} -ENTRANCE_TYPE = {10, 11} -RESOURCE_TYPE = {0, 2, 3, 4, 5, 10, 11, 12, 13} - -directions: List[Tuple[int, int]] = [ - (-1, 0), (1, 0), (0, -1), (0, 1) -] - -def find_resource_nodes(map_: List[List[int]]): - width, height = len(map_[0]), len(map_) - visited = set() - areas = [] - resource_map = {} - - for ny in range(height): - for nx in range(width): - tile = map_[ny][nx] - index = ny * width + nx - if index in visited or tile not in RESOURCE_TYPE: - continue - queue = deque([(nx, ny)]) - area = ResourceArea() - area.resources[tile] = 1 - area.members.add(index) - while queue: - cx, cy = queue.popleft() - cindex = cy * width + cx - if cindex in visited: - continue - ctile = map_[cy][cx] - if ctile not in RESOURCE_TYPE: - continue - visited.add(cindex) - area.resources[ctile] = area.resources.get(ctile, 0) + 1 - area.members.add(cindex) - resource_map[cindex] = len(areas) - for dx, dy in directions: - px, py = cx + dx, cy + dy - if 0 <= px < width and 0 <= py < height: - queue.append((px, py)) - areas.append(area) - return areas, resource_map - -def build_graph_from_entrance(map_: List[List[int]], entrance: int, resource_map: Dict[int, int], area_map: List[ResourceArea]) -> GinkaGraph: - width, height = len(map_[0]), len(map_) - graph = GinkaGraph() - graph.resourceMap = resource_map - graph.areaMap = area_map - - visited = graph.visited - visited_entrance = graph.visitedEntrance - visited_entrance.add(entrance) - - branch_nodes = set() - queue = deque([(entrance % width, entrance // width)]) - - while queue: - x, y = queue.popleft() - index = y * width + x - if index in visited: - continue - tile = map_[y][x] - if tile in ENTRANCE_TYPE: - visited_entrance.add(index) - if tile in BRANCH_TYPE: - branch_nodes.add(index) - visited.add(index) - for dx, dy in directions: - px, py = x + dx, y + dy - if 0 <= px < width and 0 <= py < height and map_[py][px] != 1: - queue.append((px, py)) - - for v in branch_nodes: - x, y = v % width, v // width - if v not in graph.graph: - graph.graph[v] = BranchNode(map_[y][x]) - node = graph.graph[v] - for dx, dy in directions: - px, py = x + dx, y + dy - if 0 <= px < width and 0 <= py < height: - index = py * width + px - if index in branch_nodes: - node.neighbor.add(index) - elif index in resource_map: - area = area_map[resource_map[index]] - area.neighbor.add(v) - for m in area.members: - node.neighbor.add(m) - - for area in area_map: - for index in area.members: - x, y = index % width, index // width - tile = map_[y][x] - if tile == 0: - continue - node = ResourceNode(tile, area) - graph.graph[index] = node - - return graph - -def build_topological_graph(map_: List[List[int]]) -> GinkaTopologicalGraphs: - width, height = len(map_[0]), len(map_) - entrances = set() - entrances = {y * width + x for y in range(height) for x in range(width) if map_[y][x] in ENTRANCE_TYPE} - area_map, resource_map = find_resource_nodes(map_) - - top_graph = GinkaTopologicalGraphs() - used_entrance = set() - total_visited = set() - - for entrance in entrances: - if entrance in used_entrance: - continue - graph = build_graph_from_entrance(map_, entrance, resource_map, area_map) - top_graph.graphs.append(graph) - for ent in graph.visitedEntrance: - used_entrance.add(ent) - top_graph.entranceMap[ent] = graph - total_visited.update(graph.visited) - - for y in range(height): - for x in range(width): - index = y * width + x - if index not in total_visited and map_[y][x] != 1: - top_graph.unreachable.add(index) - - return top_graph - -class WLNode: - def __init__(self, pos: int, label: str): - self.originalPos = pos - self.originalLabel = label - self.currentLabel = label - self.neighbors: List['WLNode'] = [] - -def encode_node_labels(graph: GinkaGraph) -> List[WLNode]: - node_map = {} - nodes = [] - for pos, node in graph.graph.items(): - label = f"B:{node.tile}" if node.type == 'branch' else f"R:{node.resourceType}" - wl_node = WLNode(pos, label) - node_map[pos] = wl_node - nodes.append(wl_node) - - for node in nodes: - g_node = graph.graph[node.originalPos] - for neighbor in g_node.neighbor: - if neighbor in node_map: - node.neighbors.append(node_map[neighbor]) - - return nodes - -def weisfeiler_lehman_iteration(nodes: List[WLNode], iterations: int, decay: float = 0.6) -> Dict[str, float]: - label_history = [] - for _ in range(iterations): - new_labels = [] - for node in nodes: - neighbor_labels = sorted(n.currentLabel for n in node.neighbors) - composite = f"{node.currentLabel}|{','.join(neighbor_labels)}"[:8192] - new_labels.append(composite) - for node, new_label in zip(nodes, new_labels): - node.currentLabel = new_label - label_history.append(new_labels[:]) - - weight = 1.0 - label_counts = defaultdict(float) - for layer in label_history: - for label in layer: - label_counts[label] += weight - weight *= decay - for node in nodes: - label_counts[node.originalLabel] += weight - return dict(label_counts) - -def vectorize_features(features: Dict[str, float], vocab: List[str]) -> List[float]: - vec = [0.0] * len(vocab) - for label, count in features.items(): - if label in vocab: - idx = vocab.index(label) - vec[idx] += count - return vec - -def cosine_similarity(a: List[float], b: List[float]) -> float: - dot = sum(x * y for x, y in zip(a, b)) - norm_a = math.sqrt(sum(x * x for x in a)) - norm_b = math.sqrt(sum(y * y for y in b)) - if norm_a == 0 or norm_b == 0: - return 0.0 - return dot / (norm_a * norm_b) - -def wl_kernel(graph_a: GinkaGraph, graph_b: GinkaGraph, iterations: int = 3) -> float: - nodes_a = encode_node_labels(graph_a) - nodes_b = encode_node_labels(graph_b) - features_a = weisfeiler_lehman_iteration(nodes_a, iterations) - features_b = weisfeiler_lehman_iteration(nodes_b, iterations) - vocab = list(set(features_a.keys()) | set(features_b.keys())) - vec_a = vectorize_features(features_a, vocab) - vec_b = vectorize_features(features_b, vocab) - return cosine_similarity(vec_a, vec_b) - -def overall_similarity(a: GinkaTopologicalGraphs, b: GinkaTopologicalGraphs) -> float: - graphs_a = a.graphs - graphs_b = b.graphs - - total_similarity = 0.0 - compared_graphs: Set[GinkaGraph] = set() - - for ga in graphs_a: - max_similarity = 0.0 - max_graph = None - for gb in graphs_b: - if gb in compared_graphs: - continue - min_nodes = min(len(ga.graph), len(gb.graph)) - iterations = max(1, math.ceil(math.log(min_nodes))) - similarity = wl_kernel(ga, gb, iterations) - if similarity > max_similarity and not math.isnan(similarity): - max_similarity = similarity - max_graph = gb - if similarity == 1: - break - total_similarity += max_similarity - if max_graph: - compared_graphs.add(max_graph) - - reduction = 1 / (1 + abs(len(a.unreachable) - len(b.unreachable))) - if not graphs_a: - return 0.0 - return math.sqrt(total_similarity / len(graphs_a)) * reduction diff --git a/shared/similarity/vision.py b/shared/similarity/vision.py deleted file mode 100644 index 09347d0..0000000 --- a/shared/similarity/vision.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import List, Dict -import math -import numpy as np - -# 视觉相似度,由 ChatGPT-4o 从 ts 转译而来 - -class VisualSimilarityConfig: - def __init__(self): - self.type_weights: Dict[int, float] = { - 0: 0.2, 1: 0.3, 2: 0.6, 3: 0.7, 4: 0.7, 5: 0.5, - 6: 0.4, 7: 0.5, 8: 0.6, 9: 0.6, 10: 0.4, 11: 0.4, 12: 0.7 - } - self.enable_visual_focus: bool = True - self.enable_density_awareness: bool = True - -def generate_focus_weights(rows: int, cols: int) -> List[List[float]]: - weights = [] - center_x = cols / 2 - center_y = rows / 2 - for i in range(rows): - row_weights = [] - for j in range(cols): - dx = (j - center_x) / cols - dy = (i - center_y) / rows - distance = math.sqrt(dx ** 2 + dy ** 2) - gaussian = math.exp(-(distance ** 2) / (2 * 0.3 ** 2)) - row_weights.append(1.0 + 0.6 * gaussian) - weights.append(row_weights) - return weights - -def calculate_density_impact(map1: List[List[int]], map2: List[List[int]], type_weights: Dict[int, float]) -> List[List[float]]: - rows, cols = len(map1), len(map1[0]) - density_map = [[0.0 for _ in range(cols)] for _ in range(rows)] - window_size = 3 - half_window = window_size // 2 - - for i in range(rows): - for j in range(cols): - density = 0 - for di in range(-half_window, half_window + 1): - for dj in range(-half_window, half_window + 1): - ni, nj = i + di, j + dj - if 0 <= ni < rows and 0 <= nj < cols: - weight1 = type_weights.get(map1[ni][nj], 0.5) - weight2 = type_weights.get(map2[ni][nj], 0.5) - density += (weight1 + weight2) / 2 - density_map[i][j] = 1.0 + 0.4 * (density / (window_size ** 2)) - return density_map - -def calculate_visual_similarity(map1: List[List[int]], map2: List[List[int]], config: VisualSimilarityConfig = None) -> float: - if config is None: - config = VisualSimilarityConfig() - - if len(map1) != len(map2) or len(map1[0]) != len(map2[0]): - return 0.0 - - rows, cols = len(map1), len(map1[0]) - total_score = 0.0 - max_possible_score = 0.0 - - focus_weights = generate_focus_weights(rows, cols) if config.enable_visual_focus else [[1.0 for _ in range(cols)] for _ in range(rows)] - density_map = calculate_density_impact(map1, map2, config.type_weights) if config.enable_density_awareness else [[1.0 for _ in range(cols)] for _ in range(rows)] - - for i in range(rows): - for j in range(cols): - type1 = map1[i][j] - type2 = map2[i][j] - base_weight = max(config.type_weights.get(type1, 0.5), config.type_weights.get(type2, 0.5)) - spatial_weight = focus_weights[i][j] * density_map[i][j] - type_score = 1.0 if type1 == type2 else 0.0 - - total_score += type_score * base_weight * spatial_weight - max_possible_score += base_weight * spatial_weight - - return total_score / max_possible_score if max_possible_score > 0 else 0.0 diff --git a/shared/utils.py b/shared/utils.py deleted file mode 100644 index 4c00e88..0000000 --- a/shared/utils.py +++ /dev/null @@ -1,17 +0,0 @@ -import torch - -def random_smooth_onehot(onehot_map, min_main=0.75, max_main=1.0, epsilon=0.25): - """ - 生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动 - """ - C, H, W = onehot_map.shape - # 生成主类别的随机概率 (min_main, max_main) - main_prob = torch.rand(H, W) * (max_main - min_main) + min_main - - # 计算剩余概率并随机分配到其他类别 - noise = torch.rand(C, H, W) * epsilon # 随机噪声 - noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon - - # 计算最终平滑 one-hot 结果 - smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise - return smooth_onehot \ No newline at end of file diff --git a/shared/visual.py b/shared/visual.py index a7233e5..0db6f1e 100644 --- a/shared/visual.py +++ b/shared/visual.py @@ -63,7 +63,7 @@ def convert_dataset_to_images( if __name__ == "__main__": convert_dataset_to_images( json_path="data/result.json", # 数据集文件 - tile_folder="tiles2", # 贴图文件夹 + tile_folder="tiles", # 贴图文件夹 output_folder="map_images", # 输出文件夹 tile_size=32 # tile 尺寸 ) \ No newline at end of file diff --git a/tiles/10.png b/tiles/10.png index 08409ab..d2eb533 100644 Binary files a/tiles/10.png and b/tiles/10.png differ diff --git a/tiles/11.png b/tiles/11.png deleted file mode 100644 index 378a738..0000000 Binary files a/tiles/11.png and /dev/null differ diff --git a/tiles/12.png b/tiles/12.png deleted file mode 100644 index c2330ec..0000000 Binary files a/tiles/12.png and /dev/null differ diff --git a/tiles/13.png b/tiles/13.png deleted file mode 100644 index 792ed88..0000000 Binary files a/tiles/13.png and /dev/null differ diff --git a/tiles/14.png b/tiles/14.png deleted file mode 100644 index a3a33e1..0000000 Binary files a/tiles/14.png and /dev/null differ diff --git a/tiles/15.png b/tiles/15.png index 56ecb3f..eb62785 100644 Binary files a/tiles/15.png and b/tiles/15.png differ diff --git a/tiles/16.png b/tiles/16.png deleted file mode 100644 index 4b8d3a6..0000000 Binary files a/tiles/16.png and /dev/null differ diff --git a/tiles/17.png b/tiles/17.png deleted file mode 100644 index 835cbde..0000000 Binary files a/tiles/17.png and /dev/null differ diff --git a/tiles/18.png b/tiles/18.png deleted file mode 100644 index b092cb4..0000000 Binary files a/tiles/18.png and /dev/null differ diff --git a/tiles/19.png b/tiles/19.png deleted file mode 100644 index b121323..0000000 Binary files a/tiles/19.png and /dev/null differ diff --git a/tiles/2.png b/tiles/2.png index 64e8e53..83de73a 100644 Binary files a/tiles/2.png and b/tiles/2.png differ diff --git a/tiles/20.png b/tiles/20.png deleted file mode 100644 index dd9a8c5..0000000 Binary files a/tiles/20.png and /dev/null differ diff --git a/tiles/21.png b/tiles/21.png deleted file mode 100644 index d9fe4bf..0000000 Binary files a/tiles/21.png and /dev/null differ diff --git a/tiles/22.png b/tiles/22.png deleted file mode 100644 index 62626ab..0000000 Binary files a/tiles/22.png and /dev/null differ diff --git a/tiles/23.png b/tiles/23.png deleted file mode 100644 index 38d7a35..0000000 Binary files a/tiles/23.png and /dev/null differ diff --git a/tiles/24.png b/tiles/24.png deleted file mode 100644 index 255a39b..0000000 Binary files a/tiles/24.png and /dev/null differ diff --git a/tiles/25.png b/tiles/25.png deleted file mode 100644 index aee20bf..0000000 Binary files a/tiles/25.png and /dev/null differ diff --git a/tiles/26.png b/tiles/26.png deleted file mode 100644 index 1329097..0000000 Binary files a/tiles/26.png and /dev/null differ diff --git a/tiles/27.png b/tiles/27.png deleted file mode 100644 index 2e564fd..0000000 Binary files a/tiles/27.png and /dev/null differ diff --git a/tiles/28.png b/tiles/28.png deleted file mode 100644 index 9404de6..0000000 Binary files a/tiles/28.png and /dev/null differ diff --git a/tiles/29.png b/tiles/29.png deleted file mode 100644 index d2eb533..0000000 Binary files a/tiles/29.png and /dev/null differ diff --git a/tiles/3.png b/tiles/3.png index 83de73a..339c1c3 100644 Binary files a/tiles/3.png and b/tiles/3.png differ diff --git a/tiles/30_1.png b/tiles/30_1.png deleted file mode 100644 index 49af222..0000000 Binary files a/tiles/30_1.png and /dev/null differ diff --git a/tiles/30_2.png b/tiles/30_2.png deleted file mode 100644 index 817f7e0..0000000 Binary files a/tiles/30_2.png and /dev/null differ diff --git a/tiles/30_3.png b/tiles/30_3.png deleted file mode 100644 index 37f1ec6..0000000 Binary files a/tiles/30_3.png and /dev/null differ diff --git a/tiles/30_4.png b/tiles/30_4.png deleted file mode 100644 index 753cbaa..0000000 Binary files a/tiles/30_4.png and /dev/null differ diff --git a/tiles/4.png b/tiles/4.png index bd23e07..08409ab 100644 Binary files a/tiles/4.png and b/tiles/4.png differ diff --git a/tiles/5.png b/tiles/5.png index 0556fb9..792ed88 100644 Binary files a/tiles/5.png and b/tiles/5.png differ diff --git a/tiles/6.png b/tiles/6.png index efce63b..4b8d3a6 100644 Binary files a/tiles/6.png and b/tiles/6.png differ diff --git a/tiles/7.png b/tiles/7.png index 339c1c3..b121323 100644 Binary files a/tiles/7.png and b/tiles/7.png differ diff --git a/tiles/8.png b/tiles/8.png index 9e5888f..38d7a35 100644 Binary files a/tiles/8.png and b/tiles/8.png differ diff --git a/tiles/9.png b/tiles/9.png index d106e37..1329097 100644 Binary files a/tiles/9.png and b/tiles/9.png differ diff --git a/tiles/999.png b/tiles/999.png deleted file mode 100644 index eb62785..0000000 Binary files a/tiles/999.png and /dev/null differ diff --git a/tiles2/0.png b/tiles2/0.png deleted file mode 100644 index 9649930..0000000 Binary files a/tiles2/0.png and /dev/null differ diff --git a/tiles2/1.png b/tiles2/1.png deleted file mode 100644 index f8e7142..0000000 Binary files a/tiles2/1.png and /dev/null differ diff --git a/tiles2/10.png b/tiles2/10.png deleted file mode 100644 index d2eb533..0000000 Binary files a/tiles2/10.png and /dev/null differ diff --git a/tiles2/15.png b/tiles2/15.png deleted file mode 100644 index eb62785..0000000 Binary files a/tiles2/15.png and /dev/null differ diff --git a/tiles2/2.png b/tiles2/2.png deleted file mode 100644 index 83de73a..0000000 Binary files a/tiles2/2.png and /dev/null differ diff --git a/tiles2/3.png b/tiles2/3.png deleted file mode 100644 index 339c1c3..0000000 Binary files a/tiles2/3.png and /dev/null differ diff --git a/tiles2/4.png b/tiles2/4.png deleted file mode 100644 index 08409ab..0000000 Binary files a/tiles2/4.png and /dev/null differ diff --git a/tiles2/5.png b/tiles2/5.png deleted file mode 100644 index 792ed88..0000000 Binary files a/tiles2/5.png and /dev/null differ diff --git a/tiles2/6.png b/tiles2/6.png deleted file mode 100644 index 4b8d3a6..0000000 Binary files a/tiles2/6.png and /dev/null differ diff --git a/tiles2/7.png b/tiles2/7.png deleted file mode 100644 index b121323..0000000 Binary files a/tiles2/7.png and /dev/null differ diff --git a/tiles2/8.png b/tiles2/8.png deleted file mode 100644 index 38d7a35..0000000 Binary files a/tiles2/8.png and /dev/null differ diff --git a/tiles2/9.png b/tiles2/9.png deleted file mode 100644 index 1329097..0000000 Binary files a/tiles2/9.png and /dev/null differ diff --git a/train.sh b/train.sh index fd37c05..fca6024 100644 --- a/train.sh +++ b/train.sh @@ -1,8 +1,3 @@ -# 从头训练 -python3 -u -m ginka.train_wgan --epochs 20 --curr_epoch 1 --checkpoint 1 >> output.log -# 接续训练 -python3 -u -m ginka.train_wgan --resume true --epochs 300 --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log - -# rnn -python3 -u -m ginka.train_rnn --epochs 150 --checkpoint 10 >> output_rnn.log -python3 -u -m ginka.train_rnn --resume true --epochs 150 --checkpoint 10 --state_ginka "result/rnn/ginka-100.pth" >> output_rnn.log +# MaskGIT +python3 -u -m ginka.train_maskGIT --epochs 150 --checkpoint 10 >> output_maskGIT.log +python3 -u -m ginka.train_maskGIT --resume true --epochs 150 --checkpoint 10 --state_ginka "result/transformer/ginka-100.pth" >> output_maskGIT.log