mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Enhance] Add extra dataloader settings in configs (#1435)
* [Enhance] Add extra dataloader settings in configs * val default samples * val default samples * del unuse * del unused
This commit is contained in:
parent
91b1bcb9d8
commit
3f797072d8
@ -79,17 +79,25 @@ def train_segmentor(model,
|
|||||||
|
|
||||||
# prepare data loaders
|
# prepare data loaders
|
||||||
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
||||||
data_loaders = [
|
# The default loader config
|
||||||
build_dataloader(
|
loader_cfg = dict(
|
||||||
ds,
|
# cfg.gpus will be ignored if distributed
|
||||||
cfg.data.samples_per_gpu,
|
num_gpus=len(cfg.gpu_ids),
|
||||||
cfg.data.workers_per_gpu,
|
dist=distributed,
|
||||||
# cfg.gpus will be ignored if distributed
|
seed=cfg.seed,
|
||||||
len(cfg.gpu_ids),
|
drop_last=True)
|
||||||
dist=distributed,
|
# The overall dataloader settings
|
||||||
seed=cfg.seed,
|
loader_cfg.update({
|
||||||
drop_last=True) for ds in dataset
|
k: v
|
||||||
]
|
for k, v in cfg.data.items() if k not in [
|
||||||
|
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
|
||||||
|
'test_dataloader'
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
# The specific dataloader settings
|
||||||
|
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
|
||||||
|
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
|
||||||
|
|
||||||
# put model on gpus
|
# put model on gpus
|
||||||
if distributed:
|
if distributed:
|
||||||
@ -142,12 +150,14 @@ def train_segmentor(model,
|
|||||||
# register eval hooks
|
# register eval hooks
|
||||||
if validate:
|
if validate:
|
||||||
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
||||||
val_dataloader = build_dataloader(
|
# The specific dataloader settings
|
||||||
val_dataset,
|
val_loader_cfg = {
|
||||||
samples_per_gpu=1,
|
**loader_cfg,
|
||||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
'samples_per_gpu': 1,
|
||||||
dist=distributed,
|
'shuffle': False, # Not shuffle by default
|
||||||
shuffle=False)
|
**cfg.data.get('val_dataloader', {}),
|
||||||
|
}
|
||||||
|
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
|
||||||
eval_cfg = cfg.get('evaluation', {})
|
eval_cfg = cfg.get('evaluation', {})
|
||||||
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
|
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
|
||||||
eval_hook = DistEvalHook if distributed else EvalHook
|
eval_hook = DistEvalHook if distributed else EvalHook
|
||||||
|
@ -191,12 +191,28 @@ def main():
|
|||||||
# build the dataloader
|
# build the dataloader
|
||||||
# TODO: support multiple images per gpu (only minor changes are needed)
|
# TODO: support multiple images per gpu (only minor changes are needed)
|
||||||
dataset = build_dataset(cfg.data.test)
|
dataset = build_dataset(cfg.data.test)
|
||||||
data_loader = build_dataloader(
|
# The default loader config
|
||||||
dataset,
|
loader_cfg = dict(
|
||||||
samples_per_gpu=1,
|
# cfg.gpus will be ignored if distributed
|
||||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
num_gpus=len(cfg.gpu_ids),
|
||||||
dist=distributed,
|
dist=distributed,
|
||||||
shuffle=False)
|
shuffle=False)
|
||||||
|
# The overall dataloader settings
|
||||||
|
loader_cfg.update({
|
||||||
|
k: v
|
||||||
|
for k, v in cfg.data.items() if k not in [
|
||||||
|
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
|
||||||
|
'test_dataloader'
|
||||||
|
]
|
||||||
|
})
|
||||||
|
test_loader_cfg = {
|
||||||
|
**loader_cfg,
|
||||||
|
'samples_per_gpu': 1,
|
||||||
|
'shuffle': False, # Not shuffle by default
|
||||||
|
**cfg.data.get('test_dataloader', {})
|
||||||
|
}
|
||||||
|
# build the dataloader
|
||||||
|
data_loader = build_dataloader(dataset, **test_loader_cfg)
|
||||||
|
|
||||||
# build the model and load checkpoint
|
# build the model and load checkpoint
|
||||||
cfg.model.train_cfg = None
|
cfg.model.train_cfg = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user