mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Refactor] Refactor DatasetWrapper
This commit is contained in:
parent
eef12a064b
commit
f2bac79f03
@ -26,14 +26,17 @@ train_dataloader = dict(
|
|||||||
num_workers=4,
|
num_workers=4,
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||||
type='RepeatDataset',
|
|
||||||
times=40000,
|
|
||||||
dataset=dict(
|
dataset=dict(
|
||||||
type=dataset_type,
|
type='RepeatDataset',
|
||||||
data_root=data_root,
|
times=40000,
|
||||||
data_prefix=dict(
|
dataset=dict(
|
||||||
img_path='images/training', seg_map_path='annotations/training'),
|
type=dataset_type,
|
||||||
pipeline=train_pipeline))
|
data_root=data_root,
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path='images/training',
|
||||||
|
seg_map_path='annotations/training'),
|
||||||
|
pipeline=train_pipeline)))
|
||||||
|
|
||||||
val_dataloader = dict(
|
val_dataloader = dict(
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -25,14 +25,16 @@ train_dataloader = dict(
|
|||||||
num_workers=4,
|
num_workers=4,
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||||
type='RepeatDataset',
|
|
||||||
times=40000,
|
|
||||||
dataset=dict(
|
dataset=dict(
|
||||||
type=dataset_type,
|
type='RepeatDataset',
|
||||||
data_root=data_root,
|
times=40000,
|
||||||
data_prefix=dict(
|
dataset=dict(
|
||||||
img_path='images/training', seg_map_path='annotations/training'),
|
type=dataset_type,
|
||||||
pipeline=train_pipeline))
|
data_root=data_root,
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path='images/training',
|
||||||
|
seg_map_path='annotations/training'),
|
||||||
|
pipeline=train_pipeline)))
|
||||||
val_dataloader = dict(
|
val_dataloader = dict(
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -25,14 +25,16 @@ train_dataloader = dict(
|
|||||||
num_workers=4,
|
num_workers=4,
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||||
type='RepeatDataset',
|
|
||||||
times=40000,
|
|
||||||
dataset=dict(
|
dataset=dict(
|
||||||
type=dataset_type,
|
type='RepeatDataset',
|
||||||
data_root=data_root,
|
times=40000,
|
||||||
data_prefix=dict(
|
dataset=dict(
|
||||||
img_path='images/training', seg_map_path='annotations/training'),
|
type=dataset_type,
|
||||||
pipeline=train_pipeline))
|
data_root=data_root,
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path='images/training',
|
||||||
|
seg_map_path='annotations/training'),
|
||||||
|
pipeline=train_pipeline)))
|
||||||
val_dataloader = dict(
|
val_dataloader = dict(
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -1,9 +1,62 @@
|
|||||||
_base_ = './pascal_voc12.py'
|
|
||||||
# dataset settings
|
# dataset settings
|
||||||
|
dataset_type = 'PascalVOCDataset'
|
||||||
|
data_root = 'data/VOCdevkit/VOC2012'
|
||||||
|
crop_size = (512, 512)
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations'),
|
||||||
|
dict(type='RandomResize', scale=(2048, 512), ratio_range=(0.5, 2.0)),
|
||||||
|
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||||
|
dict(type='RandomFlip', prob=0.5),
|
||||||
|
dict(type='PhotoMetricDistortion'),
|
||||||
|
dict(type='Pad', size=crop_size),
|
||||||
|
dict(type='PackSegInputs')
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
|
||||||
|
# add loading annotation after ``Resize`` because ground truth
|
||||||
|
# does not need to do resize data transform
|
||||||
|
dict(type='LoadAnnotations'),
|
||||||
|
dict(type='PackSegInputs')
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset_train = dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=dict(img_path='JPEGImages', seg_map_path='SegmentationClass'),
|
||||||
|
ann_file='ImageSets/Segmentation/train.txt',
|
||||||
|
pipeline=train_pipeline)
|
||||||
|
|
||||||
|
dataset_aug = dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path='JPEGImages', seg_map_path='SegmentationClassAug'),
|
||||||
|
ann_file='ImageSets/Segmentation/aug.txt',
|
||||||
|
pipeline=train_pipeline)
|
||||||
|
|
||||||
train_dataloader = dict(
|
train_dataloader = dict(
|
||||||
|
batch_size=4,
|
||||||
|
num_workers=4,
|
||||||
|
persistent_workers=True,
|
||||||
|
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||||
|
dataset=dict(type='ConcatDataset', datasets=[dataset_train, dataset_aug]))
|
||||||
|
|
||||||
|
val_dataloader = dict(
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
persistent_workers=True,
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
dataset=dict(
|
dataset=dict(
|
||||||
ann_dir=['SegmentationClass', 'SegmentationClassAug'],
|
type=dataset_type,
|
||||||
ann_file=[
|
data_root=data_root,
|
||||||
'ImageSets/Segmentation/train.txt',
|
data_prefix=dict(
|
||||||
'ImageSets/Segmentation/aug.txt'
|
img_path='JPEGImages', seg_map_path='SegmentationClass'),
|
||||||
]))
|
ann_file='ImageSets/Segmentation/val.txt',
|
||||||
|
pipeline=test_pipeline))
|
||||||
|
test_dataloader = val_dataloader
|
||||||
|
|
||||||
|
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||||
|
test_evaluator = val_evaluator
|
||||||
|
@ -25,14 +25,16 @@ train_dataloader = dict(
|
|||||||
num_workers=4,
|
num_workers=4,
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||||
type='RepeatDataset',
|
|
||||||
times=40000,
|
|
||||||
dataset=dict(
|
dataset=dict(
|
||||||
type=dataset_type,
|
type='RepeatDataset',
|
||||||
data_root=data_root,
|
times=40000,
|
||||||
data_prefix=dict(
|
dataset=dict(
|
||||||
img_path='images/training', seg_map_path='annotations/training'),
|
type=dataset_type,
|
||||||
pipeline=train_pipeline))
|
data_root=data_root,
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path='images/training',
|
||||||
|
seg_map_path='annotations/training'),
|
||||||
|
pipeline=train_pipeline)))
|
||||||
val_dataloader = dict(
|
val_dataloader = dict(
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from mmengine.dataset import ConcatDataset, RepeatDataset
|
||||||
|
|
||||||
|
from mmseg.registry import DATASETS, TRANSFORMS
|
||||||
from .ade import ADE20KDataset
|
from .ade import ADE20KDataset
|
||||||
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
|
|
||||||
from .chase_db1 import ChaseDB1Dataset
|
from .chase_db1 import ChaseDB1Dataset
|
||||||
from .cityscapes import CityscapesDataset
|
from .cityscapes import CityscapesDataset
|
||||||
from .coco_stuff import COCOStuffDataset
|
from .coco_stuff import COCOStuffDataset
|
||||||
from .custom import CustomDataset
|
from .custom import CustomDataset
|
||||||
from .dark_zurich import DarkZurichDataset
|
from .dark_zurich import DarkZurichDataset
|
||||||
from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
|
from .dataset_wrappers import MultiImageMixDataset
|
||||||
RepeatDataset)
|
|
||||||
from .drive import DRIVEDataset
|
from .drive import DRIVEDataset
|
||||||
from .hrf import HRFDataset
|
from .hrf import HRFDataset
|
||||||
from .isaid import iSAIDDataset
|
from .isaid import iSAIDDataset
|
||||||
@ -20,11 +21,10 @@ from .stare import STAREDataset
|
|||||||
from .voc import PascalVOCDataset
|
from .voc import PascalVOCDataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
|
'CustomDataset', 'ConcatDataset', 'RepeatDataset', 'DATASETS',
|
||||||
'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
|
'TRANSFORMS', 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
|
||||||
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
|
'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
|
||||||
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
|
'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
|
||||||
'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
|
'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
|
||||||
'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset',
|
'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset'
|
||||||
'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset'
|
|
||||||
]
|
]
|
||||||
|
@ -1,191 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import copy
|
|
||||||
import platform
|
|
||||||
import random
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from mmcv.parallel import collate
|
|
||||||
from mmcv.runner import get_dist_info
|
|
||||||
from mmcv.utils import digit_version
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from mmseg.registry import DATASETS, TRANSFORMS
|
|
||||||
from .samplers import DistributedSampler
|
|
||||||
|
|
||||||
if platform.system() != 'Windows':
|
|
||||||
# https://github.com/pytorch/pytorch/issues/973
|
|
||||||
import resource
|
|
||||||
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
|
||||||
base_soft_limit = rlimit[0]
|
|
||||||
hard_limit = rlimit[1]
|
|
||||||
soft_limit = min(max(4096, base_soft_limit), hard_limit)
|
|
||||||
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
|
|
||||||
|
|
||||||
PIPELINES = TRANSFORMS
|
|
||||||
|
|
||||||
|
|
||||||
def _concat_dataset(cfg, default_args=None):
|
|
||||||
"""Build :obj:`ConcatDataset by."""
|
|
||||||
from .dataset_wrappers import ConcatDataset
|
|
||||||
img_dir = cfg['img_dir']
|
|
||||||
ann_dir = cfg.get('ann_dir', None)
|
|
||||||
split = cfg.get('split', None)
|
|
||||||
# pop 'separate_eval' since it is not a valid key for common datasets.
|
|
||||||
separate_eval = cfg.pop('separate_eval', True)
|
|
||||||
num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
|
|
||||||
if ann_dir is not None:
|
|
||||||
num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
|
|
||||||
else:
|
|
||||||
num_ann_dir = 0
|
|
||||||
if split is not None:
|
|
||||||
num_split = len(split) if isinstance(split, (list, tuple)) else 1
|
|
||||||
else:
|
|
||||||
num_split = 0
|
|
||||||
if num_img_dir > 1:
|
|
||||||
assert num_img_dir == num_ann_dir or num_ann_dir == 0
|
|
||||||
assert num_img_dir == num_split or num_split == 0
|
|
||||||
else:
|
|
||||||
assert num_split == num_ann_dir or num_ann_dir <= 1
|
|
||||||
num_dset = max(num_split, num_img_dir)
|
|
||||||
|
|
||||||
datasets = []
|
|
||||||
for i in range(num_dset):
|
|
||||||
data_cfg = copy.deepcopy(cfg)
|
|
||||||
if isinstance(img_dir, (list, tuple)):
|
|
||||||
data_cfg['img_dir'] = img_dir[i]
|
|
||||||
if isinstance(ann_dir, (list, tuple)):
|
|
||||||
data_cfg['ann_dir'] = ann_dir[i]
|
|
||||||
if isinstance(split, (list, tuple)):
|
|
||||||
data_cfg['split'] = split[i]
|
|
||||||
datasets.append(build_dataset(data_cfg, default_args))
|
|
||||||
|
|
||||||
return ConcatDataset(datasets, separate_eval)
|
|
||||||
|
|
||||||
|
|
||||||
def build_dataset(cfg, default_args=None):
|
|
||||||
"""Build datasets."""
|
|
||||||
from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
|
|
||||||
RepeatDataset)
|
|
||||||
if isinstance(cfg, (list, tuple)):
|
|
||||||
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
|
|
||||||
elif cfg['type'] == 'RepeatDataset':
|
|
||||||
dataset = RepeatDataset(
|
|
||||||
build_dataset(cfg['dataset'], default_args), cfg['times'])
|
|
||||||
elif cfg['type'] == 'MultiImageMixDataset':
|
|
||||||
cp_cfg = copy.deepcopy(cfg)
|
|
||||||
cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'])
|
|
||||||
cp_cfg.pop('type')
|
|
||||||
dataset = MultiImageMixDataset(**cp_cfg)
|
|
||||||
elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
|
|
||||||
cfg.get('split', None), (list, tuple)):
|
|
||||||
dataset = _concat_dataset(cfg, default_args)
|
|
||||||
else:
|
|
||||||
dataset = DATASETS.build(cfg, default_args=default_args)
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
def build_dataloader(dataset,
|
|
||||||
samples_per_gpu,
|
|
||||||
workers_per_gpu,
|
|
||||||
num_gpus=1,
|
|
||||||
dist=True,
|
|
||||||
shuffle=True,
|
|
||||||
seed=None,
|
|
||||||
drop_last=False,
|
|
||||||
pin_memory=True,
|
|
||||||
persistent_workers=True,
|
|
||||||
**kwargs):
|
|
||||||
"""Build PyTorch DataLoader.
|
|
||||||
|
|
||||||
In distributed training, each GPU/process has a dataloader.
|
|
||||||
In non-distributed training, there is only one dataloader for all GPUs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset (Dataset): A PyTorch dataset.
|
|
||||||
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
|
|
||||||
batch size of each GPU.
|
|
||||||
workers_per_gpu (int): How many subprocesses to use for data loading
|
|
||||||
for each GPU.
|
|
||||||
num_gpus (int): Number of GPUs. Only used in non-distributed training.
|
|
||||||
dist (bool): Distributed training/test or not. Default: True.
|
|
||||||
shuffle (bool): Whether to shuffle the data at every epoch.
|
|
||||||
Default: True.
|
|
||||||
seed (int | None): Seed to be used. Default: None.
|
|
||||||
drop_last (bool): Whether to drop the last incomplete batch in epoch.
|
|
||||||
Default: False
|
|
||||||
pin_memory (bool): Whether to use pin_memory in DataLoader.
|
|
||||||
Default: True
|
|
||||||
persistent_workers (bool): If True, the data loader will not shutdown
|
|
||||||
the worker processes after a dataset has been consumed once.
|
|
||||||
This allows to maintain the workers Dataset instances alive.
|
|
||||||
The argument also has effect in PyTorch>=1.7.0.
|
|
||||||
Default: True
|
|
||||||
kwargs: any keyword argument to be used to initialize DataLoader
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DataLoader: A PyTorch dataloader.
|
|
||||||
"""
|
|
||||||
rank, world_size = get_dist_info()
|
|
||||||
if dist:
|
|
||||||
sampler = DistributedSampler(
|
|
||||||
dataset, world_size, rank, shuffle=shuffle, seed=seed)
|
|
||||||
shuffle = False
|
|
||||||
batch_size = samples_per_gpu
|
|
||||||
num_workers = workers_per_gpu
|
|
||||||
else:
|
|
||||||
sampler = None
|
|
||||||
batch_size = num_gpus * samples_per_gpu
|
|
||||||
num_workers = num_gpus * workers_per_gpu
|
|
||||||
|
|
||||||
init_fn = partial(
|
|
||||||
worker_init_fn, num_workers=num_workers, rank=rank,
|
|
||||||
seed=seed) if seed is not None else None
|
|
||||||
|
|
||||||
if digit_version(torch.__version__) >= digit_version('1.8.0'):
|
|
||||||
data_loader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
sampler=sampler,
|
|
||||||
num_workers=num_workers,
|
|
||||||
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
shuffle=shuffle,
|
|
||||||
worker_init_fn=init_fn,
|
|
||||||
drop_last=drop_last,
|
|
||||||
persistent_workers=persistent_workers,
|
|
||||||
**kwargs)
|
|
||||||
else:
|
|
||||||
data_loader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
sampler=sampler,
|
|
||||||
num_workers=num_workers,
|
|
||||||
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
shuffle=shuffle,
|
|
||||||
worker_init_fn=init_fn,
|
|
||||||
drop_last=drop_last,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
return data_loader
|
|
||||||
|
|
||||||
|
|
||||||
def worker_init_fn(worker_id, num_workers, rank, seed):
|
|
||||||
"""Worker init func for dataloader.
|
|
||||||
|
|
||||||
The seed of each worker equals to num_worker * rank + worker_id + user_seed
|
|
||||||
|
|
||||||
Args:
|
|
||||||
worker_id (int): Worker id.
|
|
||||||
num_workers (int): Number of workers.
|
|
||||||
rank (int): The rank of current process.
|
|
||||||
seed (int): The random seed to use.
|
|
||||||
"""
|
|
||||||
|
|
||||||
worker_seed = num_workers * rank + worker_id + seed
|
|
||||||
np.random.seed(worker_seed)
|
|
||||||
random.seed(worker_seed)
|
|
||||||
torch.manual_seed(worker_seed)
|
|
@ -1,195 +1,11 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import bisect
|
|
||||||
import collections
|
import collections
|
||||||
import copy
|
import copy
|
||||||
from itertools import chain
|
from typing import List, Optional, Sequence, Union
|
||||||
|
|
||||||
import mmcv
|
from mmengine.dataset import ConcatDataset, force_full_init
|
||||||
import numpy as np
|
|
||||||
from mmcv.utils import print_log
|
|
||||||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
|
||||||
|
|
||||||
from mmseg.registry import DATASETS, TRANSFORMS
|
from mmseg.registry import DATASETS, TRANSFORMS
|
||||||
from .cityscapes import CityscapesDataset
|
|
||||||
|
|
||||||
|
|
||||||
@DATASETS.register_module()
|
|
||||||
class ConcatDataset(_ConcatDataset):
|
|
||||||
"""A wrapper of concatenated dataset.
|
|
||||||
|
|
||||||
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
|
|
||||||
support evaluation and formatting results
|
|
||||||
|
|
||||||
Args:
|
|
||||||
datasets (list[:obj:`Dataset`]): A list of datasets.
|
|
||||||
separate_eval (bool): Whether to evaluate the concatenated
|
|
||||||
dataset results separately, Defaults to True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, datasets, separate_eval=True):
|
|
||||||
super(ConcatDataset, self).__init__(datasets)
|
|
||||||
self.CLASSES = datasets[0].CLASSES
|
|
||||||
self.PALETTE = datasets[0].PALETTE
|
|
||||||
self.separate_eval = separate_eval
|
|
||||||
assert separate_eval in [True, False], \
|
|
||||||
f'separate_eval can only be True or False,' \
|
|
||||||
f'but get {separate_eval}'
|
|
||||||
if any([isinstance(ds, CityscapesDataset) for ds in datasets]):
|
|
||||||
raise NotImplementedError(
|
|
||||||
'Evaluating ConcatDataset containing CityscapesDataset'
|
|
||||||
'is not supported!')
|
|
||||||
|
|
||||||
def evaluate(self, results, logger=None, **kwargs):
|
|
||||||
"""Evaluate the results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results (list[tuple[torch.Tensor]] | list[str]]): per image
|
|
||||||
pre_eval results or predict segmentation map for
|
|
||||||
computing evaluation metric.
|
|
||||||
logger (logging.Logger | str | None): Logger used for printing
|
|
||||||
related information during evaluation. Default: None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str: float]: evaluate results of the total dataset
|
|
||||||
or each separate
|
|
||||||
dataset if `self.separate_eval=True`.
|
|
||||||
"""
|
|
||||||
assert len(results) == self.cumulative_sizes[-1], \
|
|
||||||
('Dataset and results have different sizes: '
|
|
||||||
f'{self.cumulative_sizes[-1]} v.s. {len(results)}')
|
|
||||||
|
|
||||||
# Check whether all the datasets support evaluation
|
|
||||||
for dataset in self.datasets:
|
|
||||||
assert hasattr(dataset, 'evaluate'), \
|
|
||||||
f'{type(dataset)} does not implement evaluate function'
|
|
||||||
|
|
||||||
if self.separate_eval:
|
|
||||||
dataset_idx = -1
|
|
||||||
total_eval_results = dict()
|
|
||||||
for size, dataset in zip(self.cumulative_sizes, self.datasets):
|
|
||||||
start_idx = 0 if dataset_idx == -1 else \
|
|
||||||
self.cumulative_sizes[dataset_idx]
|
|
||||||
end_idx = self.cumulative_sizes[dataset_idx + 1]
|
|
||||||
|
|
||||||
results_per_dataset = results[start_idx:end_idx]
|
|
||||||
print_log(
|
|
||||||
f'\nEvaluateing {dataset.img_dir} with '
|
|
||||||
f'{len(results_per_dataset)} images now',
|
|
||||||
logger=logger)
|
|
||||||
|
|
||||||
eval_results_per_dataset = dataset.evaluate(
|
|
||||||
results_per_dataset, logger=logger, **kwargs)
|
|
||||||
dataset_idx += 1
|
|
||||||
for k, v in eval_results_per_dataset.items():
|
|
||||||
total_eval_results.update({f'{dataset_idx}_{k}': v})
|
|
||||||
|
|
||||||
return total_eval_results
|
|
||||||
|
|
||||||
if len(set([type(ds) for ds in self.datasets])) != 1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
'All the datasets should have same types when '
|
|
||||||
'self.separate_eval=False')
|
|
||||||
else:
|
|
||||||
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
|
|
||||||
results, str):
|
|
||||||
# merge the generators of gt_seg_maps
|
|
||||||
gt_seg_maps = chain(
|
|
||||||
*[dataset.get_gt_seg_maps() for dataset in self.datasets])
|
|
||||||
else:
|
|
||||||
# if the results are `pre_eval` results,
|
|
||||||
# we do not need gt_seg_maps to evaluate
|
|
||||||
gt_seg_maps = None
|
|
||||||
eval_results = self.datasets[0].evaluate(
|
|
||||||
results, gt_seg_maps=gt_seg_maps, logger=logger, **kwargs)
|
|
||||||
return eval_results
|
|
||||||
|
|
||||||
def get_dataset_idx_and_sample_idx(self, indice):
|
|
||||||
"""Return dataset and sample index when given an indice of
|
|
||||||
ConcatDataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
indice (int): indice of sample in ConcatDataset
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: the index of sub dataset the sample belong to
|
|
||||||
int: the index of sample in its corresponding subset
|
|
||||||
"""
|
|
||||||
if indice < 0:
|
|
||||||
if -indice > len(self):
|
|
||||||
raise ValueError(
|
|
||||||
'absolute value of index should not exceed dataset length')
|
|
||||||
indice = len(self) + indice
|
|
||||||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, indice)
|
|
||||||
if dataset_idx == 0:
|
|
||||||
sample_idx = indice
|
|
||||||
else:
|
|
||||||
sample_idx = indice - self.cumulative_sizes[dataset_idx - 1]
|
|
||||||
return dataset_idx, sample_idx
|
|
||||||
|
|
||||||
def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
|
|
||||||
"""format result for every sample of ConcatDataset."""
|
|
||||||
if indices is None:
|
|
||||||
indices = list(range(len(self)))
|
|
||||||
|
|
||||||
assert isinstance(results, list), 'results must be a list.'
|
|
||||||
assert isinstance(indices, list), 'indices must be a list.'
|
|
||||||
|
|
||||||
ret_res = []
|
|
||||||
for i, indice in enumerate(indices):
|
|
||||||
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
|
|
||||||
indice)
|
|
||||||
res = self.datasets[dataset_idx].format_results(
|
|
||||||
[results[i]],
|
|
||||||
imgfile_prefix + f'/{dataset_idx}',
|
|
||||||
indices=[sample_idx],
|
|
||||||
**kwargs)
|
|
||||||
ret_res.append(res)
|
|
||||||
return sum(ret_res, [])
|
|
||||||
|
|
||||||
def pre_eval(self, preds, indices):
|
|
||||||
"""do pre eval for every sample of ConcatDataset."""
|
|
||||||
# In order to compat with batch inference
|
|
||||||
if not isinstance(indices, list):
|
|
||||||
indices = [indices]
|
|
||||||
if not isinstance(preds, list):
|
|
||||||
preds = [preds]
|
|
||||||
ret_res = []
|
|
||||||
for i, indice in enumerate(indices):
|
|
||||||
dataset_idx, sample_idx = self.get_dataset_idx_and_sample_idx(
|
|
||||||
indice)
|
|
||||||
res = self.datasets[dataset_idx].pre_eval(preds[i], sample_idx)
|
|
||||||
ret_res.append(res)
|
|
||||||
return sum(ret_res, [])
|
|
||||||
|
|
||||||
|
|
||||||
@DATASETS.register_module()
|
|
||||||
class RepeatDataset(object):
|
|
||||||
"""A wrapper of repeated dataset.
|
|
||||||
|
|
||||||
The length of repeated dataset will be `times` larger than the original
|
|
||||||
dataset. This is useful when the data loading time is long but the dataset
|
|
||||||
is small. Using RepeatDataset can reduce the data loading time between
|
|
||||||
epochs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset (:obj:`Dataset`): The dataset to be repeated.
|
|
||||||
times (int): Repeat times.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dataset, times):
|
|
||||||
self.dataset = dataset
|
|
||||||
self.times = times
|
|
||||||
self.CLASSES = dataset.CLASSES
|
|
||||||
self.PALETTE = dataset.PALETTE
|
|
||||||
self._ori_len = len(self.dataset)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
"""Get item from original dataset."""
|
|
||||||
return self.dataset[idx % self._ori_len]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
"""The length is multiplied by ``times``"""
|
|
||||||
return self.times * self._ori_len
|
|
||||||
|
|
||||||
|
|
||||||
@DATASETS.register_module()
|
@DATASETS.register_module()
|
||||||
@ -197,22 +13,32 @@ class MultiImageMixDataset:
|
|||||||
"""A wrapper of multiple images mixed dataset.
|
"""A wrapper of multiple images mixed dataset.
|
||||||
|
|
||||||
Suitable for training on multiple images mixed data augmentation like
|
Suitable for training on multiple images mixed data augmentation like
|
||||||
mosaic and mixup. For the augmentation pipeline of mixed image data,
|
mosaic and mixup.
|
||||||
the `get_indexes` method needs to be provided to obtain the image
|
|
||||||
indexes, and you can set `skip_flags` to change the pipeline running
|
|
||||||
process.
|
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset (:obj:`CustomDataset`): The dataset to be mixed.
|
dataset (ConcatDataset or dict): The dataset to be mixed.
|
||||||
pipeline (Sequence[dict]): Sequence of transform object or
|
pipeline (Sequence[dict]): Sequence of transform object or
|
||||||
config dict to be composed.
|
config dict to be composed.
|
||||||
skip_type_keys (list[str], optional): Sequence of type string to
|
skip_type_keys (list[str], optional): Sequence of type string to
|
||||||
be skip pipeline. Default to None.
|
be skip pipeline. Default to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset, pipeline, skip_type_keys=None):
|
def __init__(self,
|
||||||
|
dataset: Union[ConcatDataset, dict],
|
||||||
|
pipeline: Sequence[dict],
|
||||||
|
skip_type_keys: Optional[List[str]] = None,
|
||||||
|
lazy_init: bool = False) -> None:
|
||||||
assert isinstance(pipeline, collections.abc.Sequence)
|
assert isinstance(pipeline, collections.abc.Sequence)
|
||||||
|
|
||||||
|
if isinstance(dataset, dict):
|
||||||
|
self.dataset = DATASETS.build(dataset)
|
||||||
|
elif isinstance(dataset, ConcatDataset):
|
||||||
|
self.dataset = dataset
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'elements in datasets sequence should be config or '
|
||||||
|
f'`ConcatDataset` instance, but got {type(dataset)}')
|
||||||
|
|
||||||
if skip_type_keys is not None:
|
if skip_type_keys is not None:
|
||||||
assert all([
|
assert all([
|
||||||
isinstance(skip_type_key, str)
|
isinstance(skip_type_key, str)
|
||||||
@ -230,11 +56,44 @@ class MultiImageMixDataset:
|
|||||||
else:
|
else:
|
||||||
raise TypeError('pipeline must be a dict')
|
raise TypeError('pipeline must be a dict')
|
||||||
|
|
||||||
self.dataset = dataset
|
self._metainfo = self.dataset.metainfo
|
||||||
self.CLASSES = dataset.CLASSES
|
self.num_samples = len(self.dataset)
|
||||||
self.PALETTE = dataset.PALETTE
|
|
||||||
self.num_samples = len(dataset)
|
|
||||||
|
|
||||||
|
self._fully_initialized = False
|
||||||
|
if not lazy_init:
|
||||||
|
self.full_init()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metainfo(self) -> dict:
|
||||||
|
"""Get the meta information of the multi-image-mixed dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The meta information of multi-image-mixed dataset.
|
||||||
|
"""
|
||||||
|
return copy.deepcopy(self._metainfo)
|
||||||
|
|
||||||
|
def full_init(self):
|
||||||
|
"""Loop to ``full_init`` each dataset."""
|
||||||
|
if self._fully_initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.dataset.full_init()
|
||||||
|
self._ori_len = len(self.dataset)
|
||||||
|
self._fully_initialized = True
|
||||||
|
|
||||||
|
@force_full_init
|
||||||
|
def get_data_info(self, idx: int) -> dict:
|
||||||
|
"""Get annotation by index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): Global index of ``ConcatDataset``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The idx-th annotation of the datasets.
|
||||||
|
"""
|
||||||
|
return self.dataset.get_data_info(idx)
|
||||||
|
|
||||||
|
@force_full_init
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
|
||||||
|
151
tests/test_datasets/test_dataset_builder.py
Normal file
151
tests/test_datasets/test_dataset_builder.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from mmengine.dataset import ConcatDataset, RepeatDataset
|
||||||
|
|
||||||
|
from mmseg.datasets import DATASETS, MultiImageMixDataset
|
||||||
|
from mmseg.utils import register_all_modules
|
||||||
|
|
||||||
|
register_all_modules()
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class ToyDataset(object):
|
||||||
|
|
||||||
|
def __init__(self, cnt=0):
|
||||||
|
self.cnt = cnt
|
||||||
|
|
||||||
|
def __item__(self, idx):
|
||||||
|
return idx
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_dataset():
|
||||||
|
cfg = dict(type='ToyDataset')
|
||||||
|
dataset = DATASETS.build(cfg)
|
||||||
|
assert isinstance(dataset, ToyDataset)
|
||||||
|
assert dataset.cnt == 0
|
||||||
|
dataset = DATASETS.build(cfg, default_args=dict(cnt=1))
|
||||||
|
assert isinstance(dataset, ToyDataset)
|
||||||
|
assert dataset.cnt == 1
|
||||||
|
|
||||||
|
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset')
|
||||||
|
data_prefix = dict(img_path='imgs/', seg_map_path='gts/')
|
||||||
|
|
||||||
|
# test RepeatDataset
|
||||||
|
cfg = dict(
|
||||||
|
type='CustomDataset',
|
||||||
|
pipeline=[],
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=data_prefix,
|
||||||
|
serialize_data=False)
|
||||||
|
dataset = DATASETS.build(cfg)
|
||||||
|
dataset_repeat = RepeatDataset(dataset=dataset, times=5)
|
||||||
|
assert isinstance(dataset_repeat, RepeatDataset)
|
||||||
|
assert len(dataset_repeat) == 25
|
||||||
|
|
||||||
|
# test ConcatDataset
|
||||||
|
# We use same dir twice for simplicity
|
||||||
|
# with data_prefix.seg_map_path
|
||||||
|
cfg1 = dict(
|
||||||
|
type='CustomDataset',
|
||||||
|
pipeline=[],
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=data_prefix,
|
||||||
|
serialize_data=False)
|
||||||
|
cfg2 = dict(
|
||||||
|
type='CustomDataset',
|
||||||
|
pipeline=[],
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=data_prefix,
|
||||||
|
serialize_data=False)
|
||||||
|
dataset1 = DATASETS.build(cfg1)
|
||||||
|
dataset2 = DATASETS.build(cfg2)
|
||||||
|
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
||||||
|
assert isinstance(dataset_concat, ConcatDataset)
|
||||||
|
assert len(dataset_concat) == 10
|
||||||
|
|
||||||
|
# test MultiImageMixDataset
|
||||||
|
dataset = MultiImageMixDataset(dataset=dataset_concat, pipeline=[])
|
||||||
|
assert isinstance(dataset, MultiImageMixDataset)
|
||||||
|
assert len(dataset) == 10
|
||||||
|
|
||||||
|
cfg = dict(type='ConcatDataset', datasets=[cfg1, cfg2])
|
||||||
|
|
||||||
|
dataset = MultiImageMixDataset(dataset=cfg, pipeline=[])
|
||||||
|
assert isinstance(dataset, MultiImageMixDataset)
|
||||||
|
assert len(dataset) == 10
|
||||||
|
|
||||||
|
# with data_prefix.seg_map_path, ann_file
|
||||||
|
cfg1 = dict(
|
||||||
|
type='CustomDataset',
|
||||||
|
pipeline=[],
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=data_prefix,
|
||||||
|
ann_file='splits/train.txt',
|
||||||
|
serialize_data=False)
|
||||||
|
cfg2 = dict(
|
||||||
|
type='CustomDataset',
|
||||||
|
pipeline=[],
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=data_prefix,
|
||||||
|
ann_file='splits/val.txt',
|
||||||
|
serialize_data=False)
|
||||||
|
|
||||||
|
dataset1 = DATASETS.build(cfg1)
|
||||||
|
dataset2 = DATASETS.build(cfg2)
|
||||||
|
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
||||||
|
assert isinstance(dataset_concat, ConcatDataset)
|
||||||
|
assert len(dataset_concat) == 5
|
||||||
|
|
||||||
|
# test mode
|
||||||
|
cfg1 = dict(
|
||||||
|
type='CustomDataset',
|
||||||
|
pipeline=[],
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=dict(img_path='imgs/'),
|
||||||
|
test_mode=True,
|
||||||
|
metainfo=dict(classes=('pseudo_class', )),
|
||||||
|
serialize_data=False)
|
||||||
|
cfg2 = dict(
|
||||||
|
type='CustomDataset',
|
||||||
|
pipeline=[],
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=dict(img_path='imgs/'),
|
||||||
|
test_mode=True,
|
||||||
|
metainfo=dict(classes=('pseudo_class', )),
|
||||||
|
serialize_data=False)
|
||||||
|
|
||||||
|
dataset1 = DATASETS.build(cfg1)
|
||||||
|
dataset2 = DATASETS.build(cfg2)
|
||||||
|
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
||||||
|
assert isinstance(dataset_concat, ConcatDataset)
|
||||||
|
assert len(dataset_concat) == 10
|
||||||
|
|
||||||
|
# test mode with ann_files
|
||||||
|
cfg1 = dict(
|
||||||
|
type='CustomDataset',
|
||||||
|
pipeline=[],
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=dict(img_path='imgs/'),
|
||||||
|
ann_file='splits/val.txt',
|
||||||
|
test_mode=True,
|
||||||
|
metainfo=dict(classes=('pseudo_class', )),
|
||||||
|
serialize_data=False)
|
||||||
|
cfg2 = dict(
|
||||||
|
type='CustomDataset',
|
||||||
|
pipeline=[],
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=dict(img_path='imgs/'),
|
||||||
|
ann_file='splits/val.txt',
|
||||||
|
test_mode=True,
|
||||||
|
metainfo=dict(classes=('pseudo_class', )),
|
||||||
|
serialize_data=False)
|
||||||
|
|
||||||
|
dataset1 = DATASETS.build(cfg1)
|
||||||
|
dataset2 = DATASETS.build(cfg2)
|
||||||
|
dataset_concat = ConcatDataset(datasets=[dataset1, dataset2])
|
||||||
|
assert isinstance(dataset_concat, ConcatDataset)
|
||||||
|
assert len(dataset_concat) == 2
|
@ -25,7 +25,7 @@ class TestLoading(object):
|
|||||||
assert results['img'].dtype == np.uint8
|
assert results['img'].dtype == np.uint8
|
||||||
assert results['ori_shape'] == results['img'].shape[:2]
|
assert results['ori_shape'] == results['img'].shape[:2]
|
||||||
assert repr(transform) == transform.__class__.__name__ + \
|
assert repr(transform) == transform.__class__.__name__ + \
|
||||||
"(to_float32=False, color_type='color'," + \
|
"(ignore_empty=False, to_float32=False, color_type='color'," + \
|
||||||
" imdecode_backend='cv2', file_client_args={'backend': 'disk'})"
|
" imdecode_backend='cv2', file_client_args={'backend': 'disk'})"
|
||||||
|
|
||||||
# to_float32
|
# to_float32
|
||||||
|
@ -8,7 +8,7 @@ import mmcv
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv import Config, DictAction
|
from mmcv import Config, DictAction
|
||||||
|
|
||||||
from mmseg.datasets.builder import build_dataset
|
from mmseg.datasets import DATASETS
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -159,7 +159,7 @@ def main():
|
|||||||
args = parse_args()
|
args = parse_args()
|
||||||
cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options,
|
cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options,
|
||||||
args.show_origin)
|
args.show_origin)
|
||||||
dataset = build_dataset(cfg.data.train)
|
dataset = DATASETS.build(cfg.data.train)
|
||||||
progress_bar = mmcv.ProgressBar(len(dataset))
|
progress_bar = mmcv.ProgressBar(len(dataset))
|
||||||
for item in dataset:
|
for item in dataset:
|
||||||
filename = os.path.join(args.output_dir,
|
filename = os.path.join(args.output_dir,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user