mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-24 05:01:41 +08:00
feat: 调整 decoder 的位置编码
This commit is contained in:
parent
14f391f4f4
commit
4d244d021a
@ -86,12 +86,15 @@ class GinkaPosEmbedding(nn.Module):
|
|||||||
|
|
||||||
self.row_embedding = nn.Embedding(height, embed_dim)
|
self.row_embedding = nn.Embedding(height, embed_dim)
|
||||||
self.col_embedding = nn.Embedding(width, embed_dim)
|
self.col_embedding = nn.Embedding(width, embed_dim)
|
||||||
|
self.fusion = nn.Linear(embed_dim * 2, embed_dim)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||||
row = self.row_embedding(y).squeeze(1)
|
row = self.row_embedding(y)
|
||||||
col = self.col_embedding(x).squeeze(1)
|
col = self.col_embedding(x)
|
||||||
|
embed = torch.cat([row, col], dim=2)
|
||||||
|
fused = self.fusion(embed)
|
||||||
|
|
||||||
return row, col
|
return fused
|
||||||
|
|
||||||
class GinkaInputFusion(nn.Module):
|
class GinkaInputFusion(nn.Module):
|
||||||
def __init__(self, d_model=256):
|
def __init__(self, d_model=256):
|
||||||
@ -109,16 +112,15 @@ class GinkaInputFusion(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,
|
self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,
|
||||||
row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor
|
pos_embed: torch.Tensor, patch_vec: torch.Tensor
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
tile_embed: [B, 256]
|
tile_embed: [B, 256]
|
||||||
cond_vec: [B, 256]
|
cond_vec: [B, 256]
|
||||||
row_embed: [B, 256]
|
pos_embed: [B, 256]
|
||||||
col_embed: [B, 256]
|
|
||||||
patch_vec: [B, 256]
|
patch_vec: [B, 256]
|
||||||
"""
|
"""
|
||||||
vec = torch.stack([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1)
|
vec = torch.stack([tile_embed, cond_vec, pos_embed, patch_vec], dim=1)
|
||||||
feat = self.transformer(vec)
|
feat = self.transformer(vec)
|
||||||
return feat[:, 0]
|
return feat[:, 0]
|
||||||
|
|
||||||
@ -168,6 +170,13 @@ class VAEDecoder(nn.Module):
|
|||||||
self.feat_fusion = GinkaInputFusion()
|
self.feat_fusion = GinkaInputFusion()
|
||||||
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
|
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
|
||||||
|
|
||||||
|
self.col_list = []
|
||||||
|
self.row_list = []
|
||||||
|
for y in range(0, height):
|
||||||
|
for x in range(0, width):
|
||||||
|
self.col_list.append(x)
|
||||||
|
self.row_list.append(y)
|
||||||
|
|
||||||
def forward(self, map_vec: torch.Tensor, target_map: torch.Tensor, use_self_probility=0):
|
def forward(self, map_vec: torch.Tensor, target_map: torch.Tensor, use_self_probility=0):
|
||||||
"""
|
"""
|
||||||
map_vec: [B, vec_dim]
|
map_vec: [B, vec_dim]
|
||||||
@ -183,19 +192,21 @@ class VAEDecoder(nn.Module):
|
|||||||
output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device)
|
output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device)
|
||||||
hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device)
|
hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device)
|
||||||
|
|
||||||
|
col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1)
|
||||||
|
row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1)
|
||||||
|
pos_embed = self.pos_embedding(col_list, row_list)
|
||||||
|
|
||||||
map_vec = self.map_vec_fc(map_vec)
|
map_vec = self.map_vec_fc(map_vec)
|
||||||
|
|
||||||
for y in range(0, self.height):
|
for y in range(0, self.height):
|
||||||
for x in range(0, self.width):
|
for x in range(0, self.width):
|
||||||
x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1)
|
idx = y * self.width + x
|
||||||
y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1)
|
# 图块编码、地图局部编码
|
||||||
# 位置编码、图块编码、地图局部编码
|
|
||||||
tile_embed = self.tile_embedding(now_tile)
|
tile_embed = self.tile_embedding(now_tile)
|
||||||
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
|
||||||
use_self = random.random() < use_self_probility
|
use_self = random.random() < use_self_probility
|
||||||
map_patch = self.map_patch(map if use_self else target_map, x, y)
|
map_patch = self.map_patch(map if use_self else target_map, x, y)
|
||||||
# 编码特征融合
|
# 编码特征融合
|
||||||
feat = self.feat_fusion(tile_embed, map_vec, row_embed, col_embed, map_patch)
|
feat = self.feat_fusion(tile_embed, map_vec, pos_embed[:, idx], map_patch)
|
||||||
# RNN 输出
|
# RNN 输出
|
||||||
logits, h = self.rnn(feat, hidden)
|
logits, h = self.rnn(feat, hidden)
|
||||||
# 处理输出
|
# 处理输出
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user