support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
This commit is contained in:
@@ -331,7 +331,7 @@ class VitsModel(nn.Module):
|
||||
def __init__(self, vits_path):
|
||||
super().__init__()
|
||||
# dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||
dict_s2 = torch.load(vits_path)
|
||||
dict_s2 = torch.load(vits_path, weights_only=False)
|
||||
self.hps = dict_s2["config"]
|
||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
@@ -645,7 +645,7 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
|
||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||||
dict_s1 = torch.load(gpt_path)
|
||||
dict_s1 = torch.load(gpt_path, weights_only=False)
|
||||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||
print("#### get_raw_t2s_model ####")
|
||||
print(raw_t2s.config)
|
||||
|
||||
Reference in New Issue
Block a user