parent
3c8806e4c6
commit
4dcbd269aa
|
@ -8,9 +8,11 @@ from typing import List, Sequence, Tuple, Union
|
||||||
|
|
||||||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
||||||
|
|
||||||
|
from mmengine.registry import DATASETS
|
||||||
from .base_dataset import BaseDataset, force_full_init
|
from .base_dataset import BaseDataset, force_full_init
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
class ConcatDataset(_ConcatDataset):
|
class ConcatDataset(_ConcatDataset):
|
||||||
"""A wrapper of concatenated dataset.
|
"""A wrapper of concatenated dataset.
|
||||||
|
|
||||||
|
@ -24,19 +26,28 @@ class ConcatDataset(_ConcatDataset):
|
||||||
arguments for wrapped dataset which inherit from ``BaseDataset``.
|
arguments for wrapped dataset which inherit from ``BaseDataset``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
datasets (Sequence[BaseDataset]): A list of datasets which will be
|
datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets
|
||||||
concatenated.
|
which will be concatenated.
|
||||||
lazy_init (bool, optional): Whether to load annotation during
|
lazy_init (bool, optional): Whether to load annotation during
|
||||||
instantiation. Defaults to False.
|
instantiation. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
datasets: Sequence[BaseDataset],
|
datasets: Sequence[Union[BaseDataset, dict]],
|
||||||
lazy_init: bool = False):
|
lazy_init: bool = False):
|
||||||
|
self.datasets: List[BaseDataset] = []
|
||||||
|
for i, dataset in enumerate(datasets):
|
||||||
|
if isinstance(dataset, dict):
|
||||||
|
self.datasets.append(DATASETS.build(dataset))
|
||||||
|
elif isinstance(dataset, BaseDataset):
|
||||||
|
self.datasets.append(dataset)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'elements in datasets sequence should be config or '
|
||||||
|
f'`BaseDataset` instance, but got {type(dataset)}')
|
||||||
# Only use metainfo of first dataset.
|
# Only use metainfo of first dataset.
|
||||||
self._metainfo = datasets[0].metainfo
|
self._metainfo = self.datasets[0].metainfo
|
||||||
self.datasets = datasets # type: ignore
|
for i, dataset in enumerate(self.datasets, 1):
|
||||||
for i, dataset in enumerate(datasets, 1):
|
|
||||||
if self._metainfo != dataset.metainfo:
|
if self._metainfo != dataset.metainfo:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'The meta information of the {i}-th dataset does not '
|
f'The meta information of the {i}-th dataset does not '
|
||||||
|
@ -140,6 +151,7 @@ class ConcatDataset(_ConcatDataset):
|
||||||
'dataset first and then use `ConcatDataset`.')
|
'dataset first and then use `ConcatDataset`.')
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
class RepeatDataset:
|
class RepeatDataset:
|
||||||
"""A wrapper of repeated dataset.
|
"""A wrapper of repeated dataset.
|
||||||
|
|
||||||
|
@ -156,19 +168,27 @@ class RepeatDataset:
|
||||||
arguments for wrapped dataset which inherit from ``BaseDataset``.
|
arguments for wrapped dataset which inherit from ``BaseDataset``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset (BaseDataset): The dataset to be repeated.
|
dataset (BaseDataset or dict): The dataset to be repeated.
|
||||||
times (int): Repeat times.
|
times (int): Repeat times.
|
||||||
lazy_init (bool): Whether to load annotation during
|
lazy_init (bool): Whether to load annotation during
|
||||||
instantiation. Defaults to False.
|
instantiation. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
dataset: BaseDataset,
|
dataset: Union[BaseDataset, dict],
|
||||||
times: int,
|
times: int,
|
||||||
lazy_init: bool = False):
|
lazy_init: bool = False):
|
||||||
|
self.dataset: BaseDataset
|
||||||
|
if isinstance(dataset, dict):
|
||||||
|
self.dataset = DATASETS.build(dataset)
|
||||||
|
elif isinstance(dataset, BaseDataset):
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'elements in datasets sequence should be config or '
|
||||||
|
f'`BaseDataset` instance, but got {type(dataset)}')
|
||||||
self.times = times
|
self.times = times
|
||||||
self._metainfo = dataset.metainfo
|
self._metainfo = self.dataset.metainfo
|
||||||
|
|
||||||
self._fully_initialized = False
|
self._fully_initialized = False
|
||||||
if not lazy_init:
|
if not lazy_init:
|
||||||
|
@ -283,7 +303,7 @@ class ClassBalancedDataset:
|
||||||
``BaseDataset``.
|
``BaseDataset``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset (BaseDataset): The dataset to be repeated.
|
dataset (BaseDataset or dict): The dataset to be repeated.
|
||||||
oversample_thr (float): frequency threshold below which data is
|
oversample_thr (float): frequency threshold below which data is
|
||||||
repeated. For categories with ``f_c >= oversample_thr``, there is
|
repeated. For categories with ``f_c >= oversample_thr``, there is
|
||||||
no oversampling. For categories with ``f_c < oversample_thr``, the
|
no oversampling. For categories with ``f_c < oversample_thr``, the
|
||||||
|
@ -294,12 +314,19 @@ class ClassBalancedDataset:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
dataset: BaseDataset,
|
dataset: Union[BaseDataset, dict],
|
||||||
oversample_thr: float,
|
oversample_thr: float,
|
||||||
lazy_init: bool = False):
|
lazy_init: bool = False):
|
||||||
|
if isinstance(dataset, dict):
|
||||||
|
self.dataset = DATASETS.build(dataset)
|
||||||
|
elif isinstance(dataset, BaseDataset):
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'elements in datasets sequence should be config or '
|
||||||
|
f'`BaseDataset` instance, but got {type(dataset)}')
|
||||||
self.oversample_thr = oversample_thr
|
self.oversample_thr = oversample_thr
|
||||||
self._metainfo = dataset.metainfo
|
self._metainfo = self.dataset.metainfo
|
||||||
|
|
||||||
self._fully_initialized = False
|
self._fully_initialized = False
|
||||||
if not lazy_init:
|
if not lazy_init:
|
||||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
||||||
|
|
||||||
from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose,
|
from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose,
|
||||||
ConcatDataset, RepeatDataset, force_full_init)
|
ConcatDataset, RepeatDataset, force_full_init)
|
||||||
from mmengine.registry import TRANSFORMS
|
from mmengine.registry import DATASETS, TRANSFORMS
|
||||||
|
|
||||||
|
|
||||||
def function_pipeline(data_info):
|
def function_pipeline(data_info):
|
||||||
|
@ -27,6 +27,11 @@ class NotCallableTransform:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class CustomDataset(BaseDataset):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestBaseDataset:
|
class TestBaseDataset:
|
||||||
dataset_type = BaseDataset
|
dataset_type = BaseDataset
|
||||||
data_info = dict(
|
data_info = dict(
|
||||||
|
@ -566,7 +571,7 @@ class TestBaseDataset:
|
||||||
|
|
||||||
class TestConcatDataset:
|
class TestConcatDataset:
|
||||||
|
|
||||||
def _init_dataset(self):
|
def setup(self):
|
||||||
dataset = BaseDataset
|
dataset = BaseDataset
|
||||||
|
|
||||||
# create dataset_a
|
# create dataset_a
|
||||||
|
@ -593,8 +598,25 @@ class TestConcatDataset:
|
||||||
self.cat_datasets = ConcatDataset(
|
self.cat_datasets = ConcatDataset(
|
||||||
datasets=[self.dataset_a, self.dataset_b])
|
datasets=[self.dataset_a, self.dataset_b])
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
# Test build dataset from cfg.
|
||||||
|
dataset_cfg_b = dict(
|
||||||
|
type=CustomDataset,
|
||||||
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||||
|
data_prefix=dict(img='imgs'),
|
||||||
|
ann_file='annotations/dummy_annotation.json')
|
||||||
|
cat_datasets = ConcatDataset(datasets=[self.dataset_a, dataset_cfg_b])
|
||||||
|
cat_datasets.datasets[1].pipeline = self.dataset_b.pipeline
|
||||||
|
assert len(cat_datasets) == len(self.cat_datasets)
|
||||||
|
for i in range(len(cat_datasets)):
|
||||||
|
assert (cat_datasets.get_data_info(i) ==
|
||||||
|
self.cat_datasets.get_data_info(i))
|
||||||
|
assert (cat_datasets[i] == self.cat_datasets[i])
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
ConcatDataset(datasets=[0])
|
||||||
|
|
||||||
def test_full_init(self):
|
def test_full_init(self):
|
||||||
self._init_dataset()
|
|
||||||
# test init with lazy_init=True
|
# test init with lazy_init=True
|
||||||
self.cat_datasets.full_init()
|
self.cat_datasets.full_init()
|
||||||
assert len(self.cat_datasets) == 6
|
assert len(self.cat_datasets) == 6
|
||||||
|
@ -618,16 +640,13 @@ class TestConcatDataset:
|
||||||
ConcatDataset(datasets=[self.dataset_a, dataset_b])
|
ConcatDataset(datasets=[self.dataset_a, dataset_b])
|
||||||
|
|
||||||
def test_metainfo(self):
|
def test_metainfo(self):
|
||||||
self._init_dataset()
|
|
||||||
assert self.cat_datasets.metainfo == self.dataset_a.metainfo
|
assert self.cat_datasets.metainfo == self.dataset_a.metainfo
|
||||||
|
|
||||||
def test_length(self):
|
def test_length(self):
|
||||||
self._init_dataset()
|
|
||||||
assert len(self.cat_datasets) == (
|
assert len(self.cat_datasets) == (
|
||||||
len(self.dataset_a) + len(self.dataset_b))
|
len(self.dataset_a) + len(self.dataset_b))
|
||||||
|
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
self._init_dataset()
|
|
||||||
assert (
|
assert (
|
||||||
self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all()
|
self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all()
|
||||||
assert (self.cat_datasets[0]['imgs'] !=
|
assert (self.cat_datasets[0]['imgs'] !=
|
||||||
|
@ -639,7 +658,6 @@ class TestConcatDataset:
|
||||||
self.dataset_a[-1]['imgs']).all()
|
self.dataset_a[-1]['imgs']).all()
|
||||||
|
|
||||||
def test_get_data_info(self):
|
def test_get_data_info(self):
|
||||||
self._init_dataset()
|
|
||||||
assert self.cat_datasets.get_data_info(
|
assert self.cat_datasets.get_data_info(
|
||||||
0) == self.dataset_a.get_data_info(0)
|
0) == self.dataset_a.get_data_info(0)
|
||||||
assert self.cat_datasets.get_data_info(
|
assert self.cat_datasets.get_data_info(
|
||||||
|
@ -651,7 +669,6 @@ class TestConcatDataset:
|
||||||
-1) != self.dataset_a.get_data_info(-1)
|
-1) != self.dataset_a.get_data_info(-1)
|
||||||
|
|
||||||
def test_get_ori_dataset_idx(self):
|
def test_get_ori_dataset_idx(self):
|
||||||
self._init_dataset()
|
|
||||||
assert self.cat_datasets._get_ori_dataset_idx(3) == (
|
assert self.cat_datasets._get_ori_dataset_idx(3) == (
|
||||||
1, 3 - len(self.dataset_a))
|
1, 3 - len(self.dataset_a))
|
||||||
assert self.cat_datasets._get_ori_dataset_idx(-1) == (
|
assert self.cat_datasets._get_ori_dataset_idx(-1) == (
|
||||||
|
@ -662,7 +679,7 @@ class TestConcatDataset:
|
||||||
|
|
||||||
class TestRepeatDataset:
|
class TestRepeatDataset:
|
||||||
|
|
||||||
def _init_dataset(self):
|
def setup(self):
|
||||||
dataset = BaseDataset
|
dataset = BaseDataset
|
||||||
data_info = dict(filename='test_img.jpg', height=604, width=640)
|
data_info = dict(filename='test_img.jpg', height=604, width=640)
|
||||||
dataset.parse_data_info = MagicMock(return_value=data_info)
|
dataset.parse_data_info = MagicMock(return_value=data_info)
|
||||||
|
@ -678,9 +695,26 @@ class TestRepeatDataset:
|
||||||
self.repeat_datasets = RepeatDataset(
|
self.repeat_datasets = RepeatDataset(
|
||||||
dataset=self.dataset, times=self.repeat_times)
|
dataset=self.dataset, times=self.repeat_times)
|
||||||
|
|
||||||
def test_full_init(self):
|
def test_init(self):
|
||||||
self._init_dataset()
|
# Test build dataset from cfg.
|
||||||
|
dataset_cfg = dict(
|
||||||
|
type=CustomDataset,
|
||||||
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||||
|
data_prefix=dict(img='imgs'),
|
||||||
|
ann_file='annotations/dummy_annotation.json')
|
||||||
|
repeat_dataset = RepeatDataset(
|
||||||
|
dataset=dataset_cfg, times=self.repeat_times)
|
||||||
|
repeat_dataset.dataset.pipeline = self.dataset.pipeline
|
||||||
|
assert len(repeat_dataset) == len(self.repeat_datasets)
|
||||||
|
for i in range(len(repeat_dataset)):
|
||||||
|
assert (repeat_dataset.get_data_info(i) ==
|
||||||
|
self.repeat_datasets.get_data_info(i))
|
||||||
|
assert (repeat_dataset[i] == self.repeat_datasets[i])
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
RepeatDataset(dataset=[0], times=5)
|
||||||
|
|
||||||
|
def test_full_init(self):
|
||||||
self.repeat_datasets.full_init()
|
self.repeat_datasets.full_init()
|
||||||
assert len(
|
assert len(
|
||||||
self.repeat_datasets) == self.repeat_times * len(self.dataset)
|
self.repeat_datasets) == self.repeat_times * len(self.dataset)
|
||||||
|
@ -697,22 +731,18 @@ class TestRepeatDataset:
|
||||||
self.repeat_datasets.get_subset(1)
|
self.repeat_datasets.get_subset(1)
|
||||||
|
|
||||||
def test_metainfo(self):
|
def test_metainfo(self):
|
||||||
self._init_dataset()
|
|
||||||
assert self.repeat_datasets.metainfo == self.dataset.metainfo
|
assert self.repeat_datasets.metainfo == self.dataset.metainfo
|
||||||
|
|
||||||
def test_length(self):
|
def test_length(self):
|
||||||
self._init_dataset()
|
|
||||||
assert len(
|
assert len(
|
||||||
self.repeat_datasets) == len(self.dataset) * self.repeat_times
|
self.repeat_datasets) == len(self.dataset) * self.repeat_times
|
||||||
|
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
self._init_dataset()
|
|
||||||
for i in range(self.repeat_times):
|
for i in range(self.repeat_times):
|
||||||
assert self.repeat_datasets[len(self.dataset) *
|
assert self.repeat_datasets[len(self.dataset) *
|
||||||
i] == self.dataset[0]
|
i] == self.dataset[0]
|
||||||
|
|
||||||
def test_get_data_info(self):
|
def test_get_data_info(self):
|
||||||
self._init_dataset()
|
|
||||||
for i in range(self.repeat_times):
|
for i in range(self.repeat_times):
|
||||||
assert self.repeat_datasets.get_data_info(
|
assert self.repeat_datasets.get_data_info(
|
||||||
len(self.dataset) * i) == self.dataset.get_data_info(0)
|
len(self.dataset) * i) == self.dataset.get_data_info(0)
|
||||||
|
@ -720,7 +750,7 @@ class TestRepeatDataset:
|
||||||
|
|
||||||
class TestClassBalancedDataset:
|
class TestClassBalancedDataset:
|
||||||
|
|
||||||
def _init_dataset(self):
|
def setup(self):
|
||||||
dataset = BaseDataset
|
dataset = BaseDataset
|
||||||
data_info = dict(filename='test_img.jpg', height=604, width=640)
|
data_info = dict(filename='test_img.jpg', height=604, width=640)
|
||||||
dataset.parse_data_info = MagicMock(return_value=data_info)
|
dataset.parse_data_info = MagicMock(return_value=data_info)
|
||||||
|
@ -738,17 +768,35 @@ class TestClassBalancedDataset:
|
||||||
dataset=self.dataset, oversample_thr=1e-3)
|
dataset=self.dataset, oversample_thr=1e-3)
|
||||||
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
|
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
|
||||||
|
|
||||||
def test_full_init(self):
|
def test_init(self):
|
||||||
self._init_dataset()
|
# Test build dataset from cfg.
|
||||||
|
dataset_cfg = dict(
|
||||||
|
type=CustomDataset,
|
||||||
|
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||||
|
data_prefix=dict(img='imgs'),
|
||||||
|
ann_file='annotations/dummy_annotation.json')
|
||||||
|
cls_banlanced_datasets = ClassBalancedDataset(
|
||||||
|
dataset=dataset_cfg, oversample_thr=1e-3)
|
||||||
|
cls_banlanced_datasets.repeat_indices = self.repeat_indices
|
||||||
|
cls_banlanced_datasets.dataset.pipeline = self.dataset.pipeline
|
||||||
|
assert len(cls_banlanced_datasets) == len(self.cls_banlanced_datasets)
|
||||||
|
for i in range(len(cls_banlanced_datasets)):
|
||||||
|
assert (cls_banlanced_datasets.get_data_info(i) ==
|
||||||
|
self.cls_banlanced_datasets.get_data_info(i))
|
||||||
|
assert (
|
||||||
|
cls_banlanced_datasets[i] == self.cls_banlanced_datasets[i])
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
ClassBalancedDataset(dataset=[0], times=5)
|
||||||
|
|
||||||
|
def test_full_init(self):
|
||||||
self.cls_banlanced_datasets.full_init()
|
self.cls_banlanced_datasets.full_init()
|
||||||
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
|
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
|
||||||
assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
|
assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
|
||||||
self.cls_banlanced_datasets.full_init()
|
# Reinit `repeat_indices`.
|
||||||
self.cls_banlanced_datasets._fully_initialized = False
|
self.cls_banlanced_datasets._fully_initialized = False
|
||||||
self.cls_banlanced_datasets[1]
|
|
||||||
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
|
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
|
||||||
assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
|
assert len(self.cls_banlanced_datasets) != len(self.repeat_indices)
|
||||||
|
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
self.cls_banlanced_datasets.get_subset_(1)
|
self.cls_banlanced_datasets.get_subset_(1)
|
||||||
|
@ -757,27 +805,22 @@ class TestClassBalancedDataset:
|
||||||
self.cls_banlanced_datasets.get_subset(1)
|
self.cls_banlanced_datasets.get_subset(1)
|
||||||
|
|
||||||
def test_metainfo(self):
|
def test_metainfo(self):
|
||||||
self._init_dataset()
|
|
||||||
assert self.cls_banlanced_datasets.metainfo == self.dataset.metainfo
|
assert self.cls_banlanced_datasets.metainfo == self.dataset.metainfo
|
||||||
|
|
||||||
def test_length(self):
|
def test_length(self):
|
||||||
self._init_dataset()
|
|
||||||
assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
|
assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
|
||||||
|
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
self._init_dataset()
|
|
||||||
for i in range(len(self.repeat_indices)):
|
for i in range(len(self.repeat_indices)):
|
||||||
assert self.cls_banlanced_datasets[i] == self.dataset[
|
assert self.cls_banlanced_datasets[i] == self.dataset[
|
||||||
self.repeat_indices[i]]
|
self.repeat_indices[i]]
|
||||||
|
|
||||||
def test_get_data_info(self):
|
def test_get_data_info(self):
|
||||||
self._init_dataset()
|
|
||||||
for i in range(len(self.repeat_indices)):
|
for i in range(len(self.repeat_indices)):
|
||||||
assert self.cls_banlanced_datasets.get_data_info(
|
assert self.cls_banlanced_datasets.get_data_info(
|
||||||
i) == self.dataset.get_data_info(self.repeat_indices[i])
|
i) == self.dataset.get_data_info(self.repeat_indices[i])
|
||||||
|
|
||||||
def test_get_cat_ids(self):
|
def test_get_cat_ids(self):
|
||||||
self._init_dataset()
|
|
||||||
for i in range(len(self.repeat_indices)):
|
for i in range(len(self.repeat_indices)):
|
||||||
assert self.cls_banlanced_datasets.get_cat_ids(
|
assert self.cls_banlanced_datasets.get_cat_ids(
|
||||||
i) == self.dataset.get_cat_ids(self.repeat_indices[i])
|
i) == self.dataset.get_cat_ids(self.repeat_indices[i])
|
||||||
|
|
Loading…
Reference in New Issue