Update transformer.py

fix dim bug
This commit is contained in:
Logan Zou
2025-06-23 10:48:56 +08:00
committed by GitHub
parent 3b24a9fd1e
commit bd3fb6cf48

View File

@@ -39,7 +39,7 @@ class MultiHeadAttention(nn.Module):
self.wq = nn.Linear(args.n_embd, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.n_embd, args.n_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.n_embd, args.n_heads * self.head_dim, bias=False)
# 输出权重矩阵,维度为 n_embd x n_embdhead_dim = n_embeds / n_heads
# 输出权重矩阵,维度为 dim x n_embdhead_dim = n_embeds / n_heads
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
# 注意力的 dropout
self.attn_dropout = nn.Dropout(args.dropout)