fix the bug (#472) (#484)

* fix the bug (#472)

* fix the bug (#472)

* fix the bug (#472)

Co-authored-by: hezijian <hezijian@dm-ai.cn>
pull/491/head
Zijian He 2020-08-13 19:30:28 +08:00 committed by GitHub
parent 51c65c97ec
commit e7e0c89f5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 1 deletions

View File

@ -6,4 +6,9 @@ from .hook import HOOKS, Hook
class DistSamplerSeedHook(Hook):
def before_epoch(self, runner):
runner.data_loader.sampler.set_epoch(runner.epoch)
if hasattr(runner.data_loader.sampler, 'set_epoch'):
# in case the data loader uses `SequentialSampler` in Pytorch
runner.data_loader.sampler.set_epoch(runner.epoch)
if hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
# batch sampler in pytorch warps a sampler as its attributes.
runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)