support gpt-sovits v4

support gpt-sovits v4
This commit is contained in:
RVC-Boss
2025-04-20 14:53:07 +08:00
committed by GitHub
parent e0c452f007
commit c6cb6b45f3
3 changed files with 132 additions and 74 deletions

View File

@@ -22,23 +22,24 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
01:v2
02:v3
03:v3lora
04:v4lora
"""
from io import BytesIO
def my_save2(fea, path):
def my_save2(fea, path,cfm_version):
bio = BytesIO()
torch.save(fea, bio)
bio.seek(0)
data = bio.getvalue()
data = b"03" + data[2:] ###temp for v3lora only, todo
byte=b"03" if cfm_version=="v3"else b"04"
data = byte + data[2:]
with open(path, "wb") as f:
f.write(data)
def savee(ckpt, name, epoch, steps, hps, lora_rank=None):
def savee(ckpt, name, epoch, steps, hps, cfm_version=None,lora_rank=None):
try:
opt = OrderedDict()
opt["weight"] = {}
@@ -50,7 +51,7 @@ def savee(ckpt, name, epoch, steps, hps, lora_rank=None):
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
if lora_rank:
opt["lora_rank"] = lora_rank
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name),cfm_version)
else:
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success."
@@ -63,11 +64,13 @@ head2version = {
b"01": ["v2", "v2", False],
b"02": ["v2", "v3", False],
b"03": ["v2", "v3", True],
b"04": ["v2", "v4", True],
}
hash_pretrained_dict = {
"dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
"43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
"6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
"4f26b9476d0c5033e04162c486074374": ["v2", "v4", False], # s2Gv4.pth#sovits_v4_pretrained
}
import hashlib
@@ -85,7 +88,7 @@ def get_sovits_version_from_path_fast(sovits_path):
hash = get_hash_from_file(sovits_path)
if hash in hash_pretrained_dict:
return hash_pretrained_dict[hash]
###2-new weights or old weights, by head
###2-new weights, by head
with open(sovits_path, "rb") as f:
version = f.read(2)
if version != b"PK":