mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] add base dataset (#32)
* basedataset first commit * add base dataset * add dataset * add basedataset * Fix test dataset * Fix mypy and test * Fix mypy and test * remove unused code * Update mmengine/dataset/base_dataset.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/dataset/base_dataset.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * add more corner cases in unittest * fix lint * Fix as comment * Fix lint * update unitest * Type hint Dick to dict * rename max_refetch * Fix as comment * Fix typo * Fix as comment * BaseDataset is no more an abstrac Class, change UT and docs * Fix as comment * Fix as comment and refactor type error * Add comment for full init * Fix as comment and modify dataset_wrapper * Fix as comment and modify dataset_wrapper * Fix as comment * Fix as comment * Fix as comment * Fix as comment * Fix as comment * Fix as comment * Fix as comment Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Tao Gong <gongtao950513@gmail.com>
This commit is contained in:
parent
adb2aee8c2
commit
ada6660c65
@ -69,7 +69,7 @@ data
|
||||
|
||||
2. 构建数据流水线(data pipeline),用于数据预处理与数据准备;
|
||||
|
||||
3. 读取与解析满足 OpenMMLab 2.0 数据集格式规范的标注文件,该步骤中会有 `parse_annotations()` 抽象方法,该抽象方法负责解析标注文件里的每个原始数据;
|
||||
3. 读取与解析满足 OpenMMLab 2.0 数据集格式规范的标注文件,该步骤中会有 `parse_annotations()` 方法,该方法负责解析标注文件里的每个原始数据;
|
||||
|
||||
4. 过滤无用数据,比如不包含标注的样本等;
|
||||
|
||||
@ -77,8 +77,6 @@ data
|
||||
|
||||
6. 序列化全部样本,以达到节省内存的效果,详情请参考[节省内存](#节省内存)。
|
||||
|
||||
数据集基类是一个抽象类,它有且只有一个抽象方法 `parse_annotations()` ,`parse_annotations()` 定义了将标注文件里的一个原始数据处理成一个或若干个训练/测试样本的方法。因此对于自定义数据集类,用户必须要实现 `parse_annotations()` 方法。
|
||||
|
||||
### 数据集基类提供的接口
|
||||
|
||||
与 `torch.utils.data.Dataset` 类似,数据集初始化后,支持 `__getitem__` 方法,用来索引数据,以及 `__len__` 操作获取数据集大小,除此之外,OpenMMLab 的数据集基类主要提供了以下接口来访问具体信息:
|
||||
@ -93,7 +91,7 @@ data
|
||||
|
||||
## 使用数据集基类自定义数据集类
|
||||
|
||||
在了解了数据集基类的初始化流程与提供的接口之后,就可以基于数据集基类自定义数据集类,如上所述,数据集基类是一个抽象类,它有且只有一个抽象方法 `parse_annotations()`,因此用户必须在自定义数据集类中实现该方法。以下是一个使用数据集基类来实现某一具体数据集的例子。
|
||||
在了解了数据集基类的初始化流程与提供的接口之后,就可以基于数据集基类自定义数据集类,如上所述,对于满足 OpenMMLab 2.0 数据集格式规范的标注文件,用户可以重载 `parse_annotations()`来加载标签。以下是一个使用数据集基类来实现某一具体数据集的例子。
|
||||
|
||||
```python
|
||||
import os.path as osp
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# flake8: noqa
|
||||
from .config import *
|
||||
from .dataset import *
|
||||
from .fileio import *
|
||||
from .registry import *
|
||||
from .utils import *
|
||||
|
4
mmengine/dataset/__init__.py
Normal file
4
mmengine/dataset/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# flake8: noqa
|
||||
from .base_dataset import BaseDataset, Compose, force_full_init
|
||||
from .dataset_wrapper import ClassBalancedDataset, ConcatDataset, RepeatDataset
|
553
mmengine/dataset/base_dataset.py
Normal file
553
mmengine/dataset/base_dataset.py
Normal file
@ -0,0 +1,553 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import functools
|
||||
import gc
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import warnings
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmengine.fileio import list_from_file, load
|
||||
from mmengine.registry import TRANSFORMS
|
||||
from mmengine.utils import check_file_exist
|
||||
|
||||
|
||||
class Compose:
|
||||
"""Compose multiple transforms sequentially.
|
||||
|
||||
Args:
|
||||
transforms (Sequence[dict, callable]): Sequence of transform object or
|
||||
config dict to be composed.
|
||||
"""
|
||||
|
||||
def __init__(self, transforms: Sequence[Union[dict, Callable]]):
|
||||
self.transforms: List[Callable] = []
|
||||
for transform in transforms:
|
||||
if isinstance(transform, dict):
|
||||
transform = TRANSFORMS.build(transform)
|
||||
if not callable(transform):
|
||||
raise TypeError(f'transform should be a callable object, '
|
||||
f'but got {type(transform)}')
|
||||
self.transforms.append(transform)
|
||||
elif callable(transform):
|
||||
self.transforms.append(transform)
|
||||
else:
|
||||
raise TypeError(
|
||||
f'transform must be a callable object or dict, '
|
||||
f'but got {type(transform)}')
|
||||
|
||||
def __call__(self, data: dict) -> Optional[dict]:
|
||||
"""Call function to apply transforms sequentially.
|
||||
|
||||
Args:
|
||||
data (dict): A result dict contains the data to transform.
|
||||
|
||||
Returns:
|
||||
dict: Transformed data.
|
||||
"""
|
||||
for t in self.transforms:
|
||||
data = t(data)
|
||||
if data is None:
|
||||
return None
|
||||
return data
|
||||
|
||||
def __repr__(self):
|
||||
"""Print ``self.transforms`` in sequence.
|
||||
|
||||
Returns:
|
||||
str: Formatted string.
|
||||
"""
|
||||
format_string = self.__class__.__name__ + '('
|
||||
for t in self.transforms:
|
||||
format_string += '\n'
|
||||
format_string += f' {t}'
|
||||
format_string += '\n)'
|
||||
return format_string
|
||||
|
||||
|
||||
def force_full_init(old_func: Callable) -> Any:
|
||||
"""Those methods decorated by ``force_full_init`` will be forced to call
|
||||
``full_init`` if the instance has not been fully initiated.
|
||||
|
||||
Args:
|
||||
old_func (Callable): Decorated function, make sure the first arg is an
|
||||
instance with ``full_init`` method.
|
||||
|
||||
Returns:
|
||||
Any: Depends on old_func.
|
||||
"""
|
||||
# TODO This decorator can also be used in runner.
|
||||
@functools.wraps(old_func)
|
||||
def wrapper(obj: object, *args, **kwargs):
|
||||
if not hasattr(obj, 'full_init'):
|
||||
raise AttributeError(f'{type(obj)} does not have full_init '
|
||||
'method.')
|
||||
if not getattr(obj, '_fully_initialized', False):
|
||||
warnings.warn('Attribute `_fully_initialized` is not defined in '
|
||||
f'{type(obj)} or `type(obj)._fully_initialized is '
|
||||
'False, `full_init` will be called and '
|
||||
f'{type(obj)}._fully_initialized will be set to '
|
||||
'True')
|
||||
obj.full_init() # type: ignore
|
||||
obj._fully_initialized = True # type: ignore
|
||||
|
||||
return old_func(obj, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
r"""BaseDataset for open source projects in OpenMMLab.
|
||||
|
||||
The annotation format is shown as follows.
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
{
|
||||
"metadata":
|
||||
{
|
||||
"dataset_type": "test_dataset",
|
||||
"task_name": "test_task"
|
||||
},
|
||||
"data_infos":
|
||||
[
|
||||
{
|
||||
"img_path": "test_img.jpg",
|
||||
"height": 604,
|
||||
"width": 640,
|
||||
"instances":
|
||||
[
|
||||
{
|
||||
"bbox": [0, 0, 10, 20],
|
||||
"bbox_label": 1,
|
||||
"mask": [[0,0],[0,10],[10,20],[20,0]],
|
||||
"extra_anns": [1,2,3]
|
||||
},
|
||||
{
|
||||
"bbox": [10, 10, 110, 120],
|
||||
"bbox_label": 2,
|
||||
"mask": [[10,10],[10,110],[110,120],[120,10]],
|
||||
"extra_anns": [4,5,6]
|
||||
}
|
||||
]
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path.
|
||||
meta (dict, optional): Meta information for dataset, such as class
|
||||
information. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (dict, optional): Prefix for training data. Defaults to
|
||||
dict(img=None, ann=None).
|
||||
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||
num_samples (int, optional): Support using first few data in
|
||||
annotation file to facilitate training/testing on a smaller
|
||||
dataset. Defaults to -1 which means using all ``data_infos``.
|
||||
serialize_data (bool, optional): Whether to hold memory using
|
||||
serialized objects, when enabled, data loader workers can use
|
||||
shared RAM from master process instead of making a copy. Defaults
|
||||
to True.
|
||||
pipeline (list, optional): Processing pipeline. Defaults to [].
|
||||
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
||||
Defaults to False.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. In some cases, such as visualization, only the meta
|
||||
information of the dataset is needed, which is not necessary to
|
||||
load annotation file. ``Basedataset`` can skip load annotations to
|
||||
save time by set ``lazy_init=False``. Defaults to False.
|
||||
max_refetch (int, optional): The maximum number of cycles to get a
|
||||
valid image. Defaults to 1000.
|
||||
|
||||
Note:
|
||||
BaseDataset collects meta information from `annotation file` (the
|
||||
lowest priority), ``BaseDataset.META``(medium) and `meta parameter`
|
||||
(highest) passed to constructors. The lower priority meta information
|
||||
will be overwritten by higher one.
|
||||
"""
|
||||
|
||||
META: dict = dict()
|
||||
_fully_initialized: bool = False
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str,
|
||||
meta: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: dict = dict(img=None, ann=None),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
num_samples: int = -1,
|
||||
serialize_data: bool = True,
|
||||
pipeline: List[Union[dict, Callable]] = [],
|
||||
test_mode: bool = False,
|
||||
lazy_init: bool = False,
|
||||
max_refetch: int = 1000):
|
||||
|
||||
self.data_root = data_root
|
||||
self.data_prefix = copy.copy(data_prefix)
|
||||
self.ann_file = ann_file
|
||||
self.filter_cfg = copy.deepcopy(filter_cfg)
|
||||
self._num_samples = num_samples
|
||||
self.serialize_data = serialize_data
|
||||
self.test_mode = test_mode
|
||||
self.max_refetch = max_refetch
|
||||
self.data_infos: List[dict] = []
|
||||
self.data_infos_bytes: bytearray = bytearray()
|
||||
|
||||
# set meta information
|
||||
self._meta = self._get_meta_data(copy.deepcopy(meta))
|
||||
|
||||
# join paths
|
||||
if self.data_root is not None:
|
||||
self._join_prefix()
|
||||
|
||||
# build pipeline
|
||||
self.pipeline = Compose(pipeline)
|
||||
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
@force_full_init
|
||||
def get_data_info(self, idx: int) -> dict:
|
||||
"""Get annotation by index and automatically call ``full_init`` if the
|
||||
dataset has not been fully initialized.
|
||||
|
||||
Args:
|
||||
idx (int): The index of data.
|
||||
|
||||
Returns:
|
||||
dict: The idx-th annotation of the dataset.
|
||||
"""
|
||||
if self.serialize_data:
|
||||
start_addr = 0 if idx == 0 else self.data_address[idx - 1].item()
|
||||
end_addr = self.data_address[idx].item()
|
||||
bytes = memoryview(self.data_infos_bytes[start_addr:end_addr])
|
||||
data_info = pickle.loads(bytes)
|
||||
else:
|
||||
data_info = self.data_infos[idx]
|
||||
# To record the real positive index of data information.
|
||||
if idx >= 0:
|
||||
data_info['sample_idx'] = idx
|
||||
else:
|
||||
data_info['sample_idx'] = len(self) + idx
|
||||
|
||||
return data_info
|
||||
|
||||
def full_init(self):
|
||||
"""Load annotation file and set ``BaseDataset._fully_initialized`` to
|
||||
True.
|
||||
|
||||
If ``lazy_init=False``, ``full_init`` will be called during the
|
||||
instantiation and ``self._fully_initialized`` will be set to True. If
|
||||
``obj._fully_initialized=False``, the class method decorated by
|
||||
``force_full_init`` will call ``full_init`` automatically.
|
||||
|
||||
Several steps to initialize annotation:
|
||||
|
||||
- load_annotations: Load annotations from annotation file.
|
||||
- filter data information: Filter annotations according to
|
||||
filter_cfg.
|
||||
- slice_data: Slice dataset according to ``self._num_samples``
|
||||
- serialize_data: Serialize ``self.data_infos`` if
|
||||
``self.serialize_data`` is True.
|
||||
"""
|
||||
if self._fully_initialized:
|
||||
return
|
||||
|
||||
# load data information
|
||||
self.data_infos = self.load_annotations(self.ann_file)
|
||||
# filter illegal data, such as data that has no annotations.
|
||||
self.data_infos = self.filter_data()
|
||||
# if `num_samples > 0`, return the first `num_samples` data information
|
||||
self.data_infos = self._slice_data()
|
||||
# serialize data_infos
|
||||
if self.serialize_data:
|
||||
self.data_infos_bytes, self.data_address = self._serialize_data()
|
||||
# Empty cache for preventing making multiple copies of
|
||||
# `self.data_info` when loading data multi-processes.
|
||||
self.data_infos.clear()
|
||||
gc.collect()
|
||||
self._fully_initialized = True
|
||||
|
||||
@property
|
||||
def meta(self) -> dict:
|
||||
"""Get meta information of dataset.
|
||||
|
||||
Returns:
|
||||
dict: meta information collected from ``BaseDataset.META``,
|
||||
annotation file and meta parameter during instantiation.
|
||||
"""
|
||||
return copy.deepcopy(self._meta)
|
||||
|
||||
def parse_annotations(self,
|
||||
raw_data_info: dict) -> Union[dict, List[dict]]:
|
||||
"""Parse raw annotation to target format.
|
||||
|
||||
``parse_annotations`` should return ``dict`` or ``List[dict]``. Each
|
||||
dict contains the annotations of a training sample. If the protocol of
|
||||
the sample annotations is changed, this function can be overridden to
|
||||
update the parsing logic while keeping compatibility.
|
||||
|
||||
Args:
|
||||
raw_data_info (dict): Raw annotation load from ``ann_file``
|
||||
|
||||
Returns:
|
||||
Union[dict, List[dict]]: Parsed annotation.
|
||||
"""
|
||||
return raw_data_info
|
||||
|
||||
def filter_data(self) -> List[dict]:
|
||||
"""Filter annotations according to filter_cfg. Defaults return all
|
||||
``data_infos``.
|
||||
|
||||
If some ``data_infos`` could be filtered according to specific logic,
|
||||
the subclass should override this method.
|
||||
|
||||
Returns:
|
||||
List[dict]: Filtered results.
|
||||
"""
|
||||
return self.data_infos
|
||||
|
||||
def get_cat_ids(self, idx: int) -> List[int]:
|
||||
"""Get category ids by index. Dataset wrapped by ClassBalancedDataset
|
||||
must implement this method.
|
||||
|
||||
The ``ClassBalancedDataset`` requires a subclass which implements this
|
||||
method.
|
||||
|
||||
Args:
|
||||
idx (int): The index of data.
|
||||
|
||||
Returns:
|
||||
List[int]: All categories in the image of specified index.
|
||||
"""
|
||||
raise NotImplementedError(f'{type(self)} must implement `get_cat_ids` '
|
||||
'method')
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
"""Get the idx-th image of dataset after ``self.pipelines`` and
|
||||
``full_init`` will be called if the dataset has not been fully
|
||||
initialized.
|
||||
|
||||
During training phase, if ``self.pipelines`` get ``None``,
|
||||
``self._rand_another`` will be called until a valid image is fetched or
|
||||
the maximum limit of refetech is reached.
|
||||
|
||||
Args:
|
||||
idx (int): The index of self.data_infos
|
||||
|
||||
Returns:
|
||||
dict: The idx-th image of dataset after ``self.pipelines``.
|
||||
"""
|
||||
if not self._fully_initialized:
|
||||
warnings.warn(
|
||||
'Please call `full_init()` method manually to accelerate '
|
||||
'the speed.')
|
||||
self.full_init()
|
||||
|
||||
if self.test_mode:
|
||||
return self._prepare_data(idx)
|
||||
|
||||
for _ in range(self.max_refetch):
|
||||
data_sample = self._prepare_data(idx)
|
||||
if data_sample is None:
|
||||
idx = self._rand_another()
|
||||
continue
|
||||
return data_sample
|
||||
|
||||
raise Exception(f'Cannot find valid image after {self.max_refetch}! '
|
||||
'Please check your image path and pipelines')
|
||||
|
||||
def load_annotations(self, ann_file: str) -> List[dict]:
|
||||
"""Load annotations from an annotation file.
|
||||
|
||||
If the annotation file does not follow `OpenMMLab 2.0 format dataset
|
||||
<https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md>`_ .
|
||||
The subclass must override this method for load annotations.
|
||||
|
||||
Args:
|
||||
ann_file (str): Absolute annotation file path if ``self.root=None``
|
||||
or relative path if ``self.root=/path/to/data/``.
|
||||
|
||||
Returns:
|
||||
List[dict]: A list of annotation.
|
||||
""" # noqa: E501
|
||||
check_file_exist(ann_file)
|
||||
annotations = load(ann_file)
|
||||
if not isinstance(annotations, dict):
|
||||
raise TypeError(f'The annotations loaded from annotation file '
|
||||
f'should be a dict, but got {type(annotations)}!')
|
||||
if 'data_infos' not in annotations or 'metadata' not in annotations:
|
||||
raise ValueError('Annotation must have data_infos and metadata '
|
||||
'keys')
|
||||
meta_data = annotations['metadata']
|
||||
raw_data_infos = annotations['data_infos']
|
||||
|
||||
# update self._meta
|
||||
for k, v in meta_data.items():
|
||||
# We only merge keys that are not contained in self._meta.
|
||||
self._meta.setdefault(k, v)
|
||||
|
||||
# load and parse data_infos
|
||||
data_infos = []
|
||||
for raw_data_info in raw_data_infos:
|
||||
data_info = self.parse_annotations(raw_data_info)
|
||||
if isinstance(data_info, dict):
|
||||
data_infos.append(data_info)
|
||||
elif isinstance(data_info, list):
|
||||
# data_info can also be a list of dict, which means one
|
||||
# data_info contains multiple samples.
|
||||
for item in data_info:
|
||||
if not isinstance(item, dict):
|
||||
raise TypeError('data_info must be a list[dict], but '
|
||||
f'got {type(item)}')
|
||||
data_infos.extend(data_info)
|
||||
else:
|
||||
raise TypeError('data_info should be a dict or list[dict], '
|
||||
f'but got {type(data_info)}')
|
||||
|
||||
return data_infos
|
||||
|
||||
@classmethod
|
||||
def _get_meta_data(cls, in_meta: dict = None) -> dict:
|
||||
"""Collect meta information from the dictionary of meta.
|
||||
|
||||
Args:
|
||||
in_meta (dict): Meta information dict. If ``in_meta`` contains
|
||||
existed filename, it will be parsed by ``list_from_file``.
|
||||
|
||||
Returns:
|
||||
dict: Parsed meta information.
|
||||
"""
|
||||
# cls.META will be overwritten by in_meta
|
||||
cls_meta = copy.deepcopy(cls.META)
|
||||
if in_meta is None:
|
||||
return cls_meta
|
||||
if not isinstance(in_meta, dict):
|
||||
raise TypeError(
|
||||
f'in_meta should be a dict, but got {type(in_meta)}')
|
||||
|
||||
for k, v in in_meta.items():
|
||||
if isinstance(v, str) and osp.isfile(v):
|
||||
# if filename in in_meta, this key will be further parsed.
|
||||
# nested filename will be ignored.:
|
||||
cls_meta[k] = list_from_file(v)
|
||||
else:
|
||||
cls_meta[k] = v
|
||||
|
||||
return cls_meta
|
||||
|
||||
def _join_prefix(self):
|
||||
"""Join ``self.data_root`` with ``self.data_prefix`` and
|
||||
``self.ann_file``.
|
||||
|
||||
Examples:
|
||||
>>> # self.data_prefix contains relative paths
|
||||
>>> self.data_root = 'a/b/c'
|
||||
>>> self.data_prefix = dict(img='d/e/')
|
||||
>>> self.ann_file = 'f'
|
||||
>>> self._join_prefix()
|
||||
>>> self.data_prefix
|
||||
dict(img='a/b/c/d/e')
|
||||
>>> self.ann_file
|
||||
'a/b/c/f'
|
||||
>>> # self.data_prefix contains absolute paths
|
||||
>>> self.data_root = 'a/b/c'
|
||||
>>> self.data_prefix = dict(img='/d/e/')
|
||||
>>> self.ann_file = 'f'
|
||||
>>> self._join_prefix()
|
||||
>>> self.data_prefix
|
||||
dict(img='/d/e')
|
||||
>>> self.ann_file
|
||||
'a/b/c/f'
|
||||
"""
|
||||
if not osp.isabs(self.ann_file):
|
||||
self.ann_file = osp.join(self.data_root, self.ann_file)
|
||||
|
||||
for data_key, prefix in self.data_prefix.items():
|
||||
if prefix is None:
|
||||
self.data_prefix[data_key] = self.data_root
|
||||
elif isinstance(prefix, str):
|
||||
if not osp.isabs(prefix):
|
||||
self.data_prefix[data_key] = osp.join(
|
||||
self.data_root, prefix)
|
||||
else:
|
||||
raise TypeError('prefix should be a string or None, but got '
|
||||
f'{type(prefix)}')
|
||||
|
||||
def _slice_data(self) -> List[dict]:
|
||||
"""Slice ``self.data_infos``. BaseDataset supports only using the first
|
||||
few data.
|
||||
|
||||
Returns:
|
||||
List[dict]: A slice of ``self.data_infos``
|
||||
"""
|
||||
assert self._num_samples < len(self.data_infos), \
|
||||
f'Slice size({self._num_samples}) is larger than dataset ' \
|
||||
f'size({self.data_infos}, please keep `num_sample` smaller than' \
|
||||
f'{self.data_infos})'
|
||||
if self._num_samples > 0:
|
||||
return self.data_infos[:self._num_samples]
|
||||
else:
|
||||
return self.data_infos
|
||||
|
||||
def _serialize_data(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Serialize ``self.data_infos`` to save memory when launching multiple
|
||||
workers in data loading. This function will be called in ``full_init``.
|
||||
|
||||
Hold memory using serialized objects, and data loader workers can use
|
||||
shared RAM from master process instead of making a copy.
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: serialize result and corresponding
|
||||
address.
|
||||
"""
|
||||
|
||||
def _serialize(data):
|
||||
buffer = pickle.dumps(data, protocol=4)
|
||||
return np.frombuffer(buffer, dtype=np.uint8)
|
||||
|
||||
serialized_data_infos_list = [_serialize(x) for x in self.data_infos]
|
||||
address_list = np.asarray([len(x) for x in serialized_data_infos_list],
|
||||
dtype=np.int64)
|
||||
data_address: np.ndarray = np.cumsum(address_list)
|
||||
serialized_data_infos = np.concatenate(serialized_data_infos_list)
|
||||
|
||||
return serialized_data_infos, data_address
|
||||
|
||||
def _rand_another(self) -> int:
|
||||
"""Get random index.
|
||||
|
||||
Returns:
|
||||
int: Random index from 0 to ``len(self)-1``
|
||||
"""
|
||||
return np.random.randint(0, len(self))
|
||||
|
||||
def _prepare_data(self, idx) -> Any:
|
||||
"""Get data processed by ``self.pipeline``.
|
||||
|
||||
Args:
|
||||
idx (int): The index of ``data_info``.
|
||||
|
||||
Returns:
|
||||
Any: Depends on ``self.pipeline``.
|
||||
"""
|
||||
data_info = self.get_data_info(idx)
|
||||
return self.pipeline(data_info)
|
||||
|
||||
@force_full_init
|
||||
def __len__(self) -> int:
|
||||
"""Get the length of filtered dataset and automatically call
|
||||
``full_init`` if the dataset has not been fully init.
|
||||
|
||||
Returns:
|
||||
int: The length of filtered dataset.
|
||||
"""
|
||||
if self.serialize_data:
|
||||
return len(self.data_address)
|
||||
else:
|
||||
return len(self.data_infos)
|
364
mmengine/dataset/dataset_wrapper.py
Normal file
364
mmengine/dataset/dataset_wrapper.py
Normal file
@ -0,0 +1,364 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import bisect
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
||||
|
||||
from .base_dataset import BaseDataset, force_full_init
|
||||
|
||||
|
||||
class ConcatDataset(_ConcatDataset):
|
||||
"""A wrapper of concatenated dataset.
|
||||
|
||||
Same as ``torch.utils.data.dataset.ConcatDataset`` and support lazy_init.
|
||||
|
||||
Args:
|
||||
datasets (Sequence[BaseDataset]): A list of datasets which will be
|
||||
concatenated.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
datasets: Sequence[BaseDataset],
|
||||
lazy_init: bool = False):
|
||||
# Only use meta of first dataset.
|
||||
self._meta = datasets[0].meta
|
||||
self.datasets = datasets # type: ignore
|
||||
for i, dataset in enumerate(datasets, 1):
|
||||
if self._meta != dataset.meta:
|
||||
warnings.warn(
|
||||
f'The meta information of the {i}-th dataset does not '
|
||||
'match meta information of the first dataset')
|
||||
|
||||
self._fully_initialized = False
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
@property
|
||||
def meta(self) -> dict:
|
||||
"""Get the meta information of the first dataset in ``self.datasets``.
|
||||
|
||||
Returns:
|
||||
dict: Meta information of first dataset.
|
||||
"""
|
||||
# Prevent `self._meta` from being modified by outside.
|
||||
return copy.deepcopy(self._meta)
|
||||
|
||||
def full_init(self):
|
||||
"""Loop to ``full_init`` each dataset."""
|
||||
if self._fully_initialized:
|
||||
return
|
||||
for d in self.datasets:
|
||||
d.full_init()
|
||||
# Get the cumulative sizes of `self.datasets`. For example, the length
|
||||
# of `self.datasets` is [2, 3, 4], the cumulative sizes is [2, 5, 9]
|
||||
super().__init__(self.datasets)
|
||||
self._fully_initialized = True
|
||||
|
||||
@force_full_init
|
||||
def _get_ori_dataset_idx(self, idx: int) -> Tuple[int, int]:
|
||||
"""Convert global idx to local index.
|
||||
|
||||
Args:
|
||||
idx (int): Global index of ``RepeatDataset``.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The index of ``self.datasets`` and the local
|
||||
index of data.
|
||||
"""
|
||||
if idx < 0:
|
||||
if -idx > len(self):
|
||||
raise ValueError(
|
||||
f'absolute value of index({idx}) should not exceed dataset'
|
||||
f'length({len(self)}).')
|
||||
idx = len(self) + idx
|
||||
# Get the inner index of single dataset
|
||||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
||||
if dataset_idx == 0:
|
||||
sample_idx = idx
|
||||
else:
|
||||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
||||
|
||||
return dataset_idx, sample_idx
|
||||
|
||||
@force_full_init
|
||||
def get_data_info(self, idx: int) -> dict:
|
||||
"""Get annotation by index.
|
||||
|
||||
Args:
|
||||
idx (int): Global index of ``ConcatDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The idx-th annotation of the datasets.
|
||||
"""
|
||||
dataset_idx, sample_idx = self._get_ori_dataset_idx(idx)
|
||||
return self.datasets[dataset_idx].get_data_info(sample_idx)
|
||||
|
||||
@force_full_init
|
||||
def __len__(self):
|
||||
return super().__len__()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if not self._fully_initialized:
|
||||
warnings.warn('Please call `full_init` method manually to '
|
||||
'accelerate the speed.')
|
||||
self.full_init()
|
||||
dataset_idx, sample_idx = self._get_ori_dataset_idx(idx)
|
||||
return self.datasets[dataset_idx][sample_idx]
|
||||
|
||||
|
||||
class RepeatDataset:
|
||||
"""A wrapper of repeated dataset.
|
||||
|
||||
The length of repeated dataset will be `times` larger than the original
|
||||
dataset. This is useful when the data loading time is long but the dataset
|
||||
is small. Using RepeatDataset can reduce the data loading time between
|
||||
epochs.
|
||||
|
||||
Args:
|
||||
dataset (BaseDataset): The dataset to be repeated.
|
||||
times (int): Repeat times.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset: BaseDataset,
|
||||
times: int,
|
||||
lazy_init: bool = False):
|
||||
self.dataset = dataset
|
||||
self.times = times
|
||||
self._meta = dataset.meta
|
||||
|
||||
self._fully_initialized = False
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
@property
|
||||
def meta(self) -> dict:
|
||||
"""Get the meta information of the repeated dataset.
|
||||
|
||||
Returns:
|
||||
dict: The meta information of repeated dataset.
|
||||
"""
|
||||
return copy.deepcopy(self._meta)
|
||||
|
||||
def full_init(self):
|
||||
"""Loop to ``full_init`` each dataset."""
|
||||
if self._fully_initialized:
|
||||
return
|
||||
|
||||
self.dataset.full_init()
|
||||
self._ori_len = len(self.dataset)
|
||||
self._fully_initialized = True
|
||||
|
||||
@force_full_init
|
||||
def _get_ori_dataset_idx(self, idx: int) -> int:
|
||||
"""Convert global index to local index.
|
||||
|
||||
Args:
|
||||
idx: Global index of ``RepeatDataset``.
|
||||
|
||||
Returns:
|
||||
idx (int): Local index of data.
|
||||
"""
|
||||
return idx % self._ori_len
|
||||
|
||||
@force_full_init
|
||||
def get_data_info(self, idx: int) -> dict:
|
||||
"""Get annotation by index.
|
||||
|
||||
Args:
|
||||
idx (int): Global index of ``ConcatDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The idx-th annotation of the datasets.
|
||||
"""
|
||||
sample_idx = self._get_ori_dataset_idx(idx)
|
||||
return self.dataset.get_data_info(sample_idx)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if not self._fully_initialized:
|
||||
warnings.warn('Please call `full_init` method manually to '
|
||||
'accelerate the speed.')
|
||||
self.full_init()
|
||||
|
||||
sample_idx = self._get_ori_dataset_idx(idx)
|
||||
return self.dataset[sample_idx]
|
||||
|
||||
@force_full_init
|
||||
def __len__(self):
|
||||
return self.times * self._ori_len
|
||||
|
||||
|
||||
class ClassBalancedDataset:
|
||||
"""A wrapper of class balanced dataset.
|
||||
|
||||
Suitable for training on class imbalanced datasets like LVIS. Following
|
||||
the sampling strategy in the `paper <https://arxiv.org/abs/1908.03195>`_,
|
||||
in each epoch, an image may appear multiple times based on its
|
||||
"repeat factor".
|
||||
The repeat factor for an image is a function of the frequency the rarest
|
||||
category labeled in that image. The "frequency of category c" in [0, 1]
|
||||
is defined by the fraction of images in the training set (without repeats)
|
||||
in which category c appears.
|
||||
The dataset needs to instantiate :meth:`get_cat_ids` to support
|
||||
ClassBalancedDataset.
|
||||
|
||||
The repeat factor is computed as followed.
|
||||
|
||||
1. For each category c, compute the fraction # of images
|
||||
that contain it: :math:`f(c)`
|
||||
2. For each category c, compute the category-level repeat factor:
|
||||
:math:`r(c) = max(1, sqrt(t/f(c)))`
|
||||
3. For each image I, compute the image-level repeat factor:
|
||||
:math:`r(I) = max_{c in I} r(c)`
|
||||
|
||||
Args:
|
||||
dataset (BaseDataset): The dataset to be repeated.
|
||||
oversample_thr (float): frequency threshold below which data is
|
||||
repeated. For categories with ``f_c >= oversample_thr``, there is
|
||||
no oversampling. For categories with ``f_c < oversample_thr``, the
|
||||
degree of oversampling following the square-root inverse frequency
|
||||
heuristic above.
|
||||
lazy_init (bool, optional): whether to load annotation during
|
||||
instantiation. Defaults to False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset: BaseDataset,
|
||||
oversample_thr: float,
|
||||
lazy_init: bool = False):
|
||||
self.dataset = dataset
|
||||
self.oversample_thr = oversample_thr
|
||||
self._meta = dataset.meta
|
||||
|
||||
self._fully_initialized = False
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
@property
|
||||
def meta(self) -> dict:
|
||||
"""Get the meta information of the repeated dataset.
|
||||
|
||||
Returns:
|
||||
dict: The meta information of repeated dataset.
|
||||
"""
|
||||
return copy.deepcopy(self._meta)
|
||||
|
||||
def full_init(self):
|
||||
"""Loop to ``full_init`` each dataset."""
|
||||
if self._fully_initialized:
|
||||
return
|
||||
|
||||
self.dataset.full_init()
|
||||
|
||||
repeat_factors = self._get_repeat_factors(self.dataset,
|
||||
self.oversample_thr)
|
||||
repeat_indices = []
|
||||
for dataset_index, repeat_factor in enumerate(repeat_factors):
|
||||
repeat_indices.extend([dataset_index] * math.ceil(repeat_factor))
|
||||
self.repeat_indices = repeat_indices
|
||||
|
||||
self._fully_initialized = True
|
||||
|
||||
def _get_repeat_factors(self, dataset: BaseDataset,
|
||||
repeat_thr: float) -> List[float]:
|
||||
"""Get repeat factor for each images in the dataset.
|
||||
|
||||
Args:
|
||||
dataset (BaseDataset): The dataset.
|
||||
repeat_thr (float): The threshold of frequency. If an image
|
||||
contains the categories whose frequency below the threshold,
|
||||
it would be repeated.
|
||||
|
||||
Returns:
|
||||
List[float]: The repeat factors for each images in the dataset.
|
||||
"""
|
||||
# 1. For each category c, compute the fraction # of images
|
||||
# that contain it: f(c)
|
||||
category_freq: defaultdict = defaultdict(float)
|
||||
num_images = len(dataset)
|
||||
for idx in range(num_images):
|
||||
cat_ids = set(self.dataset.get_cat_ids(idx))
|
||||
for cat_id in cat_ids:
|
||||
category_freq[cat_id] += 1
|
||||
for k, v in category_freq.items():
|
||||
assert v > 0, f'caterogy {k} does not contain any images'
|
||||
category_freq[k] = v / num_images
|
||||
|
||||
# 2. For each category c, compute the category-level repeat factor:
|
||||
# r(c) = max(1, sqrt(t/f(c)))
|
||||
category_repeat = {
|
||||
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
|
||||
for cat_id, cat_freq in category_freq.items()
|
||||
}
|
||||
|
||||
# 3. For each image I and its labels L(I), compute the image-level
|
||||
# repeat factor:
|
||||
# r(I) = max_{c in L(I)} r(c)
|
||||
repeat_factors = []
|
||||
for idx in range(num_images):
|
||||
cat_ids = set(self.dataset.get_cat_ids(idx))
|
||||
repeat_factor = max(
|
||||
{category_repeat[cat_id]
|
||||
for cat_id in cat_ids})
|
||||
repeat_factors.append(repeat_factor)
|
||||
|
||||
return repeat_factors
|
||||
|
||||
@force_full_init
|
||||
def _get_ori_dataset_idx(self, idx: int) -> int:
|
||||
"""Convert global index to local index.
|
||||
|
||||
Args:
|
||||
idx (int): Global index of ``RepeatDataset``.
|
||||
|
||||
Returns:
|
||||
int: Local index of data.
|
||||
"""
|
||||
return self.repeat_indices[idx]
|
||||
|
||||
@force_full_init
|
||||
def get_cat_ids(self, idx: int) -> List[int]:
|
||||
"""Get category ids of class balanced dataset by index.
|
||||
|
||||
Args:
|
||||
idx (int): Index of data.
|
||||
|
||||
Returns:
|
||||
List[int]: All categories in the image of specified index.
|
||||
"""
|
||||
sample_idx = self._get_ori_dataset_idx(idx)
|
||||
return self.dataset.get_cat_ids(sample_idx)
|
||||
|
||||
@force_full_init
|
||||
def get_data_info(self, idx: int) -> dict:
|
||||
"""Get annotation by index.
|
||||
|
||||
Args:
|
||||
idx (int): Global index of ``ConcatDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The idx-th annotation of the dataset.
|
||||
"""
|
||||
sample_idx = self._get_ori_dataset_idx(idx)
|
||||
return self.dataset.get_data_info(sample_idx)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
warnings.warn('Please call `full_init` method manually to '
|
||||
'accelerate the speed.')
|
||||
if not self._fully_initialized:
|
||||
self.full_init()
|
||||
|
||||
ori_index = self._get_ori_dataset_idx(idx)
|
||||
return self.dataset[ori_index]
|
||||
|
||||
@force_full_init
|
||||
def __len__(self):
|
||||
return len(self.repeat_indices)
|
5
tests/data/annotations/annotation_wrong_format.json
Normal file
5
tests/data/annotations/annotation_wrong_format.json
Normal file
@ -0,0 +1,5 @@
|
||||
[
|
||||
{
|
||||
"img_path": "test_img.jpg"
|
||||
}
|
||||
]
|
@ -2,7 +2,8 @@
|
||||
"metadata":
|
||||
{
|
||||
"dataset_type": "test_dataset",
|
||||
"task_name": "test_task"
|
||||
"task_name": "test_task",
|
||||
"empty_list": []
|
||||
},
|
||||
"data_infos":
|
||||
[
|
||||
|
1
tests/data/meta/classes.txt
Normal file
1
tests/data/meta/classes.txt
Normal file
@ -0,0 +1 @@
|
||||
dog
|
@ -1,40 +1,66 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.data import (BaseDataset, ClassBalancedDataset, ConcatDataset,
|
||||
RepeatDataset)
|
||||
from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose,
|
||||
ConcatDataset, RepeatDataset, force_full_init)
|
||||
from mmengine.registry import TRANSFORMS
|
||||
|
||||
|
||||
def function_pipeline(data_info):
|
||||
return data_info
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class CallableTransform:
|
||||
|
||||
def __call__(self, data_info):
|
||||
return data_info
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class NotCallableTransform:
|
||||
pass
|
||||
|
||||
|
||||
class TestBaseDataset:
|
||||
dataset_type = BaseDataset
|
||||
data_info = dict(
|
||||
filename='test_img.jpg', height=604, width=640, sample_idx=0)
|
||||
imgs = torch.rand((2, 3, 32, 32))
|
||||
pipeline = MagicMock(return_value=dict(imgs=imgs))
|
||||
META: dict = dict()
|
||||
parse_annotations = MagicMock(return_value=data_info)
|
||||
|
||||
def __init__(self):
|
||||
self.base_dataset = BaseDataset
|
||||
|
||||
self.data_info = dict(filename='test_img.jpg', height=604, width=640)
|
||||
self.base_dataset.parse_annotations = MagicMock(
|
||||
return_value=self.data_info)
|
||||
|
||||
self.imgs = torch.rand((2, 3, 32, 32))
|
||||
self.base_dataset.pipeline = MagicMock(
|
||||
return_value=dict(imgs=self.imgs))
|
||||
def _init_dataset(self):
|
||||
self.dataset_type.META = self.META
|
||||
self.dataset_type.parse_annotations = self.parse_annotations
|
||||
|
||||
def test_init(self):
|
||||
self._init_dataset()
|
||||
# test the instantiation of self.base_dataset
|
||||
dataset = self.base_dataset(
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json')
|
||||
assert dataset._fully_initialized
|
||||
assert hasattr(dataset, 'data_infos')
|
||||
assert hasattr(dataset, 'data_address')
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img=None),
|
||||
ann_file='annotations/dummy_annotation.json')
|
||||
assert dataset._fully_initialized
|
||||
assert hasattr(dataset, 'data_infos')
|
||||
assert hasattr(dataset, 'data_address')
|
||||
|
||||
# test the instantiation of self.base_dataset with
|
||||
# `serialize_data=False`
|
||||
dataset = self.base_dataset(
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
@ -42,33 +68,54 @@ class TestBaseDataset:
|
||||
assert dataset._fully_initialized
|
||||
assert hasattr(dataset, 'data_infos')
|
||||
assert not hasattr(dataset, 'data_address')
|
||||
assert len(dataset) == 2
|
||||
assert dataset.get_data_info(0) == self.data_info
|
||||
|
||||
# test the instantiation of self.base_dataset with lazy init
|
||||
dataset = self.base_dataset(
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
lazy_init=True)
|
||||
assert not dataset._fully_initialized
|
||||
assert not hasattr(dataset, 'data_infos')
|
||||
assert not dataset.data_infos
|
||||
|
||||
# test the instantiation of self.base_dataset if ann_file is not
|
||||
# existed.
|
||||
with pytest.raises(FileNotFoundError):
|
||||
self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/not_existed_annotation.json')
|
||||
|
||||
# test the instantiation of self.base_dataset when the ann_file is
|
||||
# wrong
|
||||
with pytest.raises(ValueError):
|
||||
self.base_dataset(
|
||||
self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/wrong_annotation.json')
|
||||
ann_file='annotations/annotation_wrong_keys.json')
|
||||
with pytest.raises(TypeError):
|
||||
self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/annotation_wrong_format.json')
|
||||
with pytest.raises(TypeError):
|
||||
self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img=['img']),
|
||||
ann_file='annotations/annotation_wrong_format.json')
|
||||
|
||||
# test the instantiation of self.base_dataset when `parse_annotations`
|
||||
# return `list[dict]`
|
||||
self.base_dataset.parse_annotations = MagicMock(
|
||||
self.dataset_type.parse_annotations = MagicMock(
|
||||
return_value=[self.data_info,
|
||||
self.data_info.copy()])
|
||||
dataset = self.base_dataset(
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json')
|
||||
dataset.pipeline = self.pipeline
|
||||
assert dataset._fully_initialized
|
||||
assert hasattr(dataset, 'data_infos')
|
||||
assert hasattr(dataset, 'data_address')
|
||||
@ -76,98 +123,139 @@ class TestBaseDataset:
|
||||
assert dataset[0] == dict(imgs=self.imgs)
|
||||
assert dataset.get_data_info(0) == self.data_info
|
||||
|
||||
# set self.base_dataset to initial state
|
||||
self.__init__()
|
||||
# test the instantiation of self.base_dataset when `parse_annotations`
|
||||
# return unsupported data.
|
||||
with pytest.raises(TypeError):
|
||||
self.dataset_type.parse_annotations = MagicMock(return_value='xxx')
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json')
|
||||
with pytest.raises(TypeError):
|
||||
self.dataset_type.parse_annotations = MagicMock(
|
||||
return_value=[self.data_info, 'xxx'])
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json')
|
||||
|
||||
def test_meta(self):
|
||||
self._init_dataset()
|
||||
# test dataset.meta with setting the meta from annotation file as the
|
||||
# meta of self.base_dataset
|
||||
dataset = self.base_dataset(
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json')
|
||||
|
||||
assert dataset.meta == dict(
|
||||
dataset_type='test_dataset', task_name='test_task')
|
||||
dataset_type='test_dataset', task_name='test_task', empty_list=[])
|
||||
|
||||
# test dataset.meta with setting META in self.base_dataset
|
||||
dataset_type = 'new_dataset'
|
||||
self.base_dataset.META = dict(
|
||||
self.dataset_type.META = dict(
|
||||
dataset_type=dataset_type, classes=('dog', 'cat'))
|
||||
|
||||
dataset = self.base_dataset(
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json')
|
||||
assert dataset.meta == dict(
|
||||
dataset_type=dataset_type,
|
||||
task_name='test_task',
|
||||
classes=('dog', 'cat'))
|
||||
classes=('dog', 'cat'),
|
||||
empty_list=[])
|
||||
|
||||
# test dataset.meta with passing meta into self.base_dataset
|
||||
meta = dict(classes=('dog', ))
|
||||
dataset = self.base_dataset(
|
||||
meta = dict(classes=('dog', ), task_name='new_task')
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
meta=meta)
|
||||
assert self.base_dataset.META == dict(
|
||||
assert self.dataset_type.META == dict(
|
||||
dataset_type=dataset_type, classes=('dog', 'cat'))
|
||||
assert dataset.meta == dict(
|
||||
dataset_type=dataset_type,
|
||||
task_name='test_task',
|
||||
classes=('dog', ))
|
||||
task_name='new_task',
|
||||
classes=('dog', ),
|
||||
empty_list=[])
|
||||
# reset `base_dataset.META`, the `dataset.meta` should not change
|
||||
self.base_dataset.META['classes'] = ('dog', 'cat', 'fish')
|
||||
assert self.base_dataset.META == dict(
|
||||
self.dataset_type.META['classes'] = ('dog', 'cat', 'fish')
|
||||
assert self.dataset_type.META == dict(
|
||||
dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
|
||||
assert dataset.meta == dict(
|
||||
dataset_type=dataset_type,
|
||||
task_name='new_task',
|
||||
classes=('dog', ),
|
||||
empty_list=[])
|
||||
|
||||
# test dataset.meta with passing meta containing a file into
|
||||
# self.base_dataset
|
||||
meta = dict(
|
||||
classes=osp.join(
|
||||
osp.dirname(__file__), '../data/meta/classes.txt'))
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
meta=meta)
|
||||
assert dataset.meta == dict(
|
||||
dataset_type=dataset_type,
|
||||
task_name='test_task',
|
||||
classes=('dog', ))
|
||||
classes=['dog'],
|
||||
empty_list=[])
|
||||
|
||||
# test dataset.meta with passing unsupported meta into
|
||||
# self.base_dataset
|
||||
with pytest.raises(TypeError):
|
||||
meta = 'dog'
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
meta=meta)
|
||||
|
||||
# test dataset.meta with passing meta into self.base_dataset and
|
||||
# lazy_init is True
|
||||
meta = dict(classes=('dog', ))
|
||||
dataset = self.base_dataset(
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
meta=meta,
|
||||
lazy_init=True)
|
||||
# 'task_name' not in dataset.meta
|
||||
# 'task_name' and 'empty_list' not in dataset.meta
|
||||
assert dataset.meta == dict(
|
||||
dataset_type=dataset_type, classes=('dog', ))
|
||||
|
||||
# test whether self.base_dataset.META is changed when a customize
|
||||
# dataset inherit self.base_dataset
|
||||
# test reset META in ToyDataset.
|
||||
class ToyDataset(self.base_dataset):
|
||||
class ToyDataset(self.dataset_type):
|
||||
META = dict(xxx='xxx')
|
||||
|
||||
assert ToyDataset.META == dict(xxx='xxx')
|
||||
assert self.base_dataset.META == dict(
|
||||
assert self.dataset_type.META == dict(
|
||||
dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
|
||||
|
||||
# test update META in ToyDataset.
|
||||
class ToyDataset(self.base_dataset):
|
||||
self.base_dataset.META['classes'] = ('bird', )
|
||||
class ToyDataset(self.dataset_type):
|
||||
META = copy.deepcopy(self.dataset_type.META)
|
||||
META['classes'] = ('bird', )
|
||||
|
||||
assert ToyDataset.META == dict(
|
||||
dataset_type=dataset_type, classes=('bird', ))
|
||||
assert self.base_dataset.META == dict(
|
||||
assert self.dataset_type.META == dict(
|
||||
dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
|
||||
|
||||
# set self.base_dataset to initial state
|
||||
self.__init__()
|
||||
|
||||
@pytest.mark.parametrize('lazy_init', [True, False])
|
||||
def test_length(self, lazy_init):
|
||||
dataset = self.base_dataset(
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
lazy_init=lazy_init)
|
||||
|
||||
if not lazy_init:
|
||||
assert dataset._fully_initialized
|
||||
assert hasattr(dataset, 'data_infos')
|
||||
@ -175,20 +263,49 @@ class TestBaseDataset:
|
||||
else:
|
||||
# test `__len__()` when lazy_init is True
|
||||
assert not dataset._fully_initialized
|
||||
assert not hasattr(dataset, 'data_infos')
|
||||
assert not dataset.data_infos
|
||||
# call `full_init()` automatically
|
||||
assert len(dataset) == 2
|
||||
assert dataset._fully_initialized
|
||||
assert hasattr(dataset, 'data_infos')
|
||||
|
||||
def test_compose(self):
|
||||
# test callable transform
|
||||
transforms = [function_pipeline]
|
||||
compose = Compose(transforms=transforms)
|
||||
assert (self.imgs == compose(dict(img=self.imgs))['img']).all()
|
||||
# test transform build from cfg_dict
|
||||
transforms = [dict(type='CallableTransform')]
|
||||
compose = Compose(transforms=transforms)
|
||||
assert (self.imgs == compose(dict(img=self.imgs))['img']).all()
|
||||
# test return None in advance
|
||||
none_func = MagicMock(return_value=None)
|
||||
transforms = [none_func, function_pipeline]
|
||||
compose = Compose(transforms=transforms)
|
||||
assert compose(dict(img=self.imgs)) is None
|
||||
# test repr
|
||||
repr_str = f'Compose(\n' \
|
||||
f' {none_func}\n' \
|
||||
f' {function_pipeline}\n' \
|
||||
f')'
|
||||
assert repr(compose) == repr_str
|
||||
# non-callable transform will raise error
|
||||
with pytest.raises(TypeError):
|
||||
transforms = [dict(type='NotCallableTransform')]
|
||||
Compose(transforms)
|
||||
|
||||
# transform must be callable or dict
|
||||
with pytest.raises(TypeError):
|
||||
Compose([1])
|
||||
|
||||
@pytest.mark.parametrize('lazy_init', [True, False])
|
||||
def test_getitem(self, lazy_init):
|
||||
dataset = self.base_dataset(
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
lazy_init=lazy_init)
|
||||
|
||||
dataset.pipeline = self.pipeline
|
||||
if not lazy_init:
|
||||
assert dataset._fully_initialized
|
||||
assert hasattr(dataset, 'data_infos')
|
||||
@ -196,15 +313,26 @@ class TestBaseDataset:
|
||||
else:
|
||||
# test `__getitem__()` when lazy_init is True
|
||||
assert not dataset._fully_initialized
|
||||
assert not hasattr(dataset, 'data_infos')
|
||||
assert not dataset.data_infos
|
||||
# call `full_init()` automatically
|
||||
assert dataset[0] == dict(imgs=self.imgs)
|
||||
assert dataset._fully_initialized
|
||||
assert hasattr(dataset, 'data_infos')
|
||||
|
||||
# test with test mode
|
||||
dataset.test_mode = True
|
||||
assert dataset[0] == dict(imgs=self.imgs)
|
||||
|
||||
pipeline = MagicMock(return_value=None)
|
||||
dataset.pipeline = pipeline
|
||||
# test cannot get a valid image.
|
||||
dataset.test_mode = False
|
||||
with pytest.raises(Exception):
|
||||
dataset[0]
|
||||
|
||||
@pytest.mark.parametrize('lazy_init', [True, False])
|
||||
def test_get_data_info(self, lazy_init):
|
||||
dataset = self.base_dataset(
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
@ -217,20 +345,32 @@ class TestBaseDataset:
|
||||
else:
|
||||
# test `get_data_info()` when lazy_init is True
|
||||
assert not dataset._fully_initialized
|
||||
assert not hasattr(dataset, 'data_infos')
|
||||
assert not dataset.data_infos
|
||||
# call `full_init()` automatically
|
||||
assert dataset.get_data_info(0) == self.data_info
|
||||
assert dataset._fully_initialized
|
||||
assert hasattr(dataset, 'data_infos')
|
||||
|
||||
def test_force_full_init(self):
|
||||
with pytest.raises(AttributeError):
|
||||
|
||||
class ClassWithoutFullInit:
|
||||
|
||||
@force_full_init
|
||||
def foo(self):
|
||||
pass
|
||||
|
||||
class_without_full_init = ClassWithoutFullInit()
|
||||
class_without_full_init.foo()
|
||||
|
||||
@pytest.mark.parametrize('lazy_init', [True, False])
|
||||
def test_full_init(self, lazy_init):
|
||||
dataset = self.base_dataset(
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
lazy_init=lazy_init)
|
||||
|
||||
dataset.pipeline = self.pipeline
|
||||
if not lazy_init:
|
||||
assert dataset._fully_initialized
|
||||
assert hasattr(dataset, 'data_infos')
|
||||
@ -240,7 +380,7 @@ class TestBaseDataset:
|
||||
else:
|
||||
# test `full_init()` when lazy_init is True
|
||||
assert not dataset._fully_initialized
|
||||
assert not hasattr(dataset, 'data_infos')
|
||||
assert not dataset.data_infos
|
||||
# call `full_init()` manually
|
||||
dataset.full_init()
|
||||
assert dataset._fully_initialized
|
||||
@ -249,55 +389,116 @@ class TestBaseDataset:
|
||||
assert dataset[0] == dict(imgs=self.imgs)
|
||||
assert dataset.get_data_info(0) == self.data_info
|
||||
|
||||
def test_slice_data(self):
|
||||
# test the instantiation of self.base_dataset when passing num_samples
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img=None),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
num_samples=1)
|
||||
assert len(dataset) == 1
|
||||
|
||||
def test_rand_another(self):
|
||||
# test the instantiation of self.base_dataset when passing num_samples
|
||||
dataset = self.dataset_type(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img=None),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
num_samples=1)
|
||||
assert dataset._rand_another() >= 0
|
||||
assert dataset._rand_another() < len(dataset)
|
||||
|
||||
|
||||
class TestConcatDataset:
|
||||
|
||||
def __init__(self):
|
||||
def _init_dataset(self):
|
||||
dataset = BaseDataset
|
||||
|
||||
# create dataset_a
|
||||
data_info = dict(filename='test_img.jpg', height=604, width=640)
|
||||
dataset.parse_annotations = MagicMock(return_value=data_info)
|
||||
imgs = torch.rand((2, 3, 32, 32))
|
||||
dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
||||
|
||||
self.dataset_a = dataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json')
|
||||
self.dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
||||
|
||||
# create dataset_b
|
||||
data_info = dict(filename='gray.jpg', height=288, width=512)
|
||||
dataset.parse_annotations = MagicMock(return_value=data_info)
|
||||
imgs = torch.rand((2, 3, 32, 32))
|
||||
dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
||||
self.dataset_b = dataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
meta=dict(classes=('dog', 'cat')))
|
||||
|
||||
self.dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
||||
# test init
|
||||
self.cat_datasets = ConcatDataset(
|
||||
datasets=[self.dataset_a, self.dataset_b])
|
||||
|
||||
def test_full_init(self):
|
||||
dataset = BaseDataset
|
||||
|
||||
# create dataset_a
|
||||
data_info = dict(filename='test_img.jpg', height=604, width=640)
|
||||
dataset.parse_annotations = MagicMock(return_value=data_info)
|
||||
imgs = torch.rand((2, 3, 32, 32))
|
||||
|
||||
dataset_a = dataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json')
|
||||
dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
||||
|
||||
# create dataset_b
|
||||
data_info = dict(filename='gray.jpg', height=288, width=512)
|
||||
dataset.parse_annotations = MagicMock(return_value=data_info)
|
||||
imgs = torch.rand((2, 3, 32, 32))
|
||||
dataset_b = dataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
meta=dict(classes=('dog', 'cat')))
|
||||
dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
||||
# test init with lazy_init=True
|
||||
cat_datasets = ConcatDataset(
|
||||
datasets=[dataset_a, dataset_b], lazy_init=True)
|
||||
cat_datasets.full_init()
|
||||
assert len(cat_datasets) == 4
|
||||
cat_datasets.full_init()
|
||||
cat_datasets._fully_initialized = False
|
||||
cat_datasets[1]
|
||||
assert len(cat_datasets) == 4
|
||||
|
||||
def test_meta(self):
|
||||
self._init_dataset()
|
||||
assert self.cat_datasets.meta == self.dataset_a.meta
|
||||
# meta of self.cat_datasets is from the first dataset when
|
||||
# concatnating datasets with different metas.
|
||||
assert self.cat_datasets.meta != self.dataset_b.meta
|
||||
|
||||
def test_length(self):
|
||||
self._init_dataset()
|
||||
assert len(self.cat_datasets) == (
|
||||
len(self.dataset_a) + len(self.dataset_b))
|
||||
|
||||
def test_getitem(self):
|
||||
assert self.cat_datasets[0] == self.dataset_a[0]
|
||||
assert self.cat_datasets[0] != self.dataset_b[0]
|
||||
self._init_dataset()
|
||||
assert (
|
||||
self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all()
|
||||
assert (self.cat_datasets[0]['imgs'] !=
|
||||
self.dataset_b[0]['imgs']).all()
|
||||
|
||||
assert self.cat_datasets[-1] == self.dataset_b[-1]
|
||||
assert self.cat_datasets[-1] != self.dataset_a[-1]
|
||||
assert (
|
||||
self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all()
|
||||
assert (self.cat_datasets[-1]['imgs'] !=
|
||||
self.dataset_a[-1]['imgs']).all()
|
||||
|
||||
def test_get_data_info(self):
|
||||
self._init_dataset()
|
||||
assert self.cat_datasets.get_data_info(
|
||||
0) == self.dataset_a.get_data_info(0)
|
||||
assert self.cat_datasets.get_data_info(
|
||||
@ -306,40 +507,76 @@ class TestConcatDataset:
|
||||
assert self.cat_datasets.get_data_info(
|
||||
-1) == self.dataset_b.get_data_info(-1)
|
||||
assert self.cat_datasets.get_data_info(
|
||||
-1) != self.dataset_a[-1].get_data_info(-1)
|
||||
-1) != self.dataset_a.get_data_info(-1)
|
||||
|
||||
def test_get_ori_dataset_idx(self):
|
||||
self._init_dataset()
|
||||
assert self.cat_datasets._get_ori_dataset_idx(3) == (
|
||||
1, 3 - len(self.dataset_a))
|
||||
assert self.cat_datasets._get_ori_dataset_idx(-1) == (
|
||||
1, len(self.dataset_b) - 1)
|
||||
with pytest.raises(ValueError):
|
||||
assert self.cat_datasets._get_ori_dataset_idx(-10)
|
||||
|
||||
|
||||
class TestRepeatDataset:
|
||||
|
||||
def __init__(self):
|
||||
def _init_dataset(self):
|
||||
dataset = BaseDataset
|
||||
data_info = dict(filename='test_img.jpg', height=604, width=640)
|
||||
dataset.parse_annotations = MagicMock(return_value=data_info)
|
||||
imgs = torch.rand((2, 3, 32, 32))
|
||||
dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
||||
self.dataset = dataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json')
|
||||
self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
||||
|
||||
self.repeat_times = 5
|
||||
# test init
|
||||
self.repeat_datasets = RepeatDataset(
|
||||
dataset=self.dataset, times=self.repeat_times)
|
||||
|
||||
def test_full_init(self):
|
||||
dataset = BaseDataset
|
||||
data_info = dict(filename='test_img.jpg', height=604, width=640)
|
||||
dataset.parse_annotations = MagicMock(return_value=data_info)
|
||||
imgs = torch.rand((2, 3, 32, 32))
|
||||
dataset = dataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json')
|
||||
dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
||||
|
||||
repeat_times = 5
|
||||
# test init
|
||||
repeat_datasets = RepeatDataset(
|
||||
dataset=dataset, times=repeat_times, lazy_init=True)
|
||||
|
||||
repeat_datasets.full_init()
|
||||
assert len(repeat_datasets) == repeat_times * len(dataset)
|
||||
repeat_datasets.full_init()
|
||||
repeat_datasets._fully_initialized = False
|
||||
repeat_datasets[1]
|
||||
assert len(repeat_datasets) == repeat_times * len(dataset)
|
||||
|
||||
def test_meta(self):
|
||||
self._init_dataset()
|
||||
assert self.repeat_datasets.meta == self.dataset.meta
|
||||
|
||||
def test_length(self):
|
||||
self._init_dataset()
|
||||
assert len(
|
||||
self.repeat_datasets) == len(self.dataset) * self.repeat_times
|
||||
|
||||
def test_getitem(self):
|
||||
self._init_dataset()
|
||||
for i in range(self.repeat_times):
|
||||
assert self.repeat_datasets[len(self.dataset) *
|
||||
i] == self.dataset[0]
|
||||
|
||||
def test_get_data_info(self):
|
||||
self._init_dataset()
|
||||
for i in range(self.repeat_times):
|
||||
assert self.repeat_datasets.get_data_info(
|
||||
len(self.dataset) * i) == self.dataset.get_data_info(0)
|
||||
@ -347,17 +584,17 @@ class TestRepeatDataset:
|
||||
|
||||
class TestClassBalancedDataset:
|
||||
|
||||
def __init__(self):
|
||||
def _init_dataset(self):
|
||||
dataset = BaseDataset
|
||||
data_info = dict(filename='test_img.jpg', height=604, width=640)
|
||||
dataset.parse_annotations = MagicMock(return_value=data_info)
|
||||
imgs = torch.rand((2, 3, 32, 32))
|
||||
dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
||||
dataset.get_cat_ids = MagicMock(return_value=[0])
|
||||
self.dataset = dataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json')
|
||||
self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
||||
|
||||
self.repeat_indices = [0, 0, 1, 1, 1]
|
||||
# test init
|
||||
@ -365,23 +602,54 @@ class TestClassBalancedDataset:
|
||||
dataset=self.dataset, oversample_thr=1e-3)
|
||||
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
|
||||
|
||||
def test_full_init(self):
|
||||
dataset = BaseDataset
|
||||
data_info = dict(filename='test_img.jpg', height=604, width=640)
|
||||
dataset.parse_annotations = MagicMock(return_value=data_info)
|
||||
imgs = torch.rand((2, 3, 32, 32))
|
||||
dataset.get_cat_ids = MagicMock(return_value=[0])
|
||||
dataset = dataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json')
|
||||
dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
|
||||
|
||||
repeat_indices = [0, 0, 1, 1, 1]
|
||||
# test init
|
||||
cls_banlanced_datasets = ClassBalancedDataset(
|
||||
dataset=dataset, oversample_thr=1e-3, lazy_init=True)
|
||||
|
||||
cls_banlanced_datasets.full_init()
|
||||
cls_banlanced_datasets.repeat_indices = repeat_indices
|
||||
assert len(cls_banlanced_datasets) == len(repeat_indices)
|
||||
cls_banlanced_datasets.full_init()
|
||||
cls_banlanced_datasets._fully_initialized = False
|
||||
cls_banlanced_datasets[1]
|
||||
cls_banlanced_datasets.repeat_indices = repeat_indices
|
||||
assert len(cls_banlanced_datasets) == len(repeat_indices)
|
||||
|
||||
def test_meta(self):
|
||||
self._init_dataset()
|
||||
assert self.cls_banlanced_datasets.meta == self.dataset.meta
|
||||
|
||||
def test_length(self):
|
||||
self._init_dataset()
|
||||
assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
|
||||
|
||||
def test_getitem(self):
|
||||
self._init_dataset()
|
||||
for i in range(len(self.repeat_indices)):
|
||||
assert self.cls_banlanced_datasets[i] == self.dataset[
|
||||
self.repeat_indices[i]]
|
||||
|
||||
def test_get_data_info(self):
|
||||
self._init_dataset()
|
||||
for i in range(len(self.repeat_indices)):
|
||||
assert self.cls_banlanced_datasets.get_data_info(
|
||||
i) == self.dataset.get_data_info(self.repeat_indices[i])
|
||||
|
||||
def test_get_cat_ids(self):
|
||||
self._init_dataset()
|
||||
for i in range(len(self.repeat_indices)):
|
||||
assert self.cls_banlanced_datasets.get_cat_ids(
|
||||
i) == self.dataset.get_cat_ids(self.repeat_indices[i])
|
||||
|
Loading…
x
Reference in New Issue
Block a user