mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 20:41:12 +08:00
feat: 提高参数量
This commit is contained in:
parent
724b6612d3
commit
f169167409
@ -1,4 +1,6 @@
|
|||||||
for i in {$1...$2}
|
start=$1
|
||||||
|
end=$2
|
||||||
|
for ((i=start; i<=end; i=i+1))
|
||||||
do
|
do
|
||||||
sh gan.sh "$i"
|
sh gan.sh "$i"
|
||||||
echo "第 $i 次循环完成"
|
echo "第 $i 次循环完成"
|
||||||
|
|||||||
@ -95,8 +95,8 @@ function generateTransformData(
|
|||||||
types.push([rot, flip]);
|
types.push([rot, flip]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 随机抽取最多两个
|
// 随机抽取最多一个
|
||||||
const trans = chooseFrom(types, Math.floor(Math.random() * 2));
|
const trans = chooseFrom(types, Math.floor(Math.random() * 1));
|
||||||
return trans
|
return trans
|
||||||
.map(([rot, flip]) => {
|
.map(([rot, flip]) => {
|
||||||
const com1 = `${id1}.${rot}.${flip}:${id1}`;
|
const com1 = `${id1}.${rot}.${flip}:${id1}`;
|
||||||
@ -167,10 +167,10 @@ function generateTransformData(
|
|||||||
}
|
}
|
||||||
|
|
||||||
function generateSimilarData(id: string, map: number[][]) {
|
function generateSimilarData(id: string, map: number[][]) {
|
||||||
// 生成最多五个微调地图
|
// 生成最多两个微调地图
|
||||||
const width = map[0].length;
|
const width = map[0].length;
|
||||||
const height = map.length;
|
const height = map.length;
|
||||||
const num = Math.floor(Math.random() * 3);
|
const num = Math.floor(Math.random() * 2);
|
||||||
const res: [id: string, data: MinamoTrainData][] = [];
|
const res: [id: string, data: MinamoTrainData][] = [];
|
||||||
|
|
||||||
for (let i = 0; i < num; i++) {
|
for (let i = 0; i < num; i++) {
|
||||||
@ -241,7 +241,7 @@ function generatePair(
|
|||||||
// 自身与自身对比的训练集,保证模型对相同地图输出 1
|
// 自身与自身对比的训练集,保证模型对相同地图输出 1
|
||||||
const self1 = `${id1}:${id1}`;
|
const self1 = `${id1}:${id1}`;
|
||||||
const self2 = `${id2}:${id2}`;
|
const self2 = `${id2}:${id2}`;
|
||||||
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 3));
|
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
|
||||||
if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) {
|
if (selfTrain.includes(self1) && !data[`${id1}:${id1}`]) {
|
||||||
const selfTrain1: MinamoTrainData = {
|
const selfTrain1: MinamoTrainData = {
|
||||||
map1: map1,
|
map1: map1,
|
||||||
|
|||||||
6
gan.sh
6
gan.sh
@ -8,10 +8,10 @@ python3 -m ginka.validate
|
|||||||
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
|
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
|
||||||
mv "minamo-eval.json" "datasets/minamo-eval-$1.json"
|
mv "minamo-eval.json" "datasets/minamo-eval-$1.json"
|
||||||
cd data
|
cd data
|
||||||
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:40
|
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:30
|
||||||
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10
|
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10
|
||||||
pnpm review "../minamo-dataset.json" "../datasets/minamo-dataset-merged.json"
|
|
||||||
pnpm review "../minamo-eval.json" "../datasets/minamo-eval-merged.json"
|
|
||||||
pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json"
|
pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json"
|
||||||
pnpm merge "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-$1.json"
|
pnpm merge "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-$1.json"
|
||||||
|
pnpm review "../minamo-dataset.json" "../datasets/minamo-dataset-merged.json"
|
||||||
|
pnpm review "../minamo-eval.json" "../datasets/minamo-eval-merged.json"
|
||||||
cd ..
|
cd ..
|
||||||
|
|||||||
@ -3,7 +3,8 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from minamo.model.model import MinamoModel
|
from minamo.model.model import MinamoModel
|
||||||
from shared.graph import convert_soft_map_to_graph
|
from shared.graph import differentiable_convert_to_data
|
||||||
|
from shared.utils import random_smooth_onehot
|
||||||
|
|
||||||
def load_data(path: str):
|
def load_data(path: str):
|
||||||
with open(path, 'r', encoding="utf-8") as f:
|
with open(path, 'r', encoding="utf-8") as f:
|
||||||
@ -28,8 +29,9 @@ class GinkaDataset(Dataset):
|
|||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.data[idx]
|
item = self.data[idx]
|
||||||
|
|
||||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float().to(self.device) # [32, H, W]
|
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||||
graph = convert_soft_map_to_graph(target).to(self.device)
|
target = random_smooth_onehot(target).to(self.device)
|
||||||
|
graph = differentiable_convert_to_data(target).to(self.device)
|
||||||
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -10,8 +10,22 @@ class GinkaModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.base_ch = base_ch
|
self.base_ch = base_ch
|
||||||
|
fc_dim = base_ch * 8 * 4 * 4
|
||||||
self.fc = nn.Sequential(
|
self.fc = nn.Sequential(
|
||||||
nn.Linear(feat_dim, 32 * 32 * base_ch)
|
nn.Linear(feat_dim, fc_dim),
|
||||||
|
nn.BatchNorm1d(fc_dim),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
self.deconv_layers = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(base_ch*8, base_ch*4, kernel_size=4, stride=2, padding=1), # Upsample 2x
|
||||||
|
nn.BatchNorm2d(base_ch*4),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.ConvTranspose2d(base_ch*4, base_ch*2, kernel_size=4, stride=2, padding=1), # Upsample 2x
|
||||||
|
nn.BatchNorm2d(base_ch*2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.ConvTranspose2d(base_ch*2, base_ch, kernel_size=4, stride=2, padding=1), # Upsample 2x
|
||||||
|
nn.BatchNorm2d(base_ch),
|
||||||
|
nn.ReLU(),
|
||||||
)
|
)
|
||||||
self.unet = GinkaUNet(base_ch, num_classes)
|
self.unet = GinkaUNet(base_ch, num_classes)
|
||||||
self.down_sample = MapDownSample(num_classes, num_classes)
|
self.down_sample = MapDownSample(num_classes, num_classes)
|
||||||
@ -25,7 +39,8 @@ class GinkaModel(nn.Module):
|
|||||||
logits: 输出logits [BS, num_classes, H, W]
|
logits: 输出logits [BS, num_classes, H, W]
|
||||||
"""
|
"""
|
||||||
x = self.fc(feat)
|
x = self.fc(feat)
|
||||||
x = x.view(-1, self.base_ch, 32, 32)
|
x = x.view(-1, self.base_ch*8, 4, 4)
|
||||||
|
x = self.deconv_layers(x)
|
||||||
x = self.unet(x)
|
x = self.unet(x)
|
||||||
x = F.interpolate(x, (13, 13), mode='bilinear')
|
x = F.interpolate(x, (13, 13), mode='bilinear')
|
||||||
return x, F.softmax(x, dim=1)
|
return x, F.softmax(x, dim=1)
|
||||||
|
|||||||
@ -48,7 +48,7 @@ def train():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 设定优化器与调度器
|
# 设定优化器与调度器
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=5e-3)
|
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||||
criterion = GinkaLoss(minamo)
|
criterion = GinkaLoss(minamo)
|
||||||
|
|
||||||
@ -72,7 +72,7 @@ def train():
|
|||||||
target = batch["target"].to(device)
|
target = batch["target"].to(device)
|
||||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||||
# 前向传播
|
# 前向传播
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
_, output_softmax = model(feat_vec)
|
_, output_softmax = model(feat_vec)
|
||||||
@ -84,6 +84,10 @@ def train():
|
|||||||
scaled_losses.backward()
|
scaled_losses.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
total_loss += losses.item()
|
total_loss += losses.item()
|
||||||
|
# for name, param in model.named_parameters():
|
||||||
|
# if param.grad is not None:
|
||||||
|
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
|
||||||
|
|
||||||
|
|
||||||
avg_loss = total_loss / len(dataloader)
|
avg_loss = total_loss / len(dataloader)
|
||||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||||
@ -112,7 +116,7 @@ def train():
|
|||||||
target = batch["target"].to(device)
|
target = batch["target"].to(device)
|
||||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||||
|
|
||||||
# 前向传播
|
# 前向传播
|
||||||
output, output_softmax = model(feat_vec)
|
output, output_softmax = model(feat_vec)
|
||||||
|
|||||||
@ -106,7 +106,7 @@ def validate():
|
|||||||
target = batch["target"].to(device)
|
target = batch["target"].to(device)
|
||||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||||
# 前向传播
|
# 前向传播
|
||||||
output, output_softmax = model(feat_vec)
|
output, output_softmax = model(feat_vec)
|
||||||
map_matrix = torch.argmax(output, dim=1)
|
map_matrix = torch.argmax(output, dim=1)
|
||||||
|
|||||||
@ -3,22 +3,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from shared.graph import differentiable_convert_to_data
|
from shared.graph import differentiable_convert_to_data
|
||||||
|
from shared.utils import random_smooth_onehot
|
||||||
def random_smooth_onehot(onehot_map, min_main=0.75, max_main=1.0, epsilon=0.25):
|
|
||||||
"""
|
|
||||||
生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动
|
|
||||||
"""
|
|
||||||
C, H, W = onehot_map.shape
|
|
||||||
# 生成主类别的随机概率 (min_main, max_main)
|
|
||||||
main_prob = torch.rand(H, W) * (max_main - min_main) + min_main
|
|
||||||
|
|
||||||
# 计算剩余概率并随机分配到其他类别
|
|
||||||
noise = torch.rand(C, H, W) * epsilon # 随机噪声
|
|
||||||
noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon
|
|
||||||
|
|
||||||
# 计算最终平滑 one-hot 结果
|
|
||||||
smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise
|
|
||||||
return smooth_onehot
|
|
||||||
|
|
||||||
def load_data(path: str):
|
def load_data(path: str):
|
||||||
with open(path, 'r', encoding="utf-8") as f:
|
with open(path, 'r', encoding="utf-8") as f:
|
||||||
|
|||||||
@ -110,7 +110,7 @@ def train():
|
|||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
ave_loss = total_loss / len(dataloader)
|
ave_loss = total_loss / len(dataloader)
|
||||||
print(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||||
|
|
||||||
# total_norm = 0
|
# total_norm = 0
|
||||||
# for p in model.parameters():
|
# for p in model.parameters():
|
||||||
@ -128,7 +128,7 @@ def train():
|
|||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
# 每十轮推理一次验证集
|
# 每十轮推理一次验证集
|
||||||
if (epoch + 1) % 1 == 0:
|
if (epoch + 1) % 5 == 0:
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loss = 0
|
val_loss = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -152,7 +152,7 @@ def train():
|
|||||||
val_loss += loss_val.item()
|
val_loss += loss_val.item()
|
||||||
|
|
||||||
avg_val_loss = val_loss / len(val_loader)
|
avg_val_loss = val_loss / len(val_loader)
|
||||||
print(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": model.state_dict(),
|
"model_state": model.state_dict(),
|
||||||
"optimizer_state": optimizer.state_dict(),
|
"optimizer_state": optimizer.state_dict(),
|
||||||
|
|||||||
@ -20,7 +20,7 @@ def validate():
|
|||||||
print(f"Total parameters: {total_params}")
|
print(f"Total parameters: {total_params}")
|
||||||
|
|
||||||
# 准备数据集
|
# 准备数据集
|
||||||
val_dataset = MinamoDataset("minamo-eval.json")
|
val_dataset = MinamoDataset("datasets/minamo-eval-1.json")
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@ -44,6 +44,8 @@ def validate():
|
|||||||
vision_feat1, topo_feat1 = model(map1_val, graph1)
|
vision_feat1, topo_feat1 = model(map1_val, graph1)
|
||||||
vision_feat2, topo_feat2 = model(map2_val, graph2)
|
vision_feat2, topo_feat2 = model(map2_val, graph2)
|
||||||
|
|
||||||
|
print(vision_feat1.isnan().any().item(), topo_feat1.isnan().any().item(), vision_feat2.isnan().any().item(), topo_feat2.isnan().any().item())
|
||||||
|
|
||||||
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||||
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||||
loss_val = criterion(
|
loss_val = criterion(
|
||||||
|
|||||||
17
shared/utils.py
Normal file
17
shared/utils.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
def random_smooth_onehot(onehot_map, min_main=0.75, max_main=1.0, epsilon=0.25):
|
||||||
|
"""
|
||||||
|
生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动
|
||||||
|
"""
|
||||||
|
C, H, W = onehot_map.shape
|
||||||
|
# 生成主类别的随机概率 (min_main, max_main)
|
||||||
|
main_prob = torch.rand(H, W) * (max_main - min_main) + min_main
|
||||||
|
|
||||||
|
# 计算剩余概率并随机分配到其他类别
|
||||||
|
noise = torch.rand(C, H, W) * epsilon # 随机噪声
|
||||||
|
noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon
|
||||||
|
|
||||||
|
# 计算最终平滑 one-hot 结果
|
||||||
|
smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise
|
||||||
|
return smooth_onehot
|
||||||
Loading…
Reference in New Issue
Block a user