Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -31,32 +31,15 @@ class MRTE(nn.Module):
text_enc = self.text_pre(text * text_mask)
if test != None:
if test == 0:
x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
elif test == 1:
x = ssl_enc + ge
elif test == 2:
x = (
self.cross_attention(
ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
)
+ ge
)
x = self.cross_attention(ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask) + ge
else:
raise ValueError("test should be 0,1,2")
else:
x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
x = self.c_post(x * ssl_mask)
return x
@@ -70,9 +53,7 @@ class SpeakerEncoder(torch.nn.Module):
model_embedding_size=256,
):
super(SpeakerEncoder, self).__init__()
self.lstm = nn.LSTM(
mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
)
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
self.relu = nn.ReLU()