Files
happy-llm/docs/chapter5/llama2_model.py
logan_zou dbced843e5 init
2024-05-28 12:25:44 +08:00

367 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File : llama2_model.py
@Time : 2024/04/14 22:26:35
@Author : 不要葱姜蒜
@Version : 1.0
@Desc : 部分代码借鉴llama2.c仓库代码
'''
import math
import struct
import inspect
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
@dataclass
class ModelArgs:
# 自定义超参数
dim: int = 288 # 模型维度
n_layers: int = 6 # Transformer层数
n_heads: int = 6 # 注意力机制的头数
n_kv_heads: Optional[int] = 6 # 键/值头数如果未指定则默认为n_heads
vocab_size: int = 32000 # 词汇表大小
hidden_dim: Optional[int] = None # 隐藏层维度,如果未指定,则使用其他规则确定
multiple_of: int = 32 # MLP隐藏层大小是这个数的倍数
norm_eps: float = 1e-5 # 归一化层的epsilon值
max_seq_len: int = 256 # 最大序列长度
dropout: float = 0.0 # 丢弃率
class LLaMA2RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
# eps是为了防止除以0的情况
self.eps = eps
# weight是一个可学习的参数全部初始化为1
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
# 计算RMSNorm的核心部分
# x.pow(2).mean(-1, keepdim=True)计算了输入x的平方的均值
# torch.rsqrt是平方根的倒数这样就得到了RMSNorm的分母部分再加上eps防止分母为0
# 最后乘以x得到RMSNorm的结果
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
# forward函数是模型的前向传播
# 首先将输入x转为float类型然后进行RMSNorm最后再转回原来的数据类型
# 最后乘以weight这是RMSNorm的一个可学习的缩放因子
output = self._norm(x.float()).type_as(x)
return output * self.weight
# 获得旋转嵌入的实部和虚部
# 注意此处的dim应为 dim//n_head因为我们是对每个head进行旋转嵌入
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
# torch.arange(0, dim, 2)[: (dim // 2)].float()生成了一个从0开始步长为2的序列长度为dim的一半
# 然后每个元素除以dim再取theta的倒数得到频率
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 生成一个从0到end的序列长度为end
t = torch.arange(end, device=freqs.device)
# 计算外积得到一个二维矩阵每一行是t的元素乘以freqs的元素
freqs = torch.outer(t, freqs).float()
# 计算频率的余弦值,得到实部
freqs_cos = torch.cos(freqs)
# 计算频率的正弦值,得到虚部
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin
# 此函数的作用是将freqs_cis调整为与x的形状相同以便能够与x进行广播操作
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
# 获取x的维度数
ndim = x.ndim
# 断言确保1在x的维度范围内
assert 0 <= 1 < ndim
# 断言确保freqs_cis的形状与x的第二维和最后一维相同
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
# 构造一个新的形状除了第二维和最后一维其他维度都为1这样做是为了能够将freqs_cis与x进行广播操作
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
# 将freqs_cis调整为新的形状并返回
return freqs_cis.view(shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 将查询和键张量转换为浮点数,并重塑形状以分离实部和虚部
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
# 重新塑形频率张量以进行广播
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
# 应用旋转,分别计算旋转后的实部和虚部
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
# 将最后两个维度合并,并还原为原始张量的形状
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
# 获取输入张量的形状:批量大小、序列长度、键/值对头的数量、每个头的维度大小
bs, slen, n_kv_heads, head_dim = x.shape
# 如果重复次数为1则不需要重复直接返回原始张量
if n_rep == 1:
return x
# 对张量进行扩展和重塑操作以重复键值对
return (
x[:, :, :, None, :] # 在第四个维度(头的维度前)添加一个新的维度
.expand(bs, slen, n_kv_heads, n_rep, head_dim) # 将新添加的维度扩展到n_rep大小实现重复的效果
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) # 重新塑形,合并键/值对头的数量和重复次数的维度
)
class LLaMA2Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
# 根据是否指定n_kv_heads确定用于键key和值value的头的数量。
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
# 确保总头数可以被键值头数整除。
assert args.n_heads % self.n_kv_heads == 0
# 模型并行处理大小默认为1。
model_parallel_size = 1
# 本地计算头数,等于总头数除以模型并行处理大小。
self.n_local_heads = args.n_heads // model_parallel_size
# 本地键值头数,等于键值头数除以模型并行处理大小。
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
# 重复次数,用于扩展键和值的尺寸。
self.n_rep = self.n_local_heads // self.n_local_kv_heads
# 每个头的维度,等于模型维度除以头的总数。
self.head_dim = args.dim // args.n_heads
# 定义权重矩阵。
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
# 输出权重矩阵。
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
# 定义dropout。
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
# 保存dropout概率。
self.dropout = args.dropout
# 检查是否使用Flash Attention需要PyTorch >= 2.0)。
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
# 若不支持Flash Attention则使用手动实现的注意力机制并设置mask。
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# 创建一个上三角矩阵,用于遮蔽未来信息。
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
# 注册为模型的缓冲区
self.register_buffer("mask", mask)
def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):
# 获取批次大小和序列长度,[batch_size, seq_len, dim]
bsz, seqlen, _ = x.shape
# 计算查询Q、键K、值V
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
# 调整形状以适应头的维度。
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# 应用旋转位置嵌入RoPE
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
# 对键和值进行扩展以适应重复次数。
xk = repeat_kv(xk, self.n_rep)
xv = repeat_kv(xv, self.n_rep)
# 将头作为批次维度处理。
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
# 根据是否支持Flash Attention选择实现方式。
if self.flash:
# 使用Flash Attention。
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
else:
# 使用手动实现的注意力机制。
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv)
# 恢复时间维度并合并头。
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
# 最终投影回残差流。
output = self.wo(output)
output = self.resid_dropout(output)
return output
class LLaMA2MLP(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
super().__init__()
# 如果没有指定隐藏层的维度我们将其设置为输入维度的4倍
# 然后将其减少到2/3最后确保它是multiple_of的倍数
if hidden_dim is None:
hidden_dim = 4 * dim
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
# 定义第一层线性变换,从输入维度到隐藏维度
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
# 定义第二层线性变换,从隐藏维度到输入维度
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
# 定义第三层线性变换,从输入维度到隐藏维度
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
# 定义dropout层用于防止过拟合
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 前向传播函数
# 首先输入x通过第一层线性变换和SILU激活函数
# 然后结果乘以输入x通过第三层线性变换的结果
# 最后通过第二层线性变换和dropout层
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class LLaMA2DecoderLayer(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
# 定义多头注意力的头数
self.n_heads = args.n_heads
# 定义输入维度
self.dim = args.dim
# 定义每个头的维度,等于输入维度除以头数
self.head_dim = args.dim // args.n_heads
# 定义LLaMA2Attention对象用于进行多头注意力计算
self.attention = LLaMA2Attention(args)
# 定义LLaMAMLP对象用于进行前馈神经网络计算
self.feed_forward = LLaMA2MLP(
dim=args.dim,
hidden_dim=args.hidden_dim,
multiple_of=args.multiple_of,
dropout=args.dropout,
)
# 定义层的ID
self.layer_id = layer_id
# 定义注意力计算的归一化层
self.attention_norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)
# 定义前馈神经网络计算的归一化层
self.ffn_norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)
def forward(self, x, freqs_cos, freqs_sin):
# 前向传播函数
# 首先输入x经过注意力归一化层然后进行注意力计算结果与输入x相加得到h
# 然后h经过前馈神经网络归一化层然后进行前馈神经网络计算结果与h相加得到输出
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
class LLaMA2Model(nn.Module):
last_loss: Optional[torch.Tensor]
def __init__(self, args: ModelArgs):
super().__init__()
# 初始化模型参数
self.args = args
# 词汇表大小
self.vocab_size = args.vocab_size
# 层数
self.n_layers = args.n_layers
# 词嵌入层
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
# Dropout层
self.dropout = nn.Dropout(args.dropout)
# Decoder层
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(LLaMA2DecoderLayer(layer_id, args))
# 归一化层
self.norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)
# 输出层
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
# 将词嵌入层的权重与输出层的权重共享
self.tok_embeddings.weight = self.output.weight
# 预计算相对位置嵌入的频率
freqs_cos, freqs_sin = precompute_freqs_cis(self.args.dim // self.args.n_heads, self.args.max_seq_len)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
# 初始化所有权重
self.apply(self._init_weights)
# 对残差投影进行特殊的缩放初始化
for pn, p in self.named_parameters():
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * args.n_layers))
# 初始化最后一次前向传播的损失属性
self.last_loss = None
def _init_weights(self, module):
# 初始化权重的函数
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
# 前向传播函数
_bsz, seqlen = tokens.shape
# 通过词嵌入层和Dropout层
h = self.tok_embeddings(tokens)
h = self.dropout(h)
# 获取相对位置嵌入的频率
freqs_cos = self.freqs_cos[:seqlen]
freqs_sin = self.freqs_sin[:seqlen]
# 通过Decoder层
for layer in self.layers:
h = layer(h, freqs_cos, freqs_sin)
# 通过归一化层
h = self.norm(h)
if targets is not None:
# 如果给定了目标,计算损失
logits = self.output(h)
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# 推理时的小优化:只对最后一个位置的输出进行前向传播
logits = self.output(h[:, [-1], :])
self.last_loss = None
return logits
if __name__ == '__main__':
args = ModelArgs()
# LLaMA2Model.forward 接受两个参数tokens和targets其中tokens是输入的张量, 应为int类型
x = torch.randint(0, 32000, (1, 50)) # [bs, seq_len]
# 实例化LLaMA2Model
model = LLaMA2Model(args=args)
# 计算model的全部参数
num_params = sum(p.numel() for p in model.parameters())
print('Number of parameters:', num_params)
out = model(x)
print(out.shape) # [batch_size, 1, vocab_size]