mps support

This commit is contained in:
Wu Zichen
2024-01-24 19:37:47 +08:00
parent 8069264e64
commit 07a5339691
8 changed files with 70 additions and 33 deletions

View File

@@ -17,7 +17,7 @@ exp_root = "logs"
python_exec = sys.executable or "python"
if torch.cuda.is_available():
infer_device = "cuda"
elif torch.mps.is_available():
elif torch.backends.mps.is_available():
infer_device = "mps"
else:
infer_device = "cpu"