mirror of https://github.com/YifanXu74/MQ-Det.git
32 lines
1.2 KiB
Python
32 lines
1.2 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
from torch.utils.data.sampler import BatchSampler
|
|
|
|
|
|
class IterationBasedBatchSampler(BatchSampler):
|
|
"""
|
|
Wraps a BatchSampler, resampling from it until
|
|
a specified number of iterations have been sampled
|
|
"""
|
|
|
|
def __init__(self, batch_sampler, num_iterations, start_iter=0):
|
|
self.batch_sampler = batch_sampler
|
|
self.num_iterations = num_iterations
|
|
self.start_iter = start_iter
|
|
|
|
def __iter__(self):
|
|
iteration = self.start_iter
|
|
while iteration <= self.num_iterations:
|
|
# if the underlying sampler has a set_epoch method, like
|
|
# DistributedSampler, used for making each process see
|
|
# a different split of the dataset, then set it
|
|
if hasattr(self.batch_sampler.sampler, "set_epoch"):
|
|
self.batch_sampler.sampler.set_epoch(iteration)
|
|
for batch in self.batch_sampler:
|
|
iteration += 1
|
|
if iteration > self.num_iterations:
|
|
break
|
|
yield batch
|
|
|
|
def __len__(self):
|
|
return self.num_iterations
|