update_infer

This commit is contained in:
Watchtower-Liu
2024-02-16 16:53:57 +08:00
parent 41041715a4
commit 1803729360
6 changed files with 88 additions and 56 deletions

View File

@@ -114,7 +114,8 @@ def logits_to_probs(
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
):
previous_tokens = previous_tokens.squeeze()
if previous_tokens is not None:
previous_tokens = previous_tokens.squeeze()
# print(logits.shape,previous_tokens.shape)
# pdb.set_trace()
if previous_tokens is not None and repetition_penalty != 1.0: