Files
happy-llm/docs/chapter5/code/sample.py
2024-09-22 16:02:14 +08:00

105 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import pickle
from contextlib import nullcontext
import torch
from model import ModelArgs, Transformer
from tokenizer import Tokenizer
import argparse
class TextGenerator:
def __init__(self,
checkpoint='output/ckpt.pt', # 模型检查点路径
tokenizer_model_path='tok4096.model', # 分词器模型路径
seed=1337, # 随机种子,确保可重复性
device=None, # 设备,优先使用 CUDA如果没有可用的 CUDA则使用 CPU
dtype="float32"): # 数据类型,默认为 float32可以选择 float16 或 bfloat16
"""
初始化 TextGenerator 类,加载模型、设置设备和分词器等。
"""
# 模型加载配置
self.checkpoint = checkpoint # 保存的模型检查点路径
self.tokenizer_model_path = tokenizer_model_path # 分词器模型文件路径
self.seed = seed # 随机数种子,用于生成的可重复性
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') # 根据硬件条件选择设备
self.dtype = dtype # 模型的浮点数类型
self.device_type = 'cuda' if 'cuda' in self.device else 'cpu' # 判断当前设备是否为 CUDA
# 设置随机种子,确保生成的可重复性
torch.manual_seed(seed) # 设置 CPU 随机种子
torch.cuda.manual_seed(seed) # 设置 CUDA 随机种子
torch.backends.cuda.matmul.allow_tf32 = True # 允许 CUDA 使用 TF32 精度进行矩阵乘法运算
torch.backends.cudnn.allow_tf32 = True # 允许 cuDNN 使用 TF32 精度加速
# 根据 dtype 选择适当的自动混合精度上下文
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[self.dtype]
self.ctx = nullcontext() if self.device_type == 'cpu' else torch.amp.autocast(device_type=self.device_type, dtype=ptdtype)
# 加载模型检查点文件
checkpoint_dict = torch.load(self.checkpoint, map_location=self.device) # 加载模型参数
gptconf = ModelArgs(**checkpoint_dict['model_args']) # 初始化模型参数
self.model = Transformer(gptconf) # 实例化 Transformer 模型
state_dict = checkpoint_dict['model'] # 获取模型状态字典
# 去除状态字典中的不必要前缀
unwanted_prefix = '_orig_mod.' # 这个前缀在保存时可能被添加,现在要去除它
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) # 去除不必要的前缀
# 加载模型参数到模型中
self.model.load_state_dict(state_dict, strict=False)
# 计算模型参数量
num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print(f"Model has {num_params} parameters.")
# 设置模型为评估模式evaluation mode防止训练模式下的 dropout 等操作影响结果
self.model.eval()
# 将模型放置到正确的设备上GPU 或 CPU
self.model.to(self.device)
# 初始化分词器
self.tokenizer = Tokenizer(tokenizer_model=self.tokenizer_model_path) # 根据指定的路径加载分词器
def sample(self,
start="Hello!", # 生成文本的起始提示词,可以是任意字符串
num_samples=3, # 生成样本的数量,默认生成 3 个样本
max_new_tokens=256, # 每个样本生成的最大 token 数,默认最多生成 256 个 token
temperature=1.0, # 控制生成的随机性1.0 为标准,值越大越随机
top_k=300): # 保留概率最高的 top_k 个 token限制生成时的选择范围
"""
根据给定的起始文本生成样本。
:param start: 生成文本的起始提示词
:param num_samples: 要生成的文本样本数
:param max_new_tokens: 每个样本生成的最大 token 数
:param temperature: 控制生成的随机性,值越小生成越确定,值越大生成越随机
:param top_k: 限制生成时选择的 token 范围
:return: 生成的文本样本列表
"""
# 如果 start 是以 'FILE:' 开头,表示从文件中读取起始文本
if start.startswith('FILE:'):
with open(start[5:], 'r', encoding='utf-8') as f:
start = f.read() # 读取文件内容作为起始文本
# 将起始文本编码为 token id 序列
start_ids = self.tokenizer.encode(start, bos=True, eos=False) # bos=True 表示加上句首标记eos=False 表示不加句尾标记
x = (torch.tensor(start_ids, dtype=torch.long, device=self.device)[None, ...]) # 将编码后的 token id 转为 PyTorch 张量
generated_texts = [] # 用于保存生成的文本样本
with torch.no_grad(): # 禁用梯度计算,提升效率
with self.ctx: # 进入自动混合精度的上下文(如果是 GPU 并使用 float16 时)
for k in range(num_samples): # 循环生成指定数量的样本
y = self.model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) # 生成文本
generated_texts.append(self.tokenizer.decode(y[0].tolist())) # 解码生成的 token 序列为可读文本
return generated_texts # 返回生成的文本样本
# 示例使用
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="One day, Lily met a Shoggoth")
args = parser.parse_args()
generator = TextGenerator() # 初始化生成器
samples = generator.sample(start=args.prompt, num_samples=3, max_new_tokens=256) # 生成 3 个样本
for i, sample in enumerate(samples):
print(f"\nSample {i+1}:\n{sample}\n{'-'*20}") # 打印生成的样本并用分隔线分割