import os import platform import argparse import time import warnings import math import pandas as pd import torch from torch import optim from torch.utils.data import DataLoader from contextlib import nullcontext from transformers import AutoTokenizer from k_model import ModelConfig, Transformer from dataset import PretrainDataset, SkyWorkPretrainDataset import swanlab warnings.filterwarnings('ignore') def Logger(content): print(content) def get_lr(it, all): warmup_iters = args.warmup_iters lr_decay_iters = all min_lr = args.learning_rate / 10 if it < warmup_iters: return args.learning_rate * it / warmup_iters if it > lr_decay_iters: return min_lr decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return min_lr + coeff * (args.learning_rate - min_lr) def train_epoch(epoch): start_time = time.time() for step, (X, Y, loss_mask) in enumerate(train_loader): X = X.to(args.device) Y = Y.to(args.device) loss_mask = loss_mask.to(args.device) lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch) for param_group in optimizer.param_groups: param_group['lr'] = lr with ctx: out = model(X, Y) loss = out.last_loss / args.accumulation_steps loss_mask = loss_mask.view(-1) loss = torch.sum(loss * loss_mask) / loss_mask.sum() scaler.scale(loss).backward() if (step + 1) % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) if step % args.log_interval == 0: spend_time = time.time() - start_time Logger( 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format( epoch + 1, args.epochs, step, iter_per_epoch, loss.item() * args.accumulation_steps, optimizer.param_groups[-1]['lr'], spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) if args.use_swanlab: swanlab.log({ "loss": loss.item() * args.accumulation_steps, "lr": optimizer.param_groups[-1]['lr'] }) if (step + 1) % args.save_interval == 0: model.eval() ckp = f'{args.save_dir}/pretrain_{lm_config.dim}_{lm_config.n_layers}_{lm_config.vocab_size}.pth' # 处理多卡保存 state_dict = model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict() torch.save(state_dict, ckp) model.train() if (step + 1) % 20000 == 0: model.eval() ckp = f'{args.save_dir}/pretrain_{lm_config.dim}_{lm_config.n_layers}_{lm_config.vocab_size}_step{step+1}.pth' state_dict = model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict() torch.save(state_dict, ckp) model.train() def init_model(): def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) tokenizer = AutoTokenizer.from_pretrained('./tokenizer_k/') model = Transformer(lm_config) # 多卡初始化 num_gpus = torch.cuda.device_count() if num_gpus > 1: Logger(f"Using {num_gpus} GPUs with DataParallel!") model = torch.nn.DataParallel(model) model = model.to(args.device) Logger(f'LLM总参数量:{count_parameters(model) / 1e6:.3f} 百万') return model, tokenizer if __name__ == "__main__": parser = argparse.ArgumentParser(description="Tiny-LLM Pretraining") parser.add_argument("--out_dir", type=str, default="base_monkey_215M", help="Output directory") parser.add_argument("--epochs", type=int, default=1, help="Number of epochs") parser.add_argument("--batch_size", type=int, default=64, help="Batch size") parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use") parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type") parser.add_argument("--use_swanlab", type=bool, default=True, help="Use Weights & Biases") parser.add_argument("--num_workers", type=int, default=8, help="Number of workers for data loading") parser.add_argument("--data_path", type=str, default="/home/user/szx/dataset/seq-monkey/seq_monkey_datawhale.jsonl", help="Path to training data") parser.add_argument("--accumulation_steps", type=int, default=8, help="Gradient accumulation steps") parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold") parser.add_argument("--warmup_iters", type=int, default=0, help="Number of warmup iterations") parser.add_argument("--log_interval", type=int, default=100, help="Logging interval") parser.add_argument("--save_interval", type=int, default=1000, help="Model saving interval") # 添加多卡参数 parser.add_argument("--gpus", type=str, default='0,1', help="Comma-separated GPU IDs (e.g. '0,1,2')") args = parser.parse_args() # 设置可见GPU if args.gpus is not None: os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus # 自动设置主设备为第一个GPU if torch.cuda.is_available(): args.device = "cuda:0" else: args.device = "cpu" if args.use_swanlab: swanlab.login(api_key='BIYVGq2rfWmD9sFMCehUG') run = swanlab.init( project="Tiny-LLM", experiment_name="Pretrain-215M", config=args, ) lm_config = ModelConfig( dim=1024, n_layers=18, ) max_seq_len = lm_config.max_seq_len args.save_dir = os.path.join(args.out_dir) os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.out_dir, exist_ok=True) torch.manual_seed(42) device_type = "cuda" if "cuda" in args.device else "cpu" ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() model, tokenizer = init_model() train_ds = SkyWorkPretrainDataset(args.data_path, tokenizer, max_length=max_seq_len) train_loader = DataLoader( train_ds, batch_size=args.batch_size, pin_memory=True, drop_last=False, shuffle=True, num_workers=args.num_workers ) scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) iter_per_epoch = len(train_loader) for epoch in range(args.epochs): train_epoch(epoch)