docs:第五章 动手搭建大模型 修改图片、引用格式

This commit is contained in:
KMnO4-zx
2025-05-13 20:10:12 +08:00
parent 9763467812
commit 7127aa48b3

View File

@@ -2,7 +2,7 @@
## 5.1 动手实现一个 LLaMA2 大模型
Meta原Facebook于2023年2月发布第一款基于Transformer结构的大型语言模型-LLaMA并于同年7月发布同系列模型-LLaMA2。我们在第四章已经学习了解的了LLM记忆如何训练LLM等等。那本小节我们就来学习如何动手写一个LLaMA2模型。
Meta原Facebook于2023年2月发布第一款基于Transformer结构的大型语言模型LLaMA并于同年7月发布同系列模型LLaMA2。我们在第四章已经学习了解的了LLM记忆如何训练LLM等等。那本小节我们就来学习如何动手写一个LLaMA2模型。
### 5.1.1 定义超参数
@@ -102,7 +102,7 @@ orch.Size([1, 50, 288])
### 5.1.3 构建 LLaMA2 Attention
在 LLaMA2 模型中,虽然只有 LLaMA2-70B模型使用了GQAGroup Query Attention但我们选择使用GQA来构建我们的 LLaMA Attention 模块,它可以提高模型的效率,并节省一些显存占用。
在 LLaMA2 模型中,虽然只有 LLaMA2-70B模型使用了分组查询注意力机制Grouped-Query AttentionGQA),但我们依然选择使用 GQA 来构建我们的 LLaMA Attention 模块,它可以提高模型的效率,并节省一些显存占用。
#### 5.1.3.1 repeat_kv
@@ -138,7 +138,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
#### 5.1.3.2 旋转嵌入
接着我们来实现旋转嵌入,旋转嵌入是 LLaMA2 模型中的一个重要组件,它可以为attention机制提供更强的上下文信息,从而提高模型的性能。
接着我们来实现旋转嵌入,旋转嵌入是 LLaMA2 模型中的一个重要组件,它可以为注意力机制提供更强的上下文信息,从而提高模型的性能。
首先,我们要构造获得旋转嵌入的实部和虚部的函数:
@@ -337,7 +337,7 @@ class Attention(nn.Module):
return output
```
同样大家可以使用下面的代码来对`Attention`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 768])`,与我们输入的形状一致,说明模块的实现是正确的。
同样大家可以使用下面的代码来对注意力模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 768])`,与我们输入的形状一致,说明模块的实现是正确的。
```python
# 创建Attention实例
@@ -1054,7 +1054,6 @@ with open('BelleGroup_sft.jsonl', 'a', encoding='utf-8') as sft:
```bash
python code/train_tokenizer.py
```
```python
@@ -1324,9 +1323,12 @@ class PretrainDataset(Dataset):
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
```
在以上代码可以看出,我们的 `Pretrain Dataset` 主要是将 `text` 通过 `tokenizer` 转换成 `input_id`,然后将 `input_id` 拆分成 `X``Y`,其中 `X``input_id` 的前 n-1 个元素,`Y``input_id` 的后 n-1 `个元素。loss_mask` 主要是用来标记哪些位置需要计算损失,哪些位置不需要计算损失。如果你不太能明白,可以看下面的示意图。
在以上代码和图5.1可以看出,`Pretrain Dataset` 主要是将 `text` 通过 `tokenizer` 转换成 `input_id`,然后将 `input_id` 拆分成 `X``Y`,其中 `X``input_id` 的前 n-1 个元素,`Y``input_id` 的后 n-1 `个元素。loss_mask` 主要是用来标记哪些位置需要计算损失,哪些位置不需要计算损失。
![alt text](./images/pretrain_dataset.png)
<div align='center'>
<img src="./images/pretrain_dataset.png" alt="alt text" width="100%">
<p>图5.1 预训练损失函数计算</p>
</div>
图中示例展示了当`max_length=9`时的处理过程:
- **输入序列**`[BOS, T1, T2, T3, T4, T5, T6, T7, EOS]`
@@ -1408,9 +1410,12 @@ class SFTDataset(Dataset):
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
```
在 SFT 阶段,这里使用的是多轮对话数据集,所以就需要区分哪些位置需要计算损失,哪些位置不需要计算损失。在上面的代码中,我使用了一个 `generate_loss_mask` 函数来生成 `loss_mask`。这个函数主要是用来生成 `loss_mask`,其中 `loss_mask` 的生成规则是:当遇到 `|<im_start|>assistant\n` 时,就开始计算损失,直到遇到 `|<im_end|>` 为止。这样就可以保证我们的模型在 SFT 阶段只计算当前轮的对话内容。那我也给出一个示意图,帮助大家理解
在 SFT 阶段,这里使用的是多轮对话数据集,所以就需要区分哪些位置需要计算损失,哪些位置不需要计算损失。在上面的代码中,我使用了一个 `generate_loss_mask` 函数来生成 `loss_mask`。这个函数主要是用来生成 `loss_mask`,其中 `loss_mask` 的生成规则是:当遇到 `|<im_start|>assistant\n` 时,就开始计算损失,直到遇到 `|<im_end|>` 为止。这样就可以保证我们的模型在 SFT 阶段只计算当前轮的对话内容如图5.2所示
![alt text](./images/sftdataset.png)
<div align='center'>
<img src="./images/sftdataset.png" alt="alt text" width="90%">
<p>图5.2 SFT 损失函数计算</p>
</div>
可以看到,其实 SFT Dataset 和 Pretrain Dataset 的 `X``Y` 是一样的,只是在 SFT Dataset 中我们需要生成一个 `loss_mask` 来标记哪些位置需要计算损失,哪些位置不需要计算损失。 图中 `Input ids` 中的蓝色小方格就是AI的回答所以是需要模型学习的地方。所以在 `loss_mask` 中,蓝色小方格对应的位置是黄色,其他位置是灰色。在代码 `loss_mask` 中的 1 对应的位置计算损失0 对应的位置不计算损失。
@@ -2017,10 +2022,16 @@ Sample 2:
**参考资料**
- [llama2.c](https://github.com/karpathy/llama2.c)
- [llm.c](https://github.com/karpathy/llm.c)
- [tokenizers](https://huggingface.co/docs/tokenizers/index)
- [SkyWork 150B](https://huggingface.co/datasets/Skywork/SkyPile-150B)
- [BelleGroup](https://huggingface.co/datasets/BelleGroup/train_3.5M_CN)
- [minimind](https://github.com/jingyaogong/minimind)
- [出门问问序列猴子开源数据集](https://github.com/mobvoi/seq-monkey-data)
[1] Andrej Karpathy. (2023). *llama2.c: Fullstack Llama 2 LLM solution in pure C*. GitHub repository. https://github.com/karpathy/llama2.c
[2] Andrej Karpathy. (2023). *llm.c: GPT-2/GPT-3 pretraining in C/CUDA*. GitHub repository. https://github.com/karpathy/llm.c
[3] Hugging Face. (2023). *Tokenizers documentation*. https://huggingface.co/docs/tokenizers/index
[4] Skywork Team. (2023). *SkyPile-150B: A large-scale bilingual dataset*. Hugging Face dataset. https://huggingface.co/datasets/Skywork/SkyPile-150B
[5] BelleGroup. (2022). *train_3.5M_CN: Chinese dialogue dataset*. Hugging Face dataset. https://huggingface.co/datasets/BelleGroup/train_3.5M_CN
[6] Jingyao Gong. (2023). *minimind: Minimalist LLM implementation*. GitHub repository. https://github.com/jingyaogong/minimind
[7] Mobvoi. (2023). *seq-monkey-data: Llama2 training/inference data*. GitHub repository. https://github.com/mobvoi/seq-monkey-data