367 lines
16 KiB
Python
367 lines
16 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
'''
|
||
@File : llama2_model.py
|
||
@Time : 2024/04/14 22:26:35
|
||
@Author : 不要葱姜蒜
|
||
@Version : 1.0
|
||
@Desc : 部分代码借鉴llama2.c仓库代码
|
||
'''
|
||
|
||
import math
|
||
import struct
|
||
import inspect
|
||
from dataclasses import dataclass
|
||
from typing import Any, Optional, Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from torch import nn
|
||
|
||
|
||
@dataclass
|
||
class ModelArgs:
|
||
# 自定义超参数
|
||
dim: int = 288 # 模型维度
|
||
n_layers: int = 6 # Transformer层数
|
||
n_heads: int = 6 # 注意力机制的头数
|
||
n_kv_heads: Optional[int] = 6 # 键/值头数,如果未指定,则默认为n_heads
|
||
vocab_size: int = 32000 # 词汇表大小
|
||
hidden_dim: Optional[int] = None # 隐藏层维度,如果未指定,则使用其他规则确定
|
||
multiple_of: int = 32 # MLP隐藏层大小是这个数的倍数
|
||
norm_eps: float = 1e-5 # 归一化层的epsilon值
|
||
max_seq_len: int = 256 # 最大序列长度
|
||
dropout: float = 0.0 # 丢弃率
|
||
|
||
|
||
class LLaMA2RMSNorm(nn.Module):
|
||
def __init__(self, dim: int, eps: float):
|
||
super().__init__()
|
||
# eps是为了防止除以0的情况
|
||
self.eps = eps
|
||
# weight是一个可学习的参数,全部初始化为1
|
||
self.weight = nn.Parameter(torch.ones(dim))
|
||
|
||
def _norm(self, x):
|
||
# 计算RMSNorm的核心部分
|
||
# x.pow(2).mean(-1, keepdim=True)计算了输入x的平方的均值
|
||
# torch.rsqrt是平方根的倒数,这样就得到了RMSNorm的分母部分,再加上eps防止分母为0
|
||
# 最后乘以x,得到RMSNorm的结果
|
||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||
|
||
def forward(self, x):
|
||
# forward函数是模型的前向传播
|
||
# 首先将输入x转为float类型,然后进行RMSNorm,最后再转回原来的数据类型
|
||
# 最后乘以weight,这是RMSNorm的一个可学习的缩放因子
|
||
output = self._norm(x.float()).type_as(x)
|
||
return output * self.weight
|
||
|
||
# 获得旋转嵌入的实部和虚部
|
||
# 注意:此处的dim应为 dim//n_head,因为我们是对每个head进行旋转嵌入
|
||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||
# torch.arange(0, dim, 2)[: (dim // 2)].float()生成了一个从0开始,步长为2的序列,长度为dim的一半
|
||
# 然后每个元素除以dim,再取theta的倒数,得到频率
|
||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||
# 生成一个从0到end的序列,长度为end
|
||
t = torch.arange(end, device=freqs.device)
|
||
# 计算外积,得到一个二维矩阵,每一行是t的元素乘以freqs的元素
|
||
freqs = torch.outer(t, freqs).float()
|
||
# 计算频率的余弦值,得到实部
|
||
freqs_cos = torch.cos(freqs)
|
||
# 计算频率的正弦值,得到虚部
|
||
freqs_sin = torch.sin(freqs)
|
||
return freqs_cos, freqs_sin
|
||
|
||
# 此函数的作用是将freqs_cis调整为与x的形状相同,以便能够与x进行广播操作
|
||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||
# 获取x的维度数
|
||
ndim = x.ndim
|
||
# 断言,确保1在x的维度范围内
|
||
assert 0 <= 1 < ndim
|
||
# 断言,确保freqs_cis的形状与x的第二维和最后一维相同
|
||
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||
# 构造一个新的形状,除了第二维和最后一维,其他维度都为1,这样做是为了能够将freqs_cis与x进行广播操作
|
||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||
# 将freqs_cis调整为新的形状,并返回
|
||
return freqs_cis.view(shape)
|
||
|
||
def apply_rotary_emb(
|
||
xq: torch.Tensor,
|
||
xk: torch.Tensor,
|
||
freqs_cos: torch.Tensor,
|
||
freqs_sin: torch.Tensor
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
||
# 将查询和键张量转换为浮点数,并重塑形状以分离实部和虚部
|
||
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
|
||
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
|
||
|
||
# 重新塑形频率张量以进行广播
|
||
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
|
||
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
|
||
|
||
# 应用旋转,分别计算旋转后的实部和虚部
|
||
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
|
||
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
|
||
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
|
||
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
|
||
|
||
# 将最后两个维度合并,并还原为原始张量的形状
|
||
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
|
||
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
|
||
|
||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||
|
||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||
# 获取输入张量的形状:批量大小、序列长度、键/值对头的数量、每个头的维度大小
|
||
bs, slen, n_kv_heads, head_dim = x.shape
|
||
|
||
# 如果重复次数为1,则不需要重复,直接返回原始张量
|
||
if n_rep == 1:
|
||
return x
|
||
|
||
# 对张量进行扩展和重塑操作以重复键值对
|
||
return (
|
||
x[:, :, :, None, :] # 在第四个维度(头的维度前)添加一个新的维度
|
||
.expand(bs, slen, n_kv_heads, n_rep, head_dim) # 将新添加的维度扩展到n_rep大小,实现重复的效果
|
||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) # 重新塑形,合并键/值对头的数量和重复次数的维度
|
||
)
|
||
|
||
class LLaMA2Attention(nn.Module):
|
||
def __init__(self, args: ModelArgs):
|
||
super().__init__()
|
||
# 根据是否指定n_kv_heads,确定用于键(key)和值(value)的头的数量。
|
||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||
# 确保总头数可以被键值头数整除。
|
||
assert args.n_heads % self.n_kv_heads == 0
|
||
|
||
# 模型并行处理大小,默认为1。
|
||
model_parallel_size = 1
|
||
# 本地计算头数,等于总头数除以模型并行处理大小。
|
||
self.n_local_heads = args.n_heads // model_parallel_size
|
||
# 本地键值头数,等于键值头数除以模型并行处理大小。
|
||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||
# 重复次数,用于扩展键和值的尺寸。
|
||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||
# 每个头的维度,等于模型维度除以头的总数。
|
||
self.head_dim = args.dim // args.n_heads
|
||
|
||
# 定义权重矩阵。
|
||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||
# 输出权重矩阵。
|
||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||
|
||
# 定义dropout。
|
||
self.attn_dropout = nn.Dropout(args.dropout)
|
||
self.resid_dropout = nn.Dropout(args.dropout)
|
||
# 保存dropout概率。
|
||
self.dropout = args.dropout
|
||
|
||
# 检查是否使用Flash Attention(需要PyTorch >= 2.0)。
|
||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
||
if not self.flash:
|
||
# 若不支持Flash Attention,则使用手动实现的注意力机制,并设置mask。
|
||
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||
# 创建一个上三角矩阵,用于遮蔽未来信息。
|
||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||
mask = torch.triu(mask, diagonal=1)
|
||
# 注册为模型的缓冲区
|
||
self.register_buffer("mask", mask)
|
||
|
||
def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):
|
||
# 获取批次大小和序列长度,[batch_size, seq_len, dim]
|
||
bsz, seqlen, _ = x.shape
|
||
|
||
# 计算查询(Q)、键(K)、值(V)。
|
||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||
# 调整形状以适应头的维度。
|
||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||
|
||
# 应用旋转位置嵌入(RoPE)。
|
||
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
|
||
|
||
# 对键和值进行扩展以适应重复次数。
|
||
xk = repeat_kv(xk, self.n_rep)
|
||
xv = repeat_kv(xv, self.n_rep)
|
||
|
||
# 将头作为批次维度处理。
|
||
xq = xq.transpose(1, 2)
|
||
xk = xk.transpose(1, 2)
|
||
xv = xv.transpose(1, 2)
|
||
|
||
# 根据是否支持Flash Attention,选择实现方式。
|
||
if self.flash:
|
||
# 使用Flash Attention。
|
||
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||
else:
|
||
# 使用手动实现的注意力机制。
|
||
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||
assert hasattr(self, 'mask')
|
||
scores = scores + self.mask[:, :, :seqlen, :seqlen]
|
||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||
scores = self.attn_dropout(scores)
|
||
output = torch.matmul(scores, xv)
|
||
|
||
# 恢复时间维度并合并头。
|
||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||
|
||
# 最终投影回残差流。
|
||
output = self.wo(output)
|
||
output = self.resid_dropout(output)
|
||
return output
|
||
|
||
class LLaMA2MLP(nn.Module):
|
||
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
|
||
super().__init__()
|
||
# 如果没有指定隐藏层的维度,我们将其设置为输入维度的4倍
|
||
# 然后将其减少到2/3,最后确保它是multiple_of的倍数
|
||
if hidden_dim is None:
|
||
hidden_dim = 4 * dim
|
||
hidden_dim = int(2 * hidden_dim / 3)
|
||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||
# 定义第一层线性变换,从输入维度到隐藏维度
|
||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||
# 定义第二层线性变换,从隐藏维度到输入维度
|
||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||
# 定义第三层线性变换,从输入维度到隐藏维度
|
||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||
# 定义dropout层,用于防止过拟合
|
||
self.dropout = nn.Dropout(dropout)
|
||
|
||
def forward(self, x):
|
||
# 前向传播函数
|
||
# 首先,输入x通过第一层线性变换和SILU激活函数
|
||
# 然后,结果乘以输入x通过第三层线性变换的结果
|
||
# 最后,通过第二层线性变换和dropout层
|
||
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
||
|
||
|
||
class LLaMA2DecoderLayer(nn.Module):
|
||
def __init__(self, layer_id: int, args: ModelArgs):
|
||
super().__init__()
|
||
# 定义多头注意力的头数
|
||
self.n_heads = args.n_heads
|
||
# 定义输入维度
|
||
self.dim = args.dim
|
||
# 定义每个头的维度,等于输入维度除以头数
|
||
self.head_dim = args.dim // args.n_heads
|
||
# 定义LLaMA2Attention对象,用于进行多头注意力计算
|
||
self.attention = LLaMA2Attention(args)
|
||
# 定义LLaMAMLP对象,用于进行前馈神经网络计算
|
||
self.feed_forward = LLaMA2MLP(
|
||
dim=args.dim,
|
||
hidden_dim=args.hidden_dim,
|
||
multiple_of=args.multiple_of,
|
||
dropout=args.dropout,
|
||
)
|
||
# 定义层的ID
|
||
self.layer_id = layer_id
|
||
# 定义注意力计算的归一化层
|
||
self.attention_norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)
|
||
# 定义前馈神经网络计算的归一化层
|
||
self.ffn_norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)
|
||
|
||
def forward(self, x, freqs_cos, freqs_sin):
|
||
# 前向传播函数
|
||
# 首先,输入x经过注意力归一化层,然后进行注意力计算,结果与输入x相加得到h
|
||
# 然后,h经过前馈神经网络归一化层,然后进行前馈神经网络计算,结果与h相加得到输出
|
||
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
|
||
out = h + self.feed_forward.forward(self.ffn_norm(h))
|
||
return out
|
||
|
||
class LLaMA2Model(nn.Module):
|
||
last_loss: Optional[torch.Tensor]
|
||
|
||
def __init__(self, args: ModelArgs):
|
||
super().__init__()
|
||
# 初始化模型参数
|
||
self.args = args
|
||
# 词汇表大小
|
||
self.vocab_size = args.vocab_size
|
||
# 层数
|
||
self.n_layers = args.n_layers
|
||
|
||
# 词嵌入层
|
||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||
# Dropout层
|
||
self.dropout = nn.Dropout(args.dropout)
|
||
# Decoder层
|
||
self.layers = torch.nn.ModuleList()
|
||
for layer_id in range(args.n_layers):
|
||
self.layers.append(LLaMA2DecoderLayer(layer_id, args))
|
||
# 归一化层
|
||
self.norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)
|
||
# 输出层
|
||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||
|
||
# 将词嵌入层的权重与输出层的权重共享
|
||
self.tok_embeddings.weight = self.output.weight
|
||
|
||
# 预计算相对位置嵌入的频率
|
||
freqs_cos, freqs_sin = precompute_freqs_cis(self.args.dim // self.args.n_heads, self.args.max_seq_len)
|
||
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
||
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
||
|
||
# 初始化所有权重
|
||
self.apply(self._init_weights)
|
||
# 对残差投影进行特殊的缩放初始化
|
||
for pn, p in self.named_parameters():
|
||
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
|
||
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * args.n_layers))
|
||
|
||
# 初始化最后一次前向传播的损失属性
|
||
self.last_loss = None
|
||
|
||
def _init_weights(self, module):
|
||
# 初始化权重的函数
|
||
if isinstance(module, nn.Linear):
|
||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||
if module.bias is not None:
|
||
torch.nn.init.zeros_(module.bias)
|
||
elif isinstance(module, nn.Embedding):
|
||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||
|
||
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||
# 前向传播函数
|
||
_bsz, seqlen = tokens.shape
|
||
# 通过词嵌入层和Dropout层
|
||
h = self.tok_embeddings(tokens)
|
||
h = self.dropout(h)
|
||
# 获取相对位置嵌入的频率
|
||
freqs_cos = self.freqs_cos[:seqlen]
|
||
freqs_sin = self.freqs_sin[:seqlen]
|
||
|
||
# 通过Decoder层
|
||
for layer in self.layers:
|
||
h = layer(h, freqs_cos, freqs_sin)
|
||
# 通过归一化层
|
||
h = self.norm(h)
|
||
|
||
if targets is not None:
|
||
# 如果给定了目标,计算损失
|
||
logits = self.output(h)
|
||
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||
else:
|
||
# 推理时的小优化:只对最后一个位置的输出进行前向传播
|
||
logits = self.output(h[:, [-1], :])
|
||
self.last_loss = None
|
||
|
||
return logits
|
||
|
||
if __name__ == '__main__':
|
||
args = ModelArgs()
|
||
# LLaMA2Model.forward 接受两个参数,tokens和targets,其中tokens是输入的张量, 应为int类型
|
||
x = torch.randint(0, 32000, (1, 50)) # [bs, seq_len]
|
||
# 实例化LLaMA2Model
|
||
model = LLaMA2Model(args=args)
|
||
# 计算model的全部参数
|
||
num_params = sum(p.numel() for p in model.parameters())
|
||
print('Number of parameters:', num_params)
|
||
|
||
out = model(x)
|
||
print(out.shape) # [batch_size, 1, vocab_size] |