修复了一些bug,优化了一些代码

This commit is contained in:
chasonjiang
2024-03-11 17:16:04 +08:00
parent 3535cfe3b0
commit d23f3a62c4
5 changed files with 72 additions and 51 deletions

View File

@@ -143,7 +143,7 @@ def logits_to_probs(
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
pivot = v[: , -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)