diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 7912ddf..694d4a7 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -1,3 +1,4 @@ +from copy import deepcopy import math import os, sys import random @@ -50,18 +51,7 @@ custom: class TTS_Config: - def __init__(self, configs: Union[dict, str]): - configs_base_path:str = "GPT_SoVITS/configs/" - os.makedirs(configs_base_path, exist_ok=True) - self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml") - if isinstance(configs, str): - self.configs_path = configs - configs:dict = self._load_configs(configs) - - # assert isinstance(configs, dict) - self.default_configs:dict = configs.get("default", None) - if self.default_configs is None: - self.default_configs={ + default_configs={ "device": "cpu", "is_half": False, "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", @@ -70,15 +60,54 @@ class TTS_Config: "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", "flash_attn_enabled": True } - self.configs:dict = configs.get("custom", self.default_configs) + configs:dict = None + def __init__(self, configs: Union[dict, str]=None): - self.device = self.configs.get("device") - self.is_half = self.configs.get("is_half") - self.t2s_weights_path = self.configs.get("t2s_weights_path") - self.vits_weights_path = self.configs.get("vits_weights_path") - self.bert_base_path = self.configs.get("bert_base_path") - self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path") - self.flash_attn_enabled = self.configs.get("flash_attn_enabled") + # 设置默认配置文件路径 + configs_base_path:str = "GPT_SoVITS/configs/" + os.makedirs(configs_base_path, exist_ok=True) + self.configs_path:str = os.path.join(configs_base_path, "tts_infer.yaml") + + if configs in ["", None]: + if not os.path.exists(self.configs_path): + self.save_configs() + print(f"Create default config file at {self.configs_path}") + configs:dict = {"default": deepcopy(self.default_configs)} + + if isinstance(configs, str): + self.configs_path = configs + configs:dict = self._load_configs(self.configs_path) + + assert isinstance(configs, dict) + default_configs:dict = configs.get("default", None) + if default_configs is not None: + self.default_configs = default_configs + + self.configs:dict = configs.get("custom", deepcopy(self.default_configs)) + + + self.device = self.configs.get("device", torch.device("cpu")) + self.is_half = self.configs.get("is_half", False) + self.flash_attn_enabled = self.configs.get("flash_attn_enabled", True) + self.t2s_weights_path = self.configs.get("t2s_weights_path", None) + self.vits_weights_path = self.configs.get("vits_weights_path", None) + self.bert_base_path = self.configs.get("bert_base_path", None) + self.cnhuhbert_base_path = self.configs.get("cnhuhbert_base_path", None) + + + if (self.t2s_weights_path in [None, ""]) or (not os.path.exists(self.t2s_weights_path)): + self.t2s_weights_path = self.default_configs['t2s_weights_path'] + print(f"fall back to default t2s_weights_path: {self.t2s_weights_path}") + if (self.vits_weights_path in [None, ""]) or (not os.path.exists(self.vits_weights_path)): + self.vits_weights_path = self.default_configs['vits_weights_path'] + print(f"fall back to default vits_weights_path: {self.vits_weights_path}") + if (self.bert_base_path in [None, ""]) or (not os.path.exists(self.bert_base_path)): + self.bert_base_path = self.default_configs['bert_base_path'] + print(f"fall back to default bert_base_path: {self.bert_base_path}") + if (self.cnhuhbert_base_path in [None, ""]) or (not os.path.exists(self.cnhuhbert_base_path)): + self.cnhuhbert_base_path = self.default_configs['cnhuhbert_base_path'] + print(f"fall back to default cnhuhbert_base_path: {self.cnhuhbert_base_path}") + self.update_configs() self.max_sec = None @@ -92,7 +121,7 @@ class TTS_Config: self.n_speakers:int = 300 self.langauges:list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] - print(self) + # print(self) def _load_configs(self, configs_path: str)->dict: with open(configs_path, 'r') as f: @@ -102,24 +131,18 @@ class TTS_Config: def save_configs(self, configs_path:str=None)->None: configs={ - "default": { - "device": "cpu", - "is_half": False, - "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", - "vits_weights_path": "GPT_SoVITS/pretrained_models/s2G488k.pth", - "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", - "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", - "flash_attn_enabled": True - }, - "custom": self.update_configs() + "default":self.default_configs, } + if self.configs is not None: + configs["custom"] = self.update_configs() + if configs_path is None: configs_path = self.configs_path with open(configs_path, 'w') as f: yaml.dump(configs, f) def update_configs(self): - config = { + self.config = { "device" : str(self.device), "is_half" : self.is_half, "t2s_weights_path" : self.t2s_weights_path, @@ -128,7 +151,7 @@ class TTS_Config: "cnhuhbert_base_path": self.cnhuhbert_base_path, "flash_attn_enabled" : self.flash_attn_enabled } - return config + return self.config def __str__(self): self.configs = self.update_configs() @@ -137,6 +160,9 @@ class TTS_Config: string += f"{str(k).ljust(20)}: {str(v)}\n" string += "-" * 100 + '\n' return string + + def __repr__(self): + return self.__str__() class TTS: @@ -253,7 +279,7 @@ class TTS: enable: bool, whether to enable half precision. ''' - if self.configs.device == "cpu": + if self.configs.device == "cpu" and enable: print("Half precision is not supported on CPU.") return diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 2308b38..bc68031 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -80,6 +80,7 @@ if cnhubert_base_path is not None: if bert_path is not None: tts_config.bert_base_path = bert_path +print(tts_config) tts_pipline = TTS(tts_config) gpt_path = tts_config.t2s_weights_path sovits_path = tts_config.vits_weights_path