mps support

This commit is contained in:
Wu Zichen
2024-01-24 19:37:47 +08:00
parent 8069264e64
commit 07a5339691
8 changed files with 70 additions and 33 deletions

View File

@@ -46,7 +46,7 @@ if os.path.exists(txt_path) == False:
bert_dir = "%s/3-bert" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True)
os.makedirs(bert_dir, exist_ok=True)
device = "cuda:0"
device = "cuda:0" if torch.cuda.is_available() else "mps"
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
if is_half == True:

View File

@@ -47,7 +47,7 @@ os.makedirs(wav32dir,exist_ok=True)
maxx=0.95
alpha=0.5
device="cuda:0"
device="cuda:0" if torch.cuda.is_available() else "mps"
model=cnhubert.get_model()
# is_half=False
if(is_half==True):

View File

@@ -38,7 +38,7 @@ semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
if os.path.exists(semantic_path) == False:
os.makedirs(opt_dir, exist_ok=True)
device = "cuda:0"
device = "cuda:0" if torch.cuda.is_available() else "mps"
hps = utils.get_hparams_from_file(s2config_path)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,