mirror of https://github.com/open-mmlab/mmcv.git
* fix the bug (#472) * fix the bug (#472) * fix the bug (#472) Co-authored-by: hezijian <hezijian@dm-ai.cn>pull/491/head
parent
51c65c97ec
commit
e7e0c89f5c
|
@ -6,4 +6,9 @@ from .hook import HOOKS, Hook
|
|||
class DistSamplerSeedHook(Hook):
|
||||
|
||||
def before_epoch(self, runner):
|
||||
runner.data_loader.sampler.set_epoch(runner.epoch)
|
||||
if hasattr(runner.data_loader.sampler, 'set_epoch'):
|
||||
# in case the data loader uses `SequentialSampler` in Pytorch
|
||||
runner.data_loader.sampler.set_epoch(runner.epoch)
|
||||
if hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
|
||||
# batch sampler in pytorch warps a sampler as its attributes.
|
||||
runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
|
||||
|
|
Loading…
Reference in New Issue