diff --git a/docs/chapter2/第二章 Transformer架构.md b/docs/chapter2/第二章 Transformer架构.md index 91c4a59..d3854fe 100644 --- a/docs/chapter2/第二章 Transformer架构.md +++ b/docs/chapter2/第二章 Transformer架构.md @@ -252,7 +252,7 @@ class MultiHeadAttention(nn.Module): # args: 配置对象 super().__init__() # 隐藏层维度必须是头数的整数倍,因为后面我们会将输入拆成头数个矩阵 - assert args.n_embd % args.n_head == 0 + assert args.n_embd % args.n_heads == 0 # 模型并行处理大小,默认为1。 model_parallel_size = 1 # 本地计算头数,等于总头数除以模型并行处理大小。