update ch05
This commit is contained in:
138
docs/chapter5/code/dataset.py
Normal file
138
docs/chapter5/code/dataset.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torch
|
||||
from sklearn.model_selection import train_test_split
|
||||
import os
|
||||
|
||||
|
||||
class PretrainDataset(Dataset):
|
||||
def __init__(self, df, tokenizer, max_length=512):
|
||||
super().__init__()
|
||||
self.df = df
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.padding = 0
|
||||
|
||||
def __len__(self):
|
||||
return self.df.shape[0]
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
#
|
||||
sample = self.df.iloc[index]
|
||||
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
|
||||
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
|
||||
text_len = len(input_id)
|
||||
# 没满最大长度的剩余部分
|
||||
padding_len = self.max_length - text_len
|
||||
input_id = input_id + [self.padding] * padding_len
|
||||
# 0表示不计算损失
|
||||
loss_mask = [1] * text_len + [0] * padding_len
|
||||
|
||||
input_id = np.array(input_id)
|
||||
X = np.array(input_id[:-1]).astype(np.int64)
|
||||
Y = np.array(input_id[1:]).astype(np.int64)
|
||||
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
|
||||
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
|
||||
|
||||
|
||||
class SkyWorkPretrainDataset(Dataset):
|
||||
def __init__(self, data_path, tokenizer, max_length=512):
|
||||
super().__init__()
|
||||
self.data_path = data_path
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.padding = 0
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
self.data = f.readlines()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
sample = json.loads(self.data[index])
|
||||
text = f"{self.tokenizer.bos_token}{sample['text']}"
|
||||
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
|
||||
text_len = len(input_id)
|
||||
# 没满最大长度的剩余部分
|
||||
padding_len = self.max_length - text_len
|
||||
input_id = input_id + [self.padding] * padding_len
|
||||
# 0表示不计算损失
|
||||
loss_mask = [1] * text_len + [0] * padding_len
|
||||
|
||||
input_id = np.array(input_id)
|
||||
X = np.array(input_id[:-1]).astype(np.int64)
|
||||
Y = np.array(input_id[1:]).astype(np.int64)
|
||||
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
|
||||
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
|
||||
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
def __init__(self, data_path, tokenizer, max_length=512):
|
||||
super().__init__()
|
||||
self.data_path = data_path
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.padding = 0
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
self.data = f.readlines()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def generate_loss_mask(self, input_ids):
|
||||
# 生成 loss mask, 0 表示不计算损失, 1 表示计算损失
|
||||
mask = [0] * len(input_ids)
|
||||
a_sequence = [3, 1074, 537, 500, 203] # <|im_start|>assistant\n
|
||||
a_length = len(a_sequence)
|
||||
n = len(input_ids)
|
||||
i = 0
|
||||
|
||||
while i <= n - a_length:
|
||||
# 检查当前位置是否匹配目标子序列
|
||||
match = True
|
||||
for k in range(a_length):
|
||||
if input_ids[i + k] != a_sequence[k]:
|
||||
match = False
|
||||
break
|
||||
if match:
|
||||
# 从子序列结束的位置开始查找第一个4
|
||||
j = None
|
||||
for idx in range(i + a_length, n):
|
||||
if input_ids[idx] == 4:
|
||||
j = idx
|
||||
break
|
||||
if j is not None:
|
||||
start = i + a_length
|
||||
end = j # 结束位置设为j(包含4)
|
||||
# 标记区间为1(包括start到end)
|
||||
if start <= end:
|
||||
for pos in range(start, end + 1):
|
||||
if pos < len(mask):
|
||||
mask[pos] = 1
|
||||
# 跳过当前子序列,避免重叠匹配
|
||||
i += a_length
|
||||
else:
|
||||
i += 1
|
||||
return mask
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
sample = json.loads(self.data[index])
|
||||
text = self.tokenizer.apply_chat_template(sample, tokenize=False, add_generation_prompt=False)
|
||||
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
|
||||
text_len = len(input_id)
|
||||
# 没满最大长度的剩余部分
|
||||
padding_len = self.max_length - text_len
|
||||
input_id = input_id + [self.padding] * padding_len
|
||||
# 0表示不计算损失
|
||||
loss_mask = self.generate_loss_mask(input_id)
|
||||
|
||||
input_id = np.array(input_id)
|
||||
X = np.array(input_id[:-1]).astype(np.int64)
|
||||
Y = np.array(input_id[1:]).astype(np.int64)
|
||||
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
|
||||
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
|
||||
Reference in New Issue
Block a user