diff --git a/mmcv/runner/hooks/sampler_seed.py b/mmcv/runner/hooks/sampler_seed.py index e57f1e9e8..93793f7c2 100644 --- a/mmcv/runner/hooks/sampler_seed.py +++ b/mmcv/runner/hooks/sampler_seed.py @@ -6,4 +6,9 @@ from .hook import HOOKS, Hook class DistSamplerSeedHook(Hook): def before_epoch(self, runner): - runner.data_loader.sampler.set_epoch(runner.epoch) + if hasattr(runner.data_loader.sampler, 'set_epoch'): + # in case the data loader uses `SequentialSampler` in Pytorch + runner.data_loader.sampler.set_epoch(runner.epoch) + if hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'): + # batch sampler in pytorch warps a sampler as its attributes. + runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)