fix ch2 demo bug-n_head
This commit is contained in:
@@ -252,7 +252,7 @@ class MultiHeadAttention(nn.Module):
|
||||
# args: 配置对象
|
||||
super().__init__()
|
||||
# 隐藏层维度必须是头数的整数倍,因为后面我们会将输入拆成头数个矩阵
|
||||
assert args.n_embd % args.n_head == 0
|
||||
assert args.n_embd % args.n_heads == 0
|
||||
# 模型并行处理大小,默认为1。
|
||||
model_parallel_size = 1
|
||||
# 本地计算头数,等于总头数除以模型并行处理大小。
|
||||
|
||||
Reference in New Issue
Block a user