Add files via upload
This commit is contained in:
128
GPT_SoVITS/AR/models/t2s_lightning_module.py
Normal file
128
GPT_SoVITS/AR/models/t2s_lightning_module.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
|
||||
import os,sys
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import LightningModule
|
||||
from AR.models.t2s_model import Text2SemanticDecoder
|
||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||
from AR.modules.optim import ScaledAdam
|
||||
|
||||
|
||||
class Text2SemanticLightningModule(LightningModule):
|
||||
def __init__(self, config, output_dir,is_train=True):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = 3
|
||||
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
|
||||
pretrained_s1=config.get("pretrained_s1")
|
||||
if(pretrained_s1 and is_train):
|
||||
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||
print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"]))
|
||||
if is_train:
|
||||
self.automatic_optimization = False
|
||||
self.save_hyperparameters()
|
||||
self.eval_dir = output_dir / 'eval'
|
||||
self.eval_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def training_step(self, batch: Dict, batch_idx: int):
|
||||
|
||||
opt = self.optimizers()
|
||||
scheduler = self.lr_schedulers()
|
||||
loss, acc = self.model.forward(
|
||||
batch['phoneme_ids'], batch['phoneme_ids_len'],
|
||||
batch['semantic_ids'], batch['semantic_ids_len'],
|
||||
batch['bert_feature'])
|
||||
self.manual_backward(loss)
|
||||
if batch_idx > 0 and batch_idx % 4 == 0:
|
||||
opt.step()
|
||||
opt.zero_grad()
|
||||
scheduler.step()
|
||||
|
||||
self.log(
|
||||
"total_loss",
|
||||
loss,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
sync_dist=True)
|
||||
self.log(
|
||||
"lr",
|
||||
scheduler.get_last_lr()[0],
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
sync_dist=True)
|
||||
self.log(
|
||||
f"top_{self.top_k}_acc",
|
||||
acc,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
sync_dist=True)
|
||||
|
||||
def validation_step(self, batch: Dict, batch_idx: int):return
|
||||
# # get loss
|
||||
# loss, acc = self.model.forward(
|
||||
# batch['phoneme_ids'], batch['phoneme_ids_len'],
|
||||
# batch['semantic_ids'], batch['semantic_ids_len'],
|
||||
# batch['bert_feature']
|
||||
# )
|
||||
#
|
||||
# self.log(
|
||||
# "val_total_loss",
|
||||
# loss,
|
||||
# on_step=True,
|
||||
# on_epoch=True,
|
||||
# prog_bar=True,
|
||||
# sync_dist=True)
|
||||
# self.log(
|
||||
# f"val_top_{self.top_k}_acc",
|
||||
# acc,
|
||||
# on_step=True,
|
||||
# on_epoch=True,
|
||||
# prog_bar=True,
|
||||
# sync_dist=True)
|
||||
#
|
||||
# # get infer output
|
||||
# semantic_len = batch['semantic_ids'].size(1)
|
||||
# prompt_len = min(int(semantic_len * 0.5), 150)
|
||||
# prompt = batch['semantic_ids'][:, :prompt_len]
|
||||
# pred_semantic = self.model.infer(batch['phoneme_ids'],
|
||||
# batch['phoneme_ids_len'], prompt,
|
||||
# batch['bert_feature']
|
||||
# )
|
||||
# save_name = f'semantic_toks_{batch_idx}.pt'
|
||||
# save_path = os.path.join(self.eval_dir, save_name)
|
||||
# torch.save(pred_semantic.detach().cpu(), save_path)
|
||||
|
||||
def configure_optimizers(self):
|
||||
model_parameters = self.model.parameters()
|
||||
parameters_names = []
|
||||
parameters_names.append([
|
||||
name_param_pair[0]
|
||||
for name_param_pair in self.model.named_parameters()
|
||||
])
|
||||
lm_opt = ScaledAdam(
|
||||
model_parameters,
|
||||
lr=0.01,
|
||||
betas=(0.9, 0.95),
|
||||
clipping_scale=2.0,
|
||||
parameters_names=parameters_names,
|
||||
show_dominant_parameters=False,
|
||||
clipping_update_period=1000, )
|
||||
|
||||
return {
|
||||
"optimizer": lm_opt,
|
||||
"lr_scheduler": {
|
||||
"scheduler":
|
||||
WarmupCosineLRSchedule(
|
||||
lm_opt,
|
||||
init_lr=self.config['optimizer']['lr_init'],
|
||||
peak_lr=self.config['optimizer']['lr'],
|
||||
end_lr=self.config['optimizer']['lr_end'],
|
||||
warmup_steps=self.config['optimizer']['warmup_steps'],
|
||||
total_steps=self.config['optimizer']['decay_steps'])
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user