增加flash attention 选项: GPT_SoVITS/AR/models/t2s_lightning_module.py

增加flash attention 选项:   GPT_SoVITS/AR/models/t2s_model.py
	增加flash attention 选项:   GPT_SoVITS/TTS_infer_pack/TTS.py
	增加flash attention 选项:   GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
	增加flash attention 选项:   GPT_SoVITS/configs/tts_infer.yaml
	增加flash attention 选项:   GPT_SoVITS/inference_webui.py
This commit is contained in:
chasonjiang
2024-03-10 14:07:58 +08:00
parent 2155091950
commit 174c4bbab3
6 changed files with 225 additions and 44 deletions

View File

@@ -17,8 +17,8 @@ from time import time as ttime
from tools.i18n.i18n import I18nAuto
from my_utils import load_audio
from module.mel_processing import spectrogram_torch
from .text_segmentation_method import splits
from .TextPreprocessor import TextPreprocessor
from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
i18n = I18nAuto()
# configs/tts_infer.yaml
@@ -30,6 +30,7 @@ default:
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
flash_attn_enabled: true
custom:
device: cuda
@@ -38,7 +39,7 @@ custom:
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
flash_attn_enabled: true
"""
@@ -63,7 +64,8 @@ class TTS_Config:
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
"flash_attn_enabled": True
}
self.configs:dict = configs.get("custom", self.default_configs)
@@ -73,6 +75,7 @@ class TTS_Config:
self.vits_weights_path = self.configs.get("vits_weights_path")
self.bert_base_path = self.configs.get("bert_base_path")
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path")
self.flash_attn_enabled = self.configs.get("flash_attn_enabled")
self.max_sec = None
@@ -103,7 +106,8 @@ class TTS_Config:
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
"flash_attn_enabled": True
},
"custom": {
"device": str(self.device),
@@ -111,7 +115,8 @@ class TTS_Config:
"t2s_weights_path": self.t2s_weights_path,
"vits_weights_path": self.vits_weights_path,
"bert_base_path": self.bert_base_path,
"cnhuhbert_base_path": self.cnhuhbert_base_path
"cnhuhbert_base_path": self.cnhuhbert_base_path,
"flash_attn_enabled": self.flash_attn_enabled
}
}
if configs_path is None:
@@ -128,6 +133,7 @@ class TTS_Config:
string += "t2s_weights_path: {}\n".format(self.t2s_weights_path)
string += "vits_weights_path: {}\n".format(self.vits_weights_path)
string += "cnhuhbert_base_path: {}\n".format(self.cnhuhbert_base_path)
string += "flash_attn_enabled: {}\n".format(self.flash_attn_enabled)
string += "----------------------------------------\n"
return string
@@ -231,7 +237,8 @@ class TTS:
dict_s1 = torch.load(weights_path, map_location=self.configs.device)
config = dict_s1["config"]
self.configs.max_sec = config["data"]["max_sec"]
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False,
flash_attn_enabled=self.configs.flash_attn_enabled)
t2s_model.load_state_dict(dict_s1["weight"])
if self.configs.is_half:
t2s_model = t2s_model.half()