diff --git a/docs/chapter5/第五章 动手搭建大模型.md b/docs/chapter5/第五章 动手搭建大模型.md
index ca0a52d..6b74489 100644
--- a/docs/chapter5/第五章 动手搭建大模型.md
+++ b/docs/chapter5/第五章 动手搭建大模型.md
@@ -2,7 +2,7 @@
## 5.1 动手实现一个 LLaMA2 大模型
-Meta(原Facebook)于2023年2月发布第一款基于Transformer结构的大型语言模型-LLaMA,并于同年7月发布同系列模型-LLaMA2。我们在第四章已经学习了解的了LLM,记忆如何训练LLM等等。那本小节我们就来学习,如何动手写一个LLaMA2模型。
+Meta(原Facebook)于2023年2月发布第一款基于Transformer结构的大型语言模型LLaMA,并于同年7月发布同系列模型LLaMA2。我们在第四章已经学习了解的了LLM,记忆如何训练LLM等等。那本小节我们就来学习,如何动手写一个LLaMA2模型。
### 5.1.1 定义超参数
@@ -102,7 +102,7 @@ orch.Size([1, 50, 288])
### 5.1.3 构建 LLaMA2 Attention
-在 LLaMA2 模型中,虽然只有 LLaMA2-70B模型使用了GQA(Group Query Attention),但我们选择使用GQA来构建我们的 LLaMA Attention 模块,它可以提高模型的效率,并节省一些显存占用。
+在 LLaMA2 模型中,虽然只有 LLaMA2-70B模型使用了分组查询注意力机制(Grouped-Query Attention,GQA),但我们依然选择使用 GQA 来构建我们的 LLaMA Attention 模块,它可以提高模型的效率,并节省一些显存占用。
#### 5.1.3.1 repeat_kv
@@ -138,7 +138,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
#### 5.1.3.2 旋转嵌入
-接着我们来实现旋转嵌入,旋转嵌入是 LLaMA2 模型中的一个重要组件,它可以为attention机制提供更强的上下文信息,从而提高模型的性能。
+接着我们来实现旋转嵌入,旋转嵌入是 LLaMA2 模型中的一个重要组件,它可以为注意力机制提供更强的上下文信息,从而提高模型的性能。
首先,我们要构造获得旋转嵌入的实部和虚部的函数:
@@ -337,7 +337,7 @@ class Attention(nn.Module):
return output
```
-同样大家可以使用下面的代码来对`Attention`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 768])`,与我们输入的形状一致,说明模块的实现是正确的。
+同样大家可以使用下面的代码来对注意力模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 768])`,与我们输入的形状一致,说明模块的实现是正确的。
```python
# 创建Attention实例
@@ -1054,7 +1054,6 @@ with open('BelleGroup_sft.jsonl', 'a', encoding='utf-8') as sft:
```bash
python code/train_tokenizer.py
-
```
```python
@@ -1324,9 +1323,12 @@ class PretrainDataset(Dataset):
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
```
-在以上代码可以看出,我们的 `Pretrain Dataset` 主要是将 `text` 通过 `tokenizer` 转换成 `input_id`,然后将 `input_id` 拆分成 `X` 和 `Y`,其中 `X` 为 `input_id` 的前 n-1 个元素,`Y` 为 `input_id` 的后 n-1 `个元素。loss_mask` 主要是用来标记哪些位置需要计算损失,哪些位置不需要计算损失。如果你不太能明白,可以看下面的示意图。
+在以上代码和图5.1可以看出,`Pretrain Dataset` 主要是将 `text` 通过 `tokenizer` 转换成 `input_id`,然后将 `input_id` 拆分成 `X` 和 `Y`,其中 `X` 为 `input_id` 的前 n-1 个元素,`Y` 为 `input_id` 的后 n-1 `个元素。loss_mask` 主要是用来标记哪些位置需要计算损失,哪些位置不需要计算损失。
-
+
+

+
图5.1 预训练损失函数计算
+
图中示例展示了当`max_length=9`时的处理过程:
- **输入序列**:`[BOS, T1, T2, T3, T4, T5, T6, T7, EOS]`
@@ -1408,9 +1410,12 @@ class SFTDataset(Dataset):
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
```
-在 SFT 阶段,我这里使用的是多轮对话数据集,所以就需要区分哪些位置需要计算损失,哪些位置不需要计算损失。在上面的代码中,我使用了一个 `generate_loss_mask` 函数来生成 `loss_mask`。这个函数主要是用来生成 `loss_mask`,其中 `loss_mask` 的生成规则是:当遇到 `|assistant\n` 时,就开始计算损失,直到遇到 `|` 为止。这样就可以保证我们的模型在 SFT 阶段只计算当前轮的对话内容。那我也给出一个示意图,帮助大家理解。
+在 SFT 阶段,这里使用的是多轮对话数据集,所以就需要区分哪些位置需要计算损失,哪些位置不需要计算损失。在上面的代码中,我使用了一个 `generate_loss_mask` 函数来生成 `loss_mask`。这个函数主要是用来生成 `loss_mask`,其中 `loss_mask` 的生成规则是:当遇到 `|assistant\n` 时,就开始计算损失,直到遇到 `|` 为止。这样就可以保证我们的模型在 SFT 阶段只计算当前轮的对话内容,如图5.2所示。
-
+
+

+
图5.2 SFT 损失函数计算
+
可以看到,其实 SFT Dataset 和 Pretrain Dataset 的 `X` 和 `Y` 是一样的,只是在 SFT Dataset 中我们需要生成一个 `loss_mask` 来标记哪些位置需要计算损失,哪些位置不需要计算损失。 图中 `Input ids` 中的蓝色小方格就是AI的回答,所以是需要模型学习的地方。所以在 `loss_mask` 中,蓝色小方格对应的位置是黄色,其他位置是灰色。在代码 `loss_mask` 中的 1 对应的位置计算损失,0 对应的位置不计算损失。
@@ -2017,10 +2022,16 @@ Sample 2:
**参考资料**
-- [llama2.c](https://github.com/karpathy/llama2.c)
-- [llm.c](https://github.com/karpathy/llm.c)
-- [tokenizers](https://huggingface.co/docs/tokenizers/index)
-- [SkyWork 150B](https://huggingface.co/datasets/Skywork/SkyPile-150B)
-- [BelleGroup](https://huggingface.co/datasets/BelleGroup/train_3.5M_CN)
-- [minimind](https://github.com/jingyaogong/minimind)
-- [出门问问序列猴子开源数据集](https://github.com/mobvoi/seq-monkey-data)
\ No newline at end of file
+[1] Andrej Karpathy. (2023). *llama2.c: Fullstack Llama 2 LLM solution in pure C*. GitHub repository. https://github.com/karpathy/llama2.c
+
+[2] Andrej Karpathy. (2023). *llm.c: GPT-2/GPT-3 pretraining in C/CUDA*. GitHub repository. https://github.com/karpathy/llm.c
+
+[3] Hugging Face. (2023). *Tokenizers documentation*. https://huggingface.co/docs/tokenizers/index
+
+[4] Skywork Team. (2023). *SkyPile-150B: A large-scale bilingual dataset*. Hugging Face dataset. https://huggingface.co/datasets/Skywork/SkyPile-150B
+
+[5] BelleGroup. (2022). *train_3.5M_CN: Chinese dialogue dataset*. Hugging Face dataset. https://huggingface.co/datasets/BelleGroup/train_3.5M_CN
+
+[6] Jingyao Gong. (2023). *minimind: Minimalist LLM implementation*. GitHub repository. https://github.com/jingyaogong/minimind
+
+[7] Mobvoi. (2023). *seq-monkey-data: Llama2 training/inference data*. GitHub repository. https://github.com/mobvoi/seq-monkey-data
\ No newline at end of file