init
This commit is contained in:
566
docs/chapter5/5.1 模型结构-LLaMA.md
Normal file
566
docs/chapter5/5.1 模型结构-LLaMA.md
Normal file
@@ -0,0 +1,566 @@
|
||||
# 5.1 动手写一个 LLaMA2 模型
|
||||
|
||||
Meta(原Facebook)于2023年2月发布第一款基于Transformer结构的大型语言模型-LLaMA,并于同年7月发布同系列模型-LLaMA2。我们在第四章已经学习了解的了LLM,记忆如何训练LLM等等。那本小节我们就来学习,如何动手写一个LLaMA2模型。
|
||||
|
||||
|
||||
## 5.1.1 定义超参数
|
||||
|
||||
首先我们需要定义一些超参数,这些超参数包括模型的大小、层数、头数、词嵌入维度、隐藏层维度等等。这些超参数可以根据实际情况进行调整。
|
||||
|
||||
这里我们自定义一个`ModelArgs`类,来存储和记录我们的超参数,方便后续修改和直接倒入。
|
||||
|
||||
```python
|
||||
class ModelArgs:
|
||||
# 自定义超参数
|
||||
dim: int = 288 # 模型维度
|
||||
n_layers: int = 6 # Transformer层数
|
||||
n_heads: int = 6 # 注意力机制的头数
|
||||
n_kv_heads: Optional[int] = 6 # 键/值头数,如果未指定,则默认为n_heads
|
||||
vocab_size: int = 32000 # 词汇表大小
|
||||
hidden_dim: Optional[int] = None # 隐藏层维度,如果未指定,则使用其他规则确定
|
||||
multiple_of: int = 32 # MLP隐藏层大小是这个数的倍数
|
||||
norm_eps: float = 1e-5 # 归一化层的epsilon值
|
||||
max_seq_len: int = 256 # 最大序列长度
|
||||
dropout: float = 0.0 # 丢弃率
|
||||
```
|
||||
|
||||
我们来看一下其中的一些超参数的含义,比如`dim`是模型维度,`n_layers`是Transformer的层数,`n_heads`是注意力机制的头数,`vocab_size`是词汇表大小,`max_seq_len`是输入的最大序列长度等等。上面的代码中也对每一个参数做了详细的注释,在后面的代码中我们会根据这些超参数来构建我们的模型。
|
||||
|
||||
## 5.1.2 构建LLaMA2RMSNorm
|
||||
|
||||
`LLaMA2RMSNorm`可以用如下的数学公式表示:
|
||||
|
||||
$$
|
||||
\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{n}\sum_{i=1}^{n}w_i^2 + \epsilon}}
|
||||
$$
|
||||
|
||||
其中:
|
||||
- ( $x$ ) 是层的输入。
|
||||
- ( $w_i$ ) 代表层的权重。
|
||||
- ( $n$ ) 是权重的数量。
|
||||
- ( $\epsilon$ ) 是一个小常数,用于数值稳定性(以避免除以零的情况)。
|
||||
|
||||
这种归一化有助于通过确保权重的规模不会变得过大或过小来稳定学习过程,这在具有许多层的深度学习模型中特别有用。
|
||||
|
||||
我们可以通过如下代码实现`LLaMA2RMSNorm`:
|
||||
|
||||
```python
|
||||
class LLaMA2RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float):
|
||||
super().__init__()
|
||||
# eps是为了防止除以0的情况
|
||||
self.eps = eps
|
||||
# weight是一个可学习的参数,全部初始化为1
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
# 计算RMSNorm的核心部分
|
||||
# x.pow(2).mean(-1, keepdim=True)计算了输入x的平方的均值
|
||||
# torch.rsqrt是平方根的倒数,这样就得到了RMSNorm的分母部分,再加上eps防止分母为0
|
||||
# 最后乘以x,得到RMSNorm的结果
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
# forward函数是模型的前向传播
|
||||
# 首先将输入x转为float类型,然后进行RMSNorm,最后再转回原来的数据类型
|
||||
# 最后乘以weight,这是RMSNorm的一个可学习的缩放因子
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
```
|
||||
|
||||
并且,我们可以用下面的代码来对`LLaMA2RMSNorm`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 288])`,与我们输入的形状一致,说明模块的实现是正确的,归一化并不会改变输入的形状。
|
||||
|
||||
```python
|
||||
norm = LLaMA2RMSNorm(args.dim, args.norm_eps)
|
||||
x = torch.randn(1, 50, args.dim)
|
||||
output = norm(x)
|
||||
print(output.shape)
|
||||
|
||||
out:
|
||||
orch.Size([1, 50, 288])
|
||||
```
|
||||
|
||||
## 5.1.3 构建 LLaMA2 Attention
|
||||
|
||||
在 LLaMA2 模型中,虽然只有 LLaMA2-70B模型使用了GQA(Group Query Attention),但我们选择使用GQA来构建我们的 LLaMA Attention 模块,它可以提高模型的效率,并节省一些显存占用。
|
||||
|
||||
### 5.1.3.1 repeat_kv
|
||||
|
||||
在 LLaMA2 模型中,我们需要将键和值的维度扩展到和查询的维度一样,这样才能进行注意力计算。我们可以通过如下代码实现`repeat_kv`:
|
||||
|
||||
```python
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
# 获取输入张量的形状:批量大小、序列长度、键/值对头的数量、每个头的维度大小
|
||||
bs, slen, n_kv_heads, head_dim = x.shape
|
||||
|
||||
# 如果重复次数为1,则不需要重复,直接返回原始张量
|
||||
if n_rep == 1:
|
||||
return x
|
||||
|
||||
# 对张量进行扩展和重塑操作以重复键值对
|
||||
return (
|
||||
x[:, :, :, None, :] # 在第四个维度(头的维度前)添加一个新的维度
|
||||
.expand(bs, slen, n_kv_heads, n_rep, head_dim) # 将新添加的维度扩展到n_rep大小,实现重复的效果
|
||||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) # 重新塑形,合并键/值对头的数量和重复次数的维度
|
||||
)
|
||||
```
|
||||
|
||||
在上述代码中:
|
||||
|
||||
- 首先,获取输入张量的形状:首先,代码通过 x.shape 获取输入张量的形状,包括批量大小(bs)、序列长度(slen)、键/值对头的数量(n_kv_heads)以及每个头的维度大小(head_dim)。
|
||||
|
||||
- 然后,检查重复次数:接着,代码检查重复次数 n_rep 是否为1。如果是1,则说明不需要对键和值进行重复,直接返回原始张量 x。
|
||||
|
||||
- 最后,扩展和重塑张量:
|
||||
- 在第三个维度(即键/值对头的维度)之后添加一个新的维度,形成 `x[:, :, :, None, :]`。
|
||||
- 使用 `expand` 方法将新添加的维度扩展到 `n_rep` 大小,实现键/值对的重复效果。
|
||||
- 最后,通过 reshape 方法重新塑形,将扩展后的维度合并回键/值对头的数量中,即 `x.reshape(bs, slen, n_kv_heads * n_rep, head_dim)`,这样最终的张量形状就达到了与查询维度一致的效果。
|
||||
|
||||
### 5.1.3.2 旋转嵌入
|
||||
|
||||
接着我们来实现旋转嵌入,旋转嵌入是 LLaMA2 模型中的一个重要组件,它可以为attention机制提供更强的上下文信息,从而提高模型的性能。
|
||||
|
||||
首先,我们要构造获得旋转嵌入的实部和虚部的函数:
|
||||
|
||||
```python
|
||||
# 注意:此处的dim应为 dim//n_head,因为我们是对每个head进行旋转嵌入
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||||
# torch.arange(0, dim, 2)[: (dim // 2)].float()生成了一个从0开始,步长为2的序列,长度为dim的一半
|
||||
# 然后每个元素除以dim,再取theta的倒数,得到频率
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
# 生成一个从0到end的序列,长度为end
|
||||
t = torch.arange(end, device=freqs.device)
|
||||
# 计算外积,得到一个二维矩阵,每一行是t的元素乘以freqs的元素
|
||||
freqs = torch.outer(t, freqs).float()
|
||||
# 计算频率的余弦值,得到实部
|
||||
freqs_cos = torch.cos(freqs)
|
||||
# 计算频率的正弦值,得到虚部
|
||||
freqs_sin = torch.sin(freqs)
|
||||
return freqs_cos, freqs_sin
|
||||
```
|
||||
|
||||
- 计算频率序列:
|
||||
- `torch.arange(0, dim, 2)[: (dim // 2)].float()` 生成了一个从0开始,步长为2的序列,其长度为`dim`的一半。
|
||||
- 每个元素除以`dim`后取`theta`的倒数,得到一个频率序列 `freqs`。这一步是为了生成适合旋转嵌入的频率。
|
||||
- 生成时间序列:
|
||||
- `t = torch.arange(end, device=freqs.device)` 生成一个从`0`到`end`的序列,长度为`end`。`end`通常是序列的最大长度。
|
||||
- 计算频率的外积
|
||||
- `freqs = torch.outer(t, freqs).float()` 计算时间序列 `t` 和频率序列 `freqs` 的外积,得到一个二维矩阵 `freqs`。每一行是时间序列 `t` 的元素乘以频率序列 `freqs` 的元素。
|
||||
- 计算实部和虚部
|
||||
- `freqs_cos = torch.cos(freqs)` 计算频率矩阵 `freqs` 的余弦值,得到旋转嵌入的实部。
|
||||
- `freqs_sin = torch.sin(freqs)` 计算频率矩阵 `freqs` 的正弦值,得到旋转嵌入的虚部。
|
||||
|
||||
最终,该函数返回两个矩阵 `freqs_cos` 和 `freqs_sin`,分别表示旋转嵌入的实部和虚部,用于后续的计算。
|
||||
|
||||
接着,我们来构造调整张量形状的`reshape_for_broadcast`函数,这个函数的主要目的是调整 `freqs_cis` 的形状,使其在进行广播操作时与 `x` 的维度对齐,从而能够进行正确的张量运算。
|
||||
|
||||
```python
|
||||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||
# 获取x的维度数
|
||||
ndim = x.ndim
|
||||
|
||||
# 断言,确保1在x的维度范围内
|
||||
assert 0 <= 1 < ndim
|
||||
|
||||
# 断言,确保freqs_cis的形状与x的第二维和最后一维相同
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
|
||||
# 构造一个新的形状,除了第二维和最后一维,其他维度都为1,这样做是为了能够将freqs_cis与x进行广播操作
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
|
||||
# 将freqs_cis调整为新的形状,并返回
|
||||
return freqs_cis.view(shape)
|
||||
```
|
||||
|
||||
最后,我们可以通过如下代码实现旋转嵌入:
|
||||
|
||||
```python
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# 将查询和键张量转换为浮点数,并重塑形状以分离实部和虚部
|
||||
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
|
||||
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
|
||||
|
||||
# 重新塑形频率张量以进行广播
|
||||
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
|
||||
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
|
||||
|
||||
# 应用旋转,分别计算旋转后的实部和虚部
|
||||
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
|
||||
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
|
||||
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
|
||||
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
|
||||
|
||||
# 将最后两个维度合并,并还原为原始张量的形状
|
||||
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
|
||||
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
|
||||
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
```
|
||||
|
||||
这里我们给出可以测试`apply_rotary_emb`函数的代码,大家也可以尝试在代码中添加断点,来查看每一步的计算结果。
|
||||
|
||||
```python
|
||||
xq = torch.randn(1, 50, 6, 48) # bs, seq_len, dim//n_head, n_head_dim
|
||||
xk = torch.randn(1, 50, 6, 48) # bs, seq_len, dim//n_head, n_head_dim
|
||||
|
||||
# 使用 precompute_freqs_cis 函数获取 sin和cos
|
||||
cos, sin = precompute_freqs_cis(288//6, 50)
|
||||
print(cos.shape, sin.shape)
|
||||
xq_out, xk_out = apply_rotary_emb(xq, xk, cos, sin)
|
||||
|
||||
xq_out.shape, xk_out.shape
|
||||
```
|
||||
|
||||
OUT:
|
||||
```
|
||||
torch.Size([50, 24]) torch.Size([50, 24])
|
||||
```
|
||||
|
||||
### 5.1.3.3 组装 LLaMA2 Attention
|
||||
|
||||
在上面我们已经完成了旋转嵌入的实现,接下来我们就可以构建 LLaMA2 Attention 模块了。
|
||||
|
||||
```python
|
||||
class LLaMA2Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
# 根据是否指定n_kv_heads,确定用于键(key)和值(value)的头的数量。
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
# 确保总头数可以被键值头数整除。
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
|
||||
# 模型并行处理大小,默认为1。
|
||||
model_parallel_size = 1
|
||||
# 本地计算头数,等于总头数除以模型并行处理大小。
|
||||
self.n_local_heads = args.n_heads // model_parallel_size
|
||||
# 本地键值头数,等于键值头数除以模型并行处理大小。
|
||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||||
# 重复次数,用于扩展键和值的尺寸。
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
# 每个头的维度,等于模型维度除以头的总数。
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
|
||||
# 定义权重矩阵。
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
# 输出权重矩阵。
|
||||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||||
|
||||
# 定义dropout。
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
# 保存dropout概率。
|
||||
self.dropout = args.dropout
|
||||
|
||||
# 检查是否使用Flash Attention(需要PyTorch >= 2.0)。
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
||||
if not self.flash:
|
||||
# 若不支持Flash Attention,则使用手动实现的注意力机制,并设置mask。
|
||||
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
# 创建一个上三角矩阵,用于遮蔽未来信息。
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
# 注册为模型的缓冲区
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):
|
||||
# 获取批次大小和序列长度,[batch_size, seq_len, dim]
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
# 计算查询(Q)、键(K)、值(V)。
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
# 调整形状以适应头的维度。
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
# 应用旋转位置嵌入(RoPE)。
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
|
||||
|
||||
# 对键和值进行扩展以适应重复次数。
|
||||
xk = repeat_kv(xk, self.n_rep)
|
||||
xv = repeat_kv(xv, self.n_rep)
|
||||
|
||||
# 将头作为批次维度处理。
|
||||
xq = xq.transpose(1, 2)
|
||||
xk = xk.transpose(1, 2)
|
||||
xv = xv.transpose(1, 2)
|
||||
|
||||
# 根据是否支持Flash Attention,选择实现方式。
|
||||
if self.flash:
|
||||
# 使用Flash Attention。
|
||||
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||||
else:
|
||||
# 使用手动实现的注意力机制。
|
||||
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
assert hasattr(self, 'mask')
|
||||
scores = scores + self.mask[:, :, :seqlen, :seqlen]
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
output = torch.matmul(scores, xv)
|
||||
|
||||
# 恢复时间维度并合并头。
|
||||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
|
||||
# 最终投影回残差流。
|
||||
output = self.wo(output)
|
||||
output = self.resid_dropout(output)
|
||||
return output
|
||||
```
|
||||
|
||||
同样大家可以使用下面的代码来对`LLaMA2Attention`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 288])`,与我们输入的形状一致,说明模块的实现是正确的。
|
||||
|
||||
```python
|
||||
# 创建Attention实例
|
||||
attention_model = LLaMA2Attention(args)
|
||||
|
||||
# 模拟输入数据
|
||||
batch_size = 1
|
||||
seq_len = 50 # 假设实际使用的序列长度为50
|
||||
dim = args.dim
|
||||
x = torch.rand(batch_size, seq_len, dim) # 随机生成输入张量
|
||||
# freqs_cos = torch.rand(seq_len, dim // 2) # 模拟cos频率,用于RoPE
|
||||
# freqs_sin = torch.rand(seq_len, dim // 2) # 模拟sin频率,用于RoPE
|
||||
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(dim//args.n_heads, seq_len)
|
||||
|
||||
# 运行Attention模型
|
||||
output = attention_model(x, freqs_cos, freqs_sin)
|
||||
|
||||
# attention出来之后的形状 依然是[batch_size, seq_len, dim]
|
||||
print("Output shape:", output.shape)
|
||||
```
|
||||
|
||||
OUT:
|
||||
```
|
||||
Output shape: torch.Size([1, 50, 288])
|
||||
```
|
||||
|
||||
## 5.1.4 构建 LLaMA2 MLP模块
|
||||
|
||||
相对于前面我们实现的LLaMA2 Attention模块,LLaMA2 MLP模块的实现要简单一些。我们可以通过如下代码实现`LLaMA2MLP`:
|
||||
|
||||
```python
|
||||
class LLaMA2MLP(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
|
||||
super().__init__()
|
||||
# 如果没有指定隐藏层的维度,我们将其设置为输入维度的4倍
|
||||
# 然后将其减少到2/3,最后确保它是multiple_of的倍数
|
||||
if hidden_dim is None:
|
||||
hidden_dim = 4 * dim
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
# 定义第一层线性变换,从输入维度到隐藏维度
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
# 定义第二层线性变换,从隐藏维度到输入维度
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
# 定义第三层线性变换,从输入维度到隐藏维度
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
# 定义dropout层,用于防止过拟合
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
# 前向传播函数
|
||||
# 首先,输入x通过第一层线性变换和SILU激活函数
|
||||
# 然后,结果乘以输入x通过第三层线性变换的结果
|
||||
# 最后,通过第二层线性变换和dropout层
|
||||
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
||||
```
|
||||
|
||||
我们着重观察一下`forward`函数的实现,首先,输入 `x` 通过第一层线性变换 `self.w1` 和 `SILU` 激活函数,然后,结果乘以输入 `x` 通过第三层线性变换 `self.w3` 的结果,最后,通过第二层线性变换 `self.w2` 和 `dropout` 层,得到最终输出。
|
||||
|
||||
同样大家可以使用下面的代码来对`LLaMAMLP`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 288])`,与我们输入的形状一致,说明模块的实现是正确的。
|
||||
|
||||
```python
|
||||
# 创建MLP实例
|
||||
mlp = LLaMA2MLP(args.dim, args.hidden_dim, args.multiple_of, args.dropout)
|
||||
# 随机生成数据
|
||||
x = torch.randn(1, 50, 288)
|
||||
# 运行MLP模型
|
||||
output = mlp(x)
|
||||
print(output.shape)
|
||||
```
|
||||
|
||||
OUT:
|
||||
```
|
||||
torch.Size([1, 50, 288])
|
||||
```
|
||||
|
||||
## 5.1.5 LLaMA2 Decoder Layer
|
||||
|
||||
到这里,我们已经实现了`LLaMA2`模型的`Attention`模块和`MLP`模块,接下来我们就可以构建`LLaMA2`的`Decoder Layer`了。
|
||||
|
||||
```python
|
||||
class LLaMA2DecoderLayer(nn.Module):
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
super().__init__()
|
||||
# 定义多头注意力的头数
|
||||
self.n_heads = args.n_heads
|
||||
# 定义输入维度
|
||||
self.dim = args.dim
|
||||
# 定义每个头的维度,等于输入维度除以头数
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
# 定义LLaMA2Attention对象,用于进行多头注意力计算
|
||||
self.attention = LLaMA2Attention(args)
|
||||
# 定义LLaMAMLP对象,用于进行前馈神经网络计算
|
||||
self.feed_forward = LLaMA2MLP(
|
||||
dim=args.dim,
|
||||
hidden_dim=args.hidden_dim,
|
||||
multiple_of=args.multiple_of,
|
||||
dropout=args.dropout,
|
||||
)
|
||||
# 定义层的ID
|
||||
self.layer_id = layer_id
|
||||
# 定义注意力计算的归一化层
|
||||
self.attention_norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)
|
||||
# 定义前馈神经网络计算的归一化层
|
||||
self.ffn_norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)
|
||||
|
||||
def forward(self, x, freqs_cos, freqs_sin):
|
||||
# 前向传播函数
|
||||
# 首先,输入x经过注意力归一化层,然后进行注意力计算,结果与输入x相加得到h
|
||||
# 然后,h经过前馈神经网络归一化层,然后进行前馈神经网络计算,结果与h相加得到输出
|
||||
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
|
||||
out = h + self.feed_forward.forward(self.ffn_norm(h))
|
||||
return out
|
||||
```
|
||||
|
||||
`DecoderLayer`就是把我们上面完成的`Attention`模块和`MLP`模块组合在一起,实现了一个完整的`Transformer`模块。
|
||||
|
||||
同样大家可以使用下面的代码来对`LLaMA2DecoderLayer`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 288])`,与我们输入的形状一致,说明模块的实现是正确的。
|
||||
|
||||
```python
|
||||
# 创建LLaMADecoderLayer实例
|
||||
decoderlayer = LLaMA2DecoderLayer(0, args)
|
||||
|
||||
# 模拟输入数据
|
||||
dim = args.dim
|
||||
seq_len = 50
|
||||
|
||||
x = torch.randn(1, seq_len, dim) # [bs, seq_len, dim]
|
||||
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(dim//args.n_heads, seq_len)
|
||||
|
||||
out = decoderlayer(x, freqs_cos, freqs_sin)
|
||||
|
||||
print(out.shape) # 形状和输入的x一样 [batch_size, seq_len, dim]
|
||||
```
|
||||
|
||||
OUT:
|
||||
```
|
||||
torch.Size([1, 50, 288])
|
||||
```
|
||||
|
||||
## 5.1.6 构建 LLaMA2 模型
|
||||
|
||||
好了,我们已经完了上述所有的模块的实现,接下来就是激动人心的时刻,我们可以构建`LLaMA2`模型了。,`LLaMA2`模型就是将`LLaMA2DecoderLayer`模块堆叠起来,构成一个完整的`Transformer`模型。
|
||||
|
||||
```python
|
||||
class LLaMA2Model(nn.Module):
|
||||
last_loss: Optional[torch.Tensor]
|
||||
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
# 初始化模型参数
|
||||
self.args = args
|
||||
# 词汇表大小
|
||||
self.vocab_size = args.vocab_size
|
||||
# 层数
|
||||
self.n_layers = args.n_layers
|
||||
|
||||
# 词嵌入层
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
# Dropout层
|
||||
self.dropout = nn.Dropout(args.dropout)
|
||||
# Decoder层
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for layer_id in range(args.n_layers):
|
||||
self.layers.append(LLaMA2DecoderLayer(layer_id, args))
|
||||
# 归一化层
|
||||
self.norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)
|
||||
# 输出层
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
# 将词嵌入层的权重与输出层的权重共享
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
|
||||
# 预计算相对位置嵌入的频率
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(self.args.dim // self.args.n_heads, self.args.max_seq_len)
|
||||
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
||||
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
||||
|
||||
# 初始化所有权重
|
||||
self.apply(self._init_weights)
|
||||
# 对残差投影进行特殊的缩放初始化
|
||||
for pn, p in self.named_parameters():
|
||||
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
|
||||
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * args.n_layers))
|
||||
|
||||
# 初始化最后一次前向传播的损失属性
|
||||
self.last_loss = None
|
||||
|
||||
def _init_weights(self, module):
|
||||
# 初始化权重的函数
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# 前向传播函数
|
||||
_bsz, seqlen = tokens.shape
|
||||
# 通过词嵌入层和Dropout层
|
||||
h = self.tok_embeddings(tokens)
|
||||
h = self.dropout(h)
|
||||
# 获取相对位置嵌入的频率
|
||||
freqs_cos = self.freqs_cos[:seqlen]
|
||||
freqs_sin = self.freqs_sin[:seqlen]
|
||||
|
||||
# 通过Decoder层
|
||||
for layer in self.layers:
|
||||
h = layer(h, freqs_cos, freqs_sin)
|
||||
# 通过归一化层
|
||||
h = self.norm(h)
|
||||
|
||||
if targets is not None:
|
||||
# 如果给定了目标,计算损失
|
||||
logits = self.output(h)
|
||||
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
else:
|
||||
# 推理时的小优化:只对最后一个位置的输出进行前向传播
|
||||
logits = self.output(h[:, [-1], :])
|
||||
self.last_loss = None
|
||||
|
||||
return logits
|
||||
```
|
||||
|
||||
同样大家可以使用下面的代码来对`LLaMA2Model`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 32000])`,与我们输入的形状一致,说明模块的实现是正确的。
|
||||
|
||||
```python
|
||||
# LLaMA2Model.forward 接受两个参数,tokens和targets,其中tokens是输入的张量, 应为int类型
|
||||
x = torch.randint(0, 32000, (1, 50)) # [bs, seq_len]
|
||||
# 实例化LLaMA2Model
|
||||
model = LLaMA2Model(args=args)
|
||||
# 计算model的全部参数
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print('Number of parameters:', num_params)
|
||||
|
||||
out = model(x)
|
||||
print(out.shape) # [batch_size, 1, vocab_size]
|
||||
```
|
||||
|
||||
OUT:
|
||||
```
|
||||
Number of parameters: 15191712
|
||||
torch.Size([1, 1, 32000])
|
||||
```
|
||||
|
||||
641
docs/chapter5/llama2.ipynb
Normal file
641
docs/chapter5/llama2.ipynb
Normal file
@@ -0,0 +1,641 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import math\n",
|
||||
"import struct\n",
|
||||
"import inspect\n",
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import Any, Optional, Tuple\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from torch import nn"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ModelArgs:\n",
|
||||
" # 自定义超参数\n",
|
||||
" dim: int = 288 # 模型维度\n",
|
||||
" n_layers: int = 6 # Transformer层数\n",
|
||||
" n_heads: int = 6 # 注意力机制的头数\n",
|
||||
" n_kv_heads: Optional[int] = 6 # 键/值头数,如果未指定,则默认为n_heads\n",
|
||||
" vocab_size: int = 32000 # 词汇表大小\n",
|
||||
" hidden_dim: Optional[int] = None # 隐藏层维度,如果未指定,则使用其他规则确定\n",
|
||||
" multiple_of: int = 32 # MLP隐藏层大小是这个数的倍数\n",
|
||||
" norm_eps: float = 1e-5 # 归一化层的epsilon值\n",
|
||||
" max_seq_len: int = 256 # 最大序列长度\n",
|
||||
" dropout: float = 0.0 # 丢弃率\n",
|
||||
"\n",
|
||||
"args = ModelArgs()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class LLaMA2RMSNorm(nn.Module):\n",
|
||||
" def __init__(self, dim: int, eps: float):\n",
|
||||
" super().__init__()\n",
|
||||
" # eps是为了防止除以0的情况\n",
|
||||
" self.eps = eps\n",
|
||||
" # weight是一个可学习的参数,全部初始化为1\n",
|
||||
" self.weight = nn.Parameter(torch.ones(dim))\n",
|
||||
"\n",
|
||||
" def _norm(self, x):\n",
|
||||
" # 计算RMSNorm的核心部分\n",
|
||||
" # x.pow(2).mean(-1, keepdim=True)计算了输入x的平方的均值\n",
|
||||
" # torch.rsqrt是平方根的倒数,这样就得到了RMSNorm的分母部分,再加上eps防止分母为0\n",
|
||||
" # 最后乘以x,得到RMSNorm的结果\n",
|
||||
" return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" # forward函数是模型的前向传播\n",
|
||||
" # 首先将输入x转为float类型,然后进行RMSNorm,最后再转回原来的数据类型\n",
|
||||
" # 最后乘以weight,这是RMSNorm的一个可学习的缩放因子\n",
|
||||
" output = self._norm(x.float()).type_as(x)\n",
|
||||
" return output * self.weight"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([1, 50, 288])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"norm = LLaMA2RMSNorm(args.dim, args.norm_eps)\n",
|
||||
"x = torch.randn(1, 50, args.dim)\n",
|
||||
"output = norm(x)\n",
|
||||
"print(output.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 获得旋转嵌入的实部和虚部\n",
|
||||
"# 注意:此处的dim应为 dim//n_head,因为我们是对每个head进行旋转嵌入\n",
|
||||
"def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):\n",
|
||||
" # torch.arange(0, dim, 2)[: (dim // 2)].float()生成了一个从0开始,步长为2的序列,长度为dim的一半\n",
|
||||
" # 然后每个元素除以dim,再取theta的倒数,得到频率\n",
|
||||
" freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))\n",
|
||||
" # 生成一个从0到end的序列,长度为end\n",
|
||||
" t = torch.arange(end, device=freqs.device)\n",
|
||||
" # 计算外积,得到一个二维矩阵,每一行是t的元素乘以freqs的元素\n",
|
||||
" freqs = torch.outer(t, freqs).float()\n",
|
||||
" # 计算频率的余弦值,得到实部\n",
|
||||
" freqs_cos = torch.cos(freqs)\n",
|
||||
" # 计算频率的正弦值,得到虚部\n",
|
||||
" freqs_sin = torch.sin(freqs)\n",
|
||||
" return freqs_cos, freqs_sin"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([50, 24]) torch.Size([50, 24])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"x = torch.randn(1, 50, 288)\n",
|
||||
"freqs_cos, freqs_sin = precompute_freqs_cis(288//6, 50)\n",
|
||||
"print(freqs_cos.shape, freqs_sin.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 此函数的作用是将freqs_cis调整为与x的形状相同,以便能够与x进行广播操作\n",
|
||||
"def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):\n",
|
||||
" # 获取x的维度数\n",
|
||||
" ndim = x.ndim\n",
|
||||
" # 断言,确保1在x的维度范围内\n",
|
||||
" assert 0 <= 1 < ndim\n",
|
||||
" # 断言,确保freqs_cis的形状与x的第二维和最后一维相同\n",
|
||||
" assert freqs_cis.shape == (x.shape[1], x.shape[-1])\n",
|
||||
" # 构造一个新的形状,除了第二维和最后一维,其他维度都为1,这样做是为了能够将freqs_cis与x进行广播操作\n",
|
||||
" shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]\n",
|
||||
" # 将freqs_cis调整为新的形状,并返回\n",
|
||||
" return freqs_cis.view(shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def apply_rotary_emb(\n",
|
||||
" xq: torch.Tensor,\n",
|
||||
" xk: torch.Tensor,\n",
|
||||
" freqs_cos: torch.Tensor,\n",
|
||||
" freqs_sin: torch.Tensor\n",
|
||||
") -> Tuple[torch.Tensor, torch.Tensor]:\n",
|
||||
"\n",
|
||||
" # 将查询和键张量转换为浮点数,并重塑形状以分离实部和虚部\n",
|
||||
" xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)\n",
|
||||
" xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)\n",
|
||||
"\n",
|
||||
" # 重新塑形频率张量以进行广播\n",
|
||||
" freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)\n",
|
||||
" freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)\n",
|
||||
"\n",
|
||||
" # 应用旋转,分别计算旋转后的实部和虚部\n",
|
||||
" xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin\n",
|
||||
" xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos\n",
|
||||
" xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin\n",
|
||||
" xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos\n",
|
||||
"\n",
|
||||
" # 将最后两个维度合并,并还原为原始张量的形状\n",
|
||||
" xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)\n",
|
||||
" xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)\n",
|
||||
"\n",
|
||||
" return xq_out.type_as(xq), xk_out.type_as(xk)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([50, 24]) torch.Size([50, 24])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(torch.Size([1, 50, 6, 48]), torch.Size([1, 50, 6, 48]))"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"xq = torch.randn(1, 50, 6, 48) # bs, seq_len, dim//n_head, n_head_dim\n",
|
||||
"xk = torch.randn(1, 50, 6, 48) # bs, seq_len, dim//n_head, n_head_dim\n",
|
||||
"\n",
|
||||
"# 使用 precompute_freqs_cis 函数获取 sin和cos\n",
|
||||
"cos, sin = precompute_freqs_cis(288//6, 50)\n",
|
||||
"print(cos.shape, sin.shape)\n",
|
||||
"xq_out, xk_out = apply_rotary_emb(xq, xk, cos, sin)\n",
|
||||
"\n",
|
||||
"xq_out.shape, xk_out.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:\n",
|
||||
" # 获取输入张量的形状:批量大小、序列长度、键/值对头的数量、每个头的维度大小\n",
|
||||
" bs, slen, n_kv_heads, head_dim = x.shape\n",
|
||||
" \n",
|
||||
" # 如果重复次数为1,则不需要重复,直接返回原始张量\n",
|
||||
" if n_rep == 1:\n",
|
||||
" return x\n",
|
||||
" \n",
|
||||
" # 对张量进行扩展和重塑操作以重复键值对\n",
|
||||
" return (\n",
|
||||
" x[:, :, :, None, :] # 在第四个维度(头的维度前)添加一个新的维度\n",
|
||||
" .expand(bs, slen, n_kv_heads, n_rep, head_dim) # 将新添加的维度扩展到n_rep大小,实现重复的效果\n",
|
||||
" .reshape(bs, slen, n_kv_heads * n_rep, head_dim) # 重新塑形,合并键/值对头的数量和重复次数的维度\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class LLaMA2Attention(nn.Module):\n",
|
||||
" def __init__(self, args: ModelArgs):\n",
|
||||
" super().__init__()\n",
|
||||
" # 根据是否指定n_kv_heads,确定用于键(key)和值(value)的头的数量。\n",
|
||||
" self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads\n",
|
||||
" # 确保总头数可以被键值头数整除。\n",
|
||||
" assert args.n_heads % self.n_kv_heads == 0\n",
|
||||
"\n",
|
||||
" # 模型并行处理大小,默认为1。\n",
|
||||
" model_parallel_size = 1\n",
|
||||
" # 本地计算头数,等于总头数除以模型并行处理大小。\n",
|
||||
" self.n_local_heads = args.n_heads // model_parallel_size\n",
|
||||
" # 本地键值头数,等于键值头数除以模型并行处理大小。\n",
|
||||
" self.n_local_kv_heads = self.n_kv_heads // model_parallel_size\n",
|
||||
" # 重复次数,用于扩展键和值的尺寸。\n",
|
||||
" self.n_rep = self.n_local_heads // self.n_local_kv_heads\n",
|
||||
" # 每个头的维度,等于模型维度除以头的总数。\n",
|
||||
" self.head_dim = args.dim // args.n_heads\n",
|
||||
"\n",
|
||||
" # 定义权重矩阵。\n",
|
||||
" self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)\n",
|
||||
" self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)\n",
|
||||
" self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)\n",
|
||||
" # 输出权重矩阵。\n",
|
||||
" self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)\n",
|
||||
"\n",
|
||||
" # 定义dropout。\n",
|
||||
" self.attn_dropout = nn.Dropout(args.dropout)\n",
|
||||
" self.resid_dropout = nn.Dropout(args.dropout)\n",
|
||||
" # 保存dropout概率。\n",
|
||||
" self.dropout = args.dropout\n",
|
||||
"\n",
|
||||
" # 检查是否使用Flash Attention(需要PyTorch >= 2.0)。\n",
|
||||
" self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')\n",
|
||||
" if not self.flash:\n",
|
||||
" # 若不支持Flash Attention,则使用手动实现的注意力机制,并设置mask。\n",
|
||||
" print(\"WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0\")\n",
|
||||
" # 创建一个上三角矩阵,用于遮蔽未来信息。\n",
|
||||
" mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float(\"-inf\"))\n",
|
||||
" mask = torch.triu(mask, diagonal=1)\n",
|
||||
" # 注册为模型的缓冲区\n",
|
||||
" self.register_buffer(\"mask\", mask)\n",
|
||||
"\n",
|
||||
" def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):\n",
|
||||
" # 获取批次大小和序列长度,[batch_size, seq_len, dim]\n",
|
||||
" bsz, seqlen, _ = x.shape\n",
|
||||
"\n",
|
||||
" # 计算查询(Q)、键(K)、值(V)。\n",
|
||||
" xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)\n",
|
||||
" # 调整形状以适应头的维度。\n",
|
||||
" xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)\n",
|
||||
" xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)\n",
|
||||
" xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)\n",
|
||||
"\n",
|
||||
" # 应用旋转位置嵌入(RoPE)。\n",
|
||||
" xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)\n",
|
||||
"\n",
|
||||
" # 对键和值进行扩展以适应重复次数。\n",
|
||||
" xk = repeat_kv(xk, self.n_rep)\n",
|
||||
" xv = repeat_kv(xv, self.n_rep)\n",
|
||||
"\n",
|
||||
" # 将头作为批次维度处理。\n",
|
||||
" xq = xq.transpose(1, 2)\n",
|
||||
" xk = xk.transpose(1, 2)\n",
|
||||
" xv = xv.transpose(1, 2)\n",
|
||||
"\n",
|
||||
" # 根据是否支持Flash Attention,选择实现方式。\n",
|
||||
" if self.flash:\n",
|
||||
" # 使用Flash Attention。\n",
|
||||
" output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)\n",
|
||||
" else:\n",
|
||||
" # 使用手动实现的注意力机制。\n",
|
||||
" scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)\n",
|
||||
" assert hasattr(self, 'mask')\n",
|
||||
" scores = scores + self.mask[:, :, :seqlen, :seqlen]\n",
|
||||
" scores = F.softmax(scores.float(), dim=-1).type_as(xq)\n",
|
||||
" scores = self.attn_dropout(scores)\n",
|
||||
" output = torch.matmul(scores, xv)\n",
|
||||
"\n",
|
||||
" # 恢复时间维度并合并头。\n",
|
||||
" output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)\n",
|
||||
"\n",
|
||||
" # 最终投影回残差流。\n",
|
||||
" output = self.wo(output)\n",
|
||||
" output = self.resid_dropout(output)\n",
|
||||
" return output"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([50, 24]) torch.Size([50, 24])\n",
|
||||
"Output shape: torch.Size([1, 50, 288])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 创建Attention实例\n",
|
||||
"attention_model = LLaMA2Attention(args)\n",
|
||||
"\n",
|
||||
"# 模拟输入数据\n",
|
||||
"batch_size = 1\n",
|
||||
"seq_len = 50 # 假设实际使用的序列长度为50\n",
|
||||
"dim = args.dim\n",
|
||||
"x = torch.rand(batch_size, seq_len, dim) # 随机生成输入张量\n",
|
||||
"# freqs_cos = torch.rand(seq_len, dim // 2) # 模拟cos频率,用于RoPE\n",
|
||||
"# freqs_sin = torch.rand(seq_len, dim // 2) # 模拟sin频率,用于RoPE\n",
|
||||
"\n",
|
||||
"freqs_cos, freqs_sin = precompute_freqs_cis(dim//args.n_heads, seq_len)\n",
|
||||
"\n",
|
||||
"print(freqs_cos.shape, freqs_sin.shape)\n",
|
||||
"\n",
|
||||
"# 运行Attention模型\n",
|
||||
"output = attention_model(x, freqs_cos, freqs_sin)\n",
|
||||
"\n",
|
||||
"# attention出来之后的形状 依然是[batch_size, seq_len, dim]\n",
|
||||
"print(\"Output shape:\", output.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class LLaMA2MLP(nn.Module):\n",
|
||||
" def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):\n",
|
||||
" super().__init__()\n",
|
||||
" # 如果没有指定隐藏层的维度,我们将其设置为输入维度的4倍\n",
|
||||
" # 然后将其减少到2/3,最后确保它是multiple_of的倍数\n",
|
||||
" if hidden_dim is None:\n",
|
||||
" hidden_dim = 4 * dim\n",
|
||||
" hidden_dim = int(2 * hidden_dim / 3)\n",
|
||||
" hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)\n",
|
||||
" # 定义第一层线性变换,从输入维度到隐藏维度\n",
|
||||
" self.w1 = nn.Linear(dim, hidden_dim, bias=False)\n",
|
||||
" # 定义第二层线性变换,从隐藏维度到输入维度\n",
|
||||
" self.w2 = nn.Linear(hidden_dim, dim, bias=False)\n",
|
||||
" # 定义第三层线性变换,从输入维度到隐藏维度\n",
|
||||
" self.w3 = nn.Linear(dim, hidden_dim, bias=False)\n",
|
||||
" # 定义dropout层,用于防止过拟合\n",
|
||||
" self.dropout = nn.Dropout(dropout)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" # 前向传播函数\n",
|
||||
" # 首先,输入x通过第一层线性变换和SILU激活函数\n",
|
||||
" # 然后,结果乘以输入x通过第三层线性变换的结果\n",
|
||||
" # 最后,通过第二层线性变换和dropout层\n",
|
||||
" return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([1, 50, 288])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 创建MLP实例\n",
|
||||
"mlp = LLaMAMLP(args.dim, args.hidden_dim, args.multiple_of, args.dropout)\n",
|
||||
"# 随机生成数据\n",
|
||||
"x = torch.randn(1, 50, 288)\n",
|
||||
"# 运行MLP模型\n",
|
||||
"output = mlp(x)\n",
|
||||
"print(output.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class LLaMA2DecoderLayer(nn.Module):\n",
|
||||
" def __init__(self, layer_id: int, args: ModelArgs):\n",
|
||||
" super().__init__()\n",
|
||||
" # 定义多头注意力的头数\n",
|
||||
" self.n_heads = args.n_heads\n",
|
||||
" # 定义输入维度\n",
|
||||
" self.dim = args.dim\n",
|
||||
" # 定义每个头的维度,等于输入维度除以头数\n",
|
||||
" self.head_dim = args.dim // args.n_heads\n",
|
||||
" # 定义LLaMA2Attention对象,用于进行多头注意力计算\n",
|
||||
" self.attention = LLaMA2Attention(args)\n",
|
||||
" # 定义LLaMAMLP对象,用于进行前馈神经网络计算\n",
|
||||
" self.feed_forward = LLaMA2MLP(\n",
|
||||
" dim=args.dim,\n",
|
||||
" hidden_dim=args.hidden_dim,\n",
|
||||
" multiple_of=args.multiple_of,\n",
|
||||
" dropout=args.dropout,\n",
|
||||
" )\n",
|
||||
" # 定义层的ID\n",
|
||||
" self.layer_id = layer_id\n",
|
||||
" # 定义注意力计算的归一化层\n",
|
||||
" self.attention_norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)\n",
|
||||
" # 定义前馈神经网络计算的归一化层\n",
|
||||
" self.ffn_norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)\n",
|
||||
"\n",
|
||||
" def forward(self, x, freqs_cos, freqs_sin):\n",
|
||||
" # 前向传播函数\n",
|
||||
" # 首先,输入x经过注意力归一化层,然后进行注意力计算,结果与输入x相加得到h\n",
|
||||
" # 然后,h经过前馈神经网络归一化层,然后进行前馈神经网络计算,结果与h相加得到输出\n",
|
||||
" h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)\n",
|
||||
" out = h + self.feed_forward.forward(self.ffn_norm(h))\n",
|
||||
" return out"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([1, 50, 288]) torch.Size([50, 24]) torch.Size([50, 24])\n",
|
||||
"torch.Size([1, 50, 288])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# LLaMADecoderLayer.forward 函数的输入是 x, freqs_cos, freqs_sin, 其中x的形状是[batch_size, seq_len, dim]\n",
|
||||
"# 由于llama2使用了GQA Attention,所以precompute_freqs_cis函数输入参数应该为dim//n_heads,seq_len、\n",
|
||||
"\n",
|
||||
"# 创建LLaMADecoderLayer实例\n",
|
||||
"decoderlayer = LLaMA2DecoderLayer(0, args)\n",
|
||||
"\n",
|
||||
"# 模拟输入数据\n",
|
||||
"dim = args.dim\n",
|
||||
"seq_len = 50\n",
|
||||
"\n",
|
||||
"x = torch.randn(1, seq_len, dim) # [bs, seq_len, dim]\n",
|
||||
"\n",
|
||||
"freqs_cos, freqs_sin = precompute_freqs_cis(dim//args.n_heads, seq_len)\n",
|
||||
"print(x.shape, freqs_cos.shape, freqs_sin.shape)\n",
|
||||
"\n",
|
||||
"out = decoderlayer(x, freqs_cos, freqs_sin)\n",
|
||||
"\n",
|
||||
"print(out.shape) # 形状和输入的x一样 [batch_size, seq_len, dim]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class LLaMA2Model(nn.Module):\n",
|
||||
" last_loss: Optional[torch.Tensor]\n",
|
||||
"\n",
|
||||
" def __init__(self, args: ModelArgs):\n",
|
||||
" super().__init__()\n",
|
||||
" # 初始化模型参数\n",
|
||||
" self.args = args\n",
|
||||
" # 词汇表大小\n",
|
||||
" self.vocab_size = args.vocab_size\n",
|
||||
" # 层数\n",
|
||||
" self.n_layers = args.n_layers\n",
|
||||
"\n",
|
||||
" # 词嵌入层\n",
|
||||
" self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)\n",
|
||||
" # Dropout层\n",
|
||||
" self.dropout = nn.Dropout(args.dropout)\n",
|
||||
" # Decoder层\n",
|
||||
" self.layers = torch.nn.ModuleList()\n",
|
||||
" for layer_id in range(args.n_layers):\n",
|
||||
" self.layers.append(LLaMA2DecoderLayer(layer_id, args))\n",
|
||||
" # 归一化层\n",
|
||||
" self.norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)\n",
|
||||
" # 输出层\n",
|
||||
" self.output = nn.Linear(args.dim, args.vocab_size, bias=False)\n",
|
||||
"\n",
|
||||
" # 将词嵌入层的权重与输出层的权重共享\n",
|
||||
" self.tok_embeddings.weight = self.output.weight \n",
|
||||
"\n",
|
||||
" # 预计算相对位置嵌入的频率\n",
|
||||
" freqs_cos, freqs_sin = precompute_freqs_cis(self.args.dim // self.args.n_heads, self.args.max_seq_len)\n",
|
||||
" self.register_buffer(\"freqs_cos\", freqs_cos, persistent=False)\n",
|
||||
" self.register_buffer(\"freqs_sin\", freqs_sin, persistent=False)\n",
|
||||
"\n",
|
||||
" # 初始化所有权重\n",
|
||||
" self.apply(self._init_weights)\n",
|
||||
" # 对残差投影进行特殊的缩放初始化\n",
|
||||
" for pn, p in self.named_parameters():\n",
|
||||
" if pn.endswith('w3.weight') or pn.endswith('wo.weight'):\n",
|
||||
" torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * args.n_layers))\n",
|
||||
"\n",
|
||||
" # 初始化最后一次前向传播的损失属性\n",
|
||||
" self.last_loss = None\n",
|
||||
"\n",
|
||||
" def _init_weights(self, module):\n",
|
||||
" # 初始化权重的函数\n",
|
||||
" if isinstance(module, nn.Linear):\n",
|
||||
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
|
||||
" if module.bias is not None:\n",
|
||||
" torch.nn.init.zeros_(module.bias)\n",
|
||||
" elif isinstance(module, nn.Embedding):\n",
|
||||
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
|
||||
" \n",
|
||||
" def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:\n",
|
||||
" # 前向传播函数\n",
|
||||
" _bsz, seqlen = tokens.shape\n",
|
||||
" # 通过词嵌入层和Dropout层\n",
|
||||
" h = self.tok_embeddings(tokens)\n",
|
||||
" h = self.dropout(h)\n",
|
||||
" # 获取相对位置嵌入的频率\n",
|
||||
" freqs_cos = self.freqs_cos[:seqlen]\n",
|
||||
" freqs_sin = self.freqs_sin[:seqlen]\n",
|
||||
"\n",
|
||||
" # 通过Decoder层\n",
|
||||
" for layer in self.layers:\n",
|
||||
" h = layer(h, freqs_cos, freqs_sin)\n",
|
||||
" # 通过归一化层\n",
|
||||
" h = self.norm(h)\n",
|
||||
"\n",
|
||||
" if targets is not None:\n",
|
||||
" # 如果给定了目标,计算损失\n",
|
||||
" logits = self.output(h)\n",
|
||||
" self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)\n",
|
||||
" else:\n",
|
||||
" # 推理时的小优化:只对最后一个位置的输出进行前向传播\n",
|
||||
" logits = self.output(h[:, [-1], :]) \n",
|
||||
" self.last_loss = None\n",
|
||||
"\n",
|
||||
" return logits"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Number of parameters: 15191712\n",
|
||||
"torch.Size([1, 1, 32000])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# LLaMA2Model.forward 接受两个参数,tokens和targets,其中tokens是输入的张量, 应为int类型\n",
|
||||
"x = torch.randint(0, 32000, (1, 50)) # [bs, seq_len]\n",
|
||||
"# 实例化LLaMA2Model\n",
|
||||
"model = LLaMA2Model(args=args)\n",
|
||||
"# 计算model的全部参数\n",
|
||||
"num_params = sum(p.numel() for p in model.parameters())\n",
|
||||
"print('Number of parameters:', num_params)\n",
|
||||
"\n",
|
||||
"out = model(x)\n",
|
||||
"print(out.shape) # [batch_size, 1, vocab_size]"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "nlp",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
367
docs/chapter5/llama2_model.py
Normal file
367
docs/chapter5/llama2_model.py
Normal file
@@ -0,0 +1,367 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@File : llama2_model.py
|
||||
@Time : 2024/04/14 22:26:35
|
||||
@Author : 不要葱姜蒜
|
||||
@Version : 1.0
|
||||
@Desc : 部分代码借鉴llama2.c仓库代码
|
||||
'''
|
||||
|
||||
import math
|
||||
import struct
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
# 自定义超参数
|
||||
dim: int = 288 # 模型维度
|
||||
n_layers: int = 6 # Transformer层数
|
||||
n_heads: int = 6 # 注意力机制的头数
|
||||
n_kv_heads: Optional[int] = 6 # 键/值头数,如果未指定,则默认为n_heads
|
||||
vocab_size: int = 32000 # 词汇表大小
|
||||
hidden_dim: Optional[int] = None # 隐藏层维度,如果未指定,则使用其他规则确定
|
||||
multiple_of: int = 32 # MLP隐藏层大小是这个数的倍数
|
||||
norm_eps: float = 1e-5 # 归一化层的epsilon值
|
||||
max_seq_len: int = 256 # 最大序列长度
|
||||
dropout: float = 0.0 # 丢弃率
|
||||
|
||||
|
||||
class LLaMA2RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float):
|
||||
super().__init__()
|
||||
# eps是为了防止除以0的情况
|
||||
self.eps = eps
|
||||
# weight是一个可学习的参数,全部初始化为1
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
# 计算RMSNorm的核心部分
|
||||
# x.pow(2).mean(-1, keepdim=True)计算了输入x的平方的均值
|
||||
# torch.rsqrt是平方根的倒数,这样就得到了RMSNorm的分母部分,再加上eps防止分母为0
|
||||
# 最后乘以x,得到RMSNorm的结果
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
# forward函数是模型的前向传播
|
||||
# 首先将输入x转为float类型,然后进行RMSNorm,最后再转回原来的数据类型
|
||||
# 最后乘以weight,这是RMSNorm的一个可学习的缩放因子
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
# 获得旋转嵌入的实部和虚部
|
||||
# 注意:此处的dim应为 dim//n_head,因为我们是对每个head进行旋转嵌入
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||||
# torch.arange(0, dim, 2)[: (dim // 2)].float()生成了一个从0开始,步长为2的序列,长度为dim的一半
|
||||
# 然后每个元素除以dim,再取theta的倒数,得到频率
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
# 生成一个从0到end的序列,长度为end
|
||||
t = torch.arange(end, device=freqs.device)
|
||||
# 计算外积,得到一个二维矩阵,每一行是t的元素乘以freqs的元素
|
||||
freqs = torch.outer(t, freqs).float()
|
||||
# 计算频率的余弦值,得到实部
|
||||
freqs_cos = torch.cos(freqs)
|
||||
# 计算频率的正弦值,得到虚部
|
||||
freqs_sin = torch.sin(freqs)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
# 此函数的作用是将freqs_cis调整为与x的形状相同,以便能够与x进行广播操作
|
||||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||
# 获取x的维度数
|
||||
ndim = x.ndim
|
||||
# 断言,确保1在x的维度范围内
|
||||
assert 0 <= 1 < ndim
|
||||
# 断言,确保freqs_cis的形状与x的第二维和最后一维相同
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
# 构造一个新的形状,除了第二维和最后一维,其他维度都为1,这样做是为了能够将freqs_cis与x进行广播操作
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
# 将freqs_cis调整为新的形状,并返回
|
||||
return freqs_cis.view(shape)
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# 将查询和键张量转换为浮点数,并重塑形状以分离实部和虚部
|
||||
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
|
||||
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
|
||||
|
||||
# 重新塑形频率张量以进行广播
|
||||
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
|
||||
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
|
||||
|
||||
# 应用旋转,分别计算旋转后的实部和虚部
|
||||
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
|
||||
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
|
||||
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
|
||||
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
|
||||
|
||||
# 将最后两个维度合并,并还原为原始张量的形状
|
||||
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
|
||||
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
|
||||
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
# 获取输入张量的形状:批量大小、序列长度、键/值对头的数量、每个头的维度大小
|
||||
bs, slen, n_kv_heads, head_dim = x.shape
|
||||
|
||||
# 如果重复次数为1,则不需要重复,直接返回原始张量
|
||||
if n_rep == 1:
|
||||
return x
|
||||
|
||||
# 对张量进行扩展和重塑操作以重复键值对
|
||||
return (
|
||||
x[:, :, :, None, :] # 在第四个维度(头的维度前)添加一个新的维度
|
||||
.expand(bs, slen, n_kv_heads, n_rep, head_dim) # 将新添加的维度扩展到n_rep大小,实现重复的效果
|
||||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) # 重新塑形,合并键/值对头的数量和重复次数的维度
|
||||
)
|
||||
|
||||
class LLaMA2Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
# 根据是否指定n_kv_heads,确定用于键(key)和值(value)的头的数量。
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
# 确保总头数可以被键值头数整除。
|
||||
assert args.n_heads % self.n_kv_heads == 0
|
||||
|
||||
# 模型并行处理大小,默认为1。
|
||||
model_parallel_size = 1
|
||||
# 本地计算头数,等于总头数除以模型并行处理大小。
|
||||
self.n_local_heads = args.n_heads // model_parallel_size
|
||||
# 本地键值头数,等于键值头数除以模型并行处理大小。
|
||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||||
# 重复次数,用于扩展键和值的尺寸。
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
# 每个头的维度,等于模型维度除以头的总数。
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
|
||||
# 定义权重矩阵。
|
||||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
# 输出权重矩阵。
|
||||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||||
|
||||
# 定义dropout。
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
# 保存dropout概率。
|
||||
self.dropout = args.dropout
|
||||
|
||||
# 检查是否使用Flash Attention(需要PyTorch >= 2.0)。
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
||||
if not self.flash:
|
||||
# 若不支持Flash Attention,则使用手动实现的注意力机制,并设置mask。
|
||||
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
# 创建一个上三角矩阵,用于遮蔽未来信息。
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
# 注册为模型的缓冲区
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):
|
||||
# 获取批次大小和序列长度,[batch_size, seq_len, dim]
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
# 计算查询(Q)、键(K)、值(V)。
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
# 调整形状以适应头的维度。
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
# 应用旋转位置嵌入(RoPE)。
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
|
||||
|
||||
# 对键和值进行扩展以适应重复次数。
|
||||
xk = repeat_kv(xk, self.n_rep)
|
||||
xv = repeat_kv(xv, self.n_rep)
|
||||
|
||||
# 将头作为批次维度处理。
|
||||
xq = xq.transpose(1, 2)
|
||||
xk = xk.transpose(1, 2)
|
||||
xv = xv.transpose(1, 2)
|
||||
|
||||
# 根据是否支持Flash Attention,选择实现方式。
|
||||
if self.flash:
|
||||
# 使用Flash Attention。
|
||||
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||||
else:
|
||||
# 使用手动实现的注意力机制。
|
||||
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
assert hasattr(self, 'mask')
|
||||
scores = scores + self.mask[:, :, :seqlen, :seqlen]
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
output = torch.matmul(scores, xv)
|
||||
|
||||
# 恢复时间维度并合并头。
|
||||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
|
||||
# 最终投影回残差流。
|
||||
output = self.wo(output)
|
||||
output = self.resid_dropout(output)
|
||||
return output
|
||||
|
||||
class LLaMA2MLP(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
|
||||
super().__init__()
|
||||
# 如果没有指定隐藏层的维度,我们将其设置为输入维度的4倍
|
||||
# 然后将其减少到2/3,最后确保它是multiple_of的倍数
|
||||
if hidden_dim is None:
|
||||
hidden_dim = 4 * dim
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
# 定义第一层线性变换,从输入维度到隐藏维度
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
# 定义第二层线性变换,从隐藏维度到输入维度
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
# 定义第三层线性变换,从输入维度到隐藏维度
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
# 定义dropout层,用于防止过拟合
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
# 前向传播函数
|
||||
# 首先,输入x通过第一层线性变换和SILU激活函数
|
||||
# 然后,结果乘以输入x通过第三层线性变换的结果
|
||||
# 最后,通过第二层线性变换和dropout层
|
||||
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
||||
|
||||
|
||||
class LLaMA2DecoderLayer(nn.Module):
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
super().__init__()
|
||||
# 定义多头注意力的头数
|
||||
self.n_heads = args.n_heads
|
||||
# 定义输入维度
|
||||
self.dim = args.dim
|
||||
# 定义每个头的维度,等于输入维度除以头数
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
# 定义LLaMA2Attention对象,用于进行多头注意力计算
|
||||
self.attention = LLaMA2Attention(args)
|
||||
# 定义LLaMAMLP对象,用于进行前馈神经网络计算
|
||||
self.feed_forward = LLaMA2MLP(
|
||||
dim=args.dim,
|
||||
hidden_dim=args.hidden_dim,
|
||||
multiple_of=args.multiple_of,
|
||||
dropout=args.dropout,
|
||||
)
|
||||
# 定义层的ID
|
||||
self.layer_id = layer_id
|
||||
# 定义注意力计算的归一化层
|
||||
self.attention_norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)
|
||||
# 定义前馈神经网络计算的归一化层
|
||||
self.ffn_norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)
|
||||
|
||||
def forward(self, x, freqs_cos, freqs_sin):
|
||||
# 前向传播函数
|
||||
# 首先,输入x经过注意力归一化层,然后进行注意力计算,结果与输入x相加得到h
|
||||
# 然后,h经过前馈神经网络归一化层,然后进行前馈神经网络计算,结果与h相加得到输出
|
||||
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
|
||||
out = h + self.feed_forward.forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
class LLaMA2Model(nn.Module):
|
||||
last_loss: Optional[torch.Tensor]
|
||||
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
# 初始化模型参数
|
||||
self.args = args
|
||||
# 词汇表大小
|
||||
self.vocab_size = args.vocab_size
|
||||
# 层数
|
||||
self.n_layers = args.n_layers
|
||||
|
||||
# 词嵌入层
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||||
# Dropout层
|
||||
self.dropout = nn.Dropout(args.dropout)
|
||||
# Decoder层
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for layer_id in range(args.n_layers):
|
||||
self.layers.append(LLaMA2DecoderLayer(layer_id, args))
|
||||
# 归一化层
|
||||
self.norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)
|
||||
# 输出层
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
|
||||
# 将词嵌入层的权重与输出层的权重共享
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
|
||||
# 预计算相对位置嵌入的频率
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(self.args.dim // self.args.n_heads, self.args.max_seq_len)
|
||||
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
||||
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
||||
|
||||
# 初始化所有权重
|
||||
self.apply(self._init_weights)
|
||||
# 对残差投影进行特殊的缩放初始化
|
||||
for pn, p in self.named_parameters():
|
||||
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
|
||||
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * args.n_layers))
|
||||
|
||||
# 初始化最后一次前向传播的损失属性
|
||||
self.last_loss = None
|
||||
|
||||
def _init_weights(self, module):
|
||||
# 初始化权重的函数
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# 前向传播函数
|
||||
_bsz, seqlen = tokens.shape
|
||||
# 通过词嵌入层和Dropout层
|
||||
h = self.tok_embeddings(tokens)
|
||||
h = self.dropout(h)
|
||||
# 获取相对位置嵌入的频率
|
||||
freqs_cos = self.freqs_cos[:seqlen]
|
||||
freqs_sin = self.freqs_sin[:seqlen]
|
||||
|
||||
# 通过Decoder层
|
||||
for layer in self.layers:
|
||||
h = layer(h, freqs_cos, freqs_sin)
|
||||
# 通过归一化层
|
||||
h = self.norm(h)
|
||||
|
||||
if targets is not None:
|
||||
# 如果给定了目标,计算损失
|
||||
logits = self.output(h)
|
||||
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
else:
|
||||
# 推理时的小优化:只对最后一个位置的输出进行前向传播
|
||||
logits = self.output(h[:, [-1], :])
|
||||
self.last_loss = None
|
||||
|
||||
return logits
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = ModelArgs()
|
||||
# LLaMA2Model.forward 接受两个参数,tokens和targets,其中tokens是输入的张量, 应为int类型
|
||||
x = torch.randint(0, 32000, (1, 50)) # [bs, seq_len]
|
||||
# 实例化LLaMA2Model
|
||||
model = LLaMA2Model(args=args)
|
||||
# 计算model的全部参数
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print('Number of parameters:', num_params)
|
||||
|
||||
out = model(x)
|
||||
print(out.shape) # [batch_size, 1, vocab_size]
|
||||
Reference in New Issue
Block a user