fix config bug
This commit is contained in:
@@ -472,7 +472,7 @@ out = h + self.feed_forward.forward(self.fnn_norm(h))
|
||||
```python
|
||||
class EncoderLayer(nn.Module):
|
||||
'''Encoder层'''
|
||||
def __init__(self, config):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
# 一个 Layer 中有两个 LayerNorm,分别在 Attention 之前和 MLP 之前
|
||||
self.attention_norm = LayerNorm(args.n_embd)
|
||||
@@ -529,7 +529,7 @@ class DecoderLayer(nn.Module):
|
||||
self.attention = MultiHeadAttention(args, is_causal=False)
|
||||
self.ffn_norm = LayerNorm(args.n_embd)
|
||||
# 第三个部分是 MLP
|
||||
self.feed_forward = MLP(config)
|
||||
self.feed_forward = MLP(args)
|
||||
|
||||
def forward(self, x, enc_out):
|
||||
# Layer Norm
|
||||
@@ -830,7 +830,7 @@ class Transformer(nn.Module):
|
||||
# 输入为 idx,维度为 (batch size, sequence length, 1);targets 为目标序列,用于计算 loss
|
||||
device = idx.device
|
||||
b, t = idx.size()
|
||||
assert t <= self.config.block_size, f"不能计算该序列,该序列长度为 {t}, 最大序列长度只有 {self.config.block_size}"
|
||||
assert t <= self.args.block_size, f"不能计算该序列,该序列长度为 {t}, 最大序列长度只有 {self.args.block_size}"
|
||||
|
||||
# 通过 self.transformer
|
||||
# 首先将输入 idx 通过 Embedding 层,得到维度为 (batch size, sequence length, n_embd)
|
||||
|
||||
Reference in New Issue
Block a user