445 lines
19 KiB
Python
445 lines
19 KiB
Python
import math
|
||
import inspect
|
||
from dataclasses import dataclass
|
||
from typing import Any, Optional, Tuple
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from torch import nn
|
||
|
||
from transformers import PreTrainedModel, AutoTokenizer
|
||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||
from transformers import PretrainedConfig
|
||
|
||
|
||
class ModelConfig(PretrainedConfig):
|
||
model_type = "Tiny-K"
|
||
def __init__(
|
||
self,
|
||
dim: int = 768,
|
||
n_layers: int = 12,
|
||
n_heads: int = 16,
|
||
n_kv_heads: int = 8,
|
||
vocab_size: int = 6144,
|
||
hidden_dim: int = None,
|
||
multiple_of: int = 64,
|
||
norm_eps: float = 1e-5,
|
||
max_seq_len: int = 512,
|
||
dropout: float = 0.0,
|
||
flash_attn: bool = True,
|
||
**kwargs,
|
||
):
|
||
self.dim = dim
|
||
self.n_layers = n_layers
|
||
self.n_heads = n_heads
|
||
self.n_kv_heads = n_kv_heads
|
||
self.vocab_size = vocab_size
|
||
self.hidden_dim = hidden_dim
|
||
self.multiple_of = multiple_of
|
||
self.norm_eps = norm_eps
|
||
self.max_seq_len = max_seq_len
|
||
self.dropout = dropout
|
||
self.flash_attn = flash_attn
|
||
super().__init__(**kwargs)
|
||
|
||
class RMSNorm(nn.Module):
|
||
def __init__(self, dim: int, eps: float):
|
||
super().__init__()
|
||
# eps是为了防止除以0的情况
|
||
self.eps = eps
|
||
# weight是一个可学习的参数,全部初始化为1
|
||
self.weight = nn.Parameter(torch.ones(dim))
|
||
|
||
def _norm(self, x):
|
||
# 计算RMSNorm的核心部分
|
||
# x.pow(2).mean(-1, keepdim=True)计算了输入x的平方的均值
|
||
# torch.rsqrt是平方根的倒数,这样就得到了RMSNorm的分母部分,再加上eps防止分母为0
|
||
# 最后乘以x,得到RMSNorm的结果
|
||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||
|
||
def forward(self, x):
|
||
# forward函数是模型的前向传播
|
||
# 首先将输入x转为float类型,然后进行RMSNorm,最后再转回原来的数据类型
|
||
# 最后乘以weight,这是RMSNorm的一个可学习的缩放因子
|
||
output = self._norm(x.float()).type_as(x)
|
||
return output * self.weight
|
||
|
||
# 获得旋转嵌入的实部和虚部
|
||
# 注意:此处的dim应为 dim//n_head,因为我们是对每个head进行旋转嵌入
|
||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||
# torch.arange(0, dim, 2)[: (dim // 2)].float()生成了一个从0开始,步长为2的序列,长度为dim的一半
|
||
# 然后每个元素除以dim,再取theta的倒数,得到频率
|
||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||
# 生成一个从0到end的序列,长度为end
|
||
t = torch.arange(end, device=freqs.device)
|
||
# 计算外积,得到一个二维矩阵,每一行是t的元素乘以freqs的元素
|
||
freqs = torch.outer(t, freqs).float()
|
||
# 计算频率的余弦值,得到实部
|
||
freqs_cos = torch.cos(freqs)
|
||
# 计算频率的正弦值,得到虚部
|
||
freqs_sin = torch.sin(freqs)
|
||
return freqs_cos, freqs_sin
|
||
|
||
# 此函数的作用是将freqs_cis调整为与x的形状相同,以便能够与x进行广播操作
|
||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||
# 获取x的维度数
|
||
ndim = x.ndim
|
||
# 断言,确保1在x的维度范围内
|
||
assert 0 <= 1 < ndim
|
||
# 断言,确保freqs_cis的形状与x的第二维和最后一维相同
|
||
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||
# 构造一个新的形状,除了第二维和最后一维,其他维度都为1,这样做是为了能够将freqs_cis与x进行广播操作
|
||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||
# 将freqs_cis调整为新的形状,并返回
|
||
return freqs_cis.view(shape)
|
||
|
||
def apply_rotary_emb(
|
||
xq: torch.Tensor,
|
||
xk: torch.Tensor,
|
||
freqs_cos: torch.Tensor,
|
||
freqs_sin: torch.Tensor
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
||
# 将查询和键张量转换为浮点数,并重塑形状以分离实部和虚部
|
||
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
|
||
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
|
||
|
||
# 重新塑形频率张量以进行广播
|
||
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
|
||
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
|
||
|
||
# 应用旋转,分别计算旋转后的实部和虚部
|
||
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
|
||
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
|
||
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
|
||
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
|
||
|
||
# 将最后两个维度合并,并还原为原始张量的形状
|
||
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
|
||
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
|
||
|
||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||
|
||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||
# 获取输入张量的形状:批量大小、序列长度、键/值对头的数量、每个头的维度大小
|
||
bs, slen, n_kv_heads, head_dim = x.shape
|
||
|
||
# 如果重复次数为1,则不需要重复,直接返回原始张量
|
||
if n_rep == 1:
|
||
return x
|
||
|
||
# 对张量进行扩展和重塑操作以重复键值对
|
||
return (
|
||
x[:, :, :, None, :] # 在第四个维度(头的维度前)添加一个新的维度
|
||
.expand(bs, slen, n_kv_heads, n_rep, head_dim) # 将新添加的维度扩展到n_rep大小,实现重复的效果
|
||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) # 重新塑形,合并键/值对头的数量和重复次数的维度
|
||
)
|
||
|
||
class Attention(nn.Module):
|
||
def __init__(self, args: ModelConfig):
|
||
super().__init__()
|
||
# 根据是否指定n_kv_heads,确定用于键(key)和值(value)的头的数量。
|
||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||
# 确保总头数可以被键值头数整除。
|
||
assert args.n_heads % self.n_kv_heads == 0
|
||
|
||
# 模型并行处理大小,默认为1。
|
||
model_parallel_size = 1
|
||
# 本地计算头数,等于总头数除以模型并行处理大小。
|
||
self.n_local_heads = args.n_heads // model_parallel_size
|
||
# 本地键值头数,等于键值头数除以模型并行处理大小。
|
||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||
# 重复次数,用于扩展键和值的尺寸。
|
||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||
# 每个头的维度,等于模型维度除以头的总数。
|
||
self.head_dim = args.dim // args.n_heads
|
||
|
||
# 定义权重矩阵。
|
||
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
||
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
||
# 输出权重矩阵。
|
||
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
||
|
||
# 定义dropout。
|
||
self.attn_dropout = nn.Dropout(args.dropout)
|
||
self.resid_dropout = nn.Dropout(args.dropout)
|
||
# 保存dropout概率。
|
||
self.dropout = args.dropout
|
||
|
||
# 检查是否使用Flash Attention(需要PyTorch >= 2.0)。
|
||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
||
if not self.flash:
|
||
# 若不支持Flash Attention,则使用手动实现的注意力机制,并设置mask。
|
||
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||
# 创建一个上三角矩阵,用于遮蔽未来信息。
|
||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||
mask = torch.triu(mask, diagonal=1)
|
||
# 注册为模型的缓冲区
|
||
self.register_buffer("mask", mask)
|
||
|
||
def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):
|
||
# 获取批次大小和序列长度,[batch_size, seq_len, dim]
|
||
bsz, seqlen, _ = x.shape
|
||
|
||
# 计算查询(Q)、键(K)、值(V)。
|
||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||
# 调整形状以适应头的维度。
|
||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||
|
||
# 应用旋转位置嵌入(RoPE)。
|
||
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
|
||
|
||
# 对键和值进行扩展以适应重复次数。
|
||
xk = repeat_kv(xk, self.n_rep)
|
||
xv = repeat_kv(xv, self.n_rep)
|
||
|
||
# 将头作为批次维度处理。
|
||
xq = xq.transpose(1, 2)
|
||
xk = xk.transpose(1, 2)
|
||
xv = xv.transpose(1, 2)
|
||
|
||
# 根据是否支持Flash Attention,选择实现方式。
|
||
if self.flash:
|
||
# 使用Flash Attention。
|
||
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||
else:
|
||
# 使用手动实现的注意力机制。
|
||
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||
assert hasattr(self, 'mask')
|
||
scores = scores + self.mask[:, :, :seqlen, :seqlen]
|
||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||
scores = self.attn_dropout(scores)
|
||
output = torch.matmul(scores, xv)
|
||
|
||
# 恢复时间维度并合并头。
|
||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||
|
||
# 最终投影回残差流。
|
||
output = self.wo(output)
|
||
output = self.resid_dropout(output)
|
||
return output
|
||
|
||
class MLP(nn.Module):
|
||
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
|
||
super().__init__()
|
||
# 如果没有指定隐藏层的维度,我们将其设置为输入维度的4倍
|
||
# 然后将其减少到2/3,最后确保它是multiple_of的倍数
|
||
if hidden_dim is None:
|
||
hidden_dim = 4 * dim
|
||
hidden_dim = int(2 * hidden_dim / 3)
|
||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||
# 定义第一层线性变换,从输入维度到隐藏维度
|
||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||
# 定义第二层线性变换,从隐藏维度到输入维度
|
||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||
# 定义第三层线性变换,从输入维度到隐藏维度
|
||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||
# 定义dropout层,用于防止过拟合
|
||
self.dropout = nn.Dropout(dropout)
|
||
|
||
def forward(self, x):
|
||
# 前向传播函数
|
||
# 首先,输入x通过第一层线性变换和SILU激活函数
|
||
# 然后,结果乘以输入x通过第三层线性变换的结果
|
||
# 最后,通过第二层线性变换和dropout层
|
||
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
||
|
||
|
||
class DecoderLayer(nn.Module):
|
||
def __init__(self, layer_id: int, args: ModelConfig):
|
||
super().__init__()
|
||
# 定义多头注意力的头数
|
||
self.n_heads = args.n_heads
|
||
# 定义输入维度
|
||
self.dim = args.dim
|
||
# 定义每个头的维度,等于输入维度除以头数
|
||
self.head_dim = args.dim // args.n_heads
|
||
# 定义LLaMA2Attention对象,用于进行多头注意力计算
|
||
self.attention = Attention(args)
|
||
# 定义LLaMAMLP对象,用于进行前馈神经网络计算
|
||
self.feed_forward = MLP(
|
||
dim=args.dim,
|
||
hidden_dim=args.hidden_dim,
|
||
multiple_of=args.multiple_of,
|
||
dropout=args.dropout,
|
||
)
|
||
# 定义层的ID
|
||
self.layer_id = layer_id
|
||
# 定义注意力计算的归一化层
|
||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||
# 定义前馈神经网络计算的归一化层
|
||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||
|
||
def forward(self, x, freqs_cos, freqs_sin):
|
||
# 前向传播函数
|
||
# 首先,输入x经过注意力归一化层,然后进行注意力计算,结果与输入x相加得到h
|
||
# 然后,h经过前馈神经网络归一化层,然后进行前馈神经网络计算,结果与h相加得到输出
|
||
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
|
||
out = h + self.feed_forward.forward(self.ffn_norm(h))
|
||
return out
|
||
|
||
class Transformer(PreTrainedModel):
|
||
config_class = ModelConfig # 配置类
|
||
last_loss: Optional[torch.Tensor] # 记录最后一次计算的损失
|
||
|
||
def __init__(self, args: ModelConfig = None):
|
||
super().__init__(args)
|
||
# 初始化模型参数
|
||
self.args = args
|
||
# 词汇表大小
|
||
self.vocab_size = args.vocab_size
|
||
# 层数
|
||
self.n_layers = args.n_layers
|
||
|
||
# 词嵌入层
|
||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
|
||
# Dropout层
|
||
self.dropout = nn.Dropout(args.dropout)
|
||
# Decoder层
|
||
self.layers = torch.nn.ModuleList()
|
||
for layer_id in range(args.n_layers):
|
||
self.layers.append(DecoderLayer(layer_id, args))
|
||
# 归一化层
|
||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||
# 输出层
|
||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||
|
||
# 将词嵌入层的权重与输出层的权重共享
|
||
self.tok_embeddings.weight = self.output.weight
|
||
|
||
# 预计算相对位置嵌入的频率
|
||
freqs_cos, freqs_sin = precompute_freqs_cis(self.args.dim // self.args.n_heads, self.args.max_seq_len)
|
||
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
||
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
||
|
||
# 初始化所有权重
|
||
self.apply(self._init_weights)
|
||
# 对残差投影进行特殊的缩放初始化
|
||
for pn, p in self.named_parameters():
|
||
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
|
||
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * args.n_layers))
|
||
|
||
# 初始化最后一次前向传播的损失属性
|
||
self.last_loss = None
|
||
self.OUT = CausalLMOutputWithPast() # 输出容器
|
||
self._no_split_modules = [name for name, _ in self.named_modules()] # 不分割的模块列表
|
||
|
||
def _init_weights(self, module):
|
||
# 初始化权重的函数
|
||
if isinstance(module, nn.Linear):
|
||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||
if module.bias is not None:
|
||
torch.nn.init.zeros_(module.bias)
|
||
elif isinstance(module, nn.Embedding):
|
||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||
|
||
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **keyargs) -> torch.Tensor:
|
||
"""
|
||
- tokens: Optional[torch.Tensor], 输入 token 张量。
|
||
- targets: Optional[torch.Tensor], 目标 token 张量。
|
||
- kv_cache: bool, 是否使用键值缓存。
|
||
- keyargs: 其他关键字参数。
|
||
|
||
- self.OUT: CausalLMOutputWithPast, 包含 logits 和损失。
|
||
"""
|
||
|
||
if 'input_ids' in keyargs:
|
||
tokens = keyargs['input_ids']
|
||
if 'attention_mask' in keyargs:
|
||
targets = keyargs['attention_mask']
|
||
|
||
# 前向传播函数
|
||
_bsz, seqlen = tokens.shape
|
||
# 通过词嵌入层和Dropout层
|
||
h = self.tok_embeddings(tokens)
|
||
h = self.dropout(h)
|
||
# 获取相对位置嵌入的频率
|
||
freqs_cos = self.freqs_cos[:seqlen]
|
||
freqs_sin = self.freqs_sin[:seqlen]
|
||
|
||
# 通过Decoder层
|
||
for layer in self.layers:
|
||
h = layer(h, freqs_cos, freqs_sin)
|
||
# 通过归一化层
|
||
h = self.norm(h)
|
||
|
||
if targets is not None:
|
||
# 如果给定了目标,计算损失
|
||
logits = self.output(h)
|
||
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0, reduction='none')
|
||
else:
|
||
# 推理时的小优化:只对最后一个位置的输出进行前向传播
|
||
logits = self.output(h[:, [-1], :])
|
||
self.last_loss = None
|
||
|
||
# 设置输出
|
||
self.OUT.__setitem__('logits', logits)
|
||
self.OUT.__setitem__('last_loss', self.last_loss)
|
||
return self.OUT
|
||
|
||
|
||
@torch.inference_mode()
|
||
def generate(self, idx, stop_id=None, max_new_tokens=256, temperature=1.0, top_k=None):
|
||
"""
|
||
给定输入序列 idx(形状为 (bz,seq_len) 的长整型张量),通过多次生成新 token 来完成序列。
|
||
在 model.eval() 模式下运行。效率较低的采样版本,没有使用键k/v cache。
|
||
"""
|
||
index = idx.shape[1]
|
||
for _ in range(max_new_tokens):
|
||
# 如果序列上下文过长,截断它到最大长度
|
||
idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]
|
||
|
||
# 前向传播获取序列中最后一个位置的 logits
|
||
logits = self(idx_cond).logits
|
||
logits = logits[:, -1, :] # 只保留最后一个时间步的输出
|
||
|
||
if temperature == 0.0:
|
||
# 选择最有可能的索引
|
||
_, idx_next = torch.topk(logits, k=1, dim=-1)
|
||
else:
|
||
# 缩放 logits 并应用 softmax
|
||
logits = logits / temperature
|
||
if top_k is not None:
|
||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||
probs = F.softmax(logits, dim=-1)
|
||
idx_next = torch.multinomial(probs, num_samples=1)
|
||
|
||
|
||
if idx_next == stop_id:
|
||
break
|
||
|
||
# 将采样的索引添加到序列中并继续
|
||
idx = torch.cat((idx, idx_next), dim=1)
|
||
|
||
return idx[:, index:] # 只返回生成的token
|
||
|
||
if __name__ == '__main__':
|
||
tokenizer = AutoTokenizer.from_pretrained("tokenizer_k")
|
||
args = ModelConfig(
|
||
dim=1024,
|
||
n_layers=18,
|
||
)
|
||
# 实例化LLaMA2Model
|
||
model = Transformer(args=args)
|
||
# 计算model的全部参数
|
||
num_params = sum(p.numel() for p in model.parameters())
|
||
print(f'LLM总参数量:{num_params / 1e6:.3f} 百万')
|
||
|
||
prompt = "你好呀,今天吃什么呢?你过得怎么样嘞?"
|
||
text = f"{tokenizer.bos_token}{prompt}{tokenizer.eos_token}"
|
||
print(f"Input text: {text}")
|
||
|
||
input_id = tokenizer(text).data['input_ids']
|
||
print("input_ids :", input_id)
|
||
print("dcode_str :", tokenizer.decode(input_id))
|
||
|
||
X = torch.tensor(input_id[:-1]).unsqueeze(0)
|
||
Y = torch.tensor(input_id[1:]).unsqueeze(0)
|
||
print("X shape :", X.shape)
|
||
print("Y shape :", Y.shape)
|
||
|
||
# 将输入张量传入模型
|
||
output = model(X, Y) |