From b4a9a87eee05332e81e7b3eedf987da904ced11a Mon Sep 17 00:00:00 2001 From: Tong Gao Date: Mon, 18 Apr 2022 09:24:07 +0800 Subject: [PATCH] [Enhancement] More customizable fields in dataloaders (#933) * [Enhancement] More customizable fields in val and test dataloaders * update default_loader_cfg --- mmocr/apis/train.py | 30 +++++++++++++++--------------- tools/test.py | 20 ++++++++++---------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/mmocr/apis/train.py b/mmocr/apis/train.py index 89ba3be6..e9178e00 100644 --- a/mmocr/apis/train.py +++ b/mmocr/apis/train.py @@ -30,30 +30,30 @@ def train_detector(model, # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] # step 1: give default values and override (if exist) from cfg.data - loader_cfg = { + default_loader_cfg = { **dict( + num_gpus=len(cfg.gpu_ids), + dist=distributed, seed=cfg.get('seed'), drop_last=False, - dist=distributed, - num_gpus=len(cfg.gpu_ids)), + persistent_workers=False), **({} if torch.__version__ != 'parrots' else dict( prefetch_num=2, pin_memory=False, )), - **dict((k, cfg.data[k]) for k in [ - 'samples_per_gpu', - 'workers_per_gpu', - 'shuffle', - 'seed', - 'drop_last', - 'prefetch_num', - 'pin_memory', - 'persistent_workers', - ] if k in cfg.data) } + # update overall dataloader(for train, val and test) setting + default_loader_cfg.update({ + k: v + for k, v in cfg.data.items() if k not in [ + 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', + 'test_dataloader' + ] + }) # step 2: cfg.data.train_dataloader has highest priority - train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {})) + train_loader_cfg = dict(default_loader_cfg, + **cfg.data.get('train_dataloader', {})) data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset] @@ -135,7 +135,7 @@ def train_detector(model, val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) val_loader_cfg = { - **loader_cfg, + **default_loader_cfg, **dict(shuffle=False, drop_last=False), **cfg.data.get('val_dataloader', {}), **dict(samples_per_gpu=val_samples_per_gpu) diff --git a/tools/test.py b/tools/test.py index 4df1b315..2774bf4c 100755 --- a/tools/test.py +++ b/tools/test.py @@ -164,22 +164,22 @@ def main(): # build the dataloader dataset = build_dataset(cfg.data.test, dict(test_mode=True)) # step 1: give default values and override (if exist) from cfg.data - loader_cfg = { + default_loader_cfg = { **dict(seed=cfg.get('seed'), drop_last=False, dist=distributed), **({} if torch.__version__ != 'parrots' else dict( prefetch_num=2, pin_memory=False, - )), - **dict((k, cfg.data[k]) for k in [ - 'workers_per_gpu', - 'seed', - 'prefetch_num', - 'pin_memory', - 'persistent_workers', - ] if k in cfg.data) + )) } + default_loader_cfg.update({ + k: v + for k, v in cfg.data.items() if k not in [ + 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', + 'test_dataloader' + ] + }) test_loader_cfg = { - **loader_cfg, + **default_loader_cfg, **dict(shuffle=False, drop_last=False), **cfg.data.get('test_dataloader', {}), **dict(samples_per_gpu=samples_per_gpu)