[Feature] add base dataset (#32)

* basedataset first commit

* add base dataset

* add dataset

* add basedataset

* Fix test dataset

* Fix mypy and test

* Fix mypy and test

* remove unused code

* Update mmengine/dataset/base_dataset.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/dataset/base_dataset.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* add more corner cases in unittest

* fix lint

* Fix as comment

* Fix lint

* update unitest

* Type hint Dick to dict

* rename max_refetch

* Fix as comment

* Fix typo

* Fix as comment

* BaseDataset is no more an abstrac Class, change UT and docs

* Fix as comment

* Fix as comment and refactor type error

* Add comment for full init

* Fix as comment and modify dataset_wrapper

* Fix as comment and modify dataset_wrapper

* Fix as comment

* Fix as comment

* Fix as comment

* Fix as comment

* Fix as comment

* Fix as comment

* Fix as comment

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: Tao Gong <gongtao950513@gmail.com>
This commit is contained in:
Mashiro 2022-02-22 14:01:06 +08:00 committed by GitHub
parent adb2aee8c2
commit ada6660c65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1269 additions and 74 deletions

View File

@ -69,7 +69,7 @@ data
2. 构建数据流水线data pipeline用于数据预处理与数据准备
3. 读取与解析满足 OpenMMLab 2.0 数据集格式规范的标注文件,该步骤中会有 `parse_annotations()` 抽象方法,该抽象方法负责解析标注文件里的每个原始数据;
3. 读取与解析满足 OpenMMLab 2.0 数据集格式规范的标注文件,该步骤中会有 `parse_annotations()` 方法,该方法负责解析标注文件里的每个原始数据;
4. 过滤无用数据,比如不包含标注的样本等;
@ -77,8 +77,6 @@ data
6. 序列化全部样本,以达到节省内存的效果,详情请参考[节省内存](#节省内存)。
数据集基类是一个抽象类,它有且只有一个抽象方法 `parse_annotations()` `parse_annotations()` 定义了将标注文件里的一个原始数据处理成一个或若干个训练/测试样本的方法。因此对于自定义数据集类,用户必须要实现 `parse_annotations()` 方法。
### 数据集基类提供的接口
`torch.utils.data.Dataset` 类似,数据集初始化后,支持 `__getitem__` 方法,用来索引数据,以及 `__len__` 操作获取数据集大小除此之外OpenMMLab 的数据集基类主要提供了以下接口来访问具体信息:
@ -93,7 +91,7 @@ data
## 使用数据集基类自定义数据集类
在了解了数据集基类的初始化流程与提供的接口之后,就可以基于数据集基类自定义数据集类,如上所述,数据集基类是一个抽象类,它有且只有一个抽象方法 `parse_annotations()`,因此用户必须在自定义数据集类中实现该方法。以下是一个使用数据集基类来实现某一具体数据集的例子。
在了解了数据集基类的初始化流程与提供的接口之后,就可以基于数据集基类自定义数据集类,如上所述,对于满足 OpenMMLab 2.0 数据集格式规范的标注文件,用户可以重载 `parse_annotations()`来加载标签。以下是一个使用数据集基类来实现某一具体数据集的例子。
```python
import os.path as osp

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa
from .config import *
from .dataset import *
from .fileio import *
from .registry import *
from .utils import *

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa
from .base_dataset import BaseDataset, Compose, force_full_init
from .dataset_wrapper import ClassBalancedDataset, ConcatDataset, RepeatDataset

View File

@ -0,0 +1,553 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import functools
import gc
import os.path as osp
import pickle
import warnings
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np
from torch.utils.data import Dataset
from mmengine.fileio import list_from_file, load
from mmengine.registry import TRANSFORMS
from mmengine.utils import check_file_exist
class Compose:
"""Compose multiple transforms sequentially.
Args:
transforms (Sequence[dict, callable]): Sequence of transform object or
config dict to be composed.
"""
def __init__(self, transforms: Sequence[Union[dict, Callable]]):
self.transforms: List[Callable] = []
for transform in transforms:
if isinstance(transform, dict):
transform = TRANSFORMS.build(transform)
if not callable(transform):
raise TypeError(f'transform should be a callable object, '
f'but got {type(transform)}')
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
else:
raise TypeError(
f'transform must be a callable object or dict, '
f'but got {type(transform)}')
def __call__(self, data: dict) -> Optional[dict]:
"""Call function to apply transforms sequentially.
Args:
data (dict): A result dict contains the data to transform.
Returns:
dict: Transformed data.
"""
for t in self.transforms:
data = t(data)
if data is None:
return None
return data
def __repr__(self):
"""Print ``self.transforms`` in sequence.
Returns:
str: Formatted string.
"""
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += f' {t}'
format_string += '\n)'
return format_string
def force_full_init(old_func: Callable) -> Any:
"""Those methods decorated by ``force_full_init`` will be forced to call
``full_init`` if the instance has not been fully initiated.
Args:
old_func (Callable): Decorated function, make sure the first arg is an
instance with ``full_init`` method.
Returns:
Any: Depends on old_func.
"""
# TODO This decorator can also be used in runner.
@functools.wraps(old_func)
def wrapper(obj: object, *args, **kwargs):
if not hasattr(obj, 'full_init'):
raise AttributeError(f'{type(obj)} does not have full_init '
'method.')
if not getattr(obj, '_fully_initialized', False):
warnings.warn('Attribute `_fully_initialized` is not defined in '
f'{type(obj)} or `type(obj)._fully_initialized is '
'False, `full_init` will be called and '
f'{type(obj)}._fully_initialized will be set to '
'True')
obj.full_init() # type: ignore
obj._fully_initialized = True # type: ignore
return old_func(obj, *args, **kwargs)
return wrapper
class BaseDataset(Dataset):
r"""BaseDataset for open source projects in OpenMMLab.
The annotation format is shown as follows.
.. code-block:: none
{
"metadata":
{
"dataset_type": "test_dataset",
"task_name": "test_task"
},
"data_infos":
[
{
"img_path": "test_img.jpg",
"height": 604,
"width": 640,
"instances":
[
{
"bbox": [0, 0, 10, 20],
"bbox_label": 1,
"mask": [[0,0],[0,10],[10,20],[20,0]],
"extra_anns": [1,2,3]
},
{
"bbox": [10, 10, 110, 120],
"bbox_label": 2,
"mask": [[10,10],[10,110],[110,120],[120,10]],
"extra_anns": [4,5,6]
}
]
},
]
}
Args:
ann_file (str): Annotation file path.
meta (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to None.
data_prefix (dict, optional): Prefix for training data. Defaults to
dict(img=None, ann=None).
filter_cfg (dict, optional): Config for filter data. Defaults to None.
num_samples (int, optional): Support using first few data in
annotation file to facilitate training/testing on a smaller
dataset. Defaults to -1 which means using all ``data_infos``.
serialize_data (bool, optional): Whether to hold memory using
serialized objects, when enabled, data loader workers can use
shared RAM from master process instead of making a copy. Defaults
to True.
pipeline (list, optional): Processing pipeline. Defaults to [].
test_mode (bool, optional): ``test_mode=True`` means in test phase.
Defaults to False.
lazy_init (bool, optional): Whether to load annotation during
instantiation. In some cases, such as visualization, only the meta
information of the dataset is needed, which is not necessary to
load annotation file. ``Basedataset`` can skip load annotations to
save time by set ``lazy_init=False``. Defaults to False.
max_refetch (int, optional): The maximum number of cycles to get a
valid image. Defaults to 1000.
Note:
BaseDataset collects meta information from `annotation file` (the
lowest priority), ``BaseDataset.META``(medium) and `meta parameter`
(highest) passed to constructors. The lower priority meta information
will be overwritten by higher one.
"""
META: dict = dict()
_fully_initialized: bool = False
def __init__(self,
ann_file: str,
meta: Optional[dict] = None,
data_root: Optional[str] = None,
data_prefix: dict = dict(img=None, ann=None),
filter_cfg: Optional[dict] = None,
num_samples: int = -1,
serialize_data: bool = True,
pipeline: List[Union[dict, Callable]] = [],
test_mode: bool = False,
lazy_init: bool = False,
max_refetch: int = 1000):
self.data_root = data_root
self.data_prefix = copy.copy(data_prefix)
self.ann_file = ann_file
self.filter_cfg = copy.deepcopy(filter_cfg)
self._num_samples = num_samples
self.serialize_data = serialize_data
self.test_mode = test_mode
self.max_refetch = max_refetch
self.data_infos: List[dict] = []
self.data_infos_bytes: bytearray = bytearray()
# set meta information
self._meta = self._get_meta_data(copy.deepcopy(meta))
# join paths
if self.data_root is not None:
self._join_prefix()
# build pipeline
self.pipeline = Compose(pipeline)
if not lazy_init:
self.full_init()
@force_full_init
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index and automatically call ``full_init`` if the
dataset has not been fully initialized.
Args:
idx (int): The index of data.
Returns:
dict: The idx-th annotation of the dataset.
"""
if self.serialize_data:
start_addr = 0 if idx == 0 else self.data_address[idx - 1].item()
end_addr = self.data_address[idx].item()
bytes = memoryview(self.data_infos_bytes[start_addr:end_addr])
data_info = pickle.loads(bytes)
else:
data_info = self.data_infos[idx]
# To record the real positive index of data information.
if idx >= 0:
data_info['sample_idx'] = idx
else:
data_info['sample_idx'] = len(self) + idx
return data_info
def full_init(self):
"""Load annotation file and set ``BaseDataset._fully_initialized`` to
True.
If ``lazy_init=False``, ``full_init`` will be called during the
instantiation and ``self._fully_initialized`` will be set to True. If
``obj._fully_initialized=False``, the class method decorated by
``force_full_init`` will call ``full_init`` automatically.
Several steps to initialize annotation:
- load_annotations: Load annotations from annotation file.
- filter data information: Filter annotations according to
filter_cfg.
- slice_data: Slice dataset according to ``self._num_samples``
- serialize_data: Serialize ``self.data_infos`` if
``self.serialize_data`` is True.
"""
if self._fully_initialized:
return
# load data information
self.data_infos = self.load_annotations(self.ann_file)
# filter illegal data, such as data that has no annotations.
self.data_infos = self.filter_data()
# if `num_samples > 0`, return the first `num_samples` data information
self.data_infos = self._slice_data()
# serialize data_infos
if self.serialize_data:
self.data_infos_bytes, self.data_address = self._serialize_data()
# Empty cache for preventing making multiple copies of
# `self.data_info` when loading data multi-processes.
self.data_infos.clear()
gc.collect()
self._fully_initialized = True
@property
def meta(self) -> dict:
"""Get meta information of dataset.
Returns:
dict: meta information collected from ``BaseDataset.META``,
annotation file and meta parameter during instantiation.
"""
return copy.deepcopy(self._meta)
def parse_annotations(self,
raw_data_info: dict) -> Union[dict, List[dict]]:
"""Parse raw annotation to target format.
``parse_annotations`` should return ``dict`` or ``List[dict]``. Each
dict contains the annotations of a training sample. If the protocol of
the sample annotations is changed, this function can be overridden to
update the parsing logic while keeping compatibility.
Args:
raw_data_info (dict): Raw annotation load from ``ann_file``
Returns:
Union[dict, List[dict]]: Parsed annotation.
"""
return raw_data_info
def filter_data(self) -> List[dict]:
"""Filter annotations according to filter_cfg. Defaults return all
``data_infos``.
If some ``data_infos`` could be filtered according to specific logic,
the subclass should override this method.
Returns:
List[dict]: Filtered results.
"""
return self.data_infos
def get_cat_ids(self, idx: int) -> List[int]:
"""Get category ids by index. Dataset wrapped by ClassBalancedDataset
must implement this method.
The ``ClassBalancedDataset`` requires a subclass which implements this
method.
Args:
idx (int): The index of data.
Returns:
List[int]: All categories in the image of specified index.
"""
raise NotImplementedError(f'{type(self)} must implement `get_cat_ids` '
'method')
def __getitem__(self, idx: int) -> dict:
"""Get the idx-th image of dataset after ``self.pipelines`` and
``full_init`` will be called if the dataset has not been fully
initialized.
During training phase, if ``self.pipelines`` get ``None``,
``self._rand_another`` will be called until a valid image is fetched or
the maximum limit of refetech is reached.
Args:
idx (int): The index of self.data_infos
Returns:
dict: The idx-th image of dataset after ``self.pipelines``.
"""
if not self._fully_initialized:
warnings.warn(
'Please call `full_init()` method manually to accelerate '
'the speed.')
self.full_init()
if self.test_mode:
return self._prepare_data(idx)
for _ in range(self.max_refetch):
data_sample = self._prepare_data(idx)
if data_sample is None:
idx = self._rand_another()
continue
return data_sample
raise Exception(f'Cannot find valid image after {self.max_refetch}! '
'Please check your image path and pipelines')
def load_annotations(self, ann_file: str) -> List[dict]:
"""Load annotations from an annotation file.
If the annotation file does not follow `OpenMMLab 2.0 format dataset
<https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md>`_ .
The subclass must override this method for load annotations.
Args:
ann_file (str): Absolute annotation file path if ``self.root=None``
or relative path if ``self.root=/path/to/data/``.
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
check_file_exist(ann_file)
annotations = load(ann_file)
if not isinstance(annotations, dict):
raise TypeError(f'The annotations loaded from annotation file '
f'should be a dict, but got {type(annotations)}!')
if 'data_infos' not in annotations or 'metadata' not in annotations:
raise ValueError('Annotation must have data_infos and metadata '
'keys')
meta_data = annotations['metadata']
raw_data_infos = annotations['data_infos']
# update self._meta
for k, v in meta_data.items():
# We only merge keys that are not contained in self._meta.
self._meta.setdefault(k, v)
# load and parse data_infos
data_infos = []
for raw_data_info in raw_data_infos:
data_info = self.parse_annotations(raw_data_info)
if isinstance(data_info, dict):
data_infos.append(data_info)
elif isinstance(data_info, list):
# data_info can also be a list of dict, which means one
# data_info contains multiple samples.
for item in data_info:
if not isinstance(item, dict):
raise TypeError('data_info must be a list[dict], but '
f'got {type(item)}')
data_infos.extend(data_info)
else:
raise TypeError('data_info should be a dict or list[dict], '
f'but got {type(data_info)}')
return data_infos
@classmethod
def _get_meta_data(cls, in_meta: dict = None) -> dict:
"""Collect meta information from the dictionary of meta.
Args:
in_meta (dict): Meta information dict. If ``in_meta`` contains
existed filename, it will be parsed by ``list_from_file``.
Returns:
dict: Parsed meta information.
"""
# cls.META will be overwritten by in_meta
cls_meta = copy.deepcopy(cls.META)
if in_meta is None:
return cls_meta
if not isinstance(in_meta, dict):
raise TypeError(
f'in_meta should be a dict, but got {type(in_meta)}')
for k, v in in_meta.items():
if isinstance(v, str) and osp.isfile(v):
# if filename in in_meta, this key will be further parsed.
# nested filename will be ignored.:
cls_meta[k] = list_from_file(v)
else:
cls_meta[k] = v
return cls_meta
def _join_prefix(self):
"""Join ``self.data_root`` with ``self.data_prefix`` and
``self.ann_file``.
Examples:
>>> # self.data_prefix contains relative paths
>>> self.data_root = 'a/b/c'
>>> self.data_prefix = dict(img='d/e/')
>>> self.ann_file = 'f'
>>> self._join_prefix()
>>> self.data_prefix
dict(img='a/b/c/d/e')
>>> self.ann_file
'a/b/c/f'
>>> # self.data_prefix contains absolute paths
>>> self.data_root = 'a/b/c'
>>> self.data_prefix = dict(img='/d/e/')
>>> self.ann_file = 'f'
>>> self._join_prefix()
>>> self.data_prefix
dict(img='/d/e')
>>> self.ann_file
'a/b/c/f'
"""
if not osp.isabs(self.ann_file):
self.ann_file = osp.join(self.data_root, self.ann_file)
for data_key, prefix in self.data_prefix.items():
if prefix is None:
self.data_prefix[data_key] = self.data_root
elif isinstance(prefix, str):
if not osp.isabs(prefix):
self.data_prefix[data_key] = osp.join(
self.data_root, prefix)
else:
raise TypeError('prefix should be a string or None, but got '
f'{type(prefix)}')
def _slice_data(self) -> List[dict]:
"""Slice ``self.data_infos``. BaseDataset supports only using the first
few data.
Returns:
List[dict]: A slice of ``self.data_infos``
"""
assert self._num_samples < len(self.data_infos), \
f'Slice size({self._num_samples}) is larger than dataset ' \
f'size({self.data_infos}, please keep `num_sample` smaller than' \
f'{self.data_infos})'
if self._num_samples > 0:
return self.data_infos[:self._num_samples]
else:
return self.data_infos
def _serialize_data(self) -> Tuple[np.ndarray, np.ndarray]:
"""Serialize ``self.data_infos`` to save memory when launching multiple
workers in data loading. This function will be called in ``full_init``.
Hold memory using serialized objects, and data loader workers can use
shared RAM from master process instead of making a copy.
Returns:
Tuple[np.ndarray, np.ndarray]: serialize result and corresponding
address.
"""
def _serialize(data):
buffer = pickle.dumps(data, protocol=4)
return np.frombuffer(buffer, dtype=np.uint8)
serialized_data_infos_list = [_serialize(x) for x in self.data_infos]
address_list = np.asarray([len(x) for x in serialized_data_infos_list],
dtype=np.int64)
data_address: np.ndarray = np.cumsum(address_list)
serialized_data_infos = np.concatenate(serialized_data_infos_list)
return serialized_data_infos, data_address
def _rand_another(self) -> int:
"""Get random index.
Returns:
int: Random index from 0 to ``len(self)-1``
"""
return np.random.randint(0, len(self))
def _prepare_data(self, idx) -> Any:
"""Get data processed by ``self.pipeline``.
Args:
idx (int): The index of ``data_info``.
Returns:
Any: Depends on ``self.pipeline``.
"""
data_info = self.get_data_info(idx)
return self.pipeline(data_info)
@force_full_init
def __len__(self) -> int:
"""Get the length of filtered dataset and automatically call
``full_init`` if the dataset has not been fully init.
Returns:
int: The length of filtered dataset.
"""
if self.serialize_data:
return len(self.data_address)
else:
return len(self.data_infos)

View File

@ -0,0 +1,364 @@
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
import copy
import math
import warnings
from collections import defaultdict
from typing import List, Sequence, Tuple
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from .base_dataset import BaseDataset, force_full_init
class ConcatDataset(_ConcatDataset):
"""A wrapper of concatenated dataset.
Same as ``torch.utils.data.dataset.ConcatDataset`` and support lazy_init.
Args:
datasets (Sequence[BaseDataset]): A list of datasets which will be
concatenated.
lazy_init (bool, optional): Whether to load annotation during
instantiation. Defaults to False.
"""
def __init__(self,
datasets: Sequence[BaseDataset],
lazy_init: bool = False):
# Only use meta of first dataset.
self._meta = datasets[0].meta
self.datasets = datasets # type: ignore
for i, dataset in enumerate(datasets, 1):
if self._meta != dataset.meta:
warnings.warn(
f'The meta information of the {i}-th dataset does not '
'match meta information of the first dataset')
self._fully_initialized = False
if not lazy_init:
self.full_init()
@property
def meta(self) -> dict:
"""Get the meta information of the first dataset in ``self.datasets``.
Returns:
dict: Meta information of first dataset.
"""
# Prevent `self._meta` from being modified by outside.
return copy.deepcopy(self._meta)
def full_init(self):
"""Loop to ``full_init`` each dataset."""
if self._fully_initialized:
return
for d in self.datasets:
d.full_init()
# Get the cumulative sizes of `self.datasets`. For example, the length
# of `self.datasets` is [2, 3, 4], the cumulative sizes is [2, 5, 9]
super().__init__(self.datasets)
self._fully_initialized = True
@force_full_init
def _get_ori_dataset_idx(self, idx: int) -> Tuple[int, int]:
"""Convert global idx to local index.
Args:
idx (int): Global index of ``RepeatDataset``.
Returns:
Tuple[int, int]: The index of ``self.datasets`` and the local
index of data.
"""
if idx < 0:
if -idx > len(self):
raise ValueError(
f'absolute value of index({idx}) should not exceed dataset'
f'length({len(self)}).')
idx = len(self) + idx
# Get the inner index of single dataset
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return dataset_idx, sample_idx
@force_full_init
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.
Args:
idx (int): Global index of ``ConcatDataset``.
Returns:
dict: The idx-th annotation of the datasets.
"""
dataset_idx, sample_idx = self._get_ori_dataset_idx(idx)
return self.datasets[dataset_idx].get_data_info(sample_idx)
@force_full_init
def __len__(self):
return super().__len__()
def __getitem__(self, idx):
if not self._fully_initialized:
warnings.warn('Please call `full_init` method manually to '
'accelerate the speed.')
self.full_init()
dataset_idx, sample_idx = self._get_ori_dataset_idx(idx)
return self.datasets[dataset_idx][sample_idx]
class RepeatDataset:
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
epochs.
Args:
dataset (BaseDataset): The dataset to be repeated.
times (int): Repeat times.
lazy_init (bool, optional): Whether to load annotation during
instantiation. Defaults to False.
"""
def __init__(self,
dataset: BaseDataset,
times: int,
lazy_init: bool = False):
self.dataset = dataset
self.times = times
self._meta = dataset.meta
self._fully_initialized = False
if not lazy_init:
self.full_init()
@property
def meta(self) -> dict:
"""Get the meta information of the repeated dataset.
Returns:
dict: The meta information of repeated dataset.
"""
return copy.deepcopy(self._meta)
def full_init(self):
"""Loop to ``full_init`` each dataset."""
if self._fully_initialized:
return
self.dataset.full_init()
self._ori_len = len(self.dataset)
self._fully_initialized = True
@force_full_init
def _get_ori_dataset_idx(self, idx: int) -> int:
"""Convert global index to local index.
Args:
idx: Global index of ``RepeatDataset``.
Returns:
idx (int): Local index of data.
"""
return idx % self._ori_len
@force_full_init
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.
Args:
idx (int): Global index of ``ConcatDataset``.
Returns:
dict: The idx-th annotation of the datasets.
"""
sample_idx = self._get_ori_dataset_idx(idx)
return self.dataset.get_data_info(sample_idx)
def __getitem__(self, idx):
if not self._fully_initialized:
warnings.warn('Please call `full_init` method manually to '
'accelerate the speed.')
self.full_init()
sample_idx = self._get_ori_dataset_idx(idx)
return self.dataset[sample_idx]
@force_full_init
def __len__(self):
return self.times * self._ori_len
class ClassBalancedDataset:
"""A wrapper of class balanced dataset.
Suitable for training on class imbalanced datasets like LVIS. Following
the sampling strategy in the `paper <https://arxiv.org/abs/1908.03195>`_,
in each epoch, an image may appear multiple times based on its
"repeat factor".
The repeat factor for an image is a function of the frequency the rarest
category labeled in that image. The "frequency of category c" in [0, 1]
is defined by the fraction of images in the training set (without repeats)
in which category c appears.
The dataset needs to instantiate :meth:`get_cat_ids` to support
ClassBalancedDataset.
The repeat factor is computed as followed.
1. For each category c, compute the fraction # of images
that contain it: :math:`f(c)`
2. For each category c, compute the category-level repeat factor:
:math:`r(c) = max(1, sqrt(t/f(c)))`
3. For each image I, compute the image-level repeat factor:
:math:`r(I) = max_{c in I} r(c)`
Args:
dataset (BaseDataset): The dataset to be repeated.
oversample_thr (float): frequency threshold below which data is
repeated. For categories with ``f_c >= oversample_thr``, there is
no oversampling. For categories with ``f_c < oversample_thr``, the
degree of oversampling following the square-root inverse frequency
heuristic above.
lazy_init (bool, optional): whether to load annotation during
instantiation. Defaults to False
"""
def __init__(self,
dataset: BaseDataset,
oversample_thr: float,
lazy_init: bool = False):
self.dataset = dataset
self.oversample_thr = oversample_thr
self._meta = dataset.meta
self._fully_initialized = False
if not lazy_init:
self.full_init()
@property
def meta(self) -> dict:
"""Get the meta information of the repeated dataset.
Returns:
dict: The meta information of repeated dataset.
"""
return copy.deepcopy(self._meta)
def full_init(self):
"""Loop to ``full_init`` each dataset."""
if self._fully_initialized:
return
self.dataset.full_init()
repeat_factors = self._get_repeat_factors(self.dataset,
self.oversample_thr)
repeat_indices = []
for dataset_index, repeat_factor in enumerate(repeat_factors):
repeat_indices.extend([dataset_index] * math.ceil(repeat_factor))
self.repeat_indices = repeat_indices
self._fully_initialized = True
def _get_repeat_factors(self, dataset: BaseDataset,
repeat_thr: float) -> List[float]:
"""Get repeat factor for each images in the dataset.
Args:
dataset (BaseDataset): The dataset.
repeat_thr (float): The threshold of frequency. If an image
contains the categories whose frequency below the threshold,
it would be repeated.
Returns:
List[float]: The repeat factors for each images in the dataset.
"""
# 1. For each category c, compute the fraction # of images
# that contain it: f(c)
category_freq: defaultdict = defaultdict(float)
num_images = len(dataset)
for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx))
for cat_id in cat_ids:
category_freq[cat_id] += 1
for k, v in category_freq.items():
assert v > 0, f'caterogy {k} does not contain any images'
category_freq[k] = v / num_images
# 2. For each category c, compute the category-level repeat factor:
# r(c) = max(1, sqrt(t/f(c)))
category_repeat = {
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
for cat_id, cat_freq in category_freq.items()
}
# 3. For each image I and its labels L(I), compute the image-level
# repeat factor:
# r(I) = max_{c in L(I)} r(c)
repeat_factors = []
for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx))
repeat_factor = max(
{category_repeat[cat_id]
for cat_id in cat_ids})
repeat_factors.append(repeat_factor)
return repeat_factors
@force_full_init
def _get_ori_dataset_idx(self, idx: int) -> int:
"""Convert global index to local index.
Args:
idx (int): Global index of ``RepeatDataset``.
Returns:
int: Local index of data.
"""
return self.repeat_indices[idx]
@force_full_init
def get_cat_ids(self, idx: int) -> List[int]:
"""Get category ids of class balanced dataset by index.
Args:
idx (int): Index of data.
Returns:
List[int]: All categories in the image of specified index.
"""
sample_idx = self._get_ori_dataset_idx(idx)
return self.dataset.get_cat_ids(sample_idx)
@force_full_init
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.
Args:
idx (int): Global index of ``ConcatDataset``.
Returns:
dict: The idx-th annotation of the dataset.
"""
sample_idx = self._get_ori_dataset_idx(idx)
return self.dataset.get_data_info(sample_idx)
def __getitem__(self, idx):
warnings.warn('Please call `full_init` method manually to '
'accelerate the speed.')
if not self._fully_initialized:
self.full_init()
ori_index = self._get_ori_dataset_idx(idx)
return self.dataset[ori_index]
@force_full_init
def __len__(self):
return len(self.repeat_indices)

View File

@ -0,0 +1,5 @@
[
{
"img_path": "test_img.jpg"
}
]

View File

@ -2,7 +2,8 @@
"metadata":
{
"dataset_type": "test_dataset",
"task_name": "test_task"
"task_name": "test_task",
"empty_list": []
},
"data_infos":
[

View File

@ -0,0 +1 @@
dog

View File

@ -1,40 +1,66 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from unittest.mock import MagicMock
import pytest
import torch
from mmengine.data import (BaseDataset, ClassBalancedDataset, ConcatDataset,
RepeatDataset)
from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose,
ConcatDataset, RepeatDataset, force_full_init)
from mmengine.registry import TRANSFORMS
def function_pipeline(data_info):
return data_info
@TRANSFORMS.register_module()
class CallableTransform:
def __call__(self, data_info):
return data_info
@TRANSFORMS.register_module()
class NotCallableTransform:
pass
class TestBaseDataset:
dataset_type = BaseDataset
data_info = dict(
filename='test_img.jpg', height=604, width=640, sample_idx=0)
imgs = torch.rand((2, 3, 32, 32))
pipeline = MagicMock(return_value=dict(imgs=imgs))
META: dict = dict()
parse_annotations = MagicMock(return_value=data_info)
def __init__(self):
self.base_dataset = BaseDataset
self.data_info = dict(filename='test_img.jpg', height=604, width=640)
self.base_dataset.parse_annotations = MagicMock(
return_value=self.data_info)
self.imgs = torch.rand((2, 3, 32, 32))
self.base_dataset.pipeline = MagicMock(
return_value=dict(imgs=self.imgs))
def _init_dataset(self):
self.dataset_type.META = self.META
self.dataset_type.parse_annotations = self.parse_annotations
def test_init(self):
self._init_dataset()
# test the instantiation of self.base_dataset
dataset = self.base_dataset(
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
assert hasattr(dataset, 'data_address')
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img=None),
ann_file='annotations/dummy_annotation.json')
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
assert hasattr(dataset, 'data_address')
# test the instantiation of self.base_dataset with
# `serialize_data=False`
dataset = self.base_dataset(
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
@ -42,33 +68,54 @@ class TestBaseDataset:
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
assert not hasattr(dataset, 'data_address')
assert len(dataset) == 2
assert dataset.get_data_info(0) == self.data_info
# test the instantiation of self.base_dataset with lazy init
dataset = self.base_dataset(
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
lazy_init=True)
assert not dataset._fully_initialized
assert not hasattr(dataset, 'data_infos')
assert not dataset.data_infos
# test the instantiation of self.base_dataset if ann_file is not
# existed.
with pytest.raises(FileNotFoundError):
self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/not_existed_annotation.json')
# test the instantiation of self.base_dataset when the ann_file is
# wrong
with pytest.raises(ValueError):
self.base_dataset(
self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/wrong_annotation.json')
ann_file='annotations/annotation_wrong_keys.json')
with pytest.raises(TypeError):
self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/annotation_wrong_format.json')
with pytest.raises(TypeError):
self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img=['img']),
ann_file='annotations/annotation_wrong_format.json')
# test the instantiation of self.base_dataset when `parse_annotations`
# return `list[dict]`
self.base_dataset.parse_annotations = MagicMock(
self.dataset_type.parse_annotations = MagicMock(
return_value=[self.data_info,
self.data_info.copy()])
dataset = self.base_dataset(
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
dataset.pipeline = self.pipeline
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
assert hasattr(dataset, 'data_address')
@ -76,98 +123,139 @@ class TestBaseDataset:
assert dataset[0] == dict(imgs=self.imgs)
assert dataset.get_data_info(0) == self.data_info
# set self.base_dataset to initial state
self.__init__()
# test the instantiation of self.base_dataset when `parse_annotations`
# return unsupported data.
with pytest.raises(TypeError):
self.dataset_type.parse_annotations = MagicMock(return_value='xxx')
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
with pytest.raises(TypeError):
self.dataset_type.parse_annotations = MagicMock(
return_value=[self.data_info, 'xxx'])
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
def test_meta(self):
self._init_dataset()
# test dataset.meta with setting the meta from annotation file as the
# meta of self.base_dataset
dataset = self.base_dataset(
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
assert dataset.meta == dict(
dataset_type='test_dataset', task_name='test_task')
dataset_type='test_dataset', task_name='test_task', empty_list=[])
# test dataset.meta with setting META in self.base_dataset
dataset_type = 'new_dataset'
self.base_dataset.META = dict(
self.dataset_type.META = dict(
dataset_type=dataset_type, classes=('dog', 'cat'))
dataset = self.base_dataset(
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
assert dataset.meta == dict(
dataset_type=dataset_type,
task_name='test_task',
classes=('dog', 'cat'))
classes=('dog', 'cat'),
empty_list=[])
# test dataset.meta with passing meta into self.base_dataset
meta = dict(classes=('dog', ))
dataset = self.base_dataset(
meta = dict(classes=('dog', ), task_name='new_task')
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
meta=meta)
assert self.base_dataset.META == dict(
assert self.dataset_type.META == dict(
dataset_type=dataset_type, classes=('dog', 'cat'))
assert dataset.meta == dict(
dataset_type=dataset_type,
task_name='test_task',
classes=('dog', ))
task_name='new_task',
classes=('dog', ),
empty_list=[])
# reset `base_dataset.META`, the `dataset.meta` should not change
self.base_dataset.META['classes'] = ('dog', 'cat', 'fish')
assert self.base_dataset.META == dict(
self.dataset_type.META['classes'] = ('dog', 'cat', 'fish')
assert self.dataset_type.META == dict(
dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
assert dataset.meta == dict(
dataset_type=dataset_type,
task_name='new_task',
classes=('dog', ),
empty_list=[])
# test dataset.meta with passing meta containing a file into
# self.base_dataset
meta = dict(
classes=osp.join(
osp.dirname(__file__), '../data/meta/classes.txt'))
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
meta=meta)
assert dataset.meta == dict(
dataset_type=dataset_type,
task_name='test_task',
classes=('dog', ))
classes=['dog'],
empty_list=[])
# test dataset.meta with passing unsupported meta into
# self.base_dataset
with pytest.raises(TypeError):
meta = 'dog'
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
meta=meta)
# test dataset.meta with passing meta into self.base_dataset and
# lazy_init is True
meta = dict(classes=('dog', ))
dataset = self.base_dataset(
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
meta=meta,
lazy_init=True)
# 'task_name' not in dataset.meta
# 'task_name' and 'empty_list' not in dataset.meta
assert dataset.meta == dict(
dataset_type=dataset_type, classes=('dog', ))
# test whether self.base_dataset.META is changed when a customize
# dataset inherit self.base_dataset
# test reset META in ToyDataset.
class ToyDataset(self.base_dataset):
class ToyDataset(self.dataset_type):
META = dict(xxx='xxx')
assert ToyDataset.META == dict(xxx='xxx')
assert self.base_dataset.META == dict(
assert self.dataset_type.META == dict(
dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
# test update META in ToyDataset.
class ToyDataset(self.base_dataset):
self.base_dataset.META['classes'] = ('bird', )
class ToyDataset(self.dataset_type):
META = copy.deepcopy(self.dataset_type.META)
META['classes'] = ('bird', )
assert ToyDataset.META == dict(
dataset_type=dataset_type, classes=('bird', ))
assert self.base_dataset.META == dict(
assert self.dataset_type.META == dict(
dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
# set self.base_dataset to initial state
self.__init__()
@pytest.mark.parametrize('lazy_init', [True, False])
def test_length(self, lazy_init):
dataset = self.base_dataset(
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
lazy_init=lazy_init)
if not lazy_init:
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
@ -175,20 +263,49 @@ class TestBaseDataset:
else:
# test `__len__()` when lazy_init is True
assert not dataset._fully_initialized
assert not hasattr(dataset, 'data_infos')
assert not dataset.data_infos
# call `full_init()` automatically
assert len(dataset) == 2
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
def test_compose(self):
# test callable transform
transforms = [function_pipeline]
compose = Compose(transforms=transforms)
assert (self.imgs == compose(dict(img=self.imgs))['img']).all()
# test transform build from cfg_dict
transforms = [dict(type='CallableTransform')]
compose = Compose(transforms=transforms)
assert (self.imgs == compose(dict(img=self.imgs))['img']).all()
# test return None in advance
none_func = MagicMock(return_value=None)
transforms = [none_func, function_pipeline]
compose = Compose(transforms=transforms)
assert compose(dict(img=self.imgs)) is None
# test repr
repr_str = f'Compose(\n' \
f' {none_func}\n' \
f' {function_pipeline}\n' \
f')'
assert repr(compose) == repr_str
# non-callable transform will raise error
with pytest.raises(TypeError):
transforms = [dict(type='NotCallableTransform')]
Compose(transforms)
# transform must be callable or dict
with pytest.raises(TypeError):
Compose([1])
@pytest.mark.parametrize('lazy_init', [True, False])
def test_getitem(self, lazy_init):
dataset = self.base_dataset(
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
lazy_init=lazy_init)
dataset.pipeline = self.pipeline
if not lazy_init:
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
@ -196,15 +313,26 @@ class TestBaseDataset:
else:
# test `__getitem__()` when lazy_init is True
assert not dataset._fully_initialized
assert not hasattr(dataset, 'data_infos')
assert not dataset.data_infos
# call `full_init()` automatically
assert dataset[0] == dict(imgs=self.imgs)
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
# test with test mode
dataset.test_mode = True
assert dataset[0] == dict(imgs=self.imgs)
pipeline = MagicMock(return_value=None)
dataset.pipeline = pipeline
# test cannot get a valid image.
dataset.test_mode = False
with pytest.raises(Exception):
dataset[0]
@pytest.mark.parametrize('lazy_init', [True, False])
def test_get_data_info(self, lazy_init):
dataset = self.base_dataset(
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
@ -217,20 +345,32 @@ class TestBaseDataset:
else:
# test `get_data_info()` when lazy_init is True
assert not dataset._fully_initialized
assert not hasattr(dataset, 'data_infos')
assert not dataset.data_infos
# call `full_init()` automatically
assert dataset.get_data_info(0) == self.data_info
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
def test_force_full_init(self):
with pytest.raises(AttributeError):
class ClassWithoutFullInit:
@force_full_init
def foo(self):
pass
class_without_full_init = ClassWithoutFullInit()
class_without_full_init.foo()
@pytest.mark.parametrize('lazy_init', [True, False])
def test_full_init(self, lazy_init):
dataset = self.base_dataset(
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
lazy_init=lazy_init)
dataset.pipeline = self.pipeline
if not lazy_init:
assert dataset._fully_initialized
assert hasattr(dataset, 'data_infos')
@ -240,7 +380,7 @@ class TestBaseDataset:
else:
# test `full_init()` when lazy_init is True
assert not dataset._fully_initialized
assert not hasattr(dataset, 'data_infos')
assert not dataset.data_infos
# call `full_init()` manually
dataset.full_init()
assert dataset._fully_initialized
@ -249,55 +389,116 @@ class TestBaseDataset:
assert dataset[0] == dict(imgs=self.imgs)
assert dataset.get_data_info(0) == self.data_info
def test_slice_data(self):
# test the instantiation of self.base_dataset when passing num_samples
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img=None),
ann_file='annotations/dummy_annotation.json',
num_samples=1)
assert len(dataset) == 1
def test_rand_another(self):
# test the instantiation of self.base_dataset when passing num_samples
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img=None),
ann_file='annotations/dummy_annotation.json',
num_samples=1)
assert dataset._rand_another() >= 0
assert dataset._rand_another() < len(dataset)
class TestConcatDataset:
def __init__(self):
def _init_dataset(self):
dataset = BaseDataset
# create dataset_a
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
self.dataset_a = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
self.dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs))
# create dataset_b
data_info = dict(filename='gray.jpg', height=288, width=512)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
self.dataset_b = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
meta=dict(classes=('dog', 'cat')))
self.dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs))
# test init
self.cat_datasets = ConcatDataset(
datasets=[self.dataset_a, self.dataset_b])
def test_full_init(self):
dataset = BaseDataset
# create dataset_a
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
dataset_a = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs))
# create dataset_b
data_info = dict(filename='gray.jpg', height=288, width=512)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
dataset_b = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
meta=dict(classes=('dog', 'cat')))
dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs))
# test init with lazy_init=True
cat_datasets = ConcatDataset(
datasets=[dataset_a, dataset_b], lazy_init=True)
cat_datasets.full_init()
assert len(cat_datasets) == 4
cat_datasets.full_init()
cat_datasets._fully_initialized = False
cat_datasets[1]
assert len(cat_datasets) == 4
def test_meta(self):
self._init_dataset()
assert self.cat_datasets.meta == self.dataset_a.meta
# meta of self.cat_datasets is from the first dataset when
# concatnating datasets with different metas.
assert self.cat_datasets.meta != self.dataset_b.meta
def test_length(self):
self._init_dataset()
assert len(self.cat_datasets) == (
len(self.dataset_a) + len(self.dataset_b))
def test_getitem(self):
assert self.cat_datasets[0] == self.dataset_a[0]
assert self.cat_datasets[0] != self.dataset_b[0]
self._init_dataset()
assert (
self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all()
assert (self.cat_datasets[0]['imgs'] !=
self.dataset_b[0]['imgs']).all()
assert self.cat_datasets[-1] == self.dataset_b[-1]
assert self.cat_datasets[-1] != self.dataset_a[-1]
assert (
self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all()
assert (self.cat_datasets[-1]['imgs'] !=
self.dataset_a[-1]['imgs']).all()
def test_get_data_info(self):
self._init_dataset()
assert self.cat_datasets.get_data_info(
0) == self.dataset_a.get_data_info(0)
assert self.cat_datasets.get_data_info(
@ -306,40 +507,76 @@ class TestConcatDataset:
assert self.cat_datasets.get_data_info(
-1) == self.dataset_b.get_data_info(-1)
assert self.cat_datasets.get_data_info(
-1) != self.dataset_a[-1].get_data_info(-1)
-1) != self.dataset_a.get_data_info(-1)
def test_get_ori_dataset_idx(self):
self._init_dataset()
assert self.cat_datasets._get_ori_dataset_idx(3) == (
1, 3 - len(self.dataset_a))
assert self.cat_datasets._get_ori_dataset_idx(-1) == (
1, len(self.dataset_b) - 1)
with pytest.raises(ValueError):
assert self.cat_datasets._get_ori_dataset_idx(-10)
class TestRepeatDataset:
def __init__(self):
def _init_dataset(self):
dataset = BaseDataset
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
self.dataset = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
self.repeat_times = 5
# test init
self.repeat_datasets = RepeatDataset(
dataset=self.dataset, times=self.repeat_times)
def test_full_init(self):
dataset = BaseDataset
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
dataset = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
repeat_times = 5
# test init
repeat_datasets = RepeatDataset(
dataset=dataset, times=repeat_times, lazy_init=True)
repeat_datasets.full_init()
assert len(repeat_datasets) == repeat_times * len(dataset)
repeat_datasets.full_init()
repeat_datasets._fully_initialized = False
repeat_datasets[1]
assert len(repeat_datasets) == repeat_times * len(dataset)
def test_meta(self):
self._init_dataset()
assert self.repeat_datasets.meta == self.dataset.meta
def test_length(self):
self._init_dataset()
assert len(
self.repeat_datasets) == len(self.dataset) * self.repeat_times
def test_getitem(self):
self._init_dataset()
for i in range(self.repeat_times):
assert self.repeat_datasets[len(self.dataset) *
i] == self.dataset[0]
def test_get_data_info(self):
self._init_dataset()
for i in range(self.repeat_times):
assert self.repeat_datasets.get_data_info(
len(self.dataset) * i) == self.dataset.get_data_info(0)
@ -347,17 +584,17 @@ class TestRepeatDataset:
class TestClassBalancedDataset:
def __init__(self):
def _init_dataset(self):
dataset = BaseDataset
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
dataset.get_cat_ids = MagicMock(return_value=[0])
self.dataset = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
self.repeat_indices = [0, 0, 1, 1, 1]
# test init
@ -365,23 +602,54 @@ class TestClassBalancedDataset:
dataset=self.dataset, oversample_thr=1e-3)
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
def test_full_init(self):
dataset = BaseDataset
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_annotations = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
dataset.get_cat_ids = MagicMock(return_value=[0])
dataset = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
repeat_indices = [0, 0, 1, 1, 1]
# test init
cls_banlanced_datasets = ClassBalancedDataset(
dataset=dataset, oversample_thr=1e-3, lazy_init=True)
cls_banlanced_datasets.full_init()
cls_banlanced_datasets.repeat_indices = repeat_indices
assert len(cls_banlanced_datasets) == len(repeat_indices)
cls_banlanced_datasets.full_init()
cls_banlanced_datasets._fully_initialized = False
cls_banlanced_datasets[1]
cls_banlanced_datasets.repeat_indices = repeat_indices
assert len(cls_banlanced_datasets) == len(repeat_indices)
def test_meta(self):
self._init_dataset()
assert self.cls_banlanced_datasets.meta == self.dataset.meta
def test_length(self):
self._init_dataset()
assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
def test_getitem(self):
self._init_dataset()
for i in range(len(self.repeat_indices)):
assert self.cls_banlanced_datasets[i] == self.dataset[
self.repeat_indices[i]]
def test_get_data_info(self):
self._init_dataset()
for i in range(len(self.repeat_indices)):
assert self.cls_banlanced_datasets.get_data_info(
i) == self.dataset.get_data_info(self.repeat_indices[i])
def test_get_cat_ids(self):
self._init_dataset()
for i in range(len(self.repeat_indices)):
assert self.cls_banlanced_datasets.get_cat_ids(
i) == self.dataset.get_cat_ids(self.repeat_indices[i])