From 3950b06a5f38a61ffde75e80f640041453bf4d14 Mon Sep 17 00:00:00 2001 From: Logan Zou <74288839+logan-zou@users.noreply.github.com> Date: Mon, 23 Jun 2025 10:53:25 +0800 Subject: [PATCH] Update transformer.py fix arg bug --- docs/chapter2/code/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/chapter2/code/transformer.py b/docs/chapter2/code/transformer.py index 80ec21a..958de50 100644 --- a/docs/chapter2/code/transformer.py +++ b/docs/chapter2/code/transformer.py @@ -25,7 +25,7 @@ class MultiHeadAttention(nn.Module): # args: 配置对象 super().__init__() # 隐藏层维度必须是头数的整数倍,因为后面我们会将输入拆成头数个矩阵 - assert args.n_embd % args.n_heads == 0 + assert args.dim % args.n_heads == 0 # 模型并行处理大小,默认为1。 model_parallel_size = 1 # 本地计算头数,等于总头数除以模型并行处理大小。