docs: 更新README和文档内容,添加模型下载链接

- 在README中新增模型下载章节,包含ModelScope链接
- 更新模型示例代码中的默认检查点路径
- 优化训练脚本的注释和参数说明
- 添加中文文档的模型下载和体验地址
- 修复文档中的训练时长和设备信息
This commit is contained in:
KMnO4-zx
2025-06-22 10:05:36 +08:00
parent b421894dcc
commit 3b24a9fd1e
5 changed files with 333 additions and 86 deletions

View File

@@ -8,7 +8,7 @@ import argparse
class TextGenerator:
def __init__(self,
checkpoint='out/SkyWork_pretrain_768_12_6144.pth', # 模型检查点路径
checkpoint='./base_model_215M/pretrain_1024_18_6144.pth', # 模型检查点路径
tokenizer_model_path='./tokenizer_k/', # 分词器模型路径
seed=42, # 随机种子,确保可重复性
device=None, # 设备,优先使用 CUDA如果没有可用的 CUDA则使用 CPU
@@ -55,7 +55,7 @@ class TextGenerator:
def chat_template(self, prompt):
message = [
{"role": "system", "content": "你是一个AI助手。"},
{"role": "system", "content": "你是一个AI助手,你的名字叫小明"},
{"role": "user", "content": prompt}
]
return self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
@@ -126,18 +126,6 @@ class TextGenerator:
return generated_texts # 返回生成的文本样本
if __name__ == "__main__":
print("\n ------------------- SFT Sample ------------------- \n")
sft_prompt_datas = [
'你好呀',
"中国的首都是哪里?",
"1+1等于多少",
]
generator = TextGenerator(checkpoint='./BeelGroup_sft_model_215M/sft_dim1024_layers18_vocab_size6144.pth') # 初始化生成器
for i in range(len(sft_prompt_datas)):
samples = generator.sft_sample(start=sft_prompt_datas[i], num_samples=1, max_new_tokens=512, temperature=0.75)
print(f"\nSample {i+1}:\nQuestion: {sft_prompt_datas[i]} \nAI answer: {samples[0]}\n{'-'*20}") # 打印生成的样本并用分隔线分割
print("------------------- Pretrain Sample ------------------- \n")
pretrain_prompt_datas = [
@@ -145,7 +133,22 @@ if __name__ == "__main__":
'<|im_start|>中国矿业大学(北京)地球科学与测绘工程学院',
]
generator = TextGenerator(checkpoint='./base_monkey_215M/pretrain_1024_18_6144.pth') # 初始化生成器
generator = TextGenerator(checkpoint='./base_model_215M/pretrain_1024_18_6144.pth') # 初始化生成器
for i in range(len(pretrain_prompt_datas)):
samples = generator.pretrain_sample(start=pretrain_prompt_datas[i], num_samples=1, max_new_tokens=120, temperature=1.0)
print(f"\nSample {i+1}:\n{pretrain_prompt_datas[i]}{samples[0]}\n{'-'*20}") # 打印生成的样本并用分隔线分割
samples = generator.pretrain_sample(start=pretrain_prompt_datas[i], num_samples=1, max_new_tokens=120, temperature=0.75)
print(f"\nSample {i+1}:\n{pretrain_prompt_datas[i]}{samples[0]}\n{'-'*20}") # 打印生成的样本并用分隔线分割
print("\n ------------------- SFT Sample ------------------- \n")
sft_prompt_datas = [
'你好呀',
"中国的首都是哪里?",
"1+12等于多少",
"你是谁?"
]
generator = TextGenerator(checkpoint='./sft_model_215M/sft_dim1024_layers18_vocab_size6144.pth') # 初始化生成器
for i in range(len(sft_prompt_datas)):
samples = generator.sft_sample(start=sft_prompt_datas[i], num_samples=1, max_new_tokens=128, temperature=0.6)
print(f"\nSample {i+1}:\nQuestion: {sft_prompt_datas[i]} \nAI answer: {samples[0]}\n{'-'*20}") # 打印生成的样本并用分隔线分割