mps support
This commit is contained in:
@@ -41,12 +41,13 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
torch.cuda.set_device(rank)
|
||||
rank = dist.get_rank() if torch.cuda.is_available() else 0
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
if rank >= num_replicas or rank < 0:
|
||||
raise ValueError(
|
||||
"Invalid rank {}, rank should be in the interval"
|
||||
|
||||
Reference in New Issue
Block a user