mps support

This commit is contained in:
Wu Zichen
2024-01-24 19:37:47 +08:00
parent 8069264e64
commit 07a5339691
8 changed files with 70 additions and 33 deletions

View File

@@ -116,7 +116,7 @@ def main(args):
devices=-1,
benchmark=False,
fast_dev_run=False,
strategy = "auto" if torch.mps.is_available() else DDPStrategy(
strategy = "auto" if torch.backends.mps.is_available() else DDPStrategy(
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
), # mps 不支持多节点训练
precision=config["train"]["precision"],