修复了t2s模型无prompt输入时的bug GPT_SoVITS/AR/models/t2s_model.py
增加一些新特性,并修复了一些bug GPT_SoVITS/TTS_infer_pack/TTS.py 优化网页布局 GPT_SoVITS/inference_webui.py
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
import os, sys
|
||||
|
||||
import ffmpeg
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
import ffmpeg
|
||||
import os
|
||||
from typing import Generator, List, Union
|
||||
import numpy as np
|
||||
@@ -164,6 +163,9 @@ class TTS:
|
||||
"bert_features":None,
|
||||
"norm_text":None,
|
||||
}
|
||||
|
||||
|
||||
self.stop_flag:bool = False
|
||||
|
||||
def _init_models(self,):
|
||||
self.init_t2s_weights(self.configs.t2s_weights_path)
|
||||
@@ -310,7 +312,7 @@ class TTS:
|
||||
batch = torch.stack(padded_sequences)
|
||||
return batch
|
||||
|
||||
def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold:float=0.75):
|
||||
def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold:float=0.75, split_bucket:bool=True):
|
||||
|
||||
_data:list = []
|
||||
index_and_len_list = []
|
||||
@@ -318,30 +320,35 @@ class TTS:
|
||||
norm_text_len = len(item["norm_text"])
|
||||
index_and_len_list.append([idx, norm_text_len])
|
||||
|
||||
index_and_len_list.sort(key=lambda x: x[1])
|
||||
# index_and_len_batch_list = [index_and_len_list[idx:min(idx+batch_size,len(index_and_len_list))] for idx in range(0,len(index_and_len_list),batch_size)]
|
||||
index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
|
||||
|
||||
# for batch_idx, index_and_len_batch in enumerate(index_and_len_batch_list):
|
||||
|
||||
batch_index_list = []
|
||||
batch_index_list_len = 0
|
||||
pos = 0
|
||||
while pos <index_and_len_list.shape[0]:
|
||||
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
|
||||
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
|
||||
while pos < pos_end:
|
||||
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
|
||||
score=batch[(pos_end-pos)//2]/batch.mean()
|
||||
if (score>=threshold) or (pos_end-pos==1):
|
||||
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
|
||||
batch_index_list_len += len(batch_index)
|
||||
batch_index_list.append(batch_index)
|
||||
pos = pos_end
|
||||
break
|
||||
pos_end=pos_end-1
|
||||
|
||||
assert batch_index_list_len == len(data)
|
||||
if split_bucket:
|
||||
index_and_len_list.sort(key=lambda x: x[1])
|
||||
index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
|
||||
|
||||
batch_index_list_len = 0
|
||||
pos = 0
|
||||
while pos <index_and_len_list.shape[0]:
|
||||
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
|
||||
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
|
||||
while pos < pos_end:
|
||||
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
|
||||
score=batch[(pos_end-pos)//2]/batch.mean()
|
||||
if (score>=threshold) or (pos_end-pos==1):
|
||||
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
|
||||
batch_index_list_len += len(batch_index)
|
||||
batch_index_list.append(batch_index)
|
||||
pos = pos_end
|
||||
break
|
||||
pos_end=pos_end-1
|
||||
|
||||
assert batch_index_list_len == len(data)
|
||||
|
||||
else:
|
||||
for i in range(len(data)):
|
||||
if i%batch_size == 0:
|
||||
batch_index_list.append([])
|
||||
batch_index_list[-1].append(i)
|
||||
|
||||
|
||||
for batch_idx, index_list in enumerate(batch_index_list):
|
||||
item_list = [data[idx] for idx in index_list]
|
||||
@@ -399,7 +406,8 @@ class TTS:
|
||||
_data[index] = data[i][j]
|
||||
return _data
|
||||
|
||||
|
||||
def stop(self,):
|
||||
self.stop_flag = True
|
||||
|
||||
|
||||
def run(self, inputs:dict):
|
||||
@@ -409,22 +417,26 @@ class TTS:
|
||||
Args:
|
||||
inputs (dict):
|
||||
{
|
||||
"text": "",
|
||||
"text_lang: "",
|
||||
"ref_audio_path": "",
|
||||
"prompt_text": "",
|
||||
"prompt_lang": "",
|
||||
"top_k": 5,
|
||||
"top_p": 0.9,
|
||||
"temperature": 0.6,
|
||||
"text_split_method": "",
|
||||
"batch_size": 1,
|
||||
"batch_threshold": 0.75,
|
||||
"speed_factor":1.0,
|
||||
"text": "", # str. text to be synthesized
|
||||
"text_lang: "", # str. language of the text to be synthesized
|
||||
"ref_audio_path": "", # str. reference audio path
|
||||
"prompt_text": "", # str. prompt text for the reference audio
|
||||
"prompt_lang": "", # str. language of the prompt text for the reference audio
|
||||
"top_k": 5, # int. top k sampling
|
||||
"top_p": 0.9, # float. top p sampling
|
||||
"temperature": 0.6, # float. temperature for sampling
|
||||
"text_split_method": "", # str. text split method, see text_segmentaion_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
||||
"return_fragment": False, # bool. step by step return the audio fragment.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
}
|
||||
returns:
|
||||
tulpe[int, np.ndarray]: sampling rate and audio data.
|
||||
"""
|
||||
self.stop_flag:bool = False
|
||||
|
||||
text:str = inputs.get("text", "")
|
||||
text_lang:str = inputs.get("text_lang", "")
|
||||
ref_audio_path:str = inputs.get("ref_audio_path", "")
|
||||
@@ -437,7 +449,20 @@ class TTS:
|
||||
batch_size = inputs.get("batch_size", 1)
|
||||
batch_threshold = inputs.get("batch_threshold", 0.75)
|
||||
speed_factor = inputs.get("speed_factor", 1.0)
|
||||
split_bucket = inputs.get("split_bucket", True)
|
||||
return_fragment = inputs.get("return_fragment", False)
|
||||
|
||||
if return_fragment:
|
||||
split_bucket = False
|
||||
print(i18n("分段返回模式已开启"))
|
||||
if split_bucket:
|
||||
split_bucket = False
|
||||
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
|
||||
|
||||
if split_bucket:
|
||||
print(i18n("分桶处理模式已开启"))
|
||||
|
||||
|
||||
no_prompt_text = False
|
||||
if prompt_text in [None, ""]:
|
||||
no_prompt_text = True
|
||||
@@ -481,7 +506,9 @@ class TTS:
|
||||
data, batch_index_list = self.to_batch(data,
|
||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||
batch_size=batch_size,
|
||||
threshold=batch_threshold)
|
||||
threshold=batch_threshold,
|
||||
split_bucket=split_bucket
|
||||
)
|
||||
t2 = ttime()
|
||||
zero_wav = torch.zeros(
|
||||
int(self.configs.sampling_rate * 0.3),
|
||||
@@ -557,27 +584,57 @@ class TTS:
|
||||
audio_fragment.cpu().numpy()
|
||||
) ###试试重建不带上prompt部分
|
||||
|
||||
audio.append(batch_audio_fragment)
|
||||
# audio.append(zero_wav)
|
||||
t5 = ttime()
|
||||
t_45 += t5 - t4
|
||||
if return_fragment:
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
|
||||
yield self.audio_postprocess(batch_audio_fragment,
|
||||
self.configs.sampling_rate,
|
||||
batch_index_list,
|
||||
speed_factor,
|
||||
split_bucket)
|
||||
else:
|
||||
audio.append(batch_audio_fragment)
|
||||
|
||||
if self.stop_flag:
|
||||
yield self.configs.sampling_rate, (zero_wav.cpu().numpy()).astype(np.int16)
|
||||
return
|
||||
|
||||
if not return_fragment:
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
|
||||
yield self.audio_postprocess(audio,
|
||||
self.configs.sampling_rate,
|
||||
batch_index_list,
|
||||
speed_factor,
|
||||
split_bucket)
|
||||
|
||||
|
||||
|
||||
def audio_postprocess(self,
|
||||
audio:np.ndarray,
|
||||
sr:int,
|
||||
batch_index_list:list=None,
|
||||
speed_factor:float=1.0,
|
||||
split_bucket:bool=True)->tuple[int, np.ndarray]:
|
||||
if split_bucket:
|
||||
audio = self.recovery_order(audio, batch_index_list)
|
||||
else:
|
||||
audio = [item for batch in audio for item in batch]
|
||||
|
||||
|
||||
audio = self.recovery_order(audio, batch_index_list)
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
|
||||
|
||||
audio = np.concatenate(audio, 0)
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
|
||||
try:
|
||||
if speed_factor != 1.0:
|
||||
audio = speed_change(audio, speed=speed_factor, sr=int(self.configs.sampling_rate))
|
||||
audio = speed_change(audio, speed=speed_factor, sr=int(sr))
|
||||
except Exception as e:
|
||||
print(f"Failed to change speed of audio: \n{e}")
|
||||
|
||||
yield self.configs.sampling_rate, audio
|
||||
|
||||
|
||||
|
||||
return sr, audio
|
||||
|
||||
|
||||
|
||||
|
||||
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
|
||||
# 将 NumPy 数组转换为原始 PCM 流
|
||||
|
||||
Reference in New Issue
Block a user