[Feature] Support build batch sampler. (#173)
parent
3d830a28b6
commit
798eab4825
|
@ -891,6 +891,21 @@ class Runner:
|
|||
# if `sampler_cfg` is not a valid type
|
||||
sampler = sampler_cfg
|
||||
|
||||
# build batch sampler
|
||||
batch_sampler_cfg = dataloader_cfg.pop('batch_sampler', None)
|
||||
if batch_sampler_cfg is None:
|
||||
batch_sampler = None
|
||||
elif isinstance(batch_sampler_cfg, dict):
|
||||
batch_sampler = DATA_SAMPLERS.build(
|
||||
batch_sampler_cfg,
|
||||
default_args=dict(
|
||||
sampler=sampler,
|
||||
batch_size=dataloader_cfg.pop('batch_size')))
|
||||
else:
|
||||
# fallback to raise error in dataloader
|
||||
# if `batch_sampler_cfg` is not a valid type
|
||||
batch_sampler = batch_sampler_cfg
|
||||
|
||||
# build dataloader
|
||||
init_fn: Optional[partial]
|
||||
if self.seed is not None:
|
||||
|
@ -909,8 +924,8 @@ class Runner:
|
|||
# in model.
|
||||
data_loader = DataLoader(
|
||||
dataset=dataset,
|
||||
sampler=sampler,
|
||||
batch_sampler=None,
|
||||
sampler=sampler if batch_sampler is None else None,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=pseudo_collate,
|
||||
worker_init_fn=init_fn,
|
||||
**dataloader_cfg)
|
||||
|
|
Loading…
Reference in New Issue