From bd3fb6cf48fec1f0eeaff0e62c2260eebc956813 Mon Sep 17 00:00:00 2001 From: Logan Zou <74288839+logan-zou@users.noreply.github.com> Date: Mon, 23 Jun 2025 10:48:56 +0800 Subject: [PATCH] Update transformer.py fix dim bug --- docs/chapter2/code/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/chapter2/code/transformer.py b/docs/chapter2/code/transformer.py index 15879cb..80ec21a 100644 --- a/docs/chapter2/code/transformer.py +++ b/docs/chapter2/code/transformer.py @@ -39,7 +39,7 @@ class MultiHeadAttention(nn.Module): self.wq = nn.Linear(args.n_embd, args.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(args.n_embd, args.n_heads * self.head_dim, bias=False) self.wv = nn.Linear(args.n_embd, args.n_heads * self.head_dim, bias=False) - # 输出权重矩阵,维度为 n_embd x n_embd(head_dim = n_embeds / n_heads) + # 输出权重矩阵,维度为 dim x n_embd(head_dim = n_embeds / n_heads) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) # 注意力的 dropout self.attn_dropout = nn.Dropout(args.dropout)