Files
happy-llm/docs/chapter6/code/pretrain.ipynb
2025-04-25 10:04:43 +08:00

883 lines
50 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "bb9102c3-5b8d-4295-8f29-113b35ec5679",
"metadata": {},
"source": [
"# 一、LLM 预训练"
]
},
{
"cell_type": "markdown",
"id": "8557a6a6-294a-49c3-a8f6-e58bc3bf443d",
"metadata": {},
"source": [
"1.1 初始化 LLM"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "25f1fad8-772c-474e-a43e-77623106485d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Qwen2Config {\n",
" \"_name_or_path\": \"autodl-tmp/qwen-1.5b\",\n",
" \"architectures\": [\n",
" \"Qwen2ForCausalLM\"\n",
" ],\n",
" \"attention_dropout\": 0.0,\n",
" \"bos_token_id\": 151643,\n",
" \"eos_token_id\": 151643,\n",
" \"hidden_act\": \"silu\",\n",
" \"hidden_size\": 1536,\n",
" \"initializer_range\": 0.02,\n",
" \"intermediate_size\": 8960,\n",
" \"max_position_embeddings\": 131072,\n",
" \"max_window_layers\": 28,\n",
" \"model_type\": \"qwen2\",\n",
" \"num_attention_heads\": 12,\n",
" \"num_hidden_layers\": 28,\n",
" \"num_key_value_heads\": 2,\n",
" \"rms_norm_eps\": 1e-06,\n",
" \"rope_theta\": 1000000.0,\n",
" \"sliding_window\": null,\n",
" \"tie_word_embeddings\": true,\n",
" \"torch_dtype\": \"bfloat16\",\n",
" \"transformers_version\": \"4.44.2\",\n",
" \"use_cache\": true,\n",
" \"use_mrope\": false,\n",
" \"use_sliding_window\": false,\n",
" \"vocab_size\": 151936\n",
"}"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 加载定义好的模型参数-此处以 Qwen-2.5-1.5B 为例\n",
"# 使用 transforemrs 的 Config 类进行加载\n",
"from transformers import AutoConfig\n",
"\n",
"model_path = \"autodl-tmp/qwen-1.5b\"\n",
"config = AutoConfig.from_pretrained(model_path)\n",
"config"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "82b075a1-4fe9-4abb-b5b4-769d1c1a7156",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training new model from scratch - Total size=1472.20M params\n"
]
}
],
"source": [
"# 使用该配置生成一个定义好的模型\n",
"from transformers import AutoModelForCausalLM\n",
"\n",
"model = AutoModelForCausalLM.from_config(config,trust_remote_code=True)\n",
"model.to(\"cuda\")\n",
"n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())\n",
"print(f\"Training new model from scratch - Total size={n_params/2**20:.2f}M params\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e05ea707-23db-4e67-8b7d-e57d019887dd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Qwen2ForCausalLM(\n",
" (model): Qwen2Model(\n",
" (embed_tokens): Embedding(151936, 1536)\n",
" (layers): ModuleList(\n",
" (0-27): 28 x Qwen2DecoderLayer(\n",
" (self_attn): Qwen2SdpaAttention(\n",
" (q_proj): Linear(in_features=1536, out_features=1536, bias=True)\n",
" (k_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
" (v_proj): Linear(in_features=1536, out_features=256, bias=True)\n",
" (o_proj): Linear(in_features=1536, out_features=1536, bias=False)\n",
" (rotary_emb): Qwen2RotaryEmbedding()\n",
" )\n",
" (mlp): Qwen2MLP(\n",
" (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
" (up_proj): Linear(in_features=1536, out_features=8960, bias=False)\n",
" (down_proj): Linear(in_features=8960, out_features=1536, bias=False)\n",
" (act_fn): SiLU()\n",
" )\n",
" (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
" (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
" )\n",
" )\n",
" (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n",
" )\n",
" (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n",
")"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 看一下模型\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3408137b-eb50-4119-be1c-7a4ff951ab24",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Qwen2TokenizerFast(name_or_path='autodl-tmp/qwen-1.5b', vocab_size=151643, model_max_length=131072, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<|endoftext|>', 'pad_token': '<|endoftext|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>', '<|object_ref_start|>', '<|object_ref_end|>', '<|box_start|>', '<|box_end|>', '<|quad_start|>', '<|quad_end|>', '<|vision_start|>', '<|vision_end|>', '<|vision_pad|>', '<|image_pad|>', '<|video_pad|>']}, clean_up_tokenization_spaces=False), added_tokens_decoder={\n",
"\t151643: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151644: AddedToken(\"<|im_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151645: AddedToken(\"<|im_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151646: AddedToken(\"<|object_ref_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151647: AddedToken(\"<|object_ref_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151648: AddedToken(\"<|box_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151649: AddedToken(\"<|box_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151650: AddedToken(\"<|quad_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151651: AddedToken(\"<|quad_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151652: AddedToken(\"<|vision_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151653: AddedToken(\"<|vision_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151654: AddedToken(\"<|vision_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151655: AddedToken(\"<|image_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151656: AddedToken(\"<|video_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151657: AddedToken(\"<tool_call>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t151658: AddedToken(\"</tool_call>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t151659: AddedToken(\"<|fim_prefix|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t151660: AddedToken(\"<|fim_middle|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t151661: AddedToken(\"<|fim_suffix|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t151662: AddedToken(\"<|fim_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t151663: AddedToken(\"<|repo_name|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t151664: AddedToken(\"<|file_sep|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 加载一个预训练好的 tokenizer\n",
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
"tokenizer"
]
},
{
"cell_type": "markdown",
"id": "221a0fe2-a244-4e73-b82c-6da255d710dd",
"metadata": {},
"source": [
"1.2 预训练数据准备"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "936261a6-94cf-4cf3-842c-d3f1fde47a71",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "66ae9baa159b424ea5f5bc8d05b9b567",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0 examples [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 加载预训练数据\n",
"from datasets import load_dataset\n",
"\n",
"ds = load_dataset('json', data_files='autodl-tmp/dataset/pretrain_data/mobvoi_seq_monkey_general_open_corpus_small.jsonl')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "068edbb9-cb3c-49b1-aaf9-67b97ddfc58c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'text': '在查处虚开增值税专用发票案件中,常常涉及进项留抵税额和税款损失的认定和处理。在计算税款损失时,要不要将进项留抵税额包括在内?\\n对此实务中存在意见分歧。\\n有人主张归并即计算税款损失时包括进项留抵税额\\n有人主张剥离即计算税款损失时剔除进项留抵税额。分析这个问题需要确定进项留抵税额与税款损失之间是什么关系。\\n理清这二者之间的关系首先需要了解增值税的概念和其抵扣机制。增值税是以商品货物、服务等在流转过程中产生的增值额作为计税依据而征收的一种流转税。为避免重复征税在增值税中存在抵扣链条机制。\\n一般而言交易上游企业缴纳的税额交易下游企业可以对相应的税额进行抵扣。\\n对增值税一般纳税人来说其购进货物、服务等取得增值税专用发票发票上的税额是进项税额。\\n其出售货物、服务等向购买方开具增值税专用发票发票的税额是销项税额。\\n一般情况下销项税额减去进项税额的金额是应纳税额企业根据应纳税额按期申报纳税。\\n其次需要了解进项留抵税额的概念及产生原因。\\n在计算销项税额和进项税额的差额时有时会出现负数即当期进项税额大于当期销项税额。这个差额在当期未实现抵扣为进项留抵税额在以后纳税人有销项税额时再进行抵扣。\\n企业产生进项留抵税额的主要原因是其进项税额和销项税额时间上的不一致。\\n例如企业前期集中采购货物和服务投资大销项税率低于进项税率等。\\n从税款抵扣的角度看进项留抵税额只是购进的这部分进项税额参与到增值税应纳税额的计算过程中但是其对应的进项税额抵扣还未真正实现一般要等到其未来有相应的销项税额时才能真正实现进项税额抵扣。\\n可见进项留抵税额处于不确定状态能否抵扣受到很多因素影响例如企业经营中断没有销项税额这时进项留抵税额就无法实现抵扣。但如果企业按照税收政策规定申请进项留抵退税进项税额抵扣就随之实现。\\n最后需要了解税款损失的概念。\\n税款损失通常是指因虚开增值税专用发票导致国家税款被骗或者流失的金额。关于税款损失实务中有多种表述。\\n例如北京大学法学院教授陈兴良曾谈到虚开行为本身不会造成国家税款损失只有利用发票抵扣时才会造成国家税款损失。刘兵等编著的《虚开增值税专用发票案例司法观点和案例解析》一书中提到“给国家税款造成损失的数额实际上就是被骗取的国家税款在侦查终结以前无法追回的部分。”\\n赵清海与王家欣合著的《增值税专用发票虚开的判定与预防》一书中提到“司法实践中受票方用虚开的增值税专用发票予以抵扣的税款从而导致受票方应纳税额的减少是法院所认定的国家税款流失的金额。”\\n从这些表述可见税款损失应该是实际造成的损失不应包括不确定的部分——进项留抵税额进项留抵税额与税款损失之间不能直接画等号。\\n综上分析进项留抵税额只是使国家税款处于可能被抵扣的状态还没有真正造成国家税款流失一般情况下应将其从税款损失中剥离特殊条件下将其归并入税款损失。\\n例如当纳税人造假按照税收政策规定申请进项留抵税额退税后有关税款损失将会从危险状态转化成危害结果这时候要将有关进项留抵税额并入税款损失。\\n所以在虚开增值税专用发票案件中一般情况下如果以纳税人的进项税额作为税款损失的计算基数在对其进行行政处罚或刑事处罚时应把进项留抵税额从税款损失中剔除但纳税人申请进项留抵退税的除外。这样处理把处罚与危害结果相对应体现行政处罚法的过罚相当原则和刑法的罚当其罪原则。'}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds[\"train\"][0]\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ef372a1f-e82f-4f5d-8495-f21f06b35635",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['text']\n"
]
}
],
"source": [
"# 查看特征\n",
"column_names = list(ds[\"train\"].features)\n",
"print(column_names)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1fa637f5-3b23-4a33-b19b-4c90d1815c39",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "316489431b9e494eb8358a0d0048096f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running tokenizer on dataset (num_proc=10): 0%| | 0/100001 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 对数据集进行 tokenize\n",
"\n",
"def tokenize_function(examples):\n",
" # 使用预先加载的 tokenizer 进行分词\n",
" output = tokenizer([item for item in examples[\"text\"]])\n",
" return output\n",
"\n",
"# 批量处理\n",
"tokenized_datasets = ds.map(\n",
" tokenize_function,\n",
" batched=True,\n",
" num_proc=10,\n",
" remove_columns=column_names,\n",
" load_from_cache_file=True,\n",
" desc=\"Running tokenizer on dataset\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "ec30197e-fe7f-4f0d-903c-663146421f58",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['input_ids', 'attention_mask'],\n",
" num_rows: 100001\n",
" })\n",
"})"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenized_datasets"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "9ec5431b-e3cf-44e0-9260-479d984253e4",
"metadata": {},
"outputs": [],
"source": [
"# 预训练一般将文本拼接成固定长度的文本段\n",
"from itertools import chain\n",
"\n",
"# 这里我们取块长为 2048\n",
"block_size = 2048\n",
"\n",
"def group_texts(examples):\n",
" # 将文本段拼接起来\n",
" concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}\n",
" # 计算拼起来的整体长度\n",
" total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
" # 如果长度太长,进行分块\n",
" if total_length >= block_size:\n",
" total_length = (total_length // block_size) * block_size\n",
" # Split by chunks of max_len.\n",
" result = {\n",
" k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n",
" for k, t in concatenated_examples.items()\n",
" }\n",
" # print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')) \n",
" print(\"group texts input examples length%d after_group size%d\"%(len(examples['input_ids']),len(result[\"input_ids\"])))\n",
" result[\"labels\"] = result[\"input_ids\"].copy()\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "38428a53-6ba6-429f-8c4b-0985579e726b",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ae53ab8aaa0043418c2b7eb86f3d462b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Grouping texts in chunks of 2048 (num_proc=10): 0%| | 0/100001 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"group texts input examples length10001 after_group size2752\n",
"group texts input examples length10000 after_group size2817\n",
"group texts input examples length10000 after_group size2820\n",
"group texts input examples length10000 after_group size2817\n",
"group texts input examples length10000 after_group size2787\n",
"group texts input examples length10000 after_group size2797\n",
"group texts input examples length10000 after_group size2800\n",
"group texts input examples length10000 after_group size2835\n",
"group texts input examples length10000 after_group size2778\n",
"group texts input examples length10000 after_group size2825\n"
]
}
],
"source": [
"# 批量处理\n",
"lm_datasets = tokenized_datasets.map(\n",
" group_texts,\n",
" batched=True,\n",
" num_proc=10,\n",
" load_from_cache_file=True,\n",
" desc=f\"Grouping texts in chunks of {block_size}\",\n",
" batch_size = 40000,\n",
")\n",
"train_dataset = lm_datasets[\"train\"]"
]
},
{
"cell_type": "markdown",
"id": "9b0dd8c8-fb1f-4af9-8af4-21285ba389c0",
"metadata": {},
"source": [
"1.3 使用 Trainer"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e3e1a85e-fc28-4154-870e-f6a09f108059",
"metadata": {},
"outputs": [],
"source": [
"from transformers import TrainingArguments\n",
"# 配置训练参数\n",
"\n",
"training_args = TrainingArguments(\n",
" output_dir=\"autodl-tmp/output/pretrain\",\n",
" per_device_train_batch_size=4,\n",
" gradient_accumulation_steps=4,\n",
" logging_steps=10,\n",
" num_train_epochs=1,\n",
" save_steps=100, \n",
" learning_rate=1e-4,\n",
" save_on_each_node=True,\n",
" gradient_checkpointing=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "62a97e46-ff06-4278-b318-e3e4da1b93d7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda3/lib/python3.10/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n",
"################################################################################\n",
"WARNING!\n",
"The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n",
"future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n",
"to learn more and leave feedback.\n",
"################################################################################\n",
"\n",
" deprecation_warning()\n"
]
}
],
"source": [
"from transformers import Trainer, default_data_collator\n",
"from torchdata.datapipes.iter import IterableWrapper\n",
"\n",
"# 训练器\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset= IterableWrapper(train_dataset),\n",
" eval_dataset= None,\n",
" tokenizer=tokenizer,\n",
" # 默认为 MLM 的 collator使用 CLM 的 collater\n",
" data_collator=default_data_collator\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "a929b11a-99f5-45fc-9f9a-05c0163204c3",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"start train\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n",
"/root/miniconda3/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:600: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
" return fn(*args, **kwargs)\n",
"/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
" with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='101' max='1751' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [ 101/1751 29:31 < 8:12:11, 0.06 it/s, Epoch 0.06/1]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>10.987700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>9.160700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>30</td>\n",
" <td>8.352700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>40</td>\n",
" <td>8.159800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>50</td>\n",
" <td>8.042500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>60</td>\n",
" <td>8.014400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>7.986700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>80</td>\n",
" <td>7.951800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>90</td>\n",
" <td>7.875500</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "RuntimeError",
"evalue": "[enforce fail at inline_container.cc:603] . unexpected pos 6546708864 vs 6546708760",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/serialization.py:652\u001b[0m, in \u001b[0;36msave\u001b[0;34m(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization, _disable_byteorder_record)\u001b[0m\n\u001b[1;32m 651\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _open_zipfile_writer(f) \u001b[38;5;28;01mas\u001b[39;00m opened_zipfile:\n\u001b[0;32m--> 652\u001b[0m \u001b[43m_save\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopened_zipfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpickle_module\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpickle_protocol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_disable_byteorder_record\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n",
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/serialization.py:886\u001b[0m, in \u001b[0;36m_save\u001b[0;34m(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record)\u001b[0m\n\u001b[1;32m 885\u001b[0m num_bytes \u001b[38;5;241m=\u001b[39m storage\u001b[38;5;241m.\u001b[39mnbytes()\n\u001b[0;32m--> 886\u001b[0m \u001b[43mzip_file\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite_record\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstorage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_bytes\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mRuntimeError\u001b[0m: [enforce fail at inline_container.cc:778] . PytorchStreamWriter failed writing file data/401: file write failed",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[15], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstart train\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m train_result \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:1938\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1936\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1937\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1938\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1939\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1940\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1941\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1942\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1943\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:2356\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2353\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mepoch \u001b[38;5;241m=\u001b[39m epoch \u001b[38;5;241m+\u001b[39m (step \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m steps_skipped) \u001b[38;5;241m/\u001b[39m steps_in_epoch\n\u001b[1;32m 2354\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[0;32m-> 2356\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_maybe_log_save_evaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtr_loss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_norm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2357\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2358\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_substep_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n",
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:2807\u001b[0m, in \u001b[0;36mTrainer._maybe_log_save_evaluate\u001b[0;34m(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2804\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_evaluate(trial, ignore_keys_for_eval)\n\u001b[1;32m 2806\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol\u001b[38;5;241m.\u001b[39mshould_save:\n\u001b[0;32m-> 2807\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save_checkpoint\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetrics\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetrics\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2808\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_save(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n",
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:2890\u001b[0m, in \u001b[0;36mTrainer._save_checkpoint\u001b[0;34m(self, model, trial, metrics)\u001b[0m\n\u001b[1;32m 2886\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msave_model(output_dir, _internal_call\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 2888\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39msave_only_model:\n\u001b[1;32m 2889\u001b[0m \u001b[38;5;66;03m# Save optimizer and scheduler\u001b[39;00m\n\u001b[0;32m-> 2890\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save_optimizer_and_scheduler\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput_dir\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2891\u001b[0m \u001b[38;5;66;03m# Save RNG state\u001b[39;00m\n\u001b[1;32m 2892\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_save_rng_state(output_dir)\n",
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:3006\u001b[0m, in \u001b[0;36mTrainer._save_optimizer_and_scheduler\u001b[0;34m(self, output_dir)\u001b[0m\n\u001b[1;32m 3001\u001b[0m save_fsdp_optimizer(\n\u001b[1;32m 3002\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfsdp_plugin, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptimizer, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, output_dir\n\u001b[1;32m 3003\u001b[0m )\n\u001b[1;32m 3004\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mshould_save:\n\u001b[1;32m 3005\u001b[0m \u001b[38;5;66;03m# deepspeed.save_checkpoint above saves model/optim/sched\u001b[39;00m\n\u001b[0;32m-> 3006\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstate_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput_dir\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mOPTIMIZER_NAME\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3008\u001b[0m \u001b[38;5;66;03m# Save SCHEDULER & SCALER\u001b[39;00m\n\u001b[1;32m 3009\u001b[0m is_deepspeed_custom_scheduler \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_deepspeed_enabled \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[1;32m 3010\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlr_scheduler, DeepSpeedSchedulerWrapper\n\u001b[1;32m 3011\u001b[0m )\n",
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/serialization.py:651\u001b[0m, in \u001b[0;36msave\u001b[0;34m(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization, _disable_byteorder_record)\u001b[0m\n\u001b[1;32m 648\u001b[0m _check_save_filelike(f)\n\u001b[1;32m 650\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _use_new_zipfile_serialization:\n\u001b[0;32m--> 651\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _open_zipfile_writer(f) \u001b[38;5;28;01mas\u001b[39;00m opened_zipfile:\n\u001b[1;32m 652\u001b[0m _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n",
"File \u001b[0;32m~/miniconda3/lib/python3.10/site-packages/torch/serialization.py:499\u001b[0m, in \u001b[0;36m_open_zipfile_writer_file.__exit__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 498\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__exit__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 499\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfile_like\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite_end_of_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 500\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfile_stream \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 501\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfile_stream\u001b[38;5;241m.\u001b[39mclose()\n",
"\u001b[0;31mRuntimeError\u001b[0m: [enforce fail at inline_container.cc:603] . unexpected pos 6546708864 vs 6546708760"
]
}
],
"source": [
"print('start train')\n",
"train_result = trainer.train()"
]
},
{
"cell_type": "markdown",
"id": "a1ed2cd9-7169-4376-a26c-053918074761",
"metadata": {},
"source": [
"# 二、模型 SFT"
]
},
{
"cell_type": "markdown",
"id": "1bb6e02b-c04c-45a4-b36c-904f9fedf61e",
"metadata": {},
"source": [
"2.1 处理指令数据"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0d7cd012-fa2d-4c21-b6a5-c3830d12f59b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'conversations': [{'from': 'human',\n",
" 'value': '针对健身房的新手,设计一套适合他们的健身器械使用指南,包括安全应用、正确姿势等方面。'},\n",
" {'from': 'assistant',\n",
" 'value': '健身器械使用指南\\n1. 开始前,请先进行热身运动。这会帮助你的身体适应运动,并减少受伤的风险。\\n2. 在使用健身器械前,确保你已经了解了其使用方法。请阅读说明书或咨询教练以获得正确的使用技巧。\\n3. 谨防过度使用或过度挑战你的身体。 如果你觉得有些动作太难或太重,请添加锻炼计划,以逐步提高动作难度。\\n4. 使用合适的装备。 确保你拥有合适的运动鞋和舒适的运动服。 不要在裸露的脚或短裤上进行重量训练。\\n5. 在健身器械上使用安全装置。 这些通常用于保护你的身体免受不当操作造成的损伤。 例如,重量训练中,你需要使用杠铃和负重时,一定要使用卡子来防止重量滑落。\\n6. 注意正确的姿势。 如果你的姿势是错误的,那么你的身体很容易被伤害到,你也可能无法获得最佳的锻炼效果。 至关重要的是,保持直立的身体,保持头部和颈部的稳定,并使用合适的重量。\\n7. 保持合理的呼吸方式。 无论何时进行训练,都必须保持正常呼吸。 当你需要用力时,呼气; 当你放松时,吸气。\\n8. 安全存放器械。 在使用健身器械后,你需要把它们归还给适当的位置,以便其他人可以使用它们。\\n总之健身器械的正确使用是关键之一如果不健康和不安全它们将无法帮助您达到您所需的健康成果。 选择适当的训练计划,并为训练提供足够的时间,以备逐渐适应新方法。 对于任何问题,请向教练咨询。'}],\n",
" 'id': '66182880'}"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import json\n",
"\n",
"with open(\"autodl-tmp/dataset/sft_data/BelleGroup/train_3.5M_CN.json\") as f:\n",
" lst = [json.loads(line) for line in f.readlines()]\n",
"\n",
"lst[0]"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2fc8c599-89e9-4c35-a011-d2e52a1a4d9c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Qwen2TokenizerFast(name_or_path='autodl-tmp/qwen-1.5b', vocab_size=151643, model_max_length=131072, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<|endoftext|>', 'pad_token': '<|endoftext|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>', '<|object_ref_start|>', '<|object_ref_end|>', '<|box_start|>', '<|box_end|>', '<|quad_start|>', '<|quad_end|>', '<|vision_start|>', '<|vision_end|>', '<|vision_pad|>', '<|image_pad|>', '<|video_pad|>']}, clean_up_tokenization_spaces=False), added_tokens_decoder={\n",
"\t151643: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151644: AddedToken(\"<|im_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151645: AddedToken(\"<|im_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151646: AddedToken(\"<|object_ref_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151647: AddedToken(\"<|object_ref_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151648: AddedToken(\"<|box_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151649: AddedToken(\"<|box_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151650: AddedToken(\"<|quad_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151651: AddedToken(\"<|quad_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151652: AddedToken(\"<|vision_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151653: AddedToken(\"<|vision_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151654: AddedToken(\"<|vision_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151655: AddedToken(\"<|image_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151656: AddedToken(\"<|video_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t151657: AddedToken(\"<tool_call>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t151658: AddedToken(\"</tool_call>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t151659: AddedToken(\"<|fim_prefix|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t151660: AddedToken(\"<|fim_middle|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t151661: AddedToken(\"<|fim_suffix|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t151662: AddedToken(\"<|fim_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t151663: AddedToken(\"<|repo_name|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"\t151664: AddedToken(\"<|file_sep|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
"}"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 加载一个预训练好的 tokenizer\n",
"from transformers import AutoTokenizer\n",
"\n",
"model_path = \"autodl-tmp/qwen-1.5b\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
"tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "46730b29-41c0-4295-81f2-913d069b4669",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from tqdm import tqdm\n",
"\n",
"# 指令文本处理\n",
"# 参考https://github.com/QwenLM/Qwen/blob/main/finetune.py\n",
"def preprocess(sources, tokenizer, max_len, system_message: str = \"You are a helpful assistant.\"):\n",
" # prompt 模板\n",
" roles = {\"human\": \"<|im_start|>human\", \"assistant\": \"<|im_start|>assistant\"}\n",
"\n",
" # 不同的 tokenizer 需要特别定义\n",
" # BOS\n",
" im_start = tokenizer(\"<|im_start|>\").input_ids\n",
" # EOS\n",
" im_end = tokenizer(\"<|im_end|>\").input_ids\n",
" # PAD\n",
" IGNORE_TOKEN_ID = tokenizer.pad_token_id\n",
" # 换行符\n",
" nl_tokens = tokenizer('\\n').input_ids\n",
" # 角色标识符\n",
" _system = tokenizer('system').input_ids + nl_tokens\n",
" _user = tokenizer('human').input_ids + nl_tokens\n",
" _assistant = tokenizer('assistant').input_ids + nl_tokens\n",
"\n",
" # 拼接多轮对话\n",
" input_ids, targets = [], []\n",
" for i in tqdm(range(len(sources))):\n",
" source = sources[i]\n",
" # 从 user 开始\n",
" if source[0][\"from\"] != \"human\":\n",
" source = source[1:]\n",
" # 分别是输入和输出\n",
" input_id, target = [], []\n",
" # system: 【BOS】system\\nYou are a helpful assistant.【EOS】\\n\n",
" system = im_start + _system + tokenizer(system_message).input_ids + im_end + nl_tokens\n",
" input_id += system\n",
" # system 不需要拟合\n",
" target += im_start + [IGNORE_TOKEN_ID] * (len(system)-3) + im_end + nl_tokens\n",
" assert len(input_id) == len(target)\n",
" # 依次拼接\n",
" for j, sentence in enumerate(source):\n",
" role = roles[sentence[\"from\"]]\n",
" # user<|im_start|>human\\ninstruction【EOS】\\n\n",
" # assistant<|im_start|>assistant\\nresponse【EOS】\\n\n",
" _input_id = tokenizer(role).input_ids + nl_tokens + \\\n",
" tokenizer(sentence[\"value\"]).input_ids + im_end + nl_tokens\n",
" input_id += _input_id\n",
" if role == '<|im_start|>human':\n",
" # user 不需要拟合\n",
" _target = im_start + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + im_end + nl_tokens\n",
" elif role == '<|im_start|>assistant':\n",
" # assistant 需要拟合\n",
" _target = im_start + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \\\n",
" _input_id[len(tokenizer(role).input_ids)+1:-2] + im_end + nl_tokens\n",
" else:\n",
" print(role)\n",
" raise NotImplementedError\n",
" target += _target\n",
" assert len(input_id) == len(target)\n",
" # 最后进行 PAD\n",
" input_id += [tokenizer.pad_token_id] * (max_len - len(input_id))\n",
" target += [IGNORE_TOKEN_ID] * (max_len - len(target))\n",
" input_ids.append(input_id[:max_len])\n",
" targets.append(target[:max_len])\n",
" # print(input_ids)\n",
" input_ids = torch.tensor(input_ids)\n",
" targets = torch.tensor(targets)\n",
"\n",
" return dict(\n",
" input_ids=input_ids,\n",
" labels=targets,\n",
" attention_mask=input_ids.ne(tokenizer.pad_token_id),\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "7b3576cb-04d7-448a-9bd1-07cb7b344e6d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': tensor([[151644, 8948, 198, ..., 151643, 151643, 151643],\n",
" [151644, 8948, 198, ..., 151643, 151643, 151643]]),\n",
" 'labels': tensor([[151644, 151643, 151643, ..., 151643, 151643, 151643],\n",
" [151644, 151643, 151643, ..., 151643, 151643, 151643]]),\n",
" 'attention_mask': tensor([[ True, True, True, ..., False, False, False],\n",
" [ True, True, True, ..., False, False, False]])}"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 测试一下\n",
"preprocess([lst[0][\"conversations\"],lst[1][\"conversations\"]], tokenizer, 1024)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "63e01dcf-4de4-4470-97dd-3317ef1aa00b",
"metadata": {},
"outputs": [],
"source": [
"# 自定义一个 Dataset\n",
"from torch.utils.data import Dataset\n",
"from typing import Dict\n",
"\n",
"class SupervisedDataset(Dataset):\n",
"\n",
" def __init__(self, raw_data, tokenizer, max_len: int):\n",
" super(SupervisedDataset, self).__init__()\n",
" # 加载并预处理数据\n",
" sources = [example[\"conversations\"] for example in raw_data[:10000]]\n",
" data_dict = preprocess(sources, tokenizer, max_len)\n",
"\n",
" self.input_ids = data_dict[\"input_ids\"]\n",
" self.labels = data_dict[\"labels\"]\n",
" self.attention_mask = data_dict[\"attention_mask\"]\n",
"\n",
" def __len__(self):\n",
" return len(self.input_ids)\n",
"\n",
" def __getitem__(self, i) -> Dict[str, torch.Tensor]:\n",
" return dict(\n",
" input_ids=self.input_ids[i],\n",
" labels=self.labels[i],\n",
" attention_mask=self.attention_mask[i],\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "934316d3-098f-4889-9cb0-d234a630b194",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 10000/10000 [00:08<00:00, 1235.98it/s]\n"
]
}
],
"source": [
"train_ds = SupervisedDataset(lst, tokenizer=tokenizer, max_len=2048)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}