update ch05

This commit is contained in:
KMnO4-zx
2025-02-26 20:31:51 +08:00
parent ca3e727e1c
commit 3512f55993
9 changed files with 699 additions and 405 deletions

View File

@@ -13,7 +13,7 @@ from contextlib import nullcontext
from transformers import AutoTokenizer
from k_model import ModelConfig, Transformer
from dataset import PretrainDataset, SkyWorkPretrainDataset
from dataset import PretrainDataset
import swanlab
@@ -131,7 +131,7 @@ if __name__ == "__main__":
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
parser.add_argument("--use_swanlab", type=bool, default=True, help="Use Weights & Biases")
parser.add_argument("--num_workers", type=int, default=8, help="Number of workers for data loading")
parser.add_argument("--data_path", type=str, default="/home/user/szx/dataset/seq-monkey/seq_monkey_datawhale.jsonl", help="Path to training data")
parser.add_argument("--data_path", type=str, default="", help="Path to training data")
parser.add_argument("--accumulation_steps", type=int, default=8, help="Gradient accumulation steps")
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold")
parser.add_argument("--warmup_iters", type=int, default=0, help="Number of warmup iterations")
@@ -152,7 +152,7 @@ if __name__ == "__main__":
args.device = "cpu"
if args.use_swanlab:
swanlab.login(api_key='BIYVGq2rfWmD9sFMCehUG')
swanlab.login(api_key='your key')
run = swanlab.init(
project="Tiny-LLM",
experiment_name="Pretrain-215M",
@@ -174,7 +174,7 @@ if __name__ == "__main__":
model, tokenizer = init_model()
train_ds = SkyWorkPretrainDataset(args.data_path, tokenizer, max_length=max_seq_len)
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=max_seq_len)
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,