diff --git a/mmseg/apis/train.py b/mmseg/apis/train.py index be84a89e6..bda7b213f 100644 --- a/mmseg/apis/train.py +++ b/mmseg/apis/train.py @@ -79,17 +79,25 @@ def train_segmentor(model, # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] - data_loaders = [ - build_dataloader( - ds, - cfg.data.samples_per_gpu, - cfg.data.workers_per_gpu, - # cfg.gpus will be ignored if distributed - len(cfg.gpu_ids), - dist=distributed, - seed=cfg.seed, - drop_last=True) for ds in dataset - ] + # The default loader config + loader_cfg = dict( + # cfg.gpus will be ignored if distributed + num_gpus=len(cfg.gpu_ids), + dist=distributed, + seed=cfg.seed, + drop_last=True) + # The overall dataloader settings + loader_cfg.update({ + k: v + for k, v in cfg.data.items() if k not in [ + 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', + 'test_dataloader' + ] + }) + + # The specific dataloader settings + train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})} + data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset] # put model on gpus if distributed: @@ -142,12 +150,14 @@ def train_segmentor(model, # register eval hooks if validate: val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) - val_dataloader = build_dataloader( - val_dataset, - samples_per_gpu=1, - workers_per_gpu=cfg.data.workers_per_gpu, - dist=distributed, - shuffle=False) + # The specific dataloader settings + val_loader_cfg = { + **loader_cfg, + 'samples_per_gpu': 1, + 'shuffle': False, # Not shuffle by default + **cfg.data.get('val_dataloader', {}), + } + val_dataloader = build_dataloader(val_dataset, **val_loader_cfg) eval_cfg = cfg.get('evaluation', {}) eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' eval_hook = DistEvalHook if distributed else EvalHook diff --git a/tools/test.py b/tools/test.py index d5dc0d5f6..12892ec9b 100644 --- a/tools/test.py +++ b/tools/test.py @@ -191,12 +191,28 @@ def main(): # build the dataloader # TODO: support multiple images per gpu (only minor changes are needed) dataset = build_dataset(cfg.data.test) - data_loader = build_dataloader( - dataset, - samples_per_gpu=1, - workers_per_gpu=cfg.data.workers_per_gpu, + # The default loader config + loader_cfg = dict( + # cfg.gpus will be ignored if distributed + num_gpus=len(cfg.gpu_ids), dist=distributed, shuffle=False) + # The overall dataloader settings + loader_cfg.update({ + k: v + for k, v in cfg.data.items() if k not in [ + 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', + 'test_dataloader' + ] + }) + test_loader_cfg = { + **loader_cfg, + 'samples_per_gpu': 1, + 'shuffle': False, # Not shuffle by default + **cfg.data.get('test_dataloader', {}) + } + # build the dataloader + data_loader = build_dataloader(dataset, **test_loader_cfg) # build the model and load checkpoint cfg.model.train_cfg = None