mps support
This commit is contained in:
@@ -46,7 +46,7 @@ if os.path.exists(txt_path) == False:
|
||||
bert_dir = "%s/3-bert" % (opt_dir)
|
||||
os.makedirs(opt_dir, exist_ok=True)
|
||||
os.makedirs(bert_dir, exist_ok=True)
|
||||
device = "cuda:0"
|
||||
device = "cuda:0" if torch.cuda.is_available() else "mps"
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
|
||||
if is_half == True:
|
||||
|
||||
@@ -47,7 +47,7 @@ os.makedirs(wav32dir,exist_ok=True)
|
||||
|
||||
maxx=0.95
|
||||
alpha=0.5
|
||||
device="cuda:0"
|
||||
device="cuda:0" if torch.cuda.is_available() else "mps"
|
||||
model=cnhubert.get_model()
|
||||
# is_half=False
|
||||
if(is_half==True):
|
||||
|
||||
@@ -38,7 +38,7 @@ semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
|
||||
if os.path.exists(semantic_path) == False:
|
||||
os.makedirs(opt_dir, exist_ok=True)
|
||||
|
||||
device = "cuda:0"
|
||||
device = "cuda:0" if torch.cuda.is_available() else "mps"
|
||||
hps = utils.get_hparams_from_file(s2config_path)
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
|
||||
Reference in New Issue
Block a user