mmclassification/tests/test_data/test_datasets/test_dataset_wrapper.py

172 lines
6.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import bisect
import math
from collections import defaultdict
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from mmcls.datasets import (BaseDataset, ClassBalancedDataset, ConcatDataset,
KFoldDataset, RepeatDataset)
def mock_evaluate(results,
metric='accuracy',
metric_options=None,
indices=None,
logger=None):
return dict(
results=results,
metric=metric,
metric_options=metric_options,
indices=indices,
logger=logger)
@patch.multiple(BaseDataset, __abstractmethods__=set())
def construct_toy_multi_label_dataset(length):
BaseDataset.CLASSES = ('foo', 'bar')
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
cat_ids_list = [
np.random.randint(0, 80, num).tolist()
for num in np.random.randint(1, 20, length)
]
dataset.data_infos = MagicMock()
dataset.data_infos.__len__.return_value = length
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
dataset.evaluate = MagicMock(side_effect=mock_evaluate)
return dataset, cat_ids_list
@patch.multiple(BaseDataset, __abstractmethods__=set())
def construct_toy_single_label_dataset(length):
BaseDataset.CLASSES = ('foo', 'bar')
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
cat_ids_list = [[np.random.randint(0, 80)] for _ in range(length)]
dataset.data_infos = MagicMock()
dataset.data_infos.__len__.return_value = length
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
dataset.evaluate = MagicMock(side_effect=mock_evaluate)
return dataset, cat_ids_list
@pytest.mark.parametrize('construct_dataset', [
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
])
def test_concat_dataset(construct_dataset):
construct_toy_dataset = eval(construct_dataset)
dataset_a, cat_ids_list_a = construct_toy_dataset(10)
dataset_b, cat_ids_list_b = construct_toy_dataset(20)
concat_dataset = ConcatDataset([dataset_a, dataset_b])
assert concat_dataset[5] == 5
assert concat_dataset[25] == 15
assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15]
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
assert concat_dataset.CLASSES == BaseDataset.CLASSES
@pytest.mark.parametrize('construct_dataset', [
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
])
def test_repeat_dataset(construct_dataset):
construct_toy_dataset = eval(construct_dataset)
dataset, cat_ids_list = construct_toy_dataset(10)
repeat_dataset = RepeatDataset(dataset, 10)
assert repeat_dataset[5] == 5
assert repeat_dataset[15] == 5
assert repeat_dataset[27] == 7
assert repeat_dataset.get_cat_ids(5) == cat_ids_list[5]
assert repeat_dataset.get_cat_ids(15) == cat_ids_list[5]
assert repeat_dataset.get_cat_ids(27) == cat_ids_list[7]
assert len(repeat_dataset) == 10 * len(dataset)
assert repeat_dataset.CLASSES == BaseDataset.CLASSES
@pytest.mark.parametrize('construct_dataset', [
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
])
def test_class_balanced_dataset(construct_dataset):
construct_toy_dataset = eval(construct_dataset)
dataset, cat_ids_list = construct_toy_dataset(10)
category_freq = defaultdict(int)
for cat_ids in cat_ids_list:
cat_ids = set(cat_ids)
for cat_id in cat_ids:
category_freq[cat_id] += 1
for k, v in category_freq.items():
category_freq[k] = v / len(cat_ids_list)
mean_freq = np.mean(list(category_freq.values()))
repeat_thr = mean_freq
category_repeat = {
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
for cat_id, cat_freq in category_freq.items()
}
repeat_factors = []
for cat_ids in cat_ids_list:
cat_ids = set(cat_ids)
repeat_factor = max({category_repeat[cat_id] for cat_id in cat_ids})
repeat_factors.append(math.ceil(repeat_factor))
repeat_factors_cumsum = np.cumsum(repeat_factors)
repeat_factor_dataset = ClassBalancedDataset(dataset, repeat_thr)
assert repeat_factor_dataset.CLASSES == BaseDataset.CLASSES
assert len(repeat_factor_dataset) == repeat_factors_cumsum[-1]
for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
assert repeat_factor_dataset[idx] == bisect.bisect_right(
repeat_factors_cumsum, idx)
@pytest.mark.parametrize('construct_dataset', [
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
])
def test_kfold_dataset(construct_dataset):
construct_toy_dataset = eval(construct_dataset)
dataset, _ = construct_toy_dataset(10)
# test without random seed
train_datasets = [
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=False)
for i in range(5)
]
test_datasets = [
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=True)
for i in range(5)
]
assert sum([i.indices for i in test_datasets], []) == list(range(10))
for train_set, test_set in zip(train_datasets, test_datasets):
train_samples = [train_set[i] for i in range(len(train_set))]
test_samples = [test_set[i] for i in range(len(test_set))]
assert set(train_samples + test_samples) == set(range(10))
# test with random seed
train_datasets = [
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=False, seed=1)
for i in range(5)
]
test_datasets = [
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=True, seed=1)
for i in range(5)
]
assert sum([i.indices for i in test_datasets], []) != list(range(10))
assert set(sum([i.indices for i in test_datasets], [])) == set(range(10))
for train_set, test_set in zip(train_datasets, test_datasets):
train_samples = [train_set[i] for i in range(len(train_set))]
test_samples = [test_set[i] for i in range(len(test_set))]
assert set(train_samples + test_samples) == set(range(10))
# test evaluate
for test_set in test_datasets:
eval_inputs = test_set.evaluate(None)
assert eval_inputs['indices'] == test_set.indices