Update 第二章 Transformer架构.md
This commit is contained in:
@@ -483,9 +483,9 @@ class EncoderLayer(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
# Layer Norm
|
||||
x = self.attention_norm(x)
|
||||
norm_x = self.attention_norm(x)
|
||||
# 自注意力
|
||||
h = x + self.attention.forward(x, x, x)
|
||||
h = x + self.attention.forward(norm_x, norm_x, norm_x)
|
||||
# 经过前馈神经网络
|
||||
out = h + self.feed_forward.forward(self.fnn_norm(h))
|
||||
return out
|
||||
@@ -533,12 +533,12 @@ class DecoderLayer(nn.Module):
|
||||
|
||||
def forward(self, x, enc_out):
|
||||
# Layer Norm
|
||||
x = self.attention_norm_1(x)
|
||||
norm_x = self.attention_norm_1(x)
|
||||
# 掩码自注意力
|
||||
x = x + self.mask_attention.forward(x, x, x)
|
||||
x = x + self.mask_attention.forward(norm_x, norm_x, norm_x)
|
||||
# 多头注意力
|
||||
x = self.attention_norm_2(x)
|
||||
h = x + self.attention.forward(x, enc_out, enc_out)
|
||||
norm_x = self.attention_norm_2(x)
|
||||
h = x + self.attention.forward(norm_x, enc_out, enc_out)
|
||||
# 经过前馈神经网络
|
||||
out = h + self.feed_forward.forward(self.fnn_norm(h))
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user