mps support

This commit is contained in:
Wu Zichen
2024-01-24 19:37:47 +08:00
parent 8069264e64
commit 07a5339691
8 changed files with 70 additions and 33 deletions

View File

@@ -41,12 +41,13 @@ class DistributedBucketSampler(Sampler[T_co]):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
torch.cuda.set_device(rank)
rank = dist.get_rank() if torch.cuda.is_available() else 0
if torch.cuda.is_available():
torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0:
raise ValueError(
"Invalid rank {}, rank should be in the interval"