Add: ch5.3 code

This commit is contained in:
KMnO4-zx
2024-09-22 16:02:14 +08:00
parent c579aff59d
commit 9e6d8a3f77
9 changed files with 788 additions and 739 deletions

View File

@@ -0,0 +1,147 @@
import glob
import json
import os
from tqdm import tqdm
import requests
import sentencepiece as spm
import argparse
DATA_CACHE_DIR = 'data'
def download_file(url: str, fname: str, chunk_size=1024):
"""发送HTTP GET请求以流式方式获取文件"""
resp = requests.get(url, stream=True)
# 获取文件的总大小以字节为单位默认为0如果没有提供'content-length'头信息
total = int(resp.headers.get("content-length", 0))
# 以写二进制模式打开一个文件以保存下载的内容
with open(fname, "wb") as file, tqdm(
desc=fname, # 进度条前面的描述信息(通常是文件名)
total=total, # 总的字节数,用于设置进度条的总长度
unit="iB", # 进度条的单位,'iB'代表二进制字节
unit_scale=True, # 启用单位缩放如KB、MB等
unit_divisor=1024, # 设置单位换算的除数这里为1024
) as bar:
# 逐块读取响应内容并写入文件
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data) # 写入数据块到文件
bar.update(size) # 更新进度条
def download():
"""在DATA_CACHE_DIR中创建目录如果目录不存在则创建"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
# 定义TinyStories数据集的下载URL和保存的文件名
data_url = "https://www.modelscope.cn/datasets/AI-ModelScope/TinyStories/resolve/master/TinyStories_all_data.tar.gz"
data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz")
# 检查数据集是否已经下载,如果没有下载则进行下载
if not os.path.exists(data_filename):
print(f"Downloading {data_url} to {data_filename}...")
download_file(data_url, data_filename) # 使用之前定义的download_file函数进行下载
else:
print(f"{data_filename} already exists, skipping download...")
# 定义解压缩后的数据目录
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
# 检查数据目录是否存在,如果不存在则解压缩数据集
if not os.path.exists(data_dir):
os.makedirs(data_dir, exist_ok=True) # 创建数据目录
print(f"Unpacking {data_filename}...")
os.system(f"tar -xzf {data_filename} -C {data_dir}") # 使用系统命令解压缩.tar.gz文件
else:
print(f"{data_dir} already exists, skipping unpacking...")
# 查找解压后的所有JSON文件排序后获取文件名列表
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
# 打开第一个JSON文件并读取内容
with open(shard_filenames[0], "r") as f:
data = json.load(f) # 将JSON文件内容加载到变量data中
print("Download done.") # 下载完成信息
print(f"Number of shards: {len(shard_filenames)}") # 打印解压后数据分片的数量
print(f"Example story:\n{data[0]}") # 打印第一个分片中的一个示例故事
def load_text_from_files(path):
path_list = glob.glob(path)
text_data = []
for file_path in path_list:
with open(file_path, 'r', encoding='utf-8') as file:
text_data.extend(file.readlines())
return text_data
def batch_iterator(text_data, batch_size=648):
for i in range(0, len(text_data), batch_size):
yield text_data[i:i + batch_size]
def train_vocab(vocab_size: int=32000, num_shards: int=20):
"""
vocab_size: int, 词汇表的大小,决定分词器的词汇量。
num_shards: int, 用于加快词汇表训练的效率,指定要处理的分片数量。
"""
# 确保词汇表大小为正数
assert vocab_size > 0, "Vocab size must be positive"
# SentencePiece 模型的前缀路径,将用于保存分词器
prefix = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
# 1) 将多个分片中的文本导出为单个文本文件 tiny.txt
tiny_file = os.path.join(DATA_CACHE_DIR, "tiny.txt")
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
# 创建 tiny.txt 文件并写入指定数量的分片中的文本
print(f"Writing temporary file {tiny_file} with {num_shards} shards...")
with open(tiny_file, "w", encoding="utf-8") as of:
# 遍历前 num_shards 个分片
for shard in tqdm(shard_filenames[:num_shards]):
with open(shard, "r") as f:
data = json.load(f) # 读取分片中的JSON数据
# 遍历每个例子,将其中的故事文本写入 tiny.txt 文件
for example in data:
text = example["story"]
text = text.strip() # 去除文本首尾的空白字符
of.write(text + "\n") # 每个文本写入一行
# 输出生成的 tiny.txt 文件的大小
print(f"Size is: {os.path.getsize(tiny_file) / 1024 / 1024:.2f} MB")
# 2) 使用 SentencePiece 训练分词器
print("Will now train the vocab...")
spm.SentencePieceTrainer.train(
input=tiny_file, # 输入文件为之前生成的 tiny.txt
model_prefix=prefix, # 模型前缀路径
model_type="bpe", # 使用 Byte-Pair Encoding (BPE) 训练分词器
vocab_size=vocab_size, # 词汇表大小
self_test_sample_size=0, # 自测样本大小设置为 0
input_format="text", # 输入文件格式为纯文本
character_coverage=1.0, # 覆盖所有字符(包括非常见字符)
num_threads=os.cpu_count(), # 使用 CPU 的线程数
split_digits=True, # 拆分数字
allow_whitespace_only_pieces=True, # 允许仅由空格组成的词元
byte_fallback=True, # 启用字节级回退
unk_surface=r" \342\201\207 ", # UNK token 表示未知字符的方式
normalization_rule_name="identity" # 使用“identity”归一化规则
)
# 3) 可选的清理操作,询问用户是否删除临时文件 tiny.txt
dec = input(f"Delete the temporary file {tiny_file}? [y/N] ")
if dec.lower() == "y":
os.remove(tiny_file) # 删除临时文件
print(f"Deleted {tiny_file}")
# 输出模型保存的路径
print(f"Trained tokenizer is in {prefix}.model")
print("Done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--download", type=bool, default=True, help="download the dataset")
parser.add_argument("--vocab_size", type=int, default=4096, help="vocab size")
args = parser.parse_args()
if args.download:
download()
train_vocab(args.vocab_size)