添加了自定义修改随机数种子,方便复现结果。

This commit is contained in:
chasonjiang
2024-03-15 14:34:10 +08:00
parent b8ce03fd1b
commit a2f2a5f4a7
2 changed files with 25 additions and 11 deletions

View File

@@ -51,17 +51,23 @@ custom:
"""
# def set_seed(seed):
# random.seed(seed)
# os.environ['PYTHONHASHSEED'] = str(seed)
# np.random.seed(seed)
# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.enabled = True
# set_seed(1234)
def set_seed(seed:int):
seed = int(seed)
seed = seed if seed != -1 else random.randrange(1 << 32)
print(f"Set seed to {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
try:
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.enabled = True
except:
pass
return seed
class TTS_Config:
default_configs={
@@ -563,6 +569,7 @@ class TTS:
"return_fragment": False, # bool. step by step return the audio fragment.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
}
returns:
tulpe[int, np.ndarray]: sampling rate and audio data.
@@ -584,6 +591,9 @@ class TTS:
split_bucket = inputs.get("split_bucket", True)
return_fragment = inputs.get("return_fragment", False)
fragment_interval = inputs.get("fragment_interval", 0.3)
seed = inputs.get("seed", -1)
set_seed(seed)
if return_fragment:
# split_bucket = False