mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-24 13:11:10 +08:00
perf: 调优部分超参数
This commit is contained in:
parent
29cfb4d029
commit
8130296e1f
@ -311,7 +311,7 @@ def js_divergence(P, Q, epsilon=1e-10):
|
|||||||
return js.mean() # 标量
|
return js.mean() # 标量
|
||||||
|
|
||||||
class WGANGinkaLoss:
|
class WGANGinkaLoss:
|
||||||
def __init__(self, lambda_gp=10, weight=[0.7, 0.2, 0.1], diversity_lamda=0):
|
def __init__(self, lambda_gp=20, weight=[0.7, 0.2, 0.1], diversity_lamda=0):
|
||||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.diversity_lamda = diversity_lamda
|
self.diversity_lamda = diversity_lamda
|
||||||
|
|||||||
@ -38,7 +38,7 @@ def train():
|
|||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||||
|
|
||||||
c_steps = 1
|
c_steps = 1
|
||||||
g_steps = 3
|
g_steps = 4
|
||||||
|
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
|
||||||
@ -103,9 +103,9 @@ def train():
|
|||||||
for _ in range(g_steps):
|
for _ in range(g_steps):
|
||||||
z1 = torch.randn(batch_size, 1024, device=device)
|
z1 = torch.randn(batch_size, 1024, device=device)
|
||||||
z2 = torch.randn(batch_size, 1024, device=device)
|
z2 = torch.randn(batch_size, 1024, device=device)
|
||||||
fake_softmax1, fakse_softmax2 = ginka(z1), ginka(z2)
|
fake_softmax1, fake_softmax2 = ginka(z1), ginka(z2)
|
||||||
|
|
||||||
loss_g = criterion.generator_loss(minamo, fake_softmax1, fakse_softmax2)
|
loss_g = criterion.generator_loss(minamo, fake_softmax1, fake_softmax2)
|
||||||
loss_g.backward()
|
loss_g.backward()
|
||||||
optimizer_ginka.step()
|
optimizer_ginka.step()
|
||||||
|
|
||||||
@ -120,14 +120,19 @@ def train():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if avg_dis < -9:
|
if avg_dis < -9:
|
||||||
g_steps = 7
|
g_steps = 21
|
||||||
elif avg_dis < -6:
|
elif avg_dis < -6:
|
||||||
g_steps = 5
|
g_steps = 14
|
||||||
elif avg_dis < -3:
|
elif avg_dis < -3:
|
||||||
g_steps = 3
|
g_steps = 7
|
||||||
else:
|
else:
|
||||||
g_steps = 1
|
g_steps = 1
|
||||||
|
|
||||||
|
if avg_dis > 3:
|
||||||
|
c_steps = 3
|
||||||
|
else:
|
||||||
|
c_steps = 1
|
||||||
|
|
||||||
# 每五轮输出一次图片,并保存检查点
|
# 每五轮输出一次图片,并保存检查点
|
||||||
if (epoch + 1) % 5 == 0:
|
if (epoch + 1) % 5 == 0:
|
||||||
# 输出 20 张图片,每批次 4 张,一共五批
|
# 输出 20 张图片,每批次 4 张,一共五批
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user