增加flash attention 选项: GPT_SoVITS/AR/models/t2s_lightning_module.py

增加flash attention 选项:   GPT_SoVITS/AR/models/t2s_model.py
	增加flash attention 选项:   GPT_SoVITS/TTS_infer_pack/TTS.py
	增加flash attention 选项:   GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
	增加flash attention 选项:   GPT_SoVITS/configs/tts_infer.yaml
	增加flash attention 选项:   GPT_SoVITS/inference_webui.py
This commit is contained in:
chasonjiang
2024-03-10 14:07:58 +08:00
parent 2155091950
commit 174c4bbab3
6 changed files with 225 additions and 44 deletions

View File

@@ -20,7 +20,6 @@ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import pdb
import torch
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
@@ -33,8 +32,9 @@ is_half = eval(os.environ.get("is_half", "True")) and not torch.backends.mps.is_
import gradio as gr
from TTS_infer_pack.TTS import TTS, TTS_Config
from TTS_infer_pack.text_segmentation_method import cut1, cut2, cut3, cut4, cut5
from tools.i18n.i18n import I18nAuto
from TTS_infer_pack.text_segmentation_method import get_method
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。