Files
happy-llm/docs/chapter2/code/transformer.py
Logan Zou 3950b06a5f Update transformer.py
fix arg bug
2025-06-23 10:53:25 +08:00

349 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import math
from torch import nn
from dataclasses import dataclass
from transformers import BertTokenizer
import torch.nn.functional as F
@dataclass
class ModelArgs:
n_embd: int # 嵌入维度
n_heads: int # 头数
dim: int # 模型维度
dropout: float
max_seq_len: int
vocab_size: int
block_size: int
n_layer: int
class MultiHeadAttention(nn.Module):
def __init__(self, args: ModelArgs, is_causal=False):
# 构造函数
# args: 配置对象
super().__init__()
# 隐藏层维度必须是头数的整数倍,因为后面我们会将输入拆成头数个矩阵
assert args.dim % args.n_heads == 0
# 模型并行处理大小默认为1。
model_parallel_size = 1
# 本地计算头数,等于总头数除以模型并行处理大小。
self.n_local_heads = args.n_heads // model_parallel_size
# 每个头的维度,等于模型维度除以头的总数。
self.head_dim = args.dim // args.n_heads
# Wq, Wk, Wv 参数矩阵,每个参数矩阵为 n_embd x n_embd
# 这里通过三个组合矩阵来代替了n个参数矩阵的组合其逻辑在于矩阵内积再拼接其实等同于拼接矩阵再内积
# 不理解的读者可以自行模拟一下每一个线性层其实相当于n个参数矩阵的拼接
self.wq = nn.Linear(args.n_embd, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.n_embd, args.n_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.n_embd, args.n_heads * self.head_dim, bias=False)
# 输出权重矩阵,维度为 dim x n_embdhead_dim = n_embeds / n_heads
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
# 注意力的 dropout
self.attn_dropout = nn.Dropout(args.dropout)
# 残差连接的 dropout
self.resid_dropout = nn.Dropout(args.dropout)
self.is_causal = is_causal
# 创建一个上三角矩阵,用于遮蔽未来信息
# 注意因为是多头注意力Mask 矩阵比之前我们定义的多一个维度
if is_causal:
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
# 注册为模型的缓冲区
self.register_buffer("mask", mask)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
# 获取批次大小和序列长度,[batch_size, seq_len, dim]
bsz, seqlen, _ = q.shape
# 计算查询Q、键K、值V,输入通过参数矩阵层,维度为 (B, T, n_embed) x (n_embed, n_embed) -> (B, T, n_embed)
xq, xk, xv = self.wq(q), self.wk(k), self.wv(v)
# 将 Q、K、V 拆分成多头,维度为 (B, T, n_head, C // n_head),然后交换维度,变成 (B, n_head, T, C // n_head)
# 因为在注意力计算中我们是取了后两个维度参与计算
# 为什么要先按B*T*n_head*C//n_head展开再互换1、2维度而不是直接按注意力输入展开是因为view的展开方式是直接把输入全部排开
# 然后按要求构造,可以发现只有上述操作能够实现我们将每个头对应部分取出来的目标
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
# 注意力计算
# 计算 QK^T / sqrt(d_k),维度为 (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
# 掩码自注意力必须有注意力掩码
if self.is_causal:
assert hasattr(self, 'mask')
# 这里截取到序列长度,因为有些序列可能比 max_seq_len 短
scores = scores + self.mask[:, :, :seqlen, :seqlen]
# 计算 softmax维度为 (B, nh, T, T)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
# 做 Dropout
scores = self.attn_dropout(scores)
# V * Score维度为(B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
output = torch.matmul(scores, xv)
# 恢复时间维度并合并头。
# 将多头的结果拼接起来, 先交换维度为 (B, T, n_head, C // n_head),再拼接成 (B, T, n_head * C // n_head)
# contiguous 函数用于重新开辟一块新内存存储因为Pytorch设置先transpose再view会报错
# 因为view直接基于底层存储得到然而transpose并不会改变底层存储因此需要额外存储
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
# 最终投影回残差流。
output = self.wo(output)
output = self.resid_dropout(output)
return output
class LayerNorm(nn.Module):
''' Layer Norm 层'''
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
# 线性矩阵做映射
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
# 在统计每个样本所有维度的值,求均值和方差
mean = x.mean(-1, keepdim=True) # mean: [bsz, max_len, 1]
std = x.std(-1, keepdim=True) # std: [bsz, max_len, 1]
# 注意这里也在最后一个维度发生了广播
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
class MLP(nn.Module):
'''前馈神经网络'''
def __init__(self, dim: int, hidden_dim: int, dropout: float):
super().__init__()
# 定义第一层线性变换,从输入维度到隐藏维度
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
# 定义第二层线性变换,从隐藏维度到输入维度
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
# 定义dropout层用于防止过拟合
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 前向传播函数
# 首先输入x通过第一层线性变换和RELU激活函数
# 然后结果乘以输入x通过第三层线性变换的结果
# 最后通过第二层线性变换和dropout层
return self.dropout(self.w2(F.relu(self.w1(x))))
class EncoderLayer(nn.Module):
def __init__(self, args):
super().__init__()
# 一个 Layer 中有两个 LayerNorm分别在 Attention 之前和 MLP 之前
self.attention_norm = LayerNorm(args.n_embd)
# Encoder 不需要掩码,传入 is_causal=False
self.attention = MultiHeadAttention(args, is_causal=False)
self.fnn_norm = LayerNorm(args.n_embd)
self.feed_forward = MLP(args.dim, args.dim, args.dropout)
def forward(self, x):
# Layer Norm
x = self.attention_norm(x)
# 自注意力
h = x + self.attention.forward(x, x, x)
# 经过前馈神经网络
out = h + self.feed_forward.forward(self.fnn_norm(h))
return out
class Encoder(nn.Module):
'''Encoder 块'''
def __init__(self, args):
super(Encoder, self).__init__()
# 一个 Encoder 由 N 个 Encoder Layer 组成
self.layers = nn.ModuleList([EncoderLayer(args) for _ in range(args.n_layer)])
self.norm = LayerNorm(args.n_embd)
def forward(self, x):
"分别通过 N 层 Encoder Layer"
for layer in self.layers:
x = layer(x)
return self.norm(x)
class DecoderLayer(nn.Module):
'''Decoder 层'''
def __init__(self, args):
super().__init__()
# 一个 Layer 中有三个 LayerNorm分别在 Mask Attention 之前、Self Attention 之前和 MLP 之前
self.attention_norm_1 = LayerNorm(args.n_embd)
# Decoder 的第一个部分是 Mask Attention传入 is_causal=True
self.mask_attention = MultiHeadAttention(args, is_causal=True)
self.attention_norm_2 = LayerNorm(args.n_embd)
# Decoder 的第二个部分是 类似于 Encoder 的 Attention传入 is_causal=False
self.attention = MultiHeadAttention(args, is_causal=False)
self.ffn_norm = LayerNorm(args.n_embd)
# 第三个部分是 MLP
self.feed_forward = MLP(args.dim, args.dim, args.dropout)
def forward(self, x, enc_out):
# Layer Norm
x = self.attention_norm_1(x)
# 掩码自注意力
x = x + self.mask_attention.forward(x, x, x)
# 多头注意力
x = self.attention_norm_2(x)
h = x + self.attention.forward(x, enc_out, enc_out)
# 经过前馈神经网络
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
class Decoder(nn.Module):
'''解码器'''
def __init__(self, args):
super(Decoder, self).__init__()
# 一个 Decoder 由 N 个 Decoder Layer 组成
self.layers = nn.ModuleList([DecoderLayer(args) for _ in range(args.n_layer)])
self.norm = LayerNorm(args.n_embd)
def forward(self, x, enc_out):
"Pass the input (and mask) through each layer in turn."
for layer in self.layers:
x = layer(x, enc_out)
return self.norm(x)
class PositionalEncoding(nn.Module):
'''位置编码模块'''
def __init__(self, args):
super(PositionalEncoding, self).__init__()
# Dropout 层
self.dropout = nn.Dropout(p=args.dropout)
# block size 是序列的最大长度
pe = torch.zeros(args.block_size, args.n_embd)
position = torch.arange(0, args.block_size).unsqueeze(1)
# 计算 theta
div_term = torch.exp(
torch.arange(0, args.n_embd, 2) * -(math.log(10000.0) / args.n_embd)
)
# 分别计算 sin、cos 结果
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
# 将位置编码加到 Embedding 结果上
x = x + self.pe[:, : x.size(1)].requires_grad_(False)
return self.dropout(x)
class Transformer(nn.Module):
'''整体模型'''
def __init__(self, args):
super().__init__()
# 必须输入词表大小和 block size
assert args.vocab_size is not None
assert args.block_size is not None
self.args = args
self.transformer = nn.ModuleDict(dict(
wte=nn.Embedding(args.vocab_size, args.n_embd),
wpe=PositionalEncoding(args),
drop=nn.Dropout(args.dropout),
encoder=Encoder(args),
decoder=Decoder(args),
))
# 最后的线性层,输入是 n_embd输出是词表大小
self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
# 初始化所有的权重
self.apply(self._init_weights)
# 查看所有参数的数量
print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
'''统计所有参数的数量'''
def get_num_params(self, non_embedding=False):
# non_embedding: 是否统计 embedding 的参数
n_params = sum(p.numel() for p in self.parameters())
# 如果不统计 embedding 的参数,就减去
if non_embedding:
n_params -= self.transformer.wpe.weight.numel()
return n_params
'''初始化权重'''
def _init_weights(self, module):
# 线性层和 Embedding 层初始化为正则分布
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
'''前向计算函数'''
def forward(self, idx, targets=None):
# 输入为 idx维度为 (batch size, sequence length, 1)targets 为目标序列,用于计算 loss
device = idx.device
b, t = idx.size()
assert t <= self.args.block_size, f"不能计算该序列,该序列长度为 {t}, 最大序列长度只有 {self.args.block_size}"
# 通过 self.transformer
# 首先将输入 idx 通过 Embedding 层,得到维度为 (batch size, sequence length, n_embd)
print("idx", idx.size())
# 通过 Embedding 层
tok_emb = self.transformer.wte(idx)
print("tok_emb", tok_emb.size())
# 然后通过位置编码
pos_emb = self.transformer.wpe(tok_emb)
# 再进行 Dropout
x = self.transformer.drop(pos_emb)
# 然后通过 Encoder
print("x after wpe:", x.size())
enc_out = self.transformer.encoder(x)
print("enc_out:", enc_out.size())
# 再通过 Decoder
x = self.transformer.decoder(x, enc_out)
print("x after decoder:", x.size())
if targets is not None:
# 训练阶段,如果我们给了 targets就计算 loss
# 先通过最后的 Linear 层,得到维度为 (batch size, sequence length, vocab size)
logits = self.lm_head(x)
# 再跟 targets 计算交叉熵
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# 推理阶段,我们只需要 logitsloss 为 None
# 取 -1 是只取序列中的最后一个作为输出
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None
return logits, loss
def main():
args = ModelArgs(100, 10, 100, 0.1, 512, 1000, 1000, 2)
text = "我喜欢快乐地学习大模型"
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
inputs_token = tokenizer(
text,
return_tensors='pt',
max_length=args.max_seq_len,
truncation=True,
padding='max_length'
)
args.vocab_size = tokenizer.vocab_size
transformer = Transformer(args)
inputs_id = inputs_token['input_ids']
logits, loss = transformer.forward(inputs_id)
print(logits)
predicted_ids = torch.argmax(logits, dim=-1).item()
output = tokenizer.decode(predicted_ids)
print(output)
if __name__ == "__main__":
print("开始")
main()