为api_v2和inference_webui_fast适配V3版本 (#2188)
* modified: GPT_SoVITS/TTS_infer_pack/TTS.py modified: GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py modified: GPT_SoVITS/inference_webui_fast.py * 适配V3版本 * api_v2.py和inference_webui_fast.py的v3适配 * 修改了个远古bug,增加了更友好的提示信息 * 优化webui * 修改为正确的path * 修复v3 lora模型的载入问题 * 修复读取tts_infer.yaml文件时遇到的编码不匹配的问题
This commit is contained in:
24
api_v2.py
24
api_v2.py
@@ -39,6 +39,8 @@ POST:
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
}
|
||||
```
|
||||
|
||||
@@ -164,6 +166,8 @@ class TTS_Request(BaseModel):
|
||||
streaming_mode:bool = False
|
||||
parallel_infer:bool = True
|
||||
repetition_penalty:float = 1.35
|
||||
sample_steps:int = 32
|
||||
super_sampling:bool = False
|
||||
|
||||
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
||||
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
|
||||
@@ -294,7 +298,9 @@ async def tts_handle(req:dict):
|
||||
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
|
||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
||||
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
||||
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
}
|
||||
returns:
|
||||
StreamingResponse: audio stream response.
|
||||
@@ -316,10 +322,12 @@ async def tts_handle(req:dict):
|
||||
|
||||
if streaming_mode:
|
||||
def streaming_generator(tts_generator:Generator, media_type:str):
|
||||
if media_type == "wav":
|
||||
yield wave_header_chunk()
|
||||
media_type = "raw"
|
||||
if_frist_chunk = True
|
||||
for sr, chunk in tts_generator:
|
||||
if if_frist_chunk and media_type == "wav":
|
||||
yield wave_header_chunk(sample_rate=sr)
|
||||
media_type = "raw"
|
||||
if_frist_chunk = False
|
||||
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
|
||||
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
|
||||
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
|
||||
@@ -365,7 +373,9 @@ async def tts_get_endpoint(
|
||||
media_type:str = "wav",
|
||||
streaming_mode:bool = False,
|
||||
parallel_infer:bool = True,
|
||||
repetition_penalty:float = 1.35
|
||||
repetition_penalty:float = 1.35,
|
||||
sample_steps:int =32,
|
||||
super_sampling:bool = False
|
||||
):
|
||||
req = {
|
||||
"text": text,
|
||||
@@ -387,7 +397,9 @@ async def tts_get_endpoint(
|
||||
"media_type":media_type,
|
||||
"streaming_mode":streaming_mode,
|
||||
"parallel_infer":parallel_infer,
|
||||
"repetition_penalty":float(repetition_penalty)
|
||||
"repetition_penalty":float(repetition_penalty),
|
||||
"sample_steps":int(sample_steps),
|
||||
"super_sampling":super_sampling
|
||||
}
|
||||
return await tts_handle(req)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user