diff --git a/docs/chapter2/第二章 Transformer架构.md b/docs/chapter2/第二章 Transformer架构.md index d3915b0..1026216 100644 --- a/docs/chapter2/第二章 Transformer架构.md +++ b/docs/chapter2/第二章 Transformer架构.md @@ -266,7 +266,7 @@ class MultiHeadAttention(nn.Module): self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) - # 输出权重矩阵,维度为 n_embd x n_embd(head_dim = n_embeds / n_heads) + # 输出权重矩阵,维度为 dim x n_embd(head_dim = n_embeds / n_heads) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) # 注意力的 dropout self.attn_dropout = nn.Dropout(args.dropout)