fix config bug

This commit is contained in:
Logan Zou
2025-06-03 14:10:23 +08:00
committed by GitHub
parent 6522014bc3
commit bbede84054

View File

@@ -472,7 +472,7 @@ out = h + self.feed_forward.forward(self.fnn_norm(h))
```python ```python
class EncoderLayer(nn.Module): class EncoderLayer(nn.Module):
'''Encoder层''' '''Encoder层'''
def __init__(self, config): def __init__(self, args):
super().__init__() super().__init__()
# 一个 Layer 中有两个 LayerNorm分别在 Attention 之前和 MLP 之前 # 一个 Layer 中有两个 LayerNorm分别在 Attention 之前和 MLP 之前
self.attention_norm = LayerNorm(args.n_embd) self.attention_norm = LayerNorm(args.n_embd)
@@ -529,7 +529,7 @@ class DecoderLayer(nn.Module):
self.attention = MultiHeadAttention(args, is_causal=False) self.attention = MultiHeadAttention(args, is_causal=False)
self.ffn_norm = LayerNorm(args.n_embd) self.ffn_norm = LayerNorm(args.n_embd)
# 第三个部分是 MLP # 第三个部分是 MLP
self.feed_forward = MLP(config) self.feed_forward = MLP(args)
def forward(self, x, enc_out): def forward(self, x, enc_out):
# Layer Norm # Layer Norm
@@ -830,7 +830,7 @@ class Transformer(nn.Module):
# 输入为 idx维度为 (batch size, sequence length, 1)targets 为目标序列,用于计算 loss # 输入为 idx维度为 (batch size, sequence length, 1)targets 为目标序列,用于计算 loss
device = idx.device device = idx.device
b, t = idx.size() b, t = idx.size()
assert t <= self.config.block_size, f"不能计算该序列,该序列长度为 {t}, 最大序列长度只有 {self.config.block_size}" assert t <= self.args.block_size, f"不能计算该序列,该序列长度为 {t}, 最大序列长度只有 {self.args.block_size}"
# 通过 self.transformer # 通过 self.transformer
# 首先将输入 idx 通过 Embedding 层,得到维度为 (batch size, sequence length, n_embd) # 首先将输入 idx 通过 Embedding 层,得到维度为 (batch size, sequence length, n_embd)