update_infer

This commit is contained in:
Watchtower-Liu
2024-02-16 16:53:57 +08:00
parent 41041715a4
commit 1803729360
6 changed files with 88 additions and 56 deletions

View File

@@ -392,6 +392,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
1, 2
) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
@@ -423,9 +424,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
print(i18n("实际输入的目标文本(每句):"), text)
phones2, word2ph2, norm_text2 = get_cleaned_text_final(text, text_language)
bert2 = get_bert_final(phones2, word2ph2, norm_text2, text_language, device).to(dtype)
bert = torch.cat([bert1, bert2], 1)
bert = torch.cat([bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
@@ -435,14 +436,14 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
None,
bert,
# prompt_phone_len=ph_offset,
top_k=config["inference"]["top_k"],
early_stop_num=hz * max_sec,
)
t3 = ttime()
# print(pred_semantic.shape,idx)
print(pred_semantic,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
0
) # .unsqueeze(0)#mq要多unsqueeze一次
@@ -620,7 +621,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
inference_button.click(
get_tts_wav,
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut],
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature],
[output],
)