Update 第二章 Transformer架构.md
This commit is contained in:
@@ -483,9 +483,9 @@ class EncoderLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Layer Norm
|
# Layer Norm
|
||||||
x = self.attention_norm(x)
|
norm_x = self.attention_norm(x)
|
||||||
# 自注意力
|
# 自注意力
|
||||||
h = x + self.attention.forward(x, x, x)
|
h = x + self.attention.forward(norm_x, norm_x, norm_x)
|
||||||
# 经过前馈神经网络
|
# 经过前馈神经网络
|
||||||
out = h + self.feed_forward.forward(self.fnn_norm(h))
|
out = h + self.feed_forward.forward(self.fnn_norm(h))
|
||||||
return out
|
return out
|
||||||
@@ -533,12 +533,12 @@ class DecoderLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, enc_out):
|
def forward(self, x, enc_out):
|
||||||
# Layer Norm
|
# Layer Norm
|
||||||
x = self.attention_norm_1(x)
|
norm_x = self.attention_norm_1(x)
|
||||||
# 掩码自注意力
|
# 掩码自注意力
|
||||||
x = x + self.mask_attention.forward(x, x, x)
|
x = x + self.mask_attention.forward(norm_x, norm_x, norm_x)
|
||||||
# 多头注意力
|
# 多头注意力
|
||||||
x = self.attention_norm_2(x)
|
norm_x = self.attention_norm_2(x)
|
||||||
h = x + self.attention.forward(x, enc_out, enc_out)
|
h = x + self.attention.forward(norm_x, enc_out, enc_out)
|
||||||
# 经过前馈神经网络
|
# 经过前馈神经网络
|
||||||
out = h + self.feed_forward.forward(self.fnn_norm(h))
|
out = h + self.feed_forward.forward(self.fnn_norm(h))
|
||||||
return out
|
return out
|
||||||
|
|||||||
Reference in New Issue
Block a user