[Feature] Support build batch sampler. (#173)

pull/174/head
RangiLyu 2022-04-12 09:54:30 +08:00 committed by GitHub
parent 3d830a28b6
commit 798eab4825
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 2 deletions

View File

@ -891,6 +891,21 @@ class Runner:
# if `sampler_cfg` is not a valid type
sampler = sampler_cfg
# build batch sampler
batch_sampler_cfg = dataloader_cfg.pop('batch_sampler', None)
if batch_sampler_cfg is None:
batch_sampler = None
elif isinstance(batch_sampler_cfg, dict):
batch_sampler = DATA_SAMPLERS.build(
batch_sampler_cfg,
default_args=dict(
sampler=sampler,
batch_size=dataloader_cfg.pop('batch_size')))
else:
# fallback to raise error in dataloader
# if `batch_sampler_cfg` is not a valid type
batch_sampler = batch_sampler_cfg
# build dataloader
init_fn: Optional[partial]
if self.seed is not None:
@ -909,8 +924,8 @@ class Runner:
# in model.
data_loader = DataLoader(
dataset=dataset,
sampler=sampler,
batch_sampler=None,
sampler=sampler if batch_sampler is None else None,
batch_sampler=batch_sampler,
collate_fn=pseudo_collate,
worker_init_fn=init_fn,
**dataloader_cfg)