105 lines
6.0 KiB
Python
105 lines
6.0 KiB
Python
import os
|
||
import pickle
|
||
from contextlib import nullcontext
|
||
import torch
|
||
from model import ModelArgs, Transformer
|
||
from tokenizer import Tokenizer
|
||
import argparse
|
||
|
||
class TextGenerator:
|
||
def __init__(self,
|
||
checkpoint='output/ckpt.pt', # 模型检查点路径
|
||
tokenizer_model_path='tok4096.model', # 分词器模型路径
|
||
seed=1337, # 随机种子,确保可重复性
|
||
device=None, # 设备,优先使用 CUDA,如果没有可用的 CUDA,则使用 CPU
|
||
dtype="float32"): # 数据类型,默认为 float32,可以选择 float16 或 bfloat16
|
||
"""
|
||
初始化 TextGenerator 类,加载模型、设置设备和分词器等。
|
||
"""
|
||
# 模型加载配置
|
||
self.checkpoint = checkpoint # 保存的模型检查点路径
|
||
self.tokenizer_model_path = tokenizer_model_path # 分词器模型文件路径
|
||
self.seed = seed # 随机数种子,用于生成的可重复性
|
||
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') # 根据硬件条件选择设备
|
||
self.dtype = dtype # 模型的浮点数类型
|
||
self.device_type = 'cuda' if 'cuda' in self.device else 'cpu' # 判断当前设备是否为 CUDA
|
||
|
||
# 设置随机种子,确保生成的可重复性
|
||
torch.manual_seed(seed) # 设置 CPU 随机种子
|
||
torch.cuda.manual_seed(seed) # 设置 CUDA 随机种子
|
||
torch.backends.cuda.matmul.allow_tf32 = True # 允许 CUDA 使用 TF32 精度进行矩阵乘法运算
|
||
torch.backends.cudnn.allow_tf32 = True # 允许 cuDNN 使用 TF32 精度加速
|
||
|
||
# 根据 dtype 选择适当的自动混合精度上下文
|
||
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[self.dtype]
|
||
self.ctx = nullcontext() if self.device_type == 'cpu' else torch.amp.autocast(device_type=self.device_type, dtype=ptdtype)
|
||
|
||
# 加载模型检查点文件
|
||
checkpoint_dict = torch.load(self.checkpoint, map_location=self.device) # 加载模型参数
|
||
gptconf = ModelArgs(**checkpoint_dict['model_args']) # 初始化模型参数
|
||
self.model = Transformer(gptconf) # 实例化 Transformer 模型
|
||
state_dict = checkpoint_dict['model'] # 获取模型状态字典
|
||
|
||
# 去除状态字典中的不必要前缀
|
||
unwanted_prefix = '_orig_mod.' # 这个前缀在保存时可能被添加,现在要去除它
|
||
for k, v in list(state_dict.items()):
|
||
if k.startswith(unwanted_prefix):
|
||
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) # 去除不必要的前缀
|
||
|
||
# 加载模型参数到模型中
|
||
self.model.load_state_dict(state_dict, strict=False)
|
||
# 计算模型参数量
|
||
num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||
print(f"Model has {num_params} parameters.")
|
||
# 设置模型为评估模式(evaluation mode),防止训练模式下的 dropout 等操作影响结果
|
||
self.model.eval()
|
||
# 将模型放置到正确的设备上(GPU 或 CPU)
|
||
self.model.to(self.device)
|
||
# 初始化分词器
|
||
self.tokenizer = Tokenizer(tokenizer_model=self.tokenizer_model_path) # 根据指定的路径加载分词器
|
||
|
||
def sample(self,
|
||
start="Hello!", # 生成文本的起始提示词,可以是任意字符串
|
||
num_samples=3, # 生成样本的数量,默认生成 3 个样本
|
||
max_new_tokens=256, # 每个样本生成的最大 token 数,默认最多生成 256 个 token
|
||
temperature=1.0, # 控制生成的随机性,1.0 为标准,值越大越随机
|
||
top_k=300): # 保留概率最高的 top_k 个 token,限制生成时的选择范围
|
||
"""
|
||
根据给定的起始文本生成样本。
|
||
|
||
:param start: 生成文本的起始提示词
|
||
:param num_samples: 要生成的文本样本数
|
||
:param max_new_tokens: 每个样本生成的最大 token 数
|
||
:param temperature: 控制生成的随机性,值越小生成越确定,值越大生成越随机
|
||
:param top_k: 限制生成时选择的 token 范围
|
||
:return: 生成的文本样本列表
|
||
"""
|
||
# 如果 start 是以 'FILE:' 开头,表示从文件中读取起始文本
|
||
if start.startswith('FILE:'):
|
||
with open(start[5:], 'r', encoding='utf-8') as f:
|
||
start = f.read() # 读取文件内容作为起始文本
|
||
|
||
# 将起始文本编码为 token id 序列
|
||
start_ids = self.tokenizer.encode(start, bos=True, eos=False) # bos=True 表示加上句首标记,eos=False 表示不加句尾标记
|
||
x = (torch.tensor(start_ids, dtype=torch.long, device=self.device)[None, ...]) # 将编码后的 token id 转为 PyTorch 张量
|
||
|
||
generated_texts = [] # 用于保存生成的文本样本
|
||
with torch.no_grad(): # 禁用梯度计算,提升效率
|
||
with self.ctx: # 进入自动混合精度的上下文(如果是 GPU 并使用 float16 时)
|
||
for k in range(num_samples): # 循环生成指定数量的样本
|
||
y = self.model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) # 生成文本
|
||
generated_texts.append(self.tokenizer.decode(y[0].tolist())) # 解码生成的 token 序列为可读文本
|
||
|
||
return generated_texts # 返回生成的文本样本
|
||
|
||
# 示例使用
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--prompt", type=str, default="One day, Lily met a Shoggoth")
|
||
args = parser.parse_args()
|
||
|
||
generator = TextGenerator() # 初始化生成器
|
||
samples = generator.sample(start=args.prompt, num_samples=3, max_new_tokens=256) # 生成 3 个样本
|
||
for i, sample in enumerate(samples):
|
||
print(f"\nSample {i+1}:\n{sample}\n{'-'*20}") # 打印生成的样本并用分隔线分割
|