fix ch2 demo bug-n_head
This commit is contained in:
@@ -252,7 +252,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
# args: 配置对象
|
# args: 配置对象
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 隐藏层维度必须是头数的整数倍,因为后面我们会将输入拆成头数个矩阵
|
# 隐藏层维度必须是头数的整数倍,因为后面我们会将输入拆成头数个矩阵
|
||||||
assert args.n_embd % args.n_head == 0
|
assert args.n_embd % args.n_heads == 0
|
||||||
# 模型并行处理大小,默认为1。
|
# 模型并行处理大小,默认为1。
|
||||||
model_parallel_size = 1
|
model_parallel_size = 1
|
||||||
# 本地计算头数,等于总头数除以模型并行处理大小。
|
# 本地计算头数,等于总头数除以模型并行处理大小。
|
||||||
|
|||||||
Reference in New Issue
Block a user