mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Support build batch sampler. (#173)
This commit is contained in:
parent
3d830a28b6
commit
798eab4825
@ -891,6 +891,21 @@ class Runner:
|
|||||||
# if `sampler_cfg` is not a valid type
|
# if `sampler_cfg` is not a valid type
|
||||||
sampler = sampler_cfg
|
sampler = sampler_cfg
|
||||||
|
|
||||||
|
# build batch sampler
|
||||||
|
batch_sampler_cfg = dataloader_cfg.pop('batch_sampler', None)
|
||||||
|
if batch_sampler_cfg is None:
|
||||||
|
batch_sampler = None
|
||||||
|
elif isinstance(batch_sampler_cfg, dict):
|
||||||
|
batch_sampler = DATA_SAMPLERS.build(
|
||||||
|
batch_sampler_cfg,
|
||||||
|
default_args=dict(
|
||||||
|
sampler=sampler,
|
||||||
|
batch_size=dataloader_cfg.pop('batch_size')))
|
||||||
|
else:
|
||||||
|
# fallback to raise error in dataloader
|
||||||
|
# if `batch_sampler_cfg` is not a valid type
|
||||||
|
batch_sampler = batch_sampler_cfg
|
||||||
|
|
||||||
# build dataloader
|
# build dataloader
|
||||||
init_fn: Optional[partial]
|
init_fn: Optional[partial]
|
||||||
if self.seed is not None:
|
if self.seed is not None:
|
||||||
@ -909,8 +924,8 @@ class Runner:
|
|||||||
# in model.
|
# in model.
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
sampler=sampler,
|
sampler=sampler if batch_sampler is None else None,
|
||||||
batch_sampler=None,
|
batch_sampler=batch_sampler,
|
||||||
collate_fn=pseudo_collate,
|
collate_fn=pseudo_collate,
|
||||||
worker_init_fn=init_fn,
|
worker_init_fn=init_fn,
|
||||||
**dataloader_cfg)
|
**dataloader_cfg)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user