diff --git a/samplers.py b/samplers.py index d131867..dae91d9 100644 --- a/samplers.py +++ b/samplers.py @@ -22,6 +22,8 @@ class RASampler(torch.utils.data.Sampler): if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() + if num_repeats < 1: + raise ValueError("num_repeats should be greater than 0") self.dataset = dataset self.num_replicas = num_replicas self.rank = rank