107 lines
3.9 KiB
Python
107 lines
3.9 KiB
Python
import json
|
||
import random
|
||
import re
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
from torch.utils.data import Dataset, DataLoader
|
||
import torch
|
||
from sklearn.model_selection import train_test_split
|
||
import os
|
||
|
||
class PretrainDataset(Dataset):
|
||
def __init__(self, data_path, tokenizer, max_length=512):
|
||
super().__init__()
|
||
self.data_path = data_path
|
||
self.tokenizer = tokenizer
|
||
self.max_length = max_length
|
||
self.padding = 0
|
||
with open(data_path, 'r', encoding='utf-8') as f:
|
||
self.data = f.readlines()
|
||
|
||
def __len__(self):
|
||
return len(self.data)
|
||
|
||
def __getitem__(self, index: int):
|
||
sample = json.loads(self.data[index])
|
||
text = f"{self.tokenizer.bos_token}{sample['text']}"
|
||
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
|
||
text_len = len(input_id)
|
||
# 没满最大长度的剩余部分
|
||
padding_len = self.max_length - text_len
|
||
input_id = input_id + [self.padding] * padding_len
|
||
# 0表示不计算损失
|
||
loss_mask = [1] * text_len + [0] * padding_len
|
||
|
||
input_id = np.array(input_id)
|
||
X = np.array(input_id[:-1]).astype(np.int64)
|
||
Y = np.array(input_id[1:]).astype(np.int64)
|
||
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
|
||
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
|
||
|
||
|
||
class SFTDataset(Dataset):
|
||
def __init__(self, data_path, tokenizer, max_length=512):
|
||
super().__init__()
|
||
self.data_path = data_path
|
||
self.tokenizer = tokenizer
|
||
self.max_length = max_length
|
||
self.padding = 0
|
||
with open(data_path, 'r', encoding='utf-8') as f:
|
||
self.data = f.readlines()
|
||
|
||
def __len__(self):
|
||
return len(self.data)
|
||
|
||
def generate_loss_mask(self, input_ids):
|
||
# 生成 loss mask, 0 表示不计算损失, 1 表示计算损失
|
||
mask = [0] * len(input_ids)
|
||
a_sequence = [3, 1074, 537, 500, 203] # <|im_start|>assistant\n
|
||
a_length = len(a_sequence)
|
||
n = len(input_ids)
|
||
i = 0
|
||
|
||
while i <= n - a_length:
|
||
# 检查当前位置是否匹配目标子序列
|
||
match = True
|
||
for k in range(a_length):
|
||
if input_ids[i + k] != a_sequence[k]:
|
||
match = False
|
||
break
|
||
if match:
|
||
# 从子序列结束的位置开始查找第一个4
|
||
j = None
|
||
for idx in range(i + a_length, n):
|
||
if input_ids[idx] == 4:
|
||
j = idx
|
||
break
|
||
if j is not None:
|
||
start = i + a_length
|
||
end = j # 结束位置设为j(包含4)
|
||
# 标记区间为1(包括start到end)
|
||
if start <= end:
|
||
for pos in range(start, end + 1):
|
||
if pos < len(mask):
|
||
mask[pos] = 1
|
||
# 跳过当前子序列,避免重叠匹配
|
||
i += a_length
|
||
else:
|
||
i += 1
|
||
return mask
|
||
|
||
def __getitem__(self, index: int):
|
||
sample = json.loads(self.data[index])
|
||
text = self.tokenizer.apply_chat_template(sample, tokenize=False, add_generation_prompt=False)
|
||
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
|
||
text_len = len(input_id)
|
||
# 没满最大长度的剩余部分
|
||
padding_len = self.max_length - text_len
|
||
input_id = input_id + [self.padding] * padding_len
|
||
# 0表示不计算损失
|
||
loss_mask = self.generate_loss_mask(input_id)
|
||
|
||
input_id = np.array(input_id)
|
||
X = np.array(input_id[:-1]).astype(np.int64)
|
||
Y = np.array(input_id[1:]).astype(np.int64)
|
||
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
|
||
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask) |