refactor: Optimize LLaMA2Model's configure_optimizers method
This commit is contained in:
@@ -420,7 +420,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# 创建MLP实例\n",
|
"# 创建MLP实例\n",
|
||||||
"mlp = LLaMAMLP(args.dim, args.hidden_dim, args.multiple_of, args.dropout)\n",
|
"mlp = LLaMA2MLP(args.dim, args.hidden_dim, args.multiple_of, args.dropout)\n",
|
||||||
"# 随机生成数据\n",
|
"# 随机生成数据\n",
|
||||||
"x = torch.randn(1, 50, 288)\n",
|
"x = torch.randn(1, 50, 288)\n",
|
||||||
"# 运行MLP模型\n",
|
"# 运行MLP模型\n",
|
||||||
@@ -528,7 +528,7 @@
|
|||||||
" # Decoder层\n",
|
" # Decoder层\n",
|
||||||
" self.layers = torch.nn.ModuleList()\n",
|
" self.layers = torch.nn.ModuleList()\n",
|
||||||
" for layer_id in range(args.n_layers):\n",
|
" for layer_id in range(args.n_layers):\n",
|
||||||
" self.layers.append(LLaMA2DecoderLayer(layer_id, args))\n",
|
" self.layers.append(LLaMADecoderLayer(layer_id, args))\n",
|
||||||
" # 归一化层\n",
|
" # 归一化层\n",
|
||||||
" self.norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)\n",
|
" self.norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)\n",
|
||||||
" # 输出层\n",
|
" # 输出层\n",
|
||||||
@@ -586,7 +586,81 @@
|
|||||||
" logits = self.output(h[:, [-1], :]) \n",
|
" logits = self.output(h[:, [-1], :]) \n",
|
||||||
" self.last_loss = None\n",
|
" self.last_loss = None\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return logits"
|
" return logits\n",
|
||||||
|
" \n",
|
||||||
|
" def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):\n",
|
||||||
|
" # 获取所有需要更新的参数\n",
|
||||||
|
" param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}\n",
|
||||||
|
" \n",
|
||||||
|
" # 将参数分为需要权重衰减和不需要权重衰减的两组\n",
|
||||||
|
" decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]\n",
|
||||||
|
" nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]\n",
|
||||||
|
" optim_groups = [\n",
|
||||||
|
" {'params': decay_params, 'weight_decay': weight_decay},\n",
|
||||||
|
" {'params': nodecay_params, 'weight_decay': 0.0}\n",
|
||||||
|
" ]\n",
|
||||||
|
" \n",
|
||||||
|
" # 打印参数数量信息\n",
|
||||||
|
" num_decay_params = sum(p.numel() for p in decay_params)\n",
|
||||||
|
" num_nodecay_params = sum(p.numel() for p in nodecay_params)\n",
|
||||||
|
" print(f\"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters\")\n",
|
||||||
|
" print(f\"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters\")\n",
|
||||||
|
" \n",
|
||||||
|
" # 根据设备类型选择使用标准 AdamW 或其融合版本\n",
|
||||||
|
" fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters\n",
|
||||||
|
" use_fused = fused_available and device_type == 'cuda'\n",
|
||||||
|
" extra_args = dict(fused=True) if use_fused else dict()\n",
|
||||||
|
" optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)\n",
|
||||||
|
" print(f\"using fused AdamW: {use_fused}\")\n",
|
||||||
|
"\n",
|
||||||
|
" return optimizer\n",
|
||||||
|
" \n",
|
||||||
|
" def estimate_mfu(self, fwdbwd_per_iter, dt):\n",
|
||||||
|
" \"\"\" 估计模型的 FLOPs 利用率 (MFU) 单位:A100 bfloat16 的峰值 FLOPS \"\"\"\n",
|
||||||
|
" # 计算每次迭代的 FLOPs 数量(参考 PaLM 论文的附录 B)\n",
|
||||||
|
" N = sum(p.numel() for p in self.parameters())\n",
|
||||||
|
" cfg = self.args\n",
|
||||||
|
" L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len\n",
|
||||||
|
" flops_per_token = 6*N + 12*L*H*Q*T\n",
|
||||||
|
" flops_per_fwdbwd = flops_per_token * T\n",
|
||||||
|
" flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter\n",
|
||||||
|
" \n",
|
||||||
|
" # 将 FLOPs 吞吐量表示为 A100 bfloat16 峰值 FLOPS 的比例\n",
|
||||||
|
" flops_achieved = flops_per_iter * (1.0/dt) # 每秒计算的 FLOPs\n",
|
||||||
|
" flops_promised = 312e12 # A100 GPU bfloat16 的峰值 FLOPS 为 312 TFLOPS\n",
|
||||||
|
" mfu = flops_achieved / flops_promised\n",
|
||||||
|
" return mfu\n",
|
||||||
|
" \n",
|
||||||
|
" @torch.inference_mode()\n",
|
||||||
|
" def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" 给定输入序列 idx(形状为 (bz,seq_len) 的长整型张量),通过多次生成新 token 来完成序列。\n",
|
||||||
|
" 在 model.eval() 模式下运行。效率较低的采样版本,没有使用键k/v cache。\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" for _ in range(max_new_tokens):\n",
|
||||||
|
" # 如果序列上下文过长,截断它到最大长度\n",
|
||||||
|
" idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]\n",
|
||||||
|
" \n",
|
||||||
|
" # 前向传播获取序列中最后一个位置的 logits\n",
|
||||||
|
" logits = self(idx_cond)\n",
|
||||||
|
" logits = logits[:, -1, :] # 只保留最后一个时间步的输出\n",
|
||||||
|
" \n",
|
||||||
|
" if temperature == 0.0:\n",
|
||||||
|
" # 选择最有可能的索引\n",
|
||||||
|
" _, idx_next = torch.topk(logits, k=1, dim=-1)\n",
|
||||||
|
" else:\n",
|
||||||
|
" # 缩放 logits 并应用 softmax\n",
|
||||||
|
" logits = logits / temperature\n",
|
||||||
|
" if top_k is not None:\n",
|
||||||
|
" v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n",
|
||||||
|
" logits[logits < v[:, [-1]]] = -float('Inf')\n",
|
||||||
|
" probs = F.softmax(logits, dim=-1)\n",
|
||||||
|
" idx_next = torch.multinomial(probs, num_samples=1)\n",
|
||||||
|
" \n",
|
||||||
|
" # 将采样的索引添加到序列中并继续\n",
|
||||||
|
" idx = torch.cat((idx, idx_next), dim=1)\n",
|
||||||
|
"\n",
|
||||||
|
" return idx"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -352,6 +352,80 @@ class LLaMA2Model(nn.Module):
|
|||||||
self.last_loss = None
|
self.last_loss = None
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
|
||||||
|
# 获取所有需要更新的参数
|
||||||
|
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
|
||||||
|
|
||||||
|
# 将参数分为需要权重衰减和不需要权重衰减的两组
|
||||||
|
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
||||||
|
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
||||||
|
optim_groups = [
|
||||||
|
{'params': decay_params, 'weight_decay': weight_decay},
|
||||||
|
{'params': nodecay_params, 'weight_decay': 0.0}
|
||||||
|
]
|
||||||
|
|
||||||
|
# 打印参数数量信息
|
||||||
|
num_decay_params = sum(p.numel() for p in decay_params)
|
||||||
|
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
||||||
|
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
||||||
|
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
||||||
|
|
||||||
|
# 根据设备类型选择使用标准 AdamW 或其融合版本
|
||||||
|
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
||||||
|
use_fused = fused_available and device_type == 'cuda'
|
||||||
|
extra_args = dict(fused=True) if use_fused else dict()
|
||||||
|
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
||||||
|
print(f"using fused AdamW: {use_fused}")
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def estimate_mfu(self, fwdbwd_per_iter, dt):
|
||||||
|
""" 估计模型的 FLOPs 利用率 (MFU) 单位:A100 bfloat16 的峰值 FLOPS """
|
||||||
|
# 计算每次迭代的 FLOPs 数量(参考 PaLM 论文的附录 B)
|
||||||
|
N = sum(p.numel() for p in self.parameters())
|
||||||
|
cfg = self.args
|
||||||
|
L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len
|
||||||
|
flops_per_token = 6*N + 12*L*H*Q*T
|
||||||
|
flops_per_fwdbwd = flops_per_token * T
|
||||||
|
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
|
||||||
|
|
||||||
|
# 将 FLOPs 吞吐量表示为 A100 bfloat16 峰值 FLOPS 的比例
|
||||||
|
flops_achieved = flops_per_iter * (1.0/dt) # 每秒计算的 FLOPs
|
||||||
|
flops_promised = 312e12 # A100 GPU bfloat16 的峰值 FLOPS 为 312 TFLOPS
|
||||||
|
mfu = flops_achieved / flops_promised
|
||||||
|
return mfu
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
||||||
|
"""
|
||||||
|
给定输入序列 idx(形状为 (bz,seq_len) 的长整型张量),通过多次生成新 token 来完成序列。
|
||||||
|
在 model.eval() 模式下运行。效率较低的采样版本,没有使用键k/v cache。
|
||||||
|
"""
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
# 如果序列上下文过长,截断它到最大长度
|
||||||
|
idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]
|
||||||
|
|
||||||
|
# 前向传播获取序列中最后一个位置的 logits
|
||||||
|
logits = self(idx_cond)
|
||||||
|
logits = logits[:, -1, :] # 只保留最后一个时间步的输出
|
||||||
|
|
||||||
|
if temperature == 0.0:
|
||||||
|
# 选择最有可能的索引
|
||||||
|
_, idx_next = torch.topk(logits, k=1, dim=-1)
|
||||||
|
else:
|
||||||
|
# 缩放 logits 并应用 softmax
|
||||||
|
logits = logits / temperature
|
||||||
|
if top_k is not None:
|
||||||
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||||
|
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
idx_next = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
|
# 将采样的索引添加到序列中并继续
|
||||||
|
idx = torch.cat((idx, idx_next), dim=1)
|
||||||
|
|
||||||
|
return idx
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = ModelArgs()
|
args = ModelArgs()
|
||||||
|
|||||||
Reference in New Issue
Block a user