忽略ffmpeg .gitignore
使t2s模型支持批量推理: GPT_SoVITS/AR/models/t2s_model.py
修复batch bug GPT_SoVITS/AR/models/utils.py
重构的tts infer GPT_SoVITS/TTS_infer_pack/TTS.py
文本预处理模块 GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
new file GPT_SoVITS/TTS_infer_pack/__init__.py
文本拆分方法模块 GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py
tts infer配置文件 GPT_SoVITS/configs/tts_infer.yaml
modified GPT_SoVITS/feature_extractor/cnhubert.py
modified GPT_SoVITS/inference_gui.py
重构的webui GPT_SoVITS/inference_webui.py
new file GPT_SoVITS/inference_webui_old.py
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -386,7 +385,9 @@ class Text2SemanticDecoder(nn.Module):
|
||||
x.device
|
||||
)
|
||||
|
||||
|
||||
y_list = [None]*y.shape[0]
|
||||
batch_idx_map = list(range(y.shape[0]))
|
||||
idx_list = [None]*y.shape[0]
|
||||
for idx in tqdm(range(1500)):
|
||||
|
||||
xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
|
||||
@@ -397,17 +398,45 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if(idx==0):###第一次跑不能EOS否则没有了
|
||||
logits = logits[:, :-1] ###刨除1024终止符号的概率
|
||||
samples = sample(
|
||||
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
|
||||
)[0].unsqueeze(0)
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
|
||||
)[0]
|
||||
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
|
||||
# print(samples.shape)#[1,1]#第一个1是bs
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
# 移除已经生成完毕的序列
|
||||
reserved_idx_of_batch_for_y = None
|
||||
if (self.EOS in torch.argmax(logits, dim=-1)) or \
|
||||
(self.EOS in samples[:, 0]): ###如果生成到EOS,则停止
|
||||
l = samples[:, 0]==self.EOS
|
||||
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
|
||||
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
|
||||
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
||||
for i in removed_idx_of_batch_for_y:
|
||||
batch_index = batch_idx_map[i]
|
||||
idx_list[batch_index] = idx - 1
|
||||
y_list[batch_index] = y[i, :-1]
|
||||
|
||||
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
|
||||
|
||||
# 只保留未生成完毕的序列
|
||||
if reserved_idx_of_batch_for_y is not None:
|
||||
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
||||
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if cache["y_emb"] is not None:
|
||||
cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if cache["k"] is not None:
|
||||
for i in range(self.num_layers):
|
||||
# 因为kv转置了,所以batch dim是1
|
||||
cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y)
|
||||
cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y)
|
||||
|
||||
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
print("use early stop num:", early_stop_num)
|
||||
stop = True
|
||||
|
||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||
|
||||
if not (None in idx_list):
|
||||
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
|
||||
stop = True
|
||||
if stop:
|
||||
@@ -443,6 +472,12 @@ class Text2SemanticDecoder(nn.Module):
|
||||
xy_attn_mask = torch.zeros(
|
||||
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
|
||||
)
|
||||
|
||||
if (None in idx_list):
|
||||
for i in range(x.shape[0]):
|
||||
if idx_list[i] is None:
|
||||
idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
|
||||
|
||||
if ref_free:
|
||||
return y[:, :-1], 0
|
||||
return y[:, :-1], idx-1
|
||||
return y_list, [0]*x.shape[0]
|
||||
return y_list, idx_list
|
||||
|
||||
@@ -115,17 +115,17 @@ def logits_to_probs(
|
||||
top_p: Optional[int] = None,
|
||||
repetition_penalty: float = 1.0,
|
||||
):
|
||||
if previous_tokens is not None:
|
||||
previous_tokens = previous_tokens.squeeze()
|
||||
# if previous_tokens is not None:
|
||||
# previous_tokens = previous_tokens.squeeze()
|
||||
# print(logits.shape,previous_tokens.shape)
|
||||
# pdb.set_trace()
|
||||
if previous_tokens is not None and repetition_penalty != 1.0:
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=0, index=previous_tokens)
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
)
|
||||
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
@@ -133,9 +133,9 @@ def logits_to_probs(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[0] = False # keep at least one option
|
||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user