Files
happy-llm/docs/chapter5/code/dataset.py
KMnO4-zx 30f3f01619 refactor(dataset): 使用tokenizer动态生成a_sequence并替换硬编码值
fix(ddp_sft_full): 修正参数默认值和优化器类型
docs(ddp_pretrain): 添加详细注释和优化参数描述
2025-06-21 11:39:40 +08:00

106 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import random
import re
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import os
class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
self.data_path = data_path
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
with open(data_path, 'r', encoding='utf-8') as f:
self.data = f.readlines()
def __len__(self):
return len(self.data)
def __getitem__(self, index: int):
sample = json.loads(self.data[index])
text = f"{self.tokenizer.bos_token}{sample['text']}"
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
text_len = len(input_id)
# 没满最大长度的剩余部分
padding_len = self.max_length - text_len
input_id = input_id + [self.padding] * padding_len
# 0表示不计算损失
loss_mask = [1] * text_len + [0] * padding_len
input_id = np.array(input_id)
X = np.array(input_id[:-1]).astype(np.int64)
Y = np.array(input_id[1:]).astype(np.int64)
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
class SFTDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
self.data_path = data_path
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
with open(data_path, 'r', encoding='utf-8') as f:
self.data = f.readlines()
def __len__(self):
return len(self.data)
def generate_loss_mask(self, input_ids):
# 生成 loss mask, 0 表示不计算损失, 1 表示计算损失
mask = [0] * len(input_ids)
a_sequence = self.tokenizer("<|im_start|>assistant\n")['input_ids'] # <|im_start|>assistant\n
a_length = len(a_sequence)
n = len(input_ids)
i = 0
while i <= n - a_length:
# 检查当前位置是否匹配目标子序列
match = True
for k in range(a_length):
if input_ids[i + k] != a_sequence[k]:
match = False
break
if match:
# 从子序列结束的位置开始查找第一个 4 (eos_token_id)
j = None
for idx in range(i + a_length, n):
if input_ids[idx] == self.tokenizer.eos_token_id:
j = idx
break
if j is not None:
start = i + a_length
end = j # 结束位置设为j包含4
# 标记区间为1包括start到end
if start <= end:
for pos in range(start, end + 1):
if pos < len(mask):
mask[pos] = 1
# 跳过当前子序列,避免重叠匹配
i += a_length
else:
i += 1
return mask
def __getitem__(self, index: int):
sample = json.loads(self.data[index])
text = self.tokenizer.apply_chat_template(sample, tokenize=False, add_generation_prompt=False)
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
text_len = len(input_id)
# 没满最大长度的剩余部分
padding_len = self.max_length - text_len
input_id = input_id + [self.padding] * padding_len
# 0表示不计算损失
loss_mask = self.generate_loss_mask(input_id)
input_id = np.array(input_id)
X = np.array(input_id[:-1]).astype(np.int64)
Y = np.array(input_id[1:]).astype(np.int64)
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)