support sovits v2Pro v2ProPlus

support sovits v2Pro v2ProPlus
This commit is contained in:
RVC-Boss
2025-06-04 15:18:55 +08:00
committed by GitHub
parent 3f46359652
commit 0621259549
2 changed files with 57 additions and 18 deletions

View File

@@ -21,7 +21,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
3) computes spectrograms from audio files.
"""
def __init__(self, hparams, val=False):
def __init__(self, hparams, version=None,val=False):
exp_dir = hparams.exp_dir
self.path2 = "%s/2-name2text.txt" % exp_dir
self.path4 = "%s/4-cnhubert" % exp_dir
@@ -29,8 +29,14 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
assert os.path.exists(self.path2)
assert os.path.exists(self.path4)
assert os.path.exists(self.path5)
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
if self.is_v2Pro:
self.path7 = "%s/7-sv_cn" % exp_dir
assert os.path.exists(self.path7)
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
names5 = set(os.listdir(self.path5))
if self.is_v2Pro:
names6 = set([name[:-3] for name in list(os.listdir(self.path7))]) # 去除.pt后缀
self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
@@ -40,8 +46,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1]]
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
if self.is_v2Pro:
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5 & names6)
else:
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
tmp = self.audiopaths_sid_text
leng = len(tmp)
min_num = 100
@@ -109,14 +117,21 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False
if self.is_v2Pro:
sv_emb=torch.load("%s/%s.pt" % (self.path7, audiopath), map_location="cpu")
except:
traceback.print_exc()
spec = torch.zeros(1025, 100)
wav = torch.zeros(1, 100 * self.hop_length)
ssl = torch.zeros(1, 768, 100)
text = text[-1:]
if self.is_v2Pro:
sv_emb=torch.zeros(1,20480)
print("load audio or ssl error!!!!!!", audiopath)
return (ssl, spec, wav, text)
if self.is_v2Pro:
return (ssl, spec, wav, text,sv_emb)
else:
return (ssl, spec, wav, text)
def get_audio(self, filename):
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
@@ -177,8 +192,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
class TextAudioSpeakerCollate:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False):
def __init__(self, return_ids=False,version=None):
self.return_ids = return_ids
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
def __call__(self, batch):
"""Collate's training batch from normalized text, audio and speaker identities
@@ -211,6 +227,9 @@ class TextAudioSpeakerCollate:
ssl_padded.zero_()
text_padded.zero_()
if self.is_v2Pro:
sv_embs=torch.FloatTensor(len(batch),20480)
for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]
@@ -230,7 +249,12 @@ class TextAudioSpeakerCollate:
text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0)
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
if self.is_v2Pro:
sv_embs[i]=row[4]
if self.is_v2Pro:
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths,sv_embs
else:
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):