refactor: Optimize LLaMA2Model's configure_optimizers method
This commit is contained in:
@@ -420,7 +420,7 @@
|
||||
],
|
||||
"source": [
|
||||
"# 创建MLP实例\n",
|
||||
"mlp = LLaMAMLP(args.dim, args.hidden_dim, args.multiple_of, args.dropout)\n",
|
||||
"mlp = LLaMA2MLP(args.dim, args.hidden_dim, args.multiple_of, args.dropout)\n",
|
||||
"# 随机生成数据\n",
|
||||
"x = torch.randn(1, 50, 288)\n",
|
||||
"# 运行MLP模型\n",
|
||||
@@ -528,7 +528,7 @@
|
||||
" # Decoder层\n",
|
||||
" self.layers = torch.nn.ModuleList()\n",
|
||||
" for layer_id in range(args.n_layers):\n",
|
||||
" self.layers.append(LLaMA2DecoderLayer(layer_id, args))\n",
|
||||
" self.layers.append(LLaMADecoderLayer(layer_id, args))\n",
|
||||
" # 归一化层\n",
|
||||
" self.norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)\n",
|
||||
" # 输出层\n",
|
||||
@@ -586,7 +586,81 @@
|
||||
" logits = self.output(h[:, [-1], :]) \n",
|
||||
" self.last_loss = None\n",
|
||||
"\n",
|
||||
" return logits"
|
||||
" return logits\n",
|
||||
" \n",
|
||||
" def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):\n",
|
||||
" # 获取所有需要更新的参数\n",
|
||||
" param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}\n",
|
||||
" \n",
|
||||
" # 将参数分为需要权重衰减和不需要权重衰减的两组\n",
|
||||
" decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]\n",
|
||||
" nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]\n",
|
||||
" optim_groups = [\n",
|
||||
" {'params': decay_params, 'weight_decay': weight_decay},\n",
|
||||
" {'params': nodecay_params, 'weight_decay': 0.0}\n",
|
||||
" ]\n",
|
||||
" \n",
|
||||
" # 打印参数数量信息\n",
|
||||
" num_decay_params = sum(p.numel() for p in decay_params)\n",
|
||||
" num_nodecay_params = sum(p.numel() for p in nodecay_params)\n",
|
||||
" print(f\"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters\")\n",
|
||||
" print(f\"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters\")\n",
|
||||
" \n",
|
||||
" # 根据设备类型选择使用标准 AdamW 或其融合版本\n",
|
||||
" fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters\n",
|
||||
" use_fused = fused_available and device_type == 'cuda'\n",
|
||||
" extra_args = dict(fused=True) if use_fused else dict()\n",
|
||||
" optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)\n",
|
||||
" print(f\"using fused AdamW: {use_fused}\")\n",
|
||||
"\n",
|
||||
" return optimizer\n",
|
||||
" \n",
|
||||
" def estimate_mfu(self, fwdbwd_per_iter, dt):\n",
|
||||
" \"\"\" 估计模型的 FLOPs 利用率 (MFU) 单位:A100 bfloat16 的峰值 FLOPS \"\"\"\n",
|
||||
" # 计算每次迭代的 FLOPs 数量(参考 PaLM 论文的附录 B)\n",
|
||||
" N = sum(p.numel() for p in self.parameters())\n",
|
||||
" cfg = self.args\n",
|
||||
" L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len\n",
|
||||
" flops_per_token = 6*N + 12*L*H*Q*T\n",
|
||||
" flops_per_fwdbwd = flops_per_token * T\n",
|
||||
" flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter\n",
|
||||
" \n",
|
||||
" # 将 FLOPs 吞吐量表示为 A100 bfloat16 峰值 FLOPS 的比例\n",
|
||||
" flops_achieved = flops_per_iter * (1.0/dt) # 每秒计算的 FLOPs\n",
|
||||
" flops_promised = 312e12 # A100 GPU bfloat16 的峰值 FLOPS 为 312 TFLOPS\n",
|
||||
" mfu = flops_achieved / flops_promised\n",
|
||||
" return mfu\n",
|
||||
" \n",
|
||||
" @torch.inference_mode()\n",
|
||||
" def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):\n",
|
||||
" \"\"\"\n",
|
||||
" 给定输入序列 idx(形状为 (bz,seq_len) 的长整型张量),通过多次生成新 token 来完成序列。\n",
|
||||
" 在 model.eval() 模式下运行。效率较低的采样版本,没有使用键k/v cache。\n",
|
||||
" \"\"\"\n",
|
||||
" for _ in range(max_new_tokens):\n",
|
||||
" # 如果序列上下文过长,截断它到最大长度\n",
|
||||
" idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]\n",
|
||||
" \n",
|
||||
" # 前向传播获取序列中最后一个位置的 logits\n",
|
||||
" logits = self(idx_cond)\n",
|
||||
" logits = logits[:, -1, :] # 只保留最后一个时间步的输出\n",
|
||||
" \n",
|
||||
" if temperature == 0.0:\n",
|
||||
" # 选择最有可能的索引\n",
|
||||
" _, idx_next = torch.topk(logits, k=1, dim=-1)\n",
|
||||
" else:\n",
|
||||
" # 缩放 logits 并应用 softmax\n",
|
||||
" logits = logits / temperature\n",
|
||||
" if top_k is not None:\n",
|
||||
" v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n",
|
||||
" logits[logits < v[:, [-1]]] = -float('Inf')\n",
|
||||
" probs = F.softmax(logits, dim=-1)\n",
|
||||
" idx_next = torch.multinomial(probs, num_samples=1)\n",
|
||||
" \n",
|
||||
" # 将采样的索引添加到序列中并继续\n",
|
||||
" idx = torch.cat((idx, idx_next), dim=1)\n",
|
||||
"\n",
|
||||
" return idx"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -352,6 +352,80 @@ class LLaMA2Model(nn.Module):
|
||||
self.last_loss = None
|
||||
|
||||
return logits
|
||||
|
||||
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
|
||||
# 获取所有需要更新的参数
|
||||
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
|
||||
|
||||
# 将参数分为需要权重衰减和不需要权重衰减的两组
|
||||
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
||||
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
||||
optim_groups = [
|
||||
{'params': decay_params, 'weight_decay': weight_decay},
|
||||
{'params': nodecay_params, 'weight_decay': 0.0}
|
||||
]
|
||||
|
||||
# 打印参数数量信息
|
||||
num_decay_params = sum(p.numel() for p in decay_params)
|
||||
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
||||
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
||||
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
||||
|
||||
# 根据设备类型选择使用标准 AdamW 或其融合版本
|
||||
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
||||
use_fused = fused_available and device_type == 'cuda'
|
||||
extra_args = dict(fused=True) if use_fused else dict()
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
||||
print(f"using fused AdamW: {use_fused}")
|
||||
|
||||
return optimizer
|
||||
|
||||
def estimate_mfu(self, fwdbwd_per_iter, dt):
|
||||
""" 估计模型的 FLOPs 利用率 (MFU) 单位:A100 bfloat16 的峰值 FLOPS """
|
||||
# 计算每次迭代的 FLOPs 数量(参考 PaLM 论文的附录 B)
|
||||
N = sum(p.numel() for p in self.parameters())
|
||||
cfg = self.args
|
||||
L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len
|
||||
flops_per_token = 6*N + 12*L*H*Q*T
|
||||
flops_per_fwdbwd = flops_per_token * T
|
||||
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
|
||||
|
||||
# 将 FLOPs 吞吐量表示为 A100 bfloat16 峰值 FLOPS 的比例
|
||||
flops_achieved = flops_per_iter * (1.0/dt) # 每秒计算的 FLOPs
|
||||
flops_promised = 312e12 # A100 GPU bfloat16 的峰值 FLOPS 为 312 TFLOPS
|
||||
mfu = flops_achieved / flops_promised
|
||||
return mfu
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
||||
"""
|
||||
给定输入序列 idx(形状为 (bz,seq_len) 的长整型张量),通过多次生成新 token 来完成序列。
|
||||
在 model.eval() 模式下运行。效率较低的采样版本,没有使用键k/v cache。
|
||||
"""
|
||||
for _ in range(max_new_tokens):
|
||||
# 如果序列上下文过长,截断它到最大长度
|
||||
idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]
|
||||
|
||||
# 前向传播获取序列中最后一个位置的 logits
|
||||
logits = self(idx_cond)
|
||||
logits = logits[:, -1, :] # 只保留最后一个时间步的输出
|
||||
|
||||
if temperature == 0.0:
|
||||
# 选择最有可能的索引
|
||||
_, idx_next = torch.topk(logits, k=1, dim=-1)
|
||||
else:
|
||||
# 缩放 logits 并应用 softmax
|
||||
logits = logits / temperature
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
idx_next = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
# 将采样的索引添加到序列中并继续
|
||||
idx = torch.cat((idx, idx_next), dim=1)
|
||||
|
||||
return idx
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = ModelArgs()
|
||||
|
||||
Reference in New Issue
Block a user