feat: 添加导出 v3 的 script (#2208)
* feat: 添加导出 v3 的 script * Fix: 由于 export_torch_script_v3 的改动,v2 现在需要传入 top_k
This commit is contained in:
@@ -427,7 +427,7 @@ class T2SModel(nn.Module):
|
||||
self.top_k = int(raw_t2s.config["inference"]["top_k"])
|
||||
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
||||
|
||||
def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor):
|
||||
def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor,top_k:LongTensor):
|
||||
bert = torch.cat([ref_bert.T, text_bert.T], 1)
|
||||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||
bert = bert.unsqueeze(0)
|
||||
@@ -472,12 +472,13 @@ class T2SModel(nn.Module):
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
|
||||
idx = 0
|
||||
top_k = int(top_k)
|
||||
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
||||
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
logits = logits[:, :-1]
|
||||
samples = sample(logits, y, top_k=self.top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
||||
@@ -493,7 +494,7 @@ class T2SModel(nn.Module):
|
||||
if(idx<11):###至少预测出10个token不然不给停止(0.4s)
|
||||
logits = logits[:, :-1]
|
||||
|
||||
samples = sample(logits, y, top_k=self.top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
@@ -653,6 +654,8 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
torch._dynamo.mark_dynamic(ref_bert, 0)
|
||||
torch._dynamo.mark_dynamic(text_bert, 0)
|
||||
|
||||
top_k = torch.LongTensor([5]).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
gpt_sovits_export = torch.jit.trace(
|
||||
gpt_sovits,
|
||||
@@ -662,7 +665,8 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
ref_seq,
|
||||
text_seq,
|
||||
ref_bert,
|
||||
text_bert))
|
||||
text_bert,
|
||||
top_k))
|
||||
|
||||
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
||||
gpt_sovits_export.save(gpt_sovits_path)
|
||||
@@ -684,15 +688,26 @@ class GPT_SoVITS(nn.Module):
|
||||
self.t2s = t2s
|
||||
self.vits = vits
|
||||
|
||||
def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor, speed=1.0):
|
||||
def forward(
|
||||
self,
|
||||
ssl_content: torch.Tensor,
|
||||
ref_audio_sr: torch.Tensor,
|
||||
ref_seq: Tensor,
|
||||
text_seq: Tensor,
|
||||
ref_bert: Tensor,
|
||||
text_bert: Tensor,
|
||||
top_k: LongTensor,
|
||||
speed=1.0,
|
||||
):
|
||||
codes = self.vits.vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompts = prompt_semantic.unsqueeze(0)
|
||||
|
||||
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert)
|
||||
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed)
|
||||
return audio
|
||||
|
||||
|
||||
def test():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
||||
@@ -784,8 +799,10 @@ def test():
|
||||
print('text_bert:',text_bert.shape)
|
||||
text_bert=text_bert.to('cuda')
|
||||
|
||||
top_k = torch.LongTensor([5]).to('cuda')
|
||||
|
||||
with torch.no_grad():
|
||||
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert)
|
||||
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
|
||||
print('start write wav')
|
||||
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user