support gpt-sovits v4
support gpt-sovits v4
This commit is contained in:
@@ -22,23 +22,24 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||
01:v2
|
||||
02:v3
|
||||
03:v3lora
|
||||
|
||||
04:v4lora
|
||||
|
||||
"""
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def my_save2(fea, path):
|
||||
def my_save2(fea, path,cfm_version):
|
||||
bio = BytesIO()
|
||||
torch.save(fea, bio)
|
||||
bio.seek(0)
|
||||
data = bio.getvalue()
|
||||
data = b"03" + data[2:] ###temp for v3lora only, todo
|
||||
byte=b"03" if cfm_version=="v3"else b"04"
|
||||
data = byte + data[2:]
|
||||
with open(path, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
|
||||
def savee(ckpt, name, epoch, steps, hps, lora_rank=None):
|
||||
def savee(ckpt, name, epoch, steps, hps, cfm_version=None,lora_rank=None):
|
||||
try:
|
||||
opt = OrderedDict()
|
||||
opt["weight"] = {}
|
||||
@@ -50,7 +51,7 @@ def savee(ckpt, name, epoch, steps, hps, lora_rank=None):
|
||||
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
|
||||
if lora_rank:
|
||||
opt["lora_rank"] = lora_rank
|
||||
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
||||
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name),cfm_version)
|
||||
else:
|
||||
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
||||
return "Success."
|
||||
@@ -63,11 +64,13 @@ head2version = {
|
||||
b"01": ["v2", "v2", False],
|
||||
b"02": ["v2", "v3", False],
|
||||
b"03": ["v2", "v3", True],
|
||||
b"04": ["v2", "v4", True],
|
||||
}
|
||||
hash_pretrained_dict = {
|
||||
"dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
|
||||
"43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
|
||||
"6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
|
||||
"4f26b9476d0c5033e04162c486074374": ["v2", "v4", False], # s2Gv4.pth#sovits_v4_pretrained
|
||||
}
|
||||
import hashlib
|
||||
|
||||
@@ -85,7 +88,7 @@ def get_sovits_version_from_path_fast(sovits_path):
|
||||
hash = get_hash_from_file(sovits_path)
|
||||
if hash in hash_pretrained_dict:
|
||||
return hash_pretrained_dict[hash]
|
||||
###2-new weights or old weights, by head
|
||||
###2-new weights, by head
|
||||
with open(sovits_path, "rb") as f:
|
||||
version = f.read(2)
|
||||
if version != b"PK":
|
||||
|
||||
Reference in New Issue
Block a user