mirror of https://github.com/open-mmlab/mmocr.git
[Enhancement] More customizable fields in dataloaders (#933)
* [Enhancement] More customizable fields in val and test dataloaders * update default_loader_cfgpull/1018/head
parent
20fc909fc4
commit
b4a9a87eee
|
@ -30,30 +30,30 @@ def train_detector(model,
|
|||
# prepare data loaders
|
||||
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
||||
# step 1: give default values and override (if exist) from cfg.data
|
||||
loader_cfg = {
|
||||
default_loader_cfg = {
|
||||
**dict(
|
||||
num_gpus=len(cfg.gpu_ids),
|
||||
dist=distributed,
|
||||
seed=cfg.get('seed'),
|
||||
drop_last=False,
|
||||
dist=distributed,
|
||||
num_gpus=len(cfg.gpu_ids)),
|
||||
persistent_workers=False),
|
||||
**({} if torch.__version__ != 'parrots' else dict(
|
||||
prefetch_num=2,
|
||||
pin_memory=False,
|
||||
)),
|
||||
**dict((k, cfg.data[k]) for k in [
|
||||
'samples_per_gpu',
|
||||
'workers_per_gpu',
|
||||
'shuffle',
|
||||
'seed',
|
||||
'drop_last',
|
||||
'prefetch_num',
|
||||
'pin_memory',
|
||||
'persistent_workers',
|
||||
] if k in cfg.data)
|
||||
}
|
||||
# update overall dataloader(for train, val and test) setting
|
||||
default_loader_cfg.update({
|
||||
k: v
|
||||
for k, v in cfg.data.items() if k not in [
|
||||
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
|
||||
'test_dataloader'
|
||||
]
|
||||
})
|
||||
|
||||
# step 2: cfg.data.train_dataloader has highest priority
|
||||
train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {}))
|
||||
train_loader_cfg = dict(default_loader_cfg,
|
||||
**cfg.data.get('train_dataloader', {}))
|
||||
|
||||
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
|
||||
|
||||
|
@ -135,7 +135,7 @@ def train_detector(model,
|
|||
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
||||
|
||||
val_loader_cfg = {
|
||||
**loader_cfg,
|
||||
**default_loader_cfg,
|
||||
**dict(shuffle=False, drop_last=False),
|
||||
**cfg.data.get('val_dataloader', {}),
|
||||
**dict(samples_per_gpu=val_samples_per_gpu)
|
||||
|
|
|
@ -164,22 +164,22 @@ def main():
|
|||
# build the dataloader
|
||||
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
|
||||
# step 1: give default values and override (if exist) from cfg.data
|
||||
loader_cfg = {
|
||||
default_loader_cfg = {
|
||||
**dict(seed=cfg.get('seed'), drop_last=False, dist=distributed),
|
||||
**({} if torch.__version__ != 'parrots' else dict(
|
||||
prefetch_num=2,
|
||||
pin_memory=False,
|
||||
)),
|
||||
**dict((k, cfg.data[k]) for k in [
|
||||
'workers_per_gpu',
|
||||
'seed',
|
||||
'prefetch_num',
|
||||
'pin_memory',
|
||||
'persistent_workers',
|
||||
] if k in cfg.data)
|
||||
))
|
||||
}
|
||||
default_loader_cfg.update({
|
||||
k: v
|
||||
for k, v in cfg.data.items() if k not in [
|
||||
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
|
||||
'test_dataloader'
|
||||
]
|
||||
})
|
||||
test_loader_cfg = {
|
||||
**loader_cfg,
|
||||
**default_loader_cfg,
|
||||
**dict(shuffle=False, drop_last=False),
|
||||
**cfg.data.get('test_dataloader', {}),
|
||||
**dict(samples_per_gpu=samples_per_gpu)
|
||||
|
|
Loading…
Reference in New Issue