refactor: Optimize LLaMA2Model's configure_optimizers method

This commit is contained in:
KMnO4-zx
2024-08-24 18:59:48 +08:00
parent f63e6895fa
commit 629f1293ae
2 changed files with 151 additions and 3 deletions

View File

@@ -352,6 +352,80 @@ class LLaMA2Model(nn.Module):
self.last_loss = None
return logits
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
# 获取所有需要更新的参数
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
# 将参数分为需要权重衰减和不需要权重衰减的两组
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
# 打印参数数量信息
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# 根据设备类型选择使用标准 AdamW 或其融合版本
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device_type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
print(f"using fused AdamW: {use_fused}")
return optimizer
def estimate_mfu(self, fwdbwd_per_iter, dt):
""" 估计模型的 FLOPs 利用率 (MFU) 单位A100 bfloat16 的峰值 FLOPS """
# 计算每次迭代的 FLOPs 数量(参考 PaLM 论文的附录 B
N = sum(p.numel() for p in self.parameters())
cfg = self.args
L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len
flops_per_token = 6*N + 12*L*H*Q*T
flops_per_fwdbwd = flops_per_token * T
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
# 将 FLOPs 吞吐量表示为 A100 bfloat16 峰值 FLOPS 的比例
flops_achieved = flops_per_iter * (1.0/dt) # 每秒计算的 FLOPs
flops_promised = 312e12 # A100 GPU bfloat16 的峰值 FLOPS 为 312 TFLOPS
mfu = flops_achieved / flops_promised
return mfu
@torch.inference_mode()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""
给定输入序列 idx形状为 (bz,seq_len) 的长整型张量),通过多次生成新 token 来完成序列。
在 model.eval() 模式下运行。效率较低的采样版本没有使用键k/v cache。
"""
for _ in range(max_new_tokens):
# 如果序列上下文过长,截断它到最大长度
idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]
# 前向传播获取序列中最后一个位置的 logits
logits = self(idx_cond)
logits = logits[:, -1, :] # 只保留最后一个时间步的输出
if temperature == 0.0:
# 选择最有可能的索引
_, idx_next = torch.topk(logits, k=1, dim=-1)
else:
# 缩放 logits 并应用 softmax
logits = logits / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# 将采样的索引添加到序列中并继续
idx = torch.cat((idx, idx_next), dim=1)
return idx
if __name__ == '__main__':
args = ModelArgs()