update ch05

This commit is contained in:
KMnO4-zx
2025-02-26 20:31:51 +08:00
parent ca3e727e1c
commit 3512f55993
9 changed files with 699 additions and 405 deletions

View File

@@ -8,39 +8,8 @@ from torch.utils.data import Dataset, DataLoader
import torch
from sklearn.model_selection import train_test_split
import os
class PretrainDataset(Dataset):
def __init__(self, df, tokenizer, max_length=512):
super().__init__()
self.df = df
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
def __len__(self):
return self.df.shape[0]
def __getitem__(self, index: int):
#
sample = self.df.iloc[index]
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
text_len = len(input_id)
# 没满最大长度的剩余部分
padding_len = self.max_length - text_len
input_id = input_id + [self.padding] * padding_len
# 0表示不计算损失
loss_mask = [1] * text_len + [0] * padding_len
input_id = np.array(input_id)
X = np.array(input_id[:-1]).astype(np.int64)
Y = np.array(input_id[1:]).astype(np.int64)
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
class SkyWorkPretrainDataset(Dataset):
class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
self.data_path = data_path