update ch05

This commit is contained in:
KMnO4-zx
2025-02-26 11:24:19 +08:00
parent 2edfb76f7a
commit ca3e727e1c
21 changed files with 13737 additions and 1044 deletions

View File

@@ -0,0 +1,138 @@
import json
import random
import re
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from sklearn.model_selection import train_test_split
import os
class PretrainDataset(Dataset):
def __init__(self, df, tokenizer, max_length=512):
super().__init__()
self.df = df
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
def __len__(self):
return self.df.shape[0]
def __getitem__(self, index: int):
#
sample = self.df.iloc[index]
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
text_len = len(input_id)
# 没满最大长度的剩余部分
padding_len = self.max_length - text_len
input_id = input_id + [self.padding] * padding_len
# 0表示不计算损失
loss_mask = [1] * text_len + [0] * padding_len
input_id = np.array(input_id)
X = np.array(input_id[:-1]).astype(np.int64)
Y = np.array(input_id[1:]).astype(np.int64)
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
class SkyWorkPretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
self.data_path = data_path
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
with open(data_path, 'r', encoding='utf-8') as f:
self.data = f.readlines()
def __len__(self):
return len(self.data)
def __getitem__(self, index: int):
sample = json.loads(self.data[index])
text = f"{self.tokenizer.bos_token}{sample['text']}"
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
text_len = len(input_id)
# 没满最大长度的剩余部分
padding_len = self.max_length - text_len
input_id = input_id + [self.padding] * padding_len
# 0表示不计算损失
loss_mask = [1] * text_len + [0] * padding_len
input_id = np.array(input_id)
X = np.array(input_id[:-1]).astype(np.int64)
Y = np.array(input_id[1:]).astype(np.int64)
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
class SFTDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
self.data_path = data_path
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
with open(data_path, 'r', encoding='utf-8') as f:
self.data = f.readlines()
def __len__(self):
return len(self.data)
def generate_loss_mask(self, input_ids):
# 生成 loss mask, 0 表示不计算损失, 1 表示计算损失
mask = [0] * len(input_ids)
a_sequence = [3, 1074, 537, 500, 203] # <|im_start|>assistant\n
a_length = len(a_sequence)
n = len(input_ids)
i = 0
while i <= n - a_length:
# 检查当前位置是否匹配目标子序列
match = True
for k in range(a_length):
if input_ids[i + k] != a_sequence[k]:
match = False
break
if match:
# 从子序列结束的位置开始查找第一个4
j = None
for idx in range(i + a_length, n):
if input_ids[idx] == 4:
j = idx
break
if j is not None:
start = i + a_length
end = j # 结束位置设为j包含4
# 标记区间为1包括start到end
if start <= end:
for pos in range(start, end + 1):
if pos < len(mask):
mask[pos] = 1
# 跳过当前子序列,避免重叠匹配
i += a_length
else:
i += 1
return mask
def __getitem__(self, index: int):
sample = json.loads(self.data[index])
text = self.tokenizer.apply_chat_template(sample, tokenize=False, add_generation_prompt=False)
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
text_len = len(input_id)
# 没满最大长度的剩余部分
padding_len = self.max_length - text_len
input_id = input_id + [self.padding] * padding_len
# 0表示不计算损失
loss_mask = self.generate_loss_mask(input_id)
input_id = np.array(input_id)
X = np.array(input_id[:-1]).astype(np.int64)
Y = np.array(input_id[1:]).astype(np.int64)
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)