增加flash attention 选项: GPT_SoVITS/AR/models/t2s_lightning_module.py
增加flash attention 选项: GPT_SoVITS/AR/models/t2s_model.py 增加flash attention 选项: GPT_SoVITS/TTS_infer_pack/TTS.py 增加flash attention 选项: GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py 增加flash attention 选项: GPT_SoVITS/configs/tts_infer.yaml 增加flash attention 选项: GPT_SoVITS/inference_webui.py
This commit is contained in:
@@ -17,8 +17,8 @@ from time import time as ttime
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from my_utils import load_audio
|
||||
from module.mel_processing import spectrogram_torch
|
||||
from .text_segmentation_method import splits
|
||||
from .TextPreprocessor import TextPreprocessor
|
||||
from TTS_infer_pack.text_segmentation_method import splits
|
||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||
i18n = I18nAuto()
|
||||
|
||||
# configs/tts_infer.yaml
|
||||
@@ -30,6 +30,7 @@ default:
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
flash_attn_enabled: true
|
||||
|
||||
custom:
|
||||
device: cuda
|
||||
@@ -38,7 +39,7 @@ custom:
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
|
||||
flash_attn_enabled: true
|
||||
|
||||
|
||||
"""
|
||||
@@ -63,7 +64,8 @@ class TTS_Config:
|
||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
"flash_attn_enabled": True
|
||||
}
|
||||
self.configs:dict = configs.get("custom", self.default_configs)
|
||||
|
||||
@@ -73,6 +75,7 @@ class TTS_Config:
|
||||
self.vits_weights_path = self.configs.get("vits_weights_path")
|
||||
self.bert_base_path = self.configs.get("bert_base_path")
|
||||
self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path")
|
||||
self.flash_attn_enabled = self.configs.get("flash_attn_enabled")
|
||||
|
||||
|
||||
self.max_sec = None
|
||||
@@ -103,7 +106,8 @@ class TTS_Config:
|
||||
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
||||
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
"flash_attn_enabled": True
|
||||
},
|
||||
"custom": {
|
||||
"device": str(self.device),
|
||||
@@ -111,7 +115,8 @@ class TTS_Config:
|
||||
"t2s_weights_path": self.t2s_weights_path,
|
||||
"vits_weights_path": self.vits_weights_path,
|
||||
"bert_base_path": self.bert_base_path,
|
||||
"cnhuhbert_base_path": self.cnhuhbert_base_path
|
||||
"cnhuhbert_base_path": self.cnhuhbert_base_path,
|
||||
"flash_attn_enabled": self.flash_attn_enabled
|
||||
}
|
||||
}
|
||||
if configs_path is None:
|
||||
@@ -128,6 +133,7 @@ class TTS_Config:
|
||||
string += "t2s_weights_path: {}\n".format(self.t2s_weights_path)
|
||||
string += "vits_weights_path: {}\n".format(self.vits_weights_path)
|
||||
string += "cnhuhbert_base_path: {}\n".format(self.cnhuhbert_base_path)
|
||||
string += "flash_attn_enabled: {}\n".format(self.flash_attn_enabled)
|
||||
string += "----------------------------------------\n"
|
||||
return string
|
||||
|
||||
@@ -231,7 +237,8 @@ class TTS:
|
||||
dict_s1 = torch.load(weights_path, map_location=self.configs.device)
|
||||
config = dict_s1["config"]
|
||||
self.configs.max_sec = config["data"]["max_sec"]
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False,
|
||||
flash_attn_enabled=self.configs.flash_attn_enabled)
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
if self.configs.is_half:
|
||||
t2s_model = t2s_model.half()
|
||||
|
||||
Reference in New Issue
Block a user