Update 第二章 Transformer架构.md
This commit is contained in:
@@ -266,7 +266,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||||
self.wk = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
self.wk = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||||
self.wv = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
self.wv = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||||
# 输出权重矩阵,维度为 n_embd x n_embd(head_dim = n_embeds / n_heads)
|
# 输出权重矩阵,维度为 dim x n_embd(head_dim = n_embeds / n_heads)
|
||||||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||||||
# 注意力的 dropout
|
# 注意力的 dropout
|
||||||
self.attn_dropout = nn.Dropout(args.dropout)
|
self.attn_dropout = nn.Dropout(args.dropout)
|
||||||
|
|||||||
Reference in New Issue
Block a user