diff --git a/docs/chapter5/5.1 模型结构-LLaMA.md b/docs/chapter5/5.1 模型结构-LLaMA.md index 4577731..f5a538a 100644 --- a/docs/chapter5/5.1 模型结构-LLaMA.md +++ b/docs/chapter5/5.1 模型结构-LLaMA.md @@ -543,7 +543,7 @@ class LLaMA2Model(nn.Module): return logits ``` -同样大家可以使用下面的代码来对`LLaMA2Model`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 32000])`,与我们输入的形状一致,说明模块的实现是正确的。 +同样大家可以使用下面的代码来对`LLaMA2Model`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 1, 32000])`,与我们输入的形状一致,说明模块的实现是正确的。 ```python # LLaMA2Model.forward 接受两个参数,tokens和targets,其中tokens是输入的张量, 应为int类型