Files
happy-llm/docs/chapter5/code/train_vocab.py
2024-09-22 16:02:14 +08:00

147 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import glob
import json
import os
from tqdm import tqdm
import requests
import sentencepiece as spm
import argparse
DATA_CACHE_DIR = 'data'
def download_file(url: str, fname: str, chunk_size=1024):
"""发送HTTP GET请求以流式方式获取文件"""
resp = requests.get(url, stream=True)
# 获取文件的总大小以字节为单位默认为0如果没有提供'content-length'头信息
total = int(resp.headers.get("content-length", 0))
# 以写二进制模式打开一个文件以保存下载的内容
with open(fname, "wb") as file, tqdm(
desc=fname, # 进度条前面的描述信息(通常是文件名)
total=total, # 总的字节数,用于设置进度条的总长度
unit="iB", # 进度条的单位,'iB'代表二进制字节
unit_scale=True, # 启用单位缩放如KB、MB等
unit_divisor=1024, # 设置单位换算的除数这里为1024
) as bar:
# 逐块读取响应内容并写入文件
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data) # 写入数据块到文件
bar.update(size) # 更新进度条
def download():
"""在DATA_CACHE_DIR中创建目录如果目录不存在则创建"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
# 定义TinyStories数据集的下载URL和保存的文件名
data_url = "https://www.modelscope.cn/datasets/AI-ModelScope/TinyStories/resolve/master/TinyStories_all_data.tar.gz"
data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz")
# 检查数据集是否已经下载,如果没有下载则进行下载
if not os.path.exists(data_filename):
print(f"Downloading {data_url} to {data_filename}...")
download_file(data_url, data_filename) # 使用之前定义的download_file函数进行下载
else:
print(f"{data_filename} already exists, skipping download...")
# 定义解压缩后的数据目录
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
# 检查数据目录是否存在,如果不存在则解压缩数据集
if not os.path.exists(data_dir):
os.makedirs(data_dir, exist_ok=True) # 创建数据目录
print(f"Unpacking {data_filename}...")
os.system(f"tar -xzf {data_filename} -C {data_dir}") # 使用系统命令解压缩.tar.gz文件
else:
print(f"{data_dir} already exists, skipping unpacking...")
# 查找解压后的所有JSON文件排序后获取文件名列表
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
# 打开第一个JSON文件并读取内容
with open(shard_filenames[0], "r") as f:
data = json.load(f) # 将JSON文件内容加载到变量data中
print("Download done.") # 下载完成信息
print(f"Number of shards: {len(shard_filenames)}") # 打印解压后数据分片的数量
print(f"Example story:\n{data[0]}") # 打印第一个分片中的一个示例故事
def load_text_from_files(path):
path_list = glob.glob(path)
text_data = []
for file_path in path_list:
with open(file_path, 'r', encoding='utf-8') as file:
text_data.extend(file.readlines())
return text_data
def batch_iterator(text_data, batch_size=648):
for i in range(0, len(text_data), batch_size):
yield text_data[i:i + batch_size]
def train_vocab(vocab_size: int=32000, num_shards: int=20):
"""
vocab_size: int, 词汇表的大小,决定分词器的词汇量。
num_shards: int, 用于加快词汇表训练的效率,指定要处理的分片数量。
"""
# 确保词汇表大小为正数
assert vocab_size > 0, "Vocab size must be positive"
# SentencePiece 模型的前缀路径,将用于保存分词器
prefix = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
# 1) 将多个分片中的文本导出为单个文本文件 tiny.txt
tiny_file = os.path.join(DATA_CACHE_DIR, "tiny.txt")
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
# 创建 tiny.txt 文件并写入指定数量的分片中的文本
print(f"Writing temporary file {tiny_file} with {num_shards} shards...")
with open(tiny_file, "w", encoding="utf-8") as of:
# 遍历前 num_shards 个分片
for shard in tqdm(shard_filenames[:num_shards]):
with open(shard, "r") as f:
data = json.load(f) # 读取分片中的JSON数据
# 遍历每个例子,将其中的故事文本写入 tiny.txt 文件
for example in data:
text = example["story"]
text = text.strip() # 去除文本首尾的空白字符
of.write(text + "\n") # 每个文本写入一行
# 输出生成的 tiny.txt 文件的大小
print(f"Size is: {os.path.getsize(tiny_file) / 1024 / 1024:.2f} MB")
# 2) 使用 SentencePiece 训练分词器
print("Will now train the vocab...")
spm.SentencePieceTrainer.train(
input=tiny_file, # 输入文件为之前生成的 tiny.txt
model_prefix=prefix, # 模型前缀路径
model_type="bpe", # 使用 Byte-Pair Encoding (BPE) 训练分词器
vocab_size=vocab_size, # 词汇表大小
self_test_sample_size=0, # 自测样本大小设置为 0
input_format="text", # 输入文件格式为纯文本
character_coverage=1.0, # 覆盖所有字符(包括非常见字符)
num_threads=os.cpu_count(), # 使用 CPU 的线程数
split_digits=True, # 拆分数字
allow_whitespace_only_pieces=True, # 允许仅由空格组成的词元
byte_fallback=True, # 启用字节级回退
unk_surface=r" \342\201\207 ", # UNK token 表示未知字符的方式
normalization_rule_name="identity" # 使用“identity”归一化规则
)
# 3) 可选的清理操作,询问用户是否删除临时文件 tiny.txt
dec = input(f"Delete the temporary file {tiny_file}? [y/N] ")
if dec.lower() == "y":
os.remove(tiny_file) # 删除临时文件
print(f"Deleted {tiny_file}")
# 输出模型保存的路径
print(f"Trained tokenizer is in {prefix}.model")
print("Done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--download", type=bool, default=True, help="download the dataset")
parser.add_argument("--vocab_size", type=int, default=4096, help="vocab size")
args = parser.parse_args()
if args.download:
download()
train_vocab(args.vocab_size)