mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-24 05:01:41 +08:00
fix: 指定训练集和验证集
This commit is contained in:
parent
f9211965db
commit
c9c52109ed
@ -34,8 +34,8 @@ def train():
|
|||||||
# param.requires_grad = False
|
# param.requires_grad = False
|
||||||
|
|
||||||
# 准备数据集
|
# 准备数据集
|
||||||
dataset = GinkaDataset("ginka-dataset.json", device, minamo)
|
dataset = GinkaDataset(args.train, device, minamo)
|
||||||
dataset_val = GinkaDataset("ginka-eval.json", device, minamo)
|
dataset_val = GinkaDataset(args.validate, device, minamo)
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
|
|||||||
@ -41,8 +41,8 @@ def train():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
# 准备数据集
|
# 准备数据集
|
||||||
dataset = MinamoDataset("minamo-dataset.json")
|
dataset = MinamoDataset(args.train)
|
||||||
val_dataset = MinamoDataset("minamo-eval.json")
|
val_dataset = MinamoDataset(args.validate)
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user