diff --git a/docs/chapter5/第五章 动手搭建大模型.md b/docs/chapter5/第五章 动手搭建大模型.md index ca0a52d..6b74489 100644 --- a/docs/chapter5/第五章 动手搭建大模型.md +++ b/docs/chapter5/第五章 动手搭建大模型.md @@ -2,7 +2,7 @@ ## 5.1 动手实现一个 LLaMA2 大模型 -Meta(原Facebook)于2023年2月发布第一款基于Transformer结构的大型语言模型-LLaMA,并于同年7月发布同系列模型-LLaMA2。我们在第四章已经学习了解的了LLM,记忆如何训练LLM等等。那本小节我们就来学习,如何动手写一个LLaMA2模型。 +Meta(原Facebook)于2023年2月发布第一款基于Transformer结构的大型语言模型LLaMA,并于同年7月发布同系列模型LLaMA2。我们在第四章已经学习了解的了LLM,记忆如何训练LLM等等。那本小节我们就来学习,如何动手写一个LLaMA2模型。 ### 5.1.1 定义超参数 @@ -102,7 +102,7 @@ orch.Size([1, 50, 288]) ### 5.1.3 构建 LLaMA2 Attention -在 LLaMA2 模型中,虽然只有 LLaMA2-70B模型使用了GQA(Group Query Attention),但我们选择使用GQA来构建我们的 LLaMA Attention 模块,它可以提高模型的效率,并节省一些显存占用。 +在 LLaMA2 模型中,虽然只有 LLaMA2-70B模型使用了分组查询注意力机制(Grouped-Query Attention,GQA),但我们依然选择使用 GQA 来构建我们的 LLaMA Attention 模块,它可以提高模型的效率,并节省一些显存占用。 #### 5.1.3.1 repeat_kv @@ -138,7 +138,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: #### 5.1.3.2 旋转嵌入 -接着我们来实现旋转嵌入,旋转嵌入是 LLaMA2 模型中的一个重要组件,它可以为attention机制提供更强的上下文信息,从而提高模型的性能。 +接着我们来实现旋转嵌入,旋转嵌入是 LLaMA2 模型中的一个重要组件,它可以为注意力机制提供更强的上下文信息,从而提高模型的性能。 首先,我们要构造获得旋转嵌入的实部和虚部的函数: @@ -337,7 +337,7 @@ class Attention(nn.Module): return output ``` -同样大家可以使用下面的代码来对`Attention`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 768])`,与我们输入的形状一致,说明模块的实现是正确的。 +同样大家可以使用下面的代码来对注意力模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 768])`,与我们输入的形状一致,说明模块的实现是正确的。 ```python # 创建Attention实例 @@ -1054,7 +1054,6 @@ with open('BelleGroup_sft.jsonl', 'a', encoding='utf-8') as sft: ```bash python code/train_tokenizer.py - ``` ```python @@ -1324,9 +1323,12 @@ class PretrainDataset(Dataset): return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask) ``` -在以上代码可以看出,我们的 `Pretrain Dataset` 主要是将 `text` 通过 `tokenizer` 转换成 `input_id`,然后将 `input_id` 拆分成 `X` 和 `Y`,其中 `X` 为 `input_id` 的前 n-1 个元素,`Y` 为 `input_id` 的后 n-1 `个元素。loss_mask` 主要是用来标记哪些位置需要计算损失,哪些位置不需要计算损失。如果你不太能明白,可以看下面的示意图。 +在以上代码和图5.1可以看出,`Pretrain Dataset` 主要是将 `text` 通过 `tokenizer` 转换成 `input_id`,然后将 `input_id` 拆分成 `X` 和 `Y`,其中 `X` 为 `input_id` 的前 n-1 个元素,`Y` 为 `input_id` 的后 n-1 `个元素。loss_mask` 主要是用来标记哪些位置需要计算损失,哪些位置不需要计算损失。 -![alt text](./images/pretrain_dataset.png) +
+ alt text +

图5.1 预训练损失函数计算

+
图中示例展示了当`max_length=9`时的处理过程: - **输入序列**:`[BOS, T1, T2, T3, T4, T5, T6, T7, EOS]` @@ -1408,9 +1410,12 @@ class SFTDataset(Dataset): return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask) ``` -在 SFT 阶段,我这里使用的是多轮对话数据集,所以就需要区分哪些位置需要计算损失,哪些位置不需要计算损失。在上面的代码中,我使用了一个 `generate_loss_mask` 函数来生成 `loss_mask`。这个函数主要是用来生成 `loss_mask`,其中 `loss_mask` 的生成规则是:当遇到 `|assistant\n` 时,就开始计算损失,直到遇到 `|` 为止。这样就可以保证我们的模型在 SFT 阶段只计算当前轮的对话内容。那我也给出一个示意图,帮助大家理解。 +在 SFT 阶段,这里使用的是多轮对话数据集,所以就需要区分哪些位置需要计算损失,哪些位置不需要计算损失。在上面的代码中,我使用了一个 `generate_loss_mask` 函数来生成 `loss_mask`。这个函数主要是用来生成 `loss_mask`,其中 `loss_mask` 的生成规则是:当遇到 `|assistant\n` 时,就开始计算损失,直到遇到 `|` 为止。这样就可以保证我们的模型在 SFT 阶段只计算当前轮的对话内容,如图5.2所示。 -![alt text](./images/sftdataset.png) +
+ alt text +

图5.2 SFT 损失函数计算

+
可以看到,其实 SFT Dataset 和 Pretrain Dataset 的 `X` 和 `Y` 是一样的,只是在 SFT Dataset 中我们需要生成一个 `loss_mask` 来标记哪些位置需要计算损失,哪些位置不需要计算损失。 图中 `Input ids` 中的蓝色小方格就是AI的回答,所以是需要模型学习的地方。所以在 `loss_mask` 中,蓝色小方格对应的位置是黄色,其他位置是灰色。在代码 `loss_mask` 中的 1 对应的位置计算损失,0 对应的位置不计算损失。 @@ -2017,10 +2022,16 @@ Sample 2: **参考资料** -- [llama2.c](https://github.com/karpathy/llama2.c) -- [llm.c](https://github.com/karpathy/llm.c) -- [tokenizers](https://huggingface.co/docs/tokenizers/index) -- [SkyWork 150B](https://huggingface.co/datasets/Skywork/SkyPile-150B) -- [BelleGroup](https://huggingface.co/datasets/BelleGroup/train_3.5M_CN) -- [minimind](https://github.com/jingyaogong/minimind) -- [出门问问序列猴子开源数据集](https://github.com/mobvoi/seq-monkey-data) \ No newline at end of file +[1] Andrej Karpathy. (2023). *llama2.c: Fullstack Llama 2 LLM solution in pure C*. GitHub repository. https://github.com/karpathy/llama2.c + +[2] Andrej Karpathy. (2023). *llm.c: GPT-2/GPT-3 pretraining in C/CUDA*. GitHub repository. https://github.com/karpathy/llm.c + +[3] Hugging Face. (2023). *Tokenizers documentation*. https://huggingface.co/docs/tokenizers/index + +[4] Skywork Team. (2023). *SkyPile-150B: A large-scale bilingual dataset*. Hugging Face dataset. https://huggingface.co/datasets/Skywork/SkyPile-150B + +[5] BelleGroup. (2022). *train_3.5M_CN: Chinese dialogue dataset*. Hugging Face dataset. https://huggingface.co/datasets/BelleGroup/train_3.5M_CN + +[6] Jingyao Gong. (2023). *minimind: Minimalist LLM implementation*. GitHub repository. https://github.com/jingyaogong/minimind + +[7] Mobvoi. (2023). *seq-monkey-data: Llama2 training/inference data*. GitHub repository. https://github.com/mobvoi/seq-monkey-data \ No newline at end of file