增加flash attention 选项: GPT_SoVITS/AR/models/t2s_lightning_module.py

增加flash attention 选项:   GPT_SoVITS/AR/models/t2s_model.py
	增加flash attention 选项:   GPT_SoVITS/TTS_infer_pack/TTS.py
	增加flash attention 选项:   GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
	增加flash attention 选项:   GPT_SoVITS/configs/tts_infer.yaml
	增加flash attention 选项:   GPT_SoVITS/inference_webui.py
This commit is contained in:
chasonjiang
2024-03-10 14:07:58 +08:00
parent 2155091950
commit 174c4bbab3
6 changed files with 225 additions and 44 deletions

View File

@@ -13,11 +13,11 @@ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir, is_train=True):
def __init__(self, config, output_dir, is_train=True, flash_attn_enabled:bool = False):
super().__init__()
self.config = config
self.top_k = 3
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k,flash_attn_enabled=flash_attn_enabled)
pretrained_s1 = config.get("pretrained_s1")
if pretrained_s1 and is_train:
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))