Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix * ruff format --line-length 120 --target-version py39 * Change the link for G2PW Model * update pytorch version and colab
This commit is contained in:
@@ -1,17 +1,13 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from AR.modules.embedding_onnx import SinePositionalEmbedding
|
||||
from AR.modules.embedding_onnx import TokenEmbedding
|
||||
from AR.modules.transformer_onnx import LayerNorm
|
||||
from AR.modules.transformer_onnx import TransformerEncoder
|
||||
from AR.modules.transformer_onnx import TransformerEncoderLayer
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
|
||||
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
|
||||
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
default_config = {
|
||||
"embedding_dim": 512,
|
||||
"hidden_dim": 512,
|
||||
@@ -26,12 +22,13 @@ default_config = {
|
||||
|
||||
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
|
||||
|
||||
|
||||
def logits_to_probs(
|
||||
logits,
|
||||
previous_tokens = None,
|
||||
previous_tokens=None,
|
||||
temperature: float = 1.0,
|
||||
top_k = None,
|
||||
top_p = None,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
repetition_penalty: float = 1.0,
|
||||
):
|
||||
previous_tokens = previous_tokens.squeeze()
|
||||
@@ -39,19 +36,27 @@ def logits_to_probs(
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=0, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
torch.nn.functional.softmax(
|
||||
sorted_logits,
|
||||
dim=-1,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
||||
dim=0,
|
||||
index=sorted_indices,
|
||||
src=sorted_indices_to_remove,
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
@@ -67,7 +72,7 @@ def logits_to_probs(
|
||||
|
||||
|
||||
def multinomial_sample_one_no_sync(
|
||||
probs_sort
|
||||
probs_sort,
|
||||
): # Does multinomial sampling without a cuda synchronization
|
||||
q = torch.randn_like(probs_sort)
|
||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
@@ -79,7 +84,9 @@ def sample(
|
||||
**sampling_kwargs,
|
||||
):
|
||||
probs = logits_to_probs(
|
||||
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
|
||||
logits=logits,
|
||||
previous_tokens=previous_tokens,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
@@ -91,7 +98,7 @@ class OnnxEncoder(nn.Module):
|
||||
self.ar_text_embedding = ar_text_embedding
|
||||
self.bert_proj = bert_proj
|
||||
self.ar_text_position = ar_text_position
|
||||
|
||||
|
||||
def forward(self, x, bert_feature):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
@@ -99,8 +106,18 @@ class OnnxEncoder(nn.Module):
|
||||
|
||||
|
||||
class T2SFirstStageDecoder(nn.Module):
|
||||
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
|
||||
top_k, early_stop_num, num_layers):
|
||||
def __init__(
|
||||
self,
|
||||
ar_audio_embedding,
|
||||
ar_audio_position,
|
||||
h,
|
||||
ar_predict_layer,
|
||||
loss_fct,
|
||||
ar_accuracy_metric,
|
||||
top_k,
|
||||
early_stop_num,
|
||||
num_layers,
|
||||
):
|
||||
super().__init__()
|
||||
self.ar_audio_embedding = ar_audio_embedding
|
||||
self.ar_audio_position = ar_audio_position
|
||||
@@ -111,11 +128,11 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
self.top_k = top_k
|
||||
self.early_stop_num = early_stop_num
|
||||
self.num_layers = num_layers
|
||||
|
||||
|
||||
def forward(self, x, prompt):
|
||||
y = prompt
|
||||
x_example = x[:,:,0] * 0.0
|
||||
#N, 1, 512
|
||||
x_example = x[:, :, 0] * 0.0
|
||||
# N, 1, 512
|
||||
cache = {
|
||||
"all_stage": self.num_layers,
|
||||
"k": None,
|
||||
@@ -132,11 +149,15 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
|
||||
y_example = y_pos[:,:,0] * 0.0
|
||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool()
|
||||
y_example = y_pos[:, :, 0] * 0.0
|
||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
|
||||
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
|
||||
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
|
||||
torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
|
||||
torch.ones_like(
|
||||
y_example.transpose(0, 1),
|
||||
dtype=torch.int64,
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
y_attn_mask = y_attn_mask > 0
|
||||
|
||||
@@ -145,10 +166,16 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
|
||||
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
|
||||
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
|
||||
cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
|
||||
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
|
||||
cache["k"] = (
|
||||
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
||||
.unsqueeze(1)
|
||||
.repeat(self.num_layers, 1, 1, 1)
|
||||
)
|
||||
cache["v"] = (
|
||||
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
||||
.unsqueeze(1)
|
||||
.repeat(self.num_layers, 1, 1, 1)
|
||||
)
|
||||
|
||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
@@ -160,8 +187,18 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
|
||||
|
||||
class T2SStageDecoder(nn.Module):
|
||||
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
|
||||
top_k, early_stop_num, num_layers):
|
||||
def __init__(
|
||||
self,
|
||||
ar_audio_embedding,
|
||||
ar_audio_position,
|
||||
h,
|
||||
ar_predict_layer,
|
||||
loss_fct,
|
||||
ar_accuracy_metric,
|
||||
top_k,
|
||||
early_stop_num,
|
||||
num_layers,
|
||||
):
|
||||
super().__init__()
|
||||
self.ar_audio_embedding = ar_audio_embedding
|
||||
self.ar_audio_position = ar_audio_position
|
||||
@@ -184,14 +221,18 @@ class T2SStageDecoder(nn.Module):
|
||||
}
|
||||
|
||||
y_emb = torch.cat(
|
||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
||||
[
|
||||
cache["y_emb"],
|
||||
self.ar_audio_embedding(y[:, -1:]),
|
||||
],
|
||||
1,
|
||||
)
|
||||
cache["y_emb"] = y_emb
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
|
||||
xy_pos = y_pos[:, -1:]
|
||||
|
||||
y_example = y_pos[:,:,0] * 0.0
|
||||
|
||||
y_example = y_pos[:, :, 0] * 0.0
|
||||
|
||||
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
|
||||
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
|
||||
@@ -250,12 +291,28 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
def init_onnx(self):
|
||||
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
|
||||
self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
|
||||
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
|
||||
self.num_layers)
|
||||
self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
|
||||
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
|
||||
self.num_layers)
|
||||
self.first_stage_decoder = T2SFirstStageDecoder(
|
||||
self.ar_audio_embedding,
|
||||
self.ar_audio_position,
|
||||
self.h,
|
||||
self.ar_predict_layer,
|
||||
self.loss_fct,
|
||||
self.ar_accuracy_metric,
|
||||
self.top_k,
|
||||
self.early_stop_num,
|
||||
self.num_layers,
|
||||
)
|
||||
self.stage_decoder = T2SStageDecoder(
|
||||
self.ar_audio_embedding,
|
||||
self.ar_audio_position,
|
||||
self.h,
|
||||
self.ar_predict_layer,
|
||||
self.loss_fct,
|
||||
self.ar_accuracy_metric,
|
||||
self.top_k,
|
||||
self.early_stop_num,
|
||||
self.num_layers,
|
||||
)
|
||||
|
||||
def forward(self, x, prompts, bert_feature):
|
||||
early_stop_num = self.early_stop_num
|
||||
@@ -286,7 +343,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y = prompts
|
||||
prefix_len = y.shape[1]
|
||||
x_len = x.shape[1]
|
||||
x_example = x[:,:,0] * 0.0
|
||||
x_example = x[:, :, 0] * 0.0
|
||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
|
||||
x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
|
||||
|
||||
@@ -303,9 +360,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if cache["first_infer"] == 1:
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
else:
|
||||
y_emb = torch.cat(
|
||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
||||
)
|
||||
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
|
||||
cache["y_emb"] = y_emb
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
if cache["first_infer"] == 1:
|
||||
@@ -317,7 +372,8 @@ class Text2SemanticDecoder(nn.Module):
|
||||
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
|
||||
y_attn_mask = F.pad(
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
(x_len, 0), value=False
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
else:
|
||||
@@ -335,4 +391,4 @@ class Text2SemanticDecoder(nn.Module):
|
||||
break
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
cache["first_infer"] = 0
|
||||
return y, idx
|
||||
return y, idx
|
||||
|
||||
Reference in New Issue
Block a user