From 3b9259b0a1564dc83e35ebf8e513740fc7594bb3 Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Sat, 9 Mar 2024 20:21:11 +0800 Subject: [PATCH] modified: GPT_SoVITS/AR/models/t2s_model.py --- GPT_SoVITS/AR/models/t2s_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 8e3c7fc..23da380 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -166,10 +166,10 @@ class T2STransformer: return x, k_cache, v_cache def decode_next_token( - self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor], attn_mask : torch.Tensor + self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor] ): for i in range(self.num_blocks): - x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask) + x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i]) return x, k_cache, v_cache @@ -554,7 +554,7 @@ class Text2SemanticDecoder(nn.Module): if idx == 0: xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask) else: - xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask) + xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) logits = self.ar_predict_layer( xy_dec[:, -1]