update: concatenate paddings when padding_size is over 0

pull/140/head
kozistr 2022-01-02 15:39:35 +09:00
parent 96ac034e60
commit 8f69a7ee85
1 changed files with 4 additions and 2 deletions

View File

@ -45,8 +45,10 @@ class RASampler(torch.utils.data.Sampler):
indices = torch.arange(start=0, end=len(self.dataset)) indices = torch.arange(start=0, end=len(self.dataset))
# add extra samples to make it evenly divisible # add extra samples to make it evenly divisible
indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist() indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0)
indices += indices[:(self.total_size - len(indices))] padding_size: int = self.total_size - len(indices)
if padding_size > 0:
indices += torch.cat([indices, indices[:padding_size]], dim=0)
assert len(indices) == self.total_size assert len(indices) == self.total_size
# subsample # subsample