docs: 更新README和文档内容,添加模型下载链接
- 在README中新增模型下载章节,包含ModelScope链接 - 更新模型示例代码中的默认检查点路径 - 优化训练脚本的注释和参数说明 - 添加中文文档的模型下载和体验地址 - 修复文档中的训练时长和设备信息
This commit is contained in:
@@ -8,7 +8,7 @@ import argparse
|
||||
|
||||
class TextGenerator:
|
||||
def __init__(self,
|
||||
checkpoint='out/SkyWork_pretrain_768_12_6144.pth', # 模型检查点路径
|
||||
checkpoint='./base_model_215M/pretrain_1024_18_6144.pth', # 模型检查点路径
|
||||
tokenizer_model_path='./tokenizer_k/', # 分词器模型路径
|
||||
seed=42, # 随机种子,确保可重复性
|
||||
device=None, # 设备,优先使用 CUDA,如果没有可用的 CUDA,则使用 CPU
|
||||
@@ -55,7 +55,7 @@ class TextGenerator:
|
||||
|
||||
def chat_template(self, prompt):
|
||||
message = [
|
||||
{"role": "system", "content": "你是一个AI助手。"},
|
||||
{"role": "system", "content": "你是一个AI助手,你的名字叫小明。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
return self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
|
||||
@@ -126,18 +126,6 @@ class TextGenerator:
|
||||
return generated_texts # 返回生成的文本样本
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n ------------------- SFT Sample ------------------- \n")
|
||||
sft_prompt_datas = [
|
||||
'你好呀',
|
||||
"中国的首都是哪里?",
|
||||
"1+1等于多少?",
|
||||
]
|
||||
generator = TextGenerator(checkpoint='./BeelGroup_sft_model_215M/sft_dim1024_layers18_vocab_size6144.pth') # 初始化生成器
|
||||
for i in range(len(sft_prompt_datas)):
|
||||
samples = generator.sft_sample(start=sft_prompt_datas[i], num_samples=1, max_new_tokens=512, temperature=0.75)
|
||||
print(f"\nSample {i+1}:\nQuestion: {sft_prompt_datas[i]} \nAI answer: {samples[0]}\n{'-'*20}") # 打印生成的样本并用分隔线分割
|
||||
|
||||
|
||||
print("------------------- Pretrain Sample ------------------- \n")
|
||||
|
||||
pretrain_prompt_datas = [
|
||||
@@ -145,7 +133,22 @@ if __name__ == "__main__":
|
||||
'<|im_start|>中国矿业大学(北京)地球科学与测绘工程学院',
|
||||
]
|
||||
|
||||
generator = TextGenerator(checkpoint='./base_monkey_215M/pretrain_1024_18_6144.pth') # 初始化生成器
|
||||
generator = TextGenerator(checkpoint='./base_model_215M/pretrain_1024_18_6144.pth') # 初始化生成器
|
||||
for i in range(len(pretrain_prompt_datas)):
|
||||
samples = generator.pretrain_sample(start=pretrain_prompt_datas[i], num_samples=1, max_new_tokens=120, temperature=1.0)
|
||||
print(f"\nSample {i+1}:\n{pretrain_prompt_datas[i]}{samples[0]}\n{'-'*20}") # 打印生成的样本并用分隔线分割
|
||||
samples = generator.pretrain_sample(start=pretrain_prompt_datas[i], num_samples=1, max_new_tokens=120, temperature=0.75)
|
||||
print(f"\nSample {i+1}:\n{pretrain_prompt_datas[i]}{samples[0]}\n{'-'*20}") # 打印生成的样本并用分隔线分割
|
||||
|
||||
print("\n ------------------- SFT Sample ------------------- \n")
|
||||
|
||||
sft_prompt_datas = [
|
||||
'你好呀',
|
||||
"中国的首都是哪里?",
|
||||
"1+12等于多少?",
|
||||
"你是谁?"
|
||||
]
|
||||
generator = TextGenerator(checkpoint='./sft_model_215M/sft_dim1024_layers18_vocab_size6144.pth') # 初始化生成器
|
||||
for i in range(len(sft_prompt_datas)):
|
||||
samples = generator.sft_sample(start=sft_prompt_datas[i], num_samples=1, max_new_tokens=128, temperature=0.6)
|
||||
print(f"\nSample {i+1}:\nQuestion: {sft_prompt_datas[i]} \nAI answer: {samples[0]}\n{'-'*20}") # 打印生成的样本并用分隔线分割
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user