兼容了flash_attention的批量推理,并修复了一些bug GPT_SoVITS/AR/models/t2s_model.py
批量推理备份文件: GPT_SoVITS/AR/models/t2s_model_batch_only.py
This commit is contained in:
@@ -99,7 +99,8 @@ class T2SBlock:
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim)
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = F.linear(attn, self.out_w, self.out_b)
|
||||
|
||||
x = F.layer_norm(
|
||||
@@ -114,15 +115,15 @@ class T2SBlock:
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(self, x, k_cache, v_cache):
|
||||
def decode_next_token(self, x, k_cache, v_cache, attn_mask : torch.Tensor):
|
||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
k_cache = torch.cat([k_cache, k], dim=1)
|
||||
v_cache = torch.cat([v_cache, v], dim=1)
|
||||
kv_len = k_cache.shape[1]
|
||||
|
||||
|
||||
batch_size = q.shape[0]
|
||||
q_len = q.shape[1]
|
||||
kv_len = k_cache.shape[1]
|
||||
|
||||
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
|
||||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
@@ -131,7 +132,8 @@ class T2SBlock:
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size, -1, self.hidden_dim)
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = F.linear(attn, self.out_w, self.out_b)
|
||||
|
||||
x = F.layer_norm(
|
||||
@@ -164,10 +166,10 @@ class T2STransformer:
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(
|
||||
self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]
|
||||
self, x, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor], attn_mask : torch.Tensor
|
||||
):
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
|
||||
@@ -543,12 +545,16 @@ class Text2SemanticDecoder(nn.Module):
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
||||
x.device
|
||||
)
|
||||
|
||||
|
||||
y_list = [None]*y.shape[0]
|
||||
batch_idx_map = list(range(y.shape[0]))
|
||||
idx_list = [None]*y.shape[0]
|
||||
cache_y_emb = y_emb
|
||||
for idx in tqdm(range(1500)):
|
||||
if xy_attn_mask is not None:
|
||||
if idx == 0:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask)
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask)
|
||||
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
@@ -557,18 +563,51 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if idx == 0:
|
||||
xy_attn_mask = None
|
||||
logits = logits[:, :-1]
|
||||
|
||||
samples = sample(
|
||||
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
|
||||
)[0].unsqueeze(0)
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
|
||||
)[0]
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
|
||||
####### 移除batch中已经生成完毕的序列,进一步优化计算量
|
||||
reserved_idx_of_batch_for_y = None
|
||||
if (self.EOS in samples[:, 0]) or \
|
||||
(self.EOS in torch.argmax(logits, dim=-1)): ###如果生成到EOS,则停止
|
||||
l = samples[:, 0]==self.EOS
|
||||
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
|
||||
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
|
||||
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
||||
for i in removed_idx_of_batch_for_y:
|
||||
batch_index = batch_idx_map[i]
|
||||
idx_list[batch_index] = idx - 1
|
||||
y_list[batch_index] = y[i, :-1]
|
||||
|
||||
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
|
||||
|
||||
# 只保留batch中未生成完毕的序列
|
||||
if reserved_idx_of_batch_for_y is not None:
|
||||
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
||||
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if cache_y_emb is not None:
|
||||
cache_y_emb = torch.index_select(cache_y_emb, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if k_cache is not None :
|
||||
for i in range(len(k_cache)):
|
||||
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
|
||||
|
||||
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
|
||||
print("use early stop num:", early_stop_num)
|
||||
stop = True
|
||||
|
||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||
for i, batch_index in enumerate(batch_idx_map):
|
||||
batch_index = batch_idx_map[i]
|
||||
idx_list[batch_index] = idx
|
||||
y_list[batch_index] = y[i, :-1]
|
||||
|
||||
if not (None in idx_list):
|
||||
stop = True
|
||||
|
||||
if stop:
|
||||
if y.shape[1]==0:
|
||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
@@ -580,6 +619,11 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx]
|
||||
|
||||
if (None in idx_list):
|
||||
for i in range(x.shape[0]):
|
||||
if idx_list[i] is None:
|
||||
idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
|
||||
|
||||
if ref_free:
|
||||
return y[:, :-1], 0
|
||||
return y[:, :-1], idx - 1
|
||||
return y_list, [0]*x.shape[0]
|
||||
return y_list, idx_list
|
||||
Reference in New Issue
Block a user