Fix: 5.1 llama
This commit is contained in:
@@ -543,7 +543,7 @@ class LLaMA2Model(nn.Module):
|
||||
return logits
|
||||
```
|
||||
|
||||
同样大家可以使用下面的代码来对`LLaMA2Model`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 32000])`,与我们输入的形状一致,说明模块的实现是正确的。
|
||||
同样大家可以使用下面的代码来对`LLaMA2Model`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 1, 32000])`,与我们输入的形状一致,说明模块的实现是正确的。
|
||||
|
||||
```python
|
||||
# LLaMA2Model.forward 接受两个参数,tokens和targets,其中tokens是输入的张量, 应为int类型
|
||||
|
||||
Reference in New Issue
Block a user