mmpretrain/tests/test_data/test_datasets/test_dataset_wrapper.py

110 lines
4.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import bisect
import math
from collections import defaultdict
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from mmcls.datasets import (BaseDataset, ClassBalancedDataset, ConcatDataset,
RepeatDataset)
@patch.multiple(BaseDataset, __abstractmethods__=set())
def construct_toy_multi_label_dataset(length):
BaseDataset.CLASSES = ('foo', 'bar')
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
cat_ids_list = [
np.random.randint(0, 80, num).tolist()
for num in np.random.randint(1, 20, length)
]
dataset.data_infos = MagicMock()
dataset.data_infos.__len__.return_value = length
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
return dataset, cat_ids_list
@patch.multiple(BaseDataset, __abstractmethods__=set())
def construct_toy_single_label_dataset(length):
BaseDataset.CLASSES = ('foo', 'bar')
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
cat_ids_list = [[np.random.randint(0, 80)] for _ in range(length)]
dataset.data_infos = MagicMock()
dataset.data_infos.__len__.return_value = length
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
return dataset, cat_ids_list
@pytest.mark.parametrize('construct_dataset', [
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
])
def test_concat_dataset(construct_dataset):
construct_toy_dataset = eval(construct_dataset)
dataset_a, cat_ids_list_a = construct_toy_dataset(10)
dataset_b, cat_ids_list_b = construct_toy_dataset(20)
concat_dataset = ConcatDataset([dataset_a, dataset_b])
assert concat_dataset[5] == 5
assert concat_dataset[25] == 15
assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15]
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
assert concat_dataset.CLASSES == BaseDataset.CLASSES
@pytest.mark.parametrize('construct_dataset', [
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
])
def test_repeat_dataset(construct_dataset):
construct_toy_dataset = eval(construct_dataset)
dataset, cat_ids_list = construct_toy_dataset(10)
repeat_dataset = RepeatDataset(dataset, 10)
assert repeat_dataset[5] == 5
assert repeat_dataset[15] == 5
assert repeat_dataset[27] == 7
assert repeat_dataset.get_cat_ids(5) == cat_ids_list[5]
assert repeat_dataset.get_cat_ids(15) == cat_ids_list[5]
assert repeat_dataset.get_cat_ids(27) == cat_ids_list[7]
assert len(repeat_dataset) == 10 * len(dataset)
assert repeat_dataset.CLASSES == BaseDataset.CLASSES
@pytest.mark.parametrize('construct_dataset', [
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
])
def test_class_balanced_dataset(construct_dataset):
construct_toy_dataset = eval(construct_dataset)
dataset, cat_ids_list = construct_toy_dataset(10)
category_freq = defaultdict(int)
for cat_ids in cat_ids_list:
cat_ids = set(cat_ids)
for cat_id in cat_ids:
category_freq[cat_id] += 1
for k, v in category_freq.items():
category_freq[k] = v / len(cat_ids_list)
mean_freq = np.mean(list(category_freq.values()))
repeat_thr = mean_freq
category_repeat = {
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
for cat_id, cat_freq in category_freq.items()
}
repeat_factors = []
for cat_ids in cat_ids_list:
cat_ids = set(cat_ids)
repeat_factor = max({category_repeat[cat_id] for cat_id in cat_ids})
repeat_factors.append(math.ceil(repeat_factor))
repeat_factors_cumsum = np.cumsum(repeat_factors)
repeat_factor_dataset = ClassBalancedDataset(dataset, repeat_thr)
assert repeat_factor_dataset.CLASSES == BaseDataset.CLASSES
assert len(repeat_factor_dataset) == repeat_factors_cumsum[-1]
for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
assert repeat_factor_dataset[idx] == bisect.bisect_right(
repeat_factors_cumsum, idx)