fix ch2 demo bug-n_head

This commit is contained in:
Logan Zou
2025-06-09 11:00:24 +08:00
committed by GitHub
parent d99ad30711
commit 21fc6b7ac6

View File

@@ -252,7 +252,7 @@ class MultiHeadAttention(nn.Module):
# args: 配置对象
super().__init__()
# 隐藏层维度必须是头数的整数倍,因为后面我们会将输入拆成头数个矩阵
assert args.n_embd % args.n_head == 0
assert args.n_embd % args.n_heads == 0
# 模型并行处理大小默认为1。
model_parallel_size = 1
# 本地计算头数,等于总头数除以模型并行处理大小。