update: valid the condition

pull/140/head
kozistr 2022-01-01 19:14:36 +09:00
parent 4866f4b6ac
commit 96ac034e60
1 changed files with 2 additions and 0 deletions

View File

@ -22,6 +22,8 @@ class RASampler(torch.utils.data.Sampler):
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if num_repeats < 1:
raise ValueError("num_repeats should be greater than 0")
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank