more code refactor

This commit is contained in:
Blaise
2024-01-16 17:14:18 +01:00
parent 0d92575115
commit 0d3d47f3c3
44 changed files with 4516 additions and 2623 deletions

View File

@@ -16,7 +16,7 @@ __all__ = [
"DistributedBucketSampler",
]
T_co = TypeVar('T_co', covariant=True)
T_co = TypeVar("T_co", covariant=True)
class DistributedBucketSampler(Sampler[T_co]):
@@ -28,28 +28,30 @@ class DistributedBucketSampler(Sampler[T_co]):
sort batches
"""
def __init__(self,
dataset: Dataset,
num_replicas: Optional[int]=None,
rank: Optional[int]=None,
shuffle: bool=True,
seed: int=0,
drop_last: bool=False,
batch_size: int=32) -> None:
def __init__(
self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
batch_size: int = 32,
) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0:
raise ValueError("Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1))
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1)
)
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
@@ -57,19 +59,20 @@ class DistributedBucketSampler(Sampler[T_co]):
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(
self.
dataset) % self.num_replicas != 0: # type: ignore[arg-type]
if (
self.drop_last and len(self.dataset) % self.num_replicas != 0
): # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) /
self.num_replicas # type: ignore[arg-type]
(len(self.dataset) - self.num_replicas)
/ self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(
len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
len(self.dataset) / self.num_replicas
) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
@@ -84,7 +87,7 @@ class DistributedBucketSampler(Sampler[T_co]):
id_with_lengths.sort(key=lambda x: x[1])
return id_with_lengths
def make_buckets(self, bucket_width: float=2.0):
def make_buckets(self, bucket_width: float = 2.0):
buckets = []
cur = []
max_sec = bucket_width
@@ -114,8 +117,8 @@ class DistributedBucketSampler(Sampler[T_co]):
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
batches = [
shuffled_bucket[b * grouped_batch_size:(b + 1) *
grouped_batch_size] for b in range(n_batch)
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
for b in range(n_batch)
]
shuffle(batches)
indices = list(itertools.chain(*batches))
@@ -129,15 +132,16 @@ class DistributedBucketSampler(Sampler[T_co]):
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size /
len(indices)))[:padding_size]
indices += (indices * math.ceil(padding_size / len(indices)))[
:padding_size
]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)