mirror of
https://github.com/facebookresearch/deit.git
synced 2025-06-03 14:52:20 +08:00
Merge pull request #140 from kozistr/feature/tune-rasampler
Use torch.repeat_interleave() to generate the repeated indices faster
This commit is contained in:
commit
15fc7c4c96
19
samplers.py
19
samplers.py
@ -13,7 +13,7 @@ class RASampler(torch.utils.data.Sampler):
|
||||
Heavily based on torch.utils.data.DistributedSampler
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
@ -22,11 +22,14 @@ class RASampler(torch.utils.data.Sampler):
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
if num_repeats < 1:
|
||||
raise ValueError("num_repeats should be greater than 0")
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.num_repeats = num_repeats
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
# self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
|
||||
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
||||
@ -36,14 +39,16 @@ class RASampler(torch.utils.data.Sampler):
|
||||
if self.shuffle:
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(len(self.dataset), generator=g)
|
||||
else:
|
||||
indices = list(range(len(self.dataset)))
|
||||
indices = torch.arange(start=0, end=len(self.dataset))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices = [ele for ele in indices for i in range(3)]
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0)
|
||||
padding_size: int = self.total_size - len(indices)
|
||||
if padding_size > 0:
|
||||
indices = torch.cat([indices, indices[:padding_size]], dim=0)
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
|
Loading…
x
Reference in New Issue
Block a user