more code refactor
This commit is contained in:
@@ -16,7 +16,7 @@ __all__ = [
|
||||
"DistributedBucketSampler",
|
||||
]
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
|
||||
|
||||
class DistributedBucketSampler(Sampler[T_co]):
|
||||
@@ -28,28 +28,30 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
sort batches
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset: Dataset,
|
||||
num_replicas: Optional[int]=None,
|
||||
rank: Optional[int]=None,
|
||||
shuffle: bool=True,
|
||||
seed: int=0,
|
||||
drop_last: bool=False,
|
||||
batch_size: int=32) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
batch_size: int = 32,
|
||||
) -> None:
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError(
|
||||
"Requires distributed package to be available")
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError(
|
||||
"Requires distributed package to be available")
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
torch.cuda.set_device(rank)
|
||||
if rank >= num_replicas or rank < 0:
|
||||
raise ValueError("Invalid rank {}, rank should be in the interval"
|
||||
" [0, {}]".format(rank, num_replicas - 1))
|
||||
raise ValueError(
|
||||
"Invalid rank {}, rank should be in the interval"
|
||||
" [0, {}]".format(rank, num_replicas - 1)
|
||||
)
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
@@ -57,19 +59,20 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
self.drop_last = drop_last
|
||||
# If the dataset length is evenly divisible by # of replicas, then there
|
||||
# is no need to drop any data, since the dataset will be split equally.
|
||||
if self.drop_last and len(
|
||||
self.
|
||||
dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
||||
if (
|
||||
self.drop_last and len(self.dataset) % self.num_replicas != 0
|
||||
): # type: ignore[arg-type]
|
||||
# Split to nearest available length that is evenly divisible.
|
||||
# This is to ensure each rank receives the same amount of data when
|
||||
# using this Sampler.
|
||||
self.num_samples = math.ceil(
|
||||
(len(self.dataset) - self.num_replicas) /
|
||||
self.num_replicas # type: ignore[arg-type]
|
||||
(len(self.dataset) - self.num_replicas)
|
||||
/ self.num_replicas # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
self.num_samples = math.ceil(
|
||||
len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
|
||||
len(self.dataset) / self.num_replicas
|
||||
) # type: ignore[arg-type]
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
@@ -84,7 +87,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
id_with_lengths.sort(key=lambda x: x[1])
|
||||
return id_with_lengths
|
||||
|
||||
def make_buckets(self, bucket_width: float=2.0):
|
||||
def make_buckets(self, bucket_width: float = 2.0):
|
||||
buckets = []
|
||||
cur = []
|
||||
max_sec = bucket_width
|
||||
@@ -114,8 +117,8 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
|
||||
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
|
||||
batches = [
|
||||
shuffled_bucket[b * grouped_batch_size:(b + 1) *
|
||||
grouped_batch_size] for b in range(n_batch)
|
||||
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
|
||||
for b in range(n_batch)
|
||||
]
|
||||
shuffle(batches)
|
||||
indices = list(itertools.chain(*batches))
|
||||
@@ -129,15 +132,16 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
if padding_size <= len(indices):
|
||||
indices += indices[:padding_size]
|
||||
else:
|
||||
indices += (indices * math.ceil(padding_size /
|
||||
len(indices)))[:padding_size]
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[
|
||||
:padding_size
|
||||
]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[:self.total_size]
|
||||
indices = indices[: self.total_size]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
Reference in New Issue
Block a user