Merge pull request #140 from kozistr/feature/tune-rasampler

Use torch.repeat_interleave() to generate the repeated indices faster
This commit is contained in:
Hugo Touvron 2022-01-03 09:06:01 +01:00 committed by GitHub
commit 15fc7c4c96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,7 +13,7 @@ class RASampler(torch.utils.data.Sampler):
Heavily based on torch.utils.data.DistributedSampler
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
@ -22,11 +22,14 @@ class RASampler(torch.utils.data.Sampler):
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if num_repeats < 1:
raise ValueError("num_repeats should be greater than 0")
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_repeats = num_repeats
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
# self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
@ -36,14 +39,16 @@ class RASampler(torch.utils.data.Sampler):
if self.shuffle:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g)
else:
indices = list(range(len(self.dataset)))
indices = torch.arange(start=0, end=len(self.dataset))
# add extra samples to make it evenly divisible
indices = [ele for ele in indices for i in range(3)]
indices += indices[:(self.total_size - len(indices))]
indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0)
padding_size: int = self.total_size - len(indices)
if padding_size > 0:
indices = torch.cat([indices, indices[:padding_size]], dim=0)
assert len(indices) == self.total_size
# subsample