refactor: Optimize LLaMA2Model's configure_optimizers method

This commit is contained in:
KMnO4-zx
2024-08-24 18:59:48 +08:00
parent f63e6895fa
commit 629f1293ae
2 changed files with 151 additions and 3 deletions

View File

@@ -420,7 +420,7 @@
], ],
"source": [ "source": [
"# 创建MLP实例\n", "# 创建MLP实例\n",
"mlp = LLaMAMLP(args.dim, args.hidden_dim, args.multiple_of, args.dropout)\n", "mlp = LLaMA2MLP(args.dim, args.hidden_dim, args.multiple_of, args.dropout)\n",
"# 随机生成数据\n", "# 随机生成数据\n",
"x = torch.randn(1, 50, 288)\n", "x = torch.randn(1, 50, 288)\n",
"# 运行MLP模型\n", "# 运行MLP模型\n",
@@ -528,7 +528,7 @@
" # Decoder层\n", " # Decoder层\n",
" self.layers = torch.nn.ModuleList()\n", " self.layers = torch.nn.ModuleList()\n",
" for layer_id in range(args.n_layers):\n", " for layer_id in range(args.n_layers):\n",
" self.layers.append(LLaMA2DecoderLayer(layer_id, args))\n", " self.layers.append(LLaMADecoderLayer(layer_id, args))\n",
" # 归一化层\n", " # 归一化层\n",
" self.norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)\n", " self.norm = LLaMA2RMSNorm(args.dim, eps=args.norm_eps)\n",
" # 输出层\n", " # 输出层\n",
@@ -586,7 +586,81 @@
" logits = self.output(h[:, [-1], :]) \n", " logits = self.output(h[:, [-1], :]) \n",
" self.last_loss = None\n", " self.last_loss = None\n",
"\n", "\n",
" return logits" " return logits\n",
" \n",
" def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):\n",
" # 获取所有需要更新的参数\n",
" param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}\n",
" \n",
" # 将参数分为需要权重衰减和不需要权重衰减的两组\n",
" decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]\n",
" nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]\n",
" optim_groups = [\n",
" {'params': decay_params, 'weight_decay': weight_decay},\n",
" {'params': nodecay_params, 'weight_decay': 0.0}\n",
" ]\n",
" \n",
" # 打印参数数量信息\n",
" num_decay_params = sum(p.numel() for p in decay_params)\n",
" num_nodecay_params = sum(p.numel() for p in nodecay_params)\n",
" print(f\"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters\")\n",
" print(f\"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters\")\n",
" \n",
" # 根据设备类型选择使用标准 AdamW 或其融合版本\n",
" fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters\n",
" use_fused = fused_available and device_type == 'cuda'\n",
" extra_args = dict(fused=True) if use_fused else dict()\n",
" optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)\n",
" print(f\"using fused AdamW: {use_fused}\")\n",
"\n",
" return optimizer\n",
" \n",
" def estimate_mfu(self, fwdbwd_per_iter, dt):\n",
" \"\"\" 估计模型的 FLOPs 利用率 (MFU) 单位A100 bfloat16 的峰值 FLOPS \"\"\"\n",
" # 计算每次迭代的 FLOPs 数量(参考 PaLM 论文的附录 B\n",
" N = sum(p.numel() for p in self.parameters())\n",
" cfg = self.args\n",
" L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len\n",
" flops_per_token = 6*N + 12*L*H*Q*T\n",
" flops_per_fwdbwd = flops_per_token * T\n",
" flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter\n",
" \n",
" # 将 FLOPs 吞吐量表示为 A100 bfloat16 峰值 FLOPS 的比例\n",
" flops_achieved = flops_per_iter * (1.0/dt) # 每秒计算的 FLOPs\n",
" flops_promised = 312e12 # A100 GPU bfloat16 的峰值 FLOPS 为 312 TFLOPS\n",
" mfu = flops_achieved / flops_promised\n",
" return mfu\n",
" \n",
" @torch.inference_mode()\n",
" def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):\n",
" \"\"\"\n",
" 给定输入序列 idx形状为 (bz,seq_len) 的长整型张量),通过多次生成新 token 来完成序列。\n",
" 在 model.eval() 模式下运行。效率较低的采样版本没有使用键k/v cache。\n",
" \"\"\"\n",
" for _ in range(max_new_tokens):\n",
" # 如果序列上下文过长,截断它到最大长度\n",
" idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]\n",
" \n",
" # 前向传播获取序列中最后一个位置的 logits\n",
" logits = self(idx_cond)\n",
" logits = logits[:, -1, :] # 只保留最后一个时间步的输出\n",
" \n",
" if temperature == 0.0:\n",
" # 选择最有可能的索引\n",
" _, idx_next = torch.topk(logits, k=1, dim=-1)\n",
" else:\n",
" # 缩放 logits 并应用 softmax\n",
" logits = logits / temperature\n",
" if top_k is not None:\n",
" v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n",
" logits[logits < v[:, [-1]]] = -float('Inf')\n",
" probs = F.softmax(logits, dim=-1)\n",
" idx_next = torch.multinomial(probs, num_samples=1)\n",
" \n",
" # 将采样的索引添加到序列中并继续\n",
" idx = torch.cat((idx, idx_next), dim=1)\n",
"\n",
" return idx"
] ]
}, },
{ {

View File

@@ -352,6 +352,80 @@ class LLaMA2Model(nn.Module):
self.last_loss = None self.last_loss = None
return logits return logits
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
# 获取所有需要更新的参数
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
# 将参数分为需要权重衰减和不需要权重衰减的两组
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
# 打印参数数量信息
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# 根据设备类型选择使用标准 AdamW 或其融合版本
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
print(f"using fused AdamW: {use_fused}")
return optimizer
def estimate_mfu(self, fwdbwd_per_iter, dt):
""" 估计模型的 FLOPs 利用率 (MFU) 单位A100 bfloat16 的峰值 FLOPS """
# 计算每次迭代的 FLOPs 数量(参考 PaLM 论文的附录 B
N = sum(p.numel() for p in self.parameters())
cfg = self.args
L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len
flops_per_token = 6*N + 12*L*H*Q*T
flops_per_fwdbwd = flops_per_token * T
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
# 将 FLOPs 吞吐量表示为 A100 bfloat16 峰值 FLOPS 的比例
flops_achieved = flops_per_iter * (1.0/dt) # 每秒计算的 FLOPs
flops_promised = 312e12 # A100 GPU bfloat16 的峰值 FLOPS 为 312 TFLOPS
mfu = flops_achieved / flops_promised
return mfu
@torch.inference_mode()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""
给定输入序列 idx形状为 (bz,seq_len) 的长整型张量),通过多次生成新 token 来完成序列。
在 model.eval() 模式下运行。效率较低的采样版本没有使用键k/v cache。
"""
for _ in range(max_new_tokens):
# 如果序列上下文过长,截断它到最大长度
idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]
# 前向传播获取序列中最后一个位置的 logits
logits = self(idx_cond)
logits = logits[:, -1, :] # 只保留最后一个时间步的输出
if temperature == 0.0:
# 选择最有可能的索引
_, idx_next = torch.topk(logits, k=1, dim=-1)
else:
# 缩放 logits 并应用 softmax
logits = logits / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# 将采样的索引添加到序列中并继续
idx = torch.cat((idx, idx_next), dim=1)
return idx
if __name__ == '__main__': if __name__ == '__main__':
args = ModelArgs() args = ModelArgs()