support gpt-sovits v4
support gpt-sovits v4
This commit is contained in:
@@ -27,12 +27,11 @@ from random import randint
|
||||
from module import commons
|
||||
from module.data_utils import (
|
||||
DistributedBucketSampler,
|
||||
)
|
||||
from module.data_utils import (
|
||||
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
|
||||
)
|
||||
from module.data_utils import (
|
||||
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
|
||||
TextAudioSpeakerCollateV3,
|
||||
TextAudioSpeakerLoaderV3,
|
||||
TextAudioSpeakerCollateV4,
|
||||
TextAudioSpeakerLoaderV4,
|
||||
|
||||
)
|
||||
from module.models import (
|
||||
SynthesizerTrnV3 as SynthesizerTrn,
|
||||
@@ -89,6 +88,8 @@ def run(rank, n_gpus, hps):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
TextAudioSpeakerLoader=TextAudioSpeakerLoaderV3 if hps.model.version=="v3"else TextAudioSpeakerLoaderV4
|
||||
TextAudioSpeakerCollate=TextAudioSpeakerCollateV3 if hps.model.version=="v3"else TextAudioSpeakerCollateV4
|
||||
train_dataset = TextAudioSpeakerLoader(hps.data) ########
|
||||
train_sampler = DistributedBucketSampler(
|
||||
train_dataset,
|
||||
@@ -364,7 +365,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank),
|
||||
epoch,
|
||||
global_step,
|
||||
hps,
|
||||
hps,cfm_version=hps.model.version,
|
||||
lora_rank=lora_rank,
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user