diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 8e3c7fc..23da380 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -166,10 +166,10 @@ class T2STransformer: return x, k_cache, v_cache def decode_next_token( - self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor], attn_mask : torch.Tensor + self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor] ): for i in range(self.num_blocks): - x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask) + x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i]) return x, k_cache, v_cache @@ -554,7 +554,7 @@ class Text2SemanticDecoder(nn.Module): if idx == 0: xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask) else: - xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask) + xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) logits = self.ar_predict_layer( xy_dec[:, -1]