优化tts_config代码逻辑 (#1538)

* 优化tts_config

* fix

* 优化报错提示

* 优化报错提示
This commit is contained in:
ChasonJiang
2024-08-29 00:33:07 +08:00
committed by GitHub
parent 7dac47ca95
commit f35f6e9b5e
2 changed files with 11 additions and 6 deletions

View File

@@ -213,6 +213,10 @@ class TTS_Config:
"cnhuhbert_base_path": self.cnhuhbert_base_path,
}
return self.config
def update_version(self, version:str)->None:
self.version = version
self.languages = self.v2_languages if self.version=="v2" else self.v1_languages
def __str__(self):
self.configs = self.update_configs()
@@ -300,13 +304,14 @@ class TTS:
def init_vits_weights(self, weights_path: str):
print(f"Loading VITS weights from {weights_path}")
self.configs.vits_weights_path = weights_path
self.configs.save_configs()
dict_s2 = torch.load(weights_path, map_location=self.configs.device)
hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
self.configs.version = "v1"
self.configs.update_version("v1")
else:
self.configs.version = "v2"
self.configs.update_version("v2")
self.configs.save_configs()
hps["model"]["version"] = self.configs.version
self.configs.filter_length = hps["data"]["filter_length"]
self.configs.segment_size = hps["train"]["segment_size"]