Update transformer.py
fix arg bug
This commit is contained in:
@@ -25,7 +25,7 @@ class MultiHeadAttention(nn.Module):
|
||||
# args: 配置对象
|
||||
super().__init__()
|
||||
# 隐藏层维度必须是头数的整数倍,因为后面我们会将输入拆成头数个矩阵
|
||||
assert args.n_embd % args.n_heads == 0
|
||||
assert args.dim % args.n_heads == 0
|
||||
# 模型并行处理大小,默认为1。
|
||||
model_parallel_size = 1
|
||||
# 本地计算头数,等于总头数除以模型并行处理大小。
|
||||
|
||||
Reference in New Issue
Block a user