mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 20:41:12 +08:00
feat: 从 json 加载训练集
This commit is contained in:
parent
cd5df7a742
commit
280dd469e2
3
.gitignore
vendored
3
.gitignore
vendored
@ -4,4 +4,5 @@ node_modules
|
|||||||
ginka-dataset.json
|
ginka-dataset.json
|
||||||
ginka-eval.json
|
ginka-eval.json
|
||||||
minamo-dataset.json
|
minamo-dataset.json
|
||||||
minamo-eval.json
|
minamo-eval.json
|
||||||
|
datasets
|
||||||
@ -1,11 +1,12 @@
|
|||||||
import { writeFile } from 'fs-extra';
|
import { writeFile } from 'fs-extra';
|
||||||
import { FloorData, getAllFloors, parseTowerInfo } from './utils';
|
import { FloorData, readOne, getAllFloors, parseTowerInfo } from './utils';
|
||||||
import { compareMap } from './topology/compare';
|
import { compareMap } from './topology/compare';
|
||||||
import { mirrorMapX, mirrorMapY, rotateMap } from './topology/transform';
|
import { mirrorMapX, mirrorMapY, rotateMap } from './topology/transform';
|
||||||
import { directions, tileType } from './topology/graph';
|
import { directions, tileType } from './topology/graph';
|
||||||
import { calculateVisualSimilarity } from './vision/similarity';
|
import { calculateVisualSimilarity } from './vision/similarity';
|
||||||
import { BaseConfig } from './types';
|
import { BaseConfig } from './types';
|
||||||
import { Presets, SingleBar } from 'cli-progress';
|
import { Presets, SingleBar } from 'cli-progress';
|
||||||
|
import { log } from 'console';
|
||||||
|
|
||||||
interface MinamoConfig extends BaseConfig {}
|
interface MinamoConfig extends BaseConfig {}
|
||||||
|
|
||||||
@ -23,6 +24,9 @@ interface MinamoDataset {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const [output, ...list] = process.argv.slice(2);
|
const [output, ...list] = process.argv.slice(2);
|
||||||
|
// 判断 assigned 模式,此模式下只会对前两个塔处理,会在这两个塔之间对比,而单个塔的地图不会对比
|
||||||
|
const assigned = list.at(-1) === 'assigned';
|
||||||
|
if (assigned) list.pop();
|
||||||
|
|
||||||
function chooseFrom<T>(arr: T[], n: number): T[] {
|
function chooseFrom<T>(arr: T[], n: number): T[] {
|
||||||
const copy = arr.slice();
|
const copy = arr.slice();
|
||||||
@ -33,6 +37,15 @@ function chooseFrom<T>(arr: T[], n: number): T[] {
|
|||||||
return copy.slice(0, n);
|
return copy.slice(0, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function chooseN(maxCount: number, n: number) {
|
||||||
|
return chooseFrom(
|
||||||
|
Array(maxCount)
|
||||||
|
.fill(0)
|
||||||
|
.map((_, i) => i),
|
||||||
|
n
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
function choosePair(n: number, max: number = 1000) {
|
function choosePair(n: number, max: number = 1000) {
|
||||||
const totalCount = Math.round((n * (n - 1)) / 2);
|
const totalCount = Math.round((n * (n - 1)) / 2);
|
||||||
const count = Math.min(totalCount, max);
|
const count = Math.min(totalCount, max);
|
||||||
@ -204,6 +217,59 @@ function generateSimilarData(id: string, map: number[][]) {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function generatePair(
|
||||||
|
data: Record<string, MinamoTrainData>,
|
||||||
|
id1: string,
|
||||||
|
id2: string,
|
||||||
|
map1: number[][],
|
||||||
|
map2: number[][],
|
||||||
|
size: [number, number]
|
||||||
|
) {
|
||||||
|
const topoSimilarity = compareMap(id1, id2, map1, map2);
|
||||||
|
const visionSimilarity = calculateVisualSimilarity(map1, map2);
|
||||||
|
const train: MinamoTrainData = {
|
||||||
|
map1,
|
||||||
|
map2,
|
||||||
|
topoSimilarity,
|
||||||
|
visionSimilarity,
|
||||||
|
size: size
|
||||||
|
};
|
||||||
|
data[`${id1}:${id2}`] = train;
|
||||||
|
// 自身与自身对比的训练集,保证模型对相同地图输出 1
|
||||||
|
const self1 = `${id1}:${id1}`;
|
||||||
|
const self2 = `${id2}:${id2}`;
|
||||||
|
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 3));
|
||||||
|
if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) {
|
||||||
|
const selfTrain1: MinamoTrainData = {
|
||||||
|
map1: map1,
|
||||||
|
map2: map1,
|
||||||
|
topoSimilarity: 1,
|
||||||
|
visionSimilarity: 1,
|
||||||
|
size: size
|
||||||
|
};
|
||||||
|
data[`${id1}:${id1}`] = selfTrain1;
|
||||||
|
}
|
||||||
|
if (selfTrain.includes(self2) && !data[`${id2}:${id2}`]) {
|
||||||
|
const selfTrain2: MinamoTrainData = {
|
||||||
|
map1: map2,
|
||||||
|
map2: map2,
|
||||||
|
topoSimilarity: 1,
|
||||||
|
visionSimilarity: 1,
|
||||||
|
size: size
|
||||||
|
};
|
||||||
|
data[`${id2}:${id2}`] = selfTrain2;
|
||||||
|
}
|
||||||
|
// 翻转、旋转训练集
|
||||||
|
Object.assign(
|
||||||
|
data,
|
||||||
|
Object.fromEntries(
|
||||||
|
generateTransformData(id1, id2, map1, map2, topoSimilarity)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
// 地图微调训练集
|
||||||
|
Object.assign(data, Object.fromEntries(generateSimilarData(id1, map1)));
|
||||||
|
}
|
||||||
|
|
||||||
function generateDataset(
|
function generateDataset(
|
||||||
floors: Map<string, FloorData>,
|
floors: Map<string, FloorData>,
|
||||||
pairs: number[],
|
pairs: number[],
|
||||||
@ -226,53 +292,7 @@ function generateDataset(
|
|||||||
const [w1, h1] = [map1[0].length, map1.length];
|
const [w1, h1] = [map1[0].length, map1.length];
|
||||||
const [w2, h2] = [map2[0].length, map2.length];
|
const [w2, h2] = [map2[0].length, map2.length];
|
||||||
if (w1 !== w2 || h1 !== h2) return;
|
if (w1 !== w2 || h1 !== h2) return;
|
||||||
const topoSimilarity = compareMap(id1, id2, map1, map2);
|
generatePair(data, id1, id2, map1, map2, [w1, h1]);
|
||||||
const visionSimilarity = calculateVisualSimilarity(map1, map2);
|
|
||||||
const train: MinamoTrainData = {
|
|
||||||
map1,
|
|
||||||
map2,
|
|
||||||
topoSimilarity,
|
|
||||||
visionSimilarity,
|
|
||||||
size: [w1, h1]
|
|
||||||
};
|
|
||||||
data[`${id1}:${id2}`] = train;
|
|
||||||
// 自身与自身对比的训练集,保证模型对相同地图输出 1
|
|
||||||
const self1 = `${id1}:${id1}`;
|
|
||||||
const self2 = `${id2}:${id2}`;
|
|
||||||
const selfTrain = chooseFrom(
|
|
||||||
[self1, self2],
|
|
||||||
Math.floor(Math.random() * 3)
|
|
||||||
);
|
|
||||||
if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) {
|
|
||||||
const selfTrain1: MinamoTrainData = {
|
|
||||||
map1: map1,
|
|
||||||
map2: map1,
|
|
||||||
topoSimilarity: 1,
|
|
||||||
visionSimilarity: 1,
|
|
||||||
size: [w1, h1]
|
|
||||||
};
|
|
||||||
data[`${id1}:${id1}`] = selfTrain1;
|
|
||||||
}
|
|
||||||
if (selfTrain.includes(self2) && !data[`${id2}:${id2}`]) {
|
|
||||||
const selfTrain2: MinamoTrainData = {
|
|
||||||
map1: map2,
|
|
||||||
map2: map2,
|
|
||||||
topoSimilarity: 1,
|
|
||||||
visionSimilarity: 1,
|
|
||||||
size: [w1, h1]
|
|
||||||
};
|
|
||||||
data[`${id2}:${id2}`] = selfTrain2;
|
|
||||||
}
|
|
||||||
// 翻转、旋转训练集
|
|
||||||
Object.assign(
|
|
||||||
data,
|
|
||||||
Object.fromEntries(
|
|
||||||
generateTransformData(id1, id2, map1, map2, topoSimilarity)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
// 地图微调训练集
|
|
||||||
Object.assign(data, Object.fromEntries(generateSimilarData(id1, map1)));
|
|
||||||
// Object.assign(data, Object.fromEntries(generateSimilarData(id2, map2)));
|
|
||||||
progress.update(i + 1);
|
progress.update(i + 1);
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -301,13 +321,76 @@ function parseAllData(data: Map<string, FloorData>): MinamoDataset {
|
|||||||
return dataset;
|
return dataset;
|
||||||
}
|
}
|
||||||
|
|
||||||
(async () => {
|
function generateAssignedData(
|
||||||
const towers = await Promise.all(
|
data1: Map<string, FloorData>,
|
||||||
list.map(v => parseTowerInfo(v, 'minamo-config.json'))
|
data2: Map<string, FloorData>
|
||||||
|
): MinamoDataset {
|
||||||
|
const length = data1.size + data2.size;
|
||||||
|
const totalCount = data1.size * data2.size;
|
||||||
|
const count1 = Math.min(100, data1.size);
|
||||||
|
const count2 = Math.min(100, data2.size);
|
||||||
|
const keys1 = [...data1.keys()];
|
||||||
|
const keys2 = [...data2.keys()];
|
||||||
|
const choose1 = chooseFrom(keys1, count1);
|
||||||
|
const choose2 = chooseFrom(keys2, count2);
|
||||||
|
|
||||||
|
const trainData: Record<string, MinamoTrainData> = {};
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
`✅ 共发现 ${length} 个楼层,共 ${totalCount} 种组合,选取 ${
|
||||||
|
count1 * count2
|
||||||
|
} 个组合`
|
||||||
);
|
);
|
||||||
const floors = await getAllFloors(...towers);
|
|
||||||
const results = parseAllData(floors);
|
const progress = new SingleBar({}, Presets.shades_classic);
|
||||||
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
|
progress.start(count1 * count2, 0);
|
||||||
const size = Object.keys(results.data).length;
|
let n = 0;
|
||||||
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`);
|
|
||||||
|
for (const key1 of choose1) {
|
||||||
|
for (const key2 of choose2) {
|
||||||
|
const { map: map1 } = data1.get(key1)!;
|
||||||
|
const { map: map2 } = data2.get(key2)!;
|
||||||
|
if (!map1 || !map2) continue;
|
||||||
|
const [w1, h1] = [map1[0].length, map1.length];
|
||||||
|
const [w2, h2] = [map2[0].length, map2.length];
|
||||||
|
if (w1 !== w2 || h1 !== h2) continue;
|
||||||
|
generatePair(trainData, key1, key2, map1, map2, [w1, h1]);
|
||||||
|
n++;
|
||||||
|
progress.update(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
progress.stop();
|
||||||
|
|
||||||
|
const dataset: MinamoDataset = {
|
||||||
|
datasetId: Math.floor(Math.random() * 1e12),
|
||||||
|
data: trainData
|
||||||
|
};
|
||||||
|
|
||||||
|
return dataset;
|
||||||
|
}
|
||||||
|
|
||||||
|
(async () => {
|
||||||
|
if (!assigned) {
|
||||||
|
const towers = await Promise.all(
|
||||||
|
list.map(v => parseTowerInfo(v, 'minamo-config.json'))
|
||||||
|
);
|
||||||
|
const floors = await getAllFloors(...towers);
|
||||||
|
const results = parseAllData(floors);
|
||||||
|
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
|
||||||
|
const size = Object.keys(results.data).length;
|
||||||
|
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`);
|
||||||
|
} else {
|
||||||
|
const [tower1, tower2] = list;
|
||||||
|
if (!tower1 || !tower2) {
|
||||||
|
console.log(`⚠️ assigned 模式下必须传入两个塔!`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const data1 = await readOne(tower1);
|
||||||
|
const data2 = await readOne(tower2);
|
||||||
|
const results = generateAssignedData(data1, data2);
|
||||||
|
await writeFile(output, JSON.stringify(results, void 0), 'utf-8');
|
||||||
|
const size = Object.keys(results.data).length;
|
||||||
|
console.log(`✅ 已处理 ${list.length} 个塔,共 ${size} 个组合`);
|
||||||
|
}
|
||||||
})();
|
})();
|
||||||
|
|||||||
@ -144,3 +144,34 @@ export function mergeFloorIds(...info: TowerInfo[]) {
|
|||||||
});
|
});
|
||||||
return ids;
|
return ids;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function readOne(path: string) {
|
||||||
|
if (path.endsWith('.json')) {
|
||||||
|
return fromJSON(path);
|
||||||
|
} else {
|
||||||
|
return getAllFloors(await parseTowerInfo(path, 'minamo-config.json'));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function fromJSON(path: string) {
|
||||||
|
const file = await readFile(path, 'utf-8');
|
||||||
|
const data = JSON.parse(file) as Record<string, number[][]>;
|
||||||
|
const clip: Record<string, [number, number, number, number]> = {};
|
||||||
|
const config: BaseConfig = {
|
||||||
|
clip: {
|
||||||
|
defaults: [0, 0, 0, 0],
|
||||||
|
special: clip
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const name = (Math.random() * 12).toFixed(0);
|
||||||
|
const floorMap = new Map<string, FloorData>();
|
||||||
|
for (const [key, value] of Object.entries(data)) {
|
||||||
|
const floorData: FloorData = {
|
||||||
|
map: value,
|
||||||
|
id: key,
|
||||||
|
config
|
||||||
|
};
|
||||||
|
floorMap.set(`${name}:${key}`, floorData);
|
||||||
|
}
|
||||||
|
return floorMap;
|
||||||
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user