support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
This commit is contained in:
@@ -21,7 +21,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
3) computes spectrograms from audio files.
|
||||
"""
|
||||
|
||||
def __init__(self, hparams, val=False):
|
||||
def __init__(self, hparams, version=None,val=False):
|
||||
exp_dir = hparams.exp_dir
|
||||
self.path2 = "%s/2-name2text.txt" % exp_dir
|
||||
self.path4 = "%s/4-cnhubert" % exp_dir
|
||||
@@ -29,8 +29,14 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
assert os.path.exists(self.path2)
|
||||
assert os.path.exists(self.path4)
|
||||
assert os.path.exists(self.path5)
|
||||
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
|
||||
if self.is_v2Pro:
|
||||
self.path7 = "%s/7-sv_cn" % exp_dir
|
||||
assert os.path.exists(self.path7)
|
||||
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
|
||||
names5 = set(os.listdir(self.path5))
|
||||
if self.is_v2Pro:
|
||||
names6 = set([name[:-3] for name in list(os.listdir(self.path7))]) # 去除.pt后缀
|
||||
self.phoneme_data = {}
|
||||
with open(self.path2, "r", encoding="utf8") as f:
|
||||
lines = f.read().strip("\n").split("\n")
|
||||
@@ -40,8 +46,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
if len(tmp) != 4:
|
||||
continue
|
||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||
|
||||
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
|
||||
if self.is_v2Pro:
|
||||
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5 & names6)
|
||||
else:
|
||||
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
|
||||
tmp = self.audiopaths_sid_text
|
||||
leng = len(tmp)
|
||||
min_num = 100
|
||||
@@ -109,14 +117,21 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
typee = ssl.dtype
|
||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||
ssl.requires_grad = False
|
||||
if self.is_v2Pro:
|
||||
sv_emb=torch.load("%s/%s.pt" % (self.path7, audiopath), map_location="cpu")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
spec = torch.zeros(1025, 100)
|
||||
wav = torch.zeros(1, 100 * self.hop_length)
|
||||
ssl = torch.zeros(1, 768, 100)
|
||||
text = text[-1:]
|
||||
if self.is_v2Pro:
|
||||
sv_emb=torch.zeros(1,20480)
|
||||
print("load audio or ssl error!!!!!!", audiopath)
|
||||
return (ssl, spec, wav, text)
|
||||
if self.is_v2Pro:
|
||||
return (ssl, spec, wav, text,sv_emb)
|
||||
else:
|
||||
return (ssl, spec, wav, text)
|
||||
|
||||
def get_audio(self, filename):
|
||||
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
@@ -177,8 +192,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
class TextAudioSpeakerCollate:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
def __init__(self, return_ids=False):
|
||||
def __init__(self, return_ids=False,version=None):
|
||||
self.return_ids = return_ids
|
||||
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
|
||||
|
||||
def __call__(self, batch):
|
||||
"""Collate's training batch from normalized text, audio and speaker identities
|
||||
@@ -211,6 +227,9 @@ class TextAudioSpeakerCollate:
|
||||
ssl_padded.zero_()
|
||||
text_padded.zero_()
|
||||
|
||||
if self.is_v2Pro:
|
||||
sv_embs=torch.FloatTensor(len(batch),20480)
|
||||
|
||||
for i in range(len(ids_sorted_decreasing)):
|
||||
row = batch[ids_sorted_decreasing[i]]
|
||||
|
||||
@@ -230,7 +249,12 @@ class TextAudioSpeakerCollate:
|
||||
text_padded[i, : text.size(0)] = text
|
||||
text_lengths[i] = text.size(0)
|
||||
|
||||
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
|
||||
if self.is_v2Pro:
|
||||
sv_embs[i]=row[4]
|
||||
if self.is_v2Pro:
|
||||
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths,sv_embs
|
||||
else:
|
||||
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
|
||||
|
||||
|
||||
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
|
||||
Reference in New Issue
Block a user