Add files via upload
This commit is contained in:
0
GPT_SoVITS/AR/data/__init__.py
Normal file
0
GPT_SoVITS/AR/data/__init__.py
Normal file
157
GPT_SoVITS/AR/data/bucket_sampler.py
Normal file
157
GPT_SoVITS/AR/data/bucket_sampler.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/bucketsampler.py
|
||||
import itertools
|
||||
import math
|
||||
import random
|
||||
from random import shuffle
|
||||
from typing import Iterator
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
__all__ = [
|
||||
"DistributedBucketSampler",
|
||||
]
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
|
||||
|
||||
class DistributedBucketSampler(Sampler[T_co]):
|
||||
r"""
|
||||
sort the dataset wrt. input length
|
||||
divide samples into buckets
|
||||
sort within buckets
|
||||
divide buckets into batches
|
||||
sort batches
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset: Dataset,
|
||||
num_replicas: Optional[int]=None,
|
||||
rank: Optional[int]=None,
|
||||
shuffle: bool=True,
|
||||
seed: int=0,
|
||||
drop_last: bool=False,
|
||||
batch_size: int=32) -> None:
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError(
|
||||
"Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError(
|
||||
"Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
torch.cuda.set_device(rank)
|
||||
if rank >= num_replicas or rank < 0:
|
||||
raise ValueError("Invalid rank {}, rank should be in the interval"
|
||||
" [0, {}]".format(rank, num_replicas - 1))
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.drop_last = drop_last
|
||||
# If the dataset length is evenly divisible by # of replicas, then there
|
||||
# is no need to drop any data, since the dataset will be split equally.
|
||||
if self.drop_last and len(
|
||||
self.
|
||||
dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
||||
# Split to nearest available length that is evenly divisible.
|
||||
# This is to ensure each rank receives the same amount of data when
|
||||
# using this Sampler.
|
||||
self.num_samples = math.ceil(
|
||||
(len(self.dataset) - self.num_replicas) /
|
||||
self.num_replicas # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
self.num_samples = math.ceil(
|
||||
len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.id_with_length = self._get_sample_lengths()
|
||||
self.id_buckets = self.make_buckets(bucket_width=2.0)
|
||||
|
||||
def _get_sample_lengths(self):
|
||||
id_with_lengths = []
|
||||
for i in range(len(self.dataset)):
|
||||
id_with_lengths.append((i, self.dataset.get_sample_length(i)))
|
||||
id_with_lengths.sort(key=lambda x: x[1])
|
||||
return id_with_lengths
|
||||
|
||||
def make_buckets(self, bucket_width: float=2.0):
|
||||
buckets = []
|
||||
cur = []
|
||||
max_sec = bucket_width
|
||||
for id, sec in self.id_with_length:
|
||||
if sec < max_sec:
|
||||
cur.append(id)
|
||||
else:
|
||||
buckets.append(cur)
|
||||
cur = [id]
|
||||
max_sec += bucket_width
|
||||
if len(cur) > 0:
|
||||
buckets.append(cur)
|
||||
return buckets
|
||||
|
||||
def __iter__(self) -> Iterator[T_co]:
|
||||
if self.shuffle:
|
||||
# deterministically shuffle based on epoch and seed
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
random.seed(self.epoch + self.seed)
|
||||
shuffled_bucket = []
|
||||
for buc in self.id_buckets:
|
||||
buc_copy = buc.copy()
|
||||
shuffle(buc_copy)
|
||||
shuffled_bucket.append(buc_copy)
|
||||
grouped_batch_size = self.batch_size * self.num_replicas
|
||||
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
|
||||
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
|
||||
batches = [
|
||||
shuffled_bucket[b * grouped_batch_size:(b + 1) *
|
||||
grouped_batch_size] for b in range(n_batch)
|
||||
]
|
||||
shuffle(batches)
|
||||
indices = list(itertools.chain(*batches))
|
||||
else:
|
||||
# type: ignore[arg-type]
|
||||
indices = list(range(len(self.dataset)))
|
||||
|
||||
if not self.drop_last:
|
||||
# add extra samples to make it evenly divisible
|
||||
padding_size = self.total_size - len(indices)
|
||||
if padding_size <= len(indices):
|
||||
indices += indices[:padding_size]
|
||||
else:
|
||||
indices += (indices * math.ceil(padding_size /
|
||||
len(indices)))[:padding_size]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[:self.total_size]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
r"""
|
||||
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
|
||||
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
||||
sampler will yield the same ordering.
|
||||
|
||||
Args:
|
||||
epoch (int): Epoch number.
|
||||
"""
|
||||
self.epoch = epoch
|
||||
66
GPT_SoVITS/AR/data/data_module.py
Normal file
66
GPT_SoVITS/AR/data/data_module.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from AR.data.bucket_sampler import DistributedBucketSampler
|
||||
from AR.data.dataset import Text2SemanticDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class Text2SemanticDataModule(LightningDataModule):
|
||||
def __init__(self, config, train_semantic_path, train_phoneme_path,dev_semantic_path=None, dev_phoneme_path=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.train_semantic_path = train_semantic_path
|
||||
self.train_phoneme_path = train_phoneme_path
|
||||
self.dev_semantic_path = dev_semantic_path
|
||||
self.dev_phoneme_path = dev_phoneme_path
|
||||
self.num_workers = self.config['data']['num_workers']
|
||||
|
||||
def prepare_data(self):
|
||||
pass
|
||||
|
||||
def setup(self, stage=None, output_logs=False):
|
||||
self._train_dataset = Text2SemanticDataset(
|
||||
phoneme_path=self.train_phoneme_path,
|
||||
semantic_path=self.train_semantic_path,
|
||||
max_sec=self.config['data']['max_sec'],
|
||||
pad_val=self.config['data']['pad_val'])
|
||||
self._dev_dataset = self._train_dataset
|
||||
# self._dev_dataset = Text2SemanticDataset(
|
||||
# phoneme_path=self.dev_phoneme_path,
|
||||
# semantic_path=self.dev_semantic_path,
|
||||
# max_sample=self.config['data']['max_eval_sample'],
|
||||
# max_sec=self.config['data']['max_sec'],
|
||||
# pad_val=self.config['data']['pad_val'])
|
||||
|
||||
def train_dataloader(self):
|
||||
batch_size = self.config['train']['batch_size']
|
||||
sampler = DistributedBucketSampler(
|
||||
self._train_dataset, batch_size=batch_size)
|
||||
return DataLoader(
|
||||
self._train_dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
collate_fn=self._train_dataset.collate,
|
||||
num_workers=self.num_workers,
|
||||
persistent_workers=True,
|
||||
prefetch_factor=16
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(
|
||||
self._dev_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=self._train_dataset.collate,
|
||||
num_workers=max(self.num_workers,12),
|
||||
persistent_workers=True,
|
||||
prefetch_factor=16
|
||||
)
|
||||
|
||||
# 这个会使用到嘛?
|
||||
def test_dataloader(self):
|
||||
return DataLoader(
|
||||
self._dev_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=self._train_dataset.collate)
|
||||
302
GPT_SoVITS/AR/data/dataset.py
Normal file
302
GPT_SoVITS/AR/data/dataset.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
|
||||
import pdb
|
||||
import sys
|
||||
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
|
||||
import traceback,os
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch,json
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from text import cleaned_text_to_sequence
|
||||
# from config import exp_dir
|
||||
|
||||
def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
|
||||
seq = sequences[0]
|
||||
ndim = seq.ndim
|
||||
if axis < 0:
|
||||
axis += ndim
|
||||
dtype = seq.dtype
|
||||
pad_value = dtype.type(pad_value)
|
||||
seq_lengths = [seq.shape[axis] for seq in sequences]
|
||||
max_length = np.max(seq_lengths)
|
||||
|
||||
padded_sequences = []
|
||||
for seq, length in zip(sequences, seq_lengths):
|
||||
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (
|
||||
ndim - axis - 1)
|
||||
padded_seq = np.pad(
|
||||
seq, padding, mode='constant', constant_values=pad_value)
|
||||
padded_sequences.append(padded_seq)
|
||||
batch = np.stack(padded_sequences)
|
||||
return batch
|
||||
|
||||
class Text2SemanticDataset(Dataset):
|
||||
"""dataset class for text tokens to semantic model training."""
|
||||
|
||||
def __init__(self,
|
||||
phoneme_path: str,
|
||||
semantic_path: str,
|
||||
max_sample: int = None,
|
||||
max_sec: int = 100,
|
||||
pad_val: int = 1024,
|
||||
# min value of phoneme/sec
|
||||
min_ps_ratio: int = 3,
|
||||
# max value of phoneme/sec
|
||||
max_ps_ratio: int = 25) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.semantic_data = pd.read_csv(semantic_path, delimiter='\t', encoding="utf-8")
|
||||
# get dict
|
||||
self.path2=phoneme_path#"%s/2-name2text.txt"%exp_dir#phoneme_path
|
||||
self.path3="%s/3-bert"%(os.path.basename(phoneme_path))#"%s/3-bert"%exp_dir#bert_dir
|
||||
self.path6=semantic_path#"%s/6-name2semantic.tsv"%exp_dir#semantic_path
|
||||
assert os.path.exists(self.path2)
|
||||
assert os.path.exists(self.path6)
|
||||
self.phoneme_data={}
|
||||
with open(self.path2,"r",encoding="utf8")as f:
|
||||
lines=f.read().strip("\n").split("\n")
|
||||
|
||||
for line in lines:
|
||||
tmp=line.split("\t")
|
||||
if(len(tmp)!=4):continue
|
||||
self.phoneme_data[tmp[0]]=[tmp[1],tmp[2],tmp[3]]
|
||||
|
||||
# self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
|
||||
# pad for semantic tokens
|
||||
self.PAD: int = pad_val
|
||||
# self.hz = 25
|
||||
# with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
|
||||
# data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
|
||||
# self.hz=int(data[:-2])#
|
||||
self.hz=int(os.environ.get("hz","25hz")[:-2])
|
||||
|
||||
# max seconds of semantic token
|
||||
self.max_sec = max_sec
|
||||
self.min_ps_ratio = min_ps_ratio
|
||||
self.max_ps_ratio = max_ps_ratio
|
||||
|
||||
if max_sample is not None:
|
||||
self.semantic_data = self.semantic_data[:max_sample]
|
||||
|
||||
# {idx: (semantic, phoneme)}
|
||||
# semantic list, phoneme list
|
||||
self.semantic_phoneme = []
|
||||
self.item_names = []
|
||||
|
||||
self.inited = False
|
||||
|
||||
if not self.inited:
|
||||
# 调用初始化函数
|
||||
self.init_batch()
|
||||
self.inited = True
|
||||
del self.semantic_data
|
||||
del self.phoneme_data
|
||||
# self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
|
||||
# self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
|
||||
|
||||
|
||||
def init_batch(self):
|
||||
semantic_data_len = len(self.semantic_data)
|
||||
phoneme_data_len = len(self.phoneme_data.keys())
|
||||
print("semantic_data_len:", semantic_data_len)
|
||||
print("phoneme_data_len:", phoneme_data_len)
|
||||
idx = 0
|
||||
num_not_in = 0
|
||||
num_deleted_bigger = 0
|
||||
num_deleted_ps = 0
|
||||
for i in range(semantic_data_len):
|
||||
# 先依次遍历
|
||||
# get str
|
||||
item_name = self.semantic_data['item_name'][i]
|
||||
# print(self.phoneme_data)
|
||||
try:
|
||||
phoneme, word2ph, text = self.phoneme_data[item_name]
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
# print(f"{item_name} not in self.phoneme_data !")
|
||||
num_not_in += 1
|
||||
continue
|
||||
|
||||
semantic_str = self.semantic_data['semantic_audio'][i]
|
||||
# get token list
|
||||
semantic_ids = [int(idx) for idx in semantic_str.split(' ')]
|
||||
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
|
||||
# 过滤掉太长的样本
|
||||
if len(semantic_ids) > self.max_sec * self.hz:#########1###根据token个数推测总时长过滤时长60s(config里)#40*25=1k
|
||||
num_deleted_bigger += 1
|
||||
continue
|
||||
# (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
|
||||
phoneme = phoneme.split(' ')
|
||||
|
||||
try:
|
||||
phoneme_ids = cleaned_text_to_sequence(phoneme)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
# print(f"{item_name} not in self.phoneme_data !")
|
||||
num_not_in += 1
|
||||
continue
|
||||
# if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
|
||||
if len(phoneme_ids) >self.max_sec * self.hz/2.5:###########2:改为恒定限制为semantic/2.5就行
|
||||
num_deleted_ps += 1
|
||||
continue
|
||||
# if len(semantic_ids) > 1000:###########3
|
||||
# num_deleted_bigger += 1
|
||||
# continue
|
||||
|
||||
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
|
||||
|
||||
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio:##########4#3~25#每秒多少个phone
|
||||
num_deleted_ps += 1
|
||||
# print(item_name)
|
||||
continue
|
||||
|
||||
self.semantic_phoneme.append((semantic_ids, phoneme_ids))
|
||||
idx += 1
|
||||
self.item_names.append(item_name)
|
||||
|
||||
min_num=100#20直接不补#30补了也不存ckpt
|
||||
leng =len(self.semantic_phoneme)
|
||||
if(leng<min_num):
|
||||
tmp1=self.semantic_phoneme
|
||||
tmp2=self.item_names
|
||||
self.semantic_phoneme=[]
|
||||
self.item_names=[]
|
||||
for _ in range(max(2,int(min_num/leng))):
|
||||
self.semantic_phoneme+=tmp1
|
||||
self.item_names+=tmp2
|
||||
if num_not_in > 0:
|
||||
print(f"there are {num_not_in} semantic datas not in phoneme datas")
|
||||
if num_deleted_bigger > 0:
|
||||
print(
|
||||
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
|
||||
)
|
||||
if num_deleted_ps > 0:
|
||||
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
|
||||
print(
|
||||
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
|
||||
)
|
||||
'''
|
||||
there are 31 semantic datas not in phoneme datas
|
||||
deleted 34 audios who's duration are bigger than 54 seconds
|
||||
deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
|
||||
dataset.__len__(): 366463
|
||||
|
||||
'''
|
||||
# 345410 for LibriTTS
|
||||
print("dataset.__len__():", self.__len__())
|
||||
|
||||
def __get_item_names__(self) -> List[str]:
|
||||
return self.item_names
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.semantic_phoneme)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict:
|
||||
semantic_ids, phoneme_ids = self.semantic_phoneme[idx]
|
||||
item_name = self.item_names[idx]
|
||||
phoneme_ids_len = len(phoneme_ids)
|
||||
# semantic tokens target
|
||||
semantic_ids_len = len(semantic_ids)
|
||||
|
||||
flag=0
|
||||
path_bert = "%s/%s.pt" % (self.path3, item_name)
|
||||
if(os.path.exists(path_bert)==True):bert_feature = torch.load(path_bert,map_location="cpu")
|
||||
else:flag=1
|
||||
if(flag==1):
|
||||
# bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
|
||||
bert_feature=None
|
||||
else:
|
||||
assert bert_feature.shape[-1] == len(phoneme_ids)
|
||||
return {
|
||||
'idx': idx,
|
||||
'phoneme_ids': phoneme_ids,
|
||||
'phoneme_ids_len': phoneme_ids_len,
|
||||
'semantic_ids': semantic_ids,
|
||||
'semantic_ids_len': semantic_ids_len,
|
||||
'bert_feature': bert_feature,
|
||||
}
|
||||
|
||||
def get_sample_length(self, idx: int):
|
||||
semantic_ids = self.semantic_phoneme[idx][0]
|
||||
sec = 1.0 * len(semantic_ids) / self.hz
|
||||
return sec
|
||||
|
||||
def collate(self, examples: List[Dict]) -> Dict:
|
||||
sample_index: List[int] = []
|
||||
phoneme_ids: List[torch.Tensor] = []
|
||||
phoneme_ids_lens: List[int] = []
|
||||
semantic_ids: List[torch.Tensor] = []
|
||||
semantic_ids_lens: List[int] = []
|
||||
# return
|
||||
|
||||
|
||||
for item in examples:
|
||||
sample_index.append(item["idx"])
|
||||
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
|
||||
semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
|
||||
phoneme_ids_lens.append(item["phoneme_ids_len"])
|
||||
semantic_ids_lens.append(item["semantic_ids_len"])
|
||||
|
||||
# pad 0
|
||||
phoneme_ids = batch_sequences(phoneme_ids)
|
||||
semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
|
||||
|
||||
# # convert each batch to torch.tensor
|
||||
phoneme_ids = torch.tensor(phoneme_ids)
|
||||
semantic_ids = torch.tensor(semantic_ids)
|
||||
phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
|
||||
semantic_ids_lens = torch.tensor(semantic_ids_lens)
|
||||
bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
|
||||
bert_padded.zero_()
|
||||
|
||||
for idx, item in enumerate(examples):
|
||||
bert = item['bert_feature']
|
||||
if(bert!=None):
|
||||
bert_padded[idx, :, :bert.shape[-1]] = bert
|
||||
|
||||
return {
|
||||
# List[int]
|
||||
"ids": sample_index,
|
||||
# torch.Tensor (B, max_phoneme_length)
|
||||
"phoneme_ids": phoneme_ids,
|
||||
# torch.Tensor (B)
|
||||
"phoneme_ids_len": phoneme_ids_lens,
|
||||
# torch.Tensor (B, max_semantic_ids_length)
|
||||
"semantic_ids": semantic_ids,
|
||||
# torch.Tensor (B)
|
||||
"semantic_ids_len": semantic_ids_lens,
|
||||
# torch.Tensor (B, 1024, max_phoneme_length)
|
||||
"bert_feature": bert_padded,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
root_dir = '/data/docker/liujing04/gpt-vits/prepare/dump_mix/'
|
||||
dataset = Text2SemanticDataset(
|
||||
phoneme_path=root_dir + 'phoneme_train.npy',
|
||||
semantic_path=root_dir + 'semantic_train.tsv')
|
||||
|
||||
batch_size = 12
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate,
|
||||
shuffle=False)
|
||||
for i, batch in enumerate(dataloader):
|
||||
if(i%1000==0):print(i)
|
||||
# if i == 0:
|
||||
# print('batch["ids"]:', batch["ids"])
|
||||
# print('batch["phoneme_ids"]:', batch["phoneme_ids"],
|
||||
# batch["phoneme_ids"].shape)
|
||||
# print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
|
||||
# batch["phoneme_ids_len"].shape)
|
||||
# print('batch["semantic_ids"]:', batch["semantic_ids"],
|
||||
# batch["semantic_ids"].shape)
|
||||
# print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
|
||||
# batch["semantic_ids_len"].shape)
|
||||
Reference in New Issue
Block a user