修复了一些bug,优化了一些代码
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import math
|
||||
import os, sys
|
||||
import random
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
import ffmpeg
|
||||
@@ -7,6 +8,7 @@ import os
|
||||
from typing import Generator, List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import yaml
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
@@ -130,11 +132,11 @@ class TTS_Config:
|
||||
string = "----------------TTS Config--------------\n"
|
||||
string += "device: {}\n".format(self.device)
|
||||
string += "is_half: {}\n".format(self.is_half)
|
||||
string += "flash_attn_enabled: {}\n".format(self.flash_attn_enabled)
|
||||
string += "bert_base_path: {}\n".format(self.bert_base_path)
|
||||
string += "t2s_weights_path: {}\n".format(self.t2s_weights_path)
|
||||
string += "vits_weights_path: {}\n".format(self.vits_weights_path)
|
||||
string += "cnhuhbert_base_path: {}\n".format(self.cnhuhbert_base_path)
|
||||
string += "flash_attn_enabled: {}\n".format(self.flash_attn_enabled)
|
||||
string += "----------------------------------------\n"
|
||||
return string
|
||||
|
||||
@@ -184,7 +186,7 @@ class TTS:
|
||||
|
||||
def init_cnhuhbert_weights(self, base_path: str):
|
||||
self.cnhuhbert_model = CNHubert(base_path)
|
||||
self.cnhuhbert_model.eval()
|
||||
self.cnhuhbert_model=self.cnhuhbert_model.eval()
|
||||
if self.configs.is_half == True:
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
||||
@@ -194,6 +196,7 @@ class TTS:
|
||||
def init_bert_weights(self, base_path: str):
|
||||
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
|
||||
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
||||
self.bert_model=self.bert_model.eval()
|
||||
if self.configs.is_half:
|
||||
self.bert_model = self.bert_model.half()
|
||||
self.bert_model = self.bert_model.to(self.configs.device)
|
||||
@@ -226,7 +229,7 @@ class TTS:
|
||||
if self.configs.is_half:
|
||||
vits_model = vits_model.half()
|
||||
vits_model = vits_model.to(self.configs.device)
|
||||
vits_model.eval()
|
||||
vits_model = vits_model.eval()
|
||||
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
self.vits_model = vits_model
|
||||
|
||||
@@ -244,7 +247,7 @@ class TTS:
|
||||
if self.configs.is_half:
|
||||
t2s_model = t2s_model.half()
|
||||
t2s_model = t2s_model.to(self.configs.device)
|
||||
t2s_model.eval()
|
||||
t2s_model = t2s_model.eval()
|
||||
self.t2s_model = t2s_model
|
||||
|
||||
def set_ref_audio(self, ref_audio_path:str):
|
||||
@@ -377,12 +380,14 @@ class TTS:
|
||||
phones_max_len = 0
|
||||
for item in item_list:
|
||||
if prompt_data is not None:
|
||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)
|
||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
||||
.to(dtype=torch.float32 if not self.configs.is_half else torch.float16)
|
||||
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
|
||||
phones = torch.LongTensor(item["phones"])
|
||||
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
||||
else:
|
||||
all_bert_features = item["bert_features"]
|
||||
all_bert_features = item["bert_features"]\
|
||||
.to(dtype=torch.float32 if not self.configs.is_half else torch.float16)
|
||||
phones = torch.LongTensor(item["phones"])
|
||||
all_phones = phones
|
||||
# norm_text = item["norm_text"]
|
||||
@@ -401,12 +406,10 @@ class TTS:
|
||||
max_len = max(bert_max_len, phones_max_len)
|
||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len)
|
||||
all_bert_features_batch.zero_()
|
||||
|
||||
# all_bert_features_batch = all_bert_features_list
|
||||
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=torch.float32)
|
||||
for idx, item in enumerate(all_bert_features_list):
|
||||
if item != None:
|
||||
all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
||||
all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
||||
|
||||
batch = {
|
||||
"phones": phones_batch,
|
||||
@@ -458,8 +461,8 @@ class TTS:
|
||||
"prompt_text": "", # str. prompt text for the reference audio
|
||||
"prompt_lang": "", # str. language of the prompt text for the reference audio
|
||||
"top_k": 5, # int. top k sampling
|
||||
"top_p": 0.9, # float. top p sampling
|
||||
"temperature": 0.6, # float. temperature for sampling
|
||||
"top_p": 1, # float. top p sampling
|
||||
"temperature": 1, # float. temperature for sampling
|
||||
"text_split_method": "", # str. text split method, see text_segmentaion_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
@@ -477,9 +480,9 @@ class TTS:
|
||||
ref_audio_path:str = inputs.get("ref_audio_path", "")
|
||||
prompt_text:str = inputs.get("prompt_text", "")
|
||||
prompt_lang:str = inputs.get("prompt_lang", "")
|
||||
top_k:int = inputs.get("top_k", 20)
|
||||
top_p:float = inputs.get("top_p", 0.9)
|
||||
temperature:float = inputs.get("temperature", 0.6)
|
||||
top_k:int = inputs.get("top_k", 5)
|
||||
top_p:float = inputs.get("top_p", 1)
|
||||
temperature:float = inputs.get("temperature", 1)
|
||||
text_split_method:str = inputs.get("text_split_method", "")
|
||||
batch_size = inputs.get("batch_size", 1)
|
||||
batch_threshold = inputs.get("batch_threshold", 0.75)
|
||||
@@ -497,10 +500,6 @@ class TTS:
|
||||
if split_bucket:
|
||||
print(i18n("分桶处理模式已开启"))
|
||||
|
||||
# if vits_batched_inference:
|
||||
# print(i18n("VITS批量推理模式已开启"))
|
||||
# else:
|
||||
# print(i18n("VITS单句推理模式已开启"))
|
||||
|
||||
no_prompt_text = False
|
||||
if prompt_text in [None, ""]:
|
||||
@@ -547,7 +546,7 @@ class TTS:
|
||||
)
|
||||
t2 = ttime()
|
||||
|
||||
|
||||
print("############ 推理 ############")
|
||||
###### inference ######
|
||||
t_34 = 0.0
|
||||
t_45 = 0.0
|
||||
@@ -601,6 +600,10 @@ class TTS:
|
||||
# pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
# pred_semantic_len = torch.LongTensor([item.shape[0] for item in pred_semantic_list]).to(self.configs.device)
|
||||
# pred_semantic = self.batch_sequences(pred_semantic_list, axis=0, pad_value=0).unsqueeze(0)
|
||||
# max_len = 0
|
||||
# for i in range(0, len(batch_phones)):
|
||||
# max_len = max(max_len, batch_phones[i].shape[-1])
|
||||
# batch_phones = self.batch_sequences(batch_phones, axis=0, pad_value=0, max_length=max_len)
|
||||
# batch_phones = batch_phones.to(self.configs.device)
|
||||
# batch_audio_fragment = (self.vits_model.batched_decode(
|
||||
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spepc
|
||||
@@ -654,7 +657,12 @@ class TTS:
|
||||
self.configs.sampling_rate,
|
||||
batch_index_list,
|
||||
speed_factor,
|
||||
split_bucket)
|
||||
split_bucket)
|
||||
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user