update ch05
This commit is contained in:
@@ -13,7 +13,7 @@ from contextlib import nullcontext
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from k_model import ModelConfig, Transformer
|
||||
from dataset import PretrainDataset, SkyWorkPretrainDataset
|
||||
from dataset import PretrainDataset
|
||||
|
||||
import swanlab
|
||||
|
||||
@@ -131,7 +131,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
|
||||
parser.add_argument("--use_swanlab", type=bool, default=True, help="Use Weights & Biases")
|
||||
parser.add_argument("--num_workers", type=int, default=8, help="Number of workers for data loading")
|
||||
parser.add_argument("--data_path", type=str, default="/home/user/szx/dataset/seq-monkey/seq_monkey_datawhale.jsonl", help="Path to training data")
|
||||
parser.add_argument("--data_path", type=str, default="", help="Path to training data")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8, help="Gradient accumulation steps")
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold")
|
||||
parser.add_argument("--warmup_iters", type=int, default=0, help="Number of warmup iterations")
|
||||
@@ -152,7 +152,7 @@ if __name__ == "__main__":
|
||||
args.device = "cpu"
|
||||
|
||||
if args.use_swanlab:
|
||||
swanlab.login(api_key='BIYVGq2rfWmD9sFMCehUG')
|
||||
swanlab.login(api_key='your key')
|
||||
run = swanlab.init(
|
||||
project="Tiny-LLM",
|
||||
experiment_name="Pretrain-215M",
|
||||
@@ -174,7 +174,7 @@ if __name__ == "__main__":
|
||||
|
||||
model, tokenizer = init_model()
|
||||
|
||||
train_ds = SkyWorkPretrainDataset(args.data_path, tokenizer, max_length=max_seq_len)
|
||||
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=max_seq_len)
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=args.batch_size,
|
||||
|
||||
Reference in New Issue
Block a user