mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Enhance] Add extra dataloader settings in configs (#1435)
* [Enhance] Add extra dataloader settings in configs * val default samples * val default samples * del unuse * del unused
This commit is contained in:
parent
91b1bcb9d8
commit
3f797072d8
@ -79,17 +79,25 @@ def train_segmentor(model,
|
||||
|
||||
# prepare data loaders
|
||||
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
||||
data_loaders = [
|
||||
build_dataloader(
|
||||
ds,
|
||||
cfg.data.samples_per_gpu,
|
||||
cfg.data.workers_per_gpu,
|
||||
# cfg.gpus will be ignored if distributed
|
||||
len(cfg.gpu_ids),
|
||||
dist=distributed,
|
||||
seed=cfg.seed,
|
||||
drop_last=True) for ds in dataset
|
||||
]
|
||||
# The default loader config
|
||||
loader_cfg = dict(
|
||||
# cfg.gpus will be ignored if distributed
|
||||
num_gpus=len(cfg.gpu_ids),
|
||||
dist=distributed,
|
||||
seed=cfg.seed,
|
||||
drop_last=True)
|
||||
# The overall dataloader settings
|
||||
loader_cfg.update({
|
||||
k: v
|
||||
for k, v in cfg.data.items() if k not in [
|
||||
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
|
||||
'test_dataloader'
|
||||
]
|
||||
})
|
||||
|
||||
# The specific dataloader settings
|
||||
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
|
||||
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
|
||||
|
||||
# put model on gpus
|
||||
if distributed:
|
||||
@ -142,12 +150,14 @@ def train_segmentor(model,
|
||||
# register eval hooks
|
||||
if validate:
|
||||
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
||||
val_dataloader = build_dataloader(
|
||||
val_dataset,
|
||||
samples_per_gpu=1,
|
||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
||||
dist=distributed,
|
||||
shuffle=False)
|
||||
# The specific dataloader settings
|
||||
val_loader_cfg = {
|
||||
**loader_cfg,
|
||||
'samples_per_gpu': 1,
|
||||
'shuffle': False, # Not shuffle by default
|
||||
**cfg.data.get('val_dataloader', {}),
|
||||
}
|
||||
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
|
||||
eval_cfg = cfg.get('evaluation', {})
|
||||
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
|
||||
eval_hook = DistEvalHook if distributed else EvalHook
|
||||
|
@ -191,12 +191,28 @@ def main():
|
||||
# build the dataloader
|
||||
# TODO: support multiple images per gpu (only minor changes are needed)
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=1,
|
||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
||||
# The default loader config
|
||||
loader_cfg = dict(
|
||||
# cfg.gpus will be ignored if distributed
|
||||
num_gpus=len(cfg.gpu_ids),
|
||||
dist=distributed,
|
||||
shuffle=False)
|
||||
# The overall dataloader settings
|
||||
loader_cfg.update({
|
||||
k: v
|
||||
for k, v in cfg.data.items() if k not in [
|
||||
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
|
||||
'test_dataloader'
|
||||
]
|
||||
})
|
||||
test_loader_cfg = {
|
||||
**loader_cfg,
|
||||
'samples_per_gpu': 1,
|
||||
'shuffle': False, # Not shuffle by default
|
||||
**cfg.data.get('test_dataloader', {})
|
||||
}
|
||||
# build the dataloader
|
||||
data_loader = build_dataloader(dataset, **test_loader_cfg)
|
||||
|
||||
# build the model and load checkpoint
|
||||
cfg.model.train_cfg = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user