Fix: 5.1 llama

This commit is contained in:
KMnO4-zx
2024-05-28 16:18:07 +08:00
parent dbced843e5
commit 73ff50be2b

View File

@@ -543,7 +543,7 @@ class LLaMA2Model(nn.Module):
return logits
```
同样大家可以使用下面的代码来对`LLaMA2Model`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 32000])`,与我们输入的形状一致,说明模块的实现是正确的。
同样大家可以使用下面的代码来对`LLaMA2Model`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 1, 32000])`,与我们输入的形状一致,说明模块的实现是正确的。
```python
# LLaMA2Model.forward 接受两个参数tokens和targets其中tokens是输入的张量, 应为int类型