172 lines
6.4 KiB
Python
172 lines
6.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import bisect
|
|
import math
|
|
from collections import defaultdict
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from mmcls.datasets import (BaseDataset, ClassBalancedDataset, ConcatDataset,
|
|
KFoldDataset, RepeatDataset)
|
|
|
|
|
|
def mock_evaluate(results,
|
|
metric='accuracy',
|
|
metric_options=None,
|
|
indices=None,
|
|
logger=None):
|
|
return dict(
|
|
results=results,
|
|
metric=metric,
|
|
metric_options=metric_options,
|
|
indices=indices,
|
|
logger=logger)
|
|
|
|
|
|
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
|
def construct_toy_multi_label_dataset(length):
|
|
BaseDataset.CLASSES = ('foo', 'bar')
|
|
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
|
|
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
|
|
cat_ids_list = [
|
|
np.random.randint(0, 80, num).tolist()
|
|
for num in np.random.randint(1, 20, length)
|
|
]
|
|
dataset.data_infos = MagicMock()
|
|
dataset.data_infos.__len__.return_value = length
|
|
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
|
|
|
|
dataset.evaluate = MagicMock(side_effect=mock_evaluate)
|
|
return dataset, cat_ids_list
|
|
|
|
|
|
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
|
def construct_toy_single_label_dataset(length):
|
|
BaseDataset.CLASSES = ('foo', 'bar')
|
|
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
|
|
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
|
|
cat_ids_list = [[np.random.randint(0, 80)] for _ in range(length)]
|
|
dataset.data_infos = MagicMock()
|
|
dataset.data_infos.__len__.return_value = length
|
|
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
|
|
dataset.evaluate = MagicMock(side_effect=mock_evaluate)
|
|
return dataset, cat_ids_list
|
|
|
|
|
|
@pytest.mark.parametrize('construct_dataset', [
|
|
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
|
|
])
|
|
def test_concat_dataset(construct_dataset):
|
|
construct_toy_dataset = eval(construct_dataset)
|
|
dataset_a, cat_ids_list_a = construct_toy_dataset(10)
|
|
dataset_b, cat_ids_list_b = construct_toy_dataset(20)
|
|
|
|
concat_dataset = ConcatDataset([dataset_a, dataset_b])
|
|
assert concat_dataset[5] == 5
|
|
assert concat_dataset[25] == 15
|
|
assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
|
|
assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15]
|
|
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
|
|
assert concat_dataset.CLASSES == BaseDataset.CLASSES
|
|
|
|
|
|
@pytest.mark.parametrize('construct_dataset', [
|
|
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
|
|
])
|
|
def test_repeat_dataset(construct_dataset):
|
|
construct_toy_dataset = eval(construct_dataset)
|
|
dataset, cat_ids_list = construct_toy_dataset(10)
|
|
repeat_dataset = RepeatDataset(dataset, 10)
|
|
assert repeat_dataset[5] == 5
|
|
assert repeat_dataset[15] == 5
|
|
assert repeat_dataset[27] == 7
|
|
assert repeat_dataset.get_cat_ids(5) == cat_ids_list[5]
|
|
assert repeat_dataset.get_cat_ids(15) == cat_ids_list[5]
|
|
assert repeat_dataset.get_cat_ids(27) == cat_ids_list[7]
|
|
assert len(repeat_dataset) == 10 * len(dataset)
|
|
assert repeat_dataset.CLASSES == BaseDataset.CLASSES
|
|
|
|
|
|
@pytest.mark.parametrize('construct_dataset', [
|
|
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
|
|
])
|
|
def test_class_balanced_dataset(construct_dataset):
|
|
construct_toy_dataset = eval(construct_dataset)
|
|
dataset, cat_ids_list = construct_toy_dataset(10)
|
|
|
|
category_freq = defaultdict(int)
|
|
for cat_ids in cat_ids_list:
|
|
cat_ids = set(cat_ids)
|
|
for cat_id in cat_ids:
|
|
category_freq[cat_id] += 1
|
|
for k, v in category_freq.items():
|
|
category_freq[k] = v / len(cat_ids_list)
|
|
|
|
mean_freq = np.mean(list(category_freq.values()))
|
|
repeat_thr = mean_freq
|
|
|
|
category_repeat = {
|
|
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
|
|
for cat_id, cat_freq in category_freq.items()
|
|
}
|
|
|
|
repeat_factors = []
|
|
for cat_ids in cat_ids_list:
|
|
cat_ids = set(cat_ids)
|
|
repeat_factor = max({category_repeat[cat_id] for cat_id in cat_ids})
|
|
repeat_factors.append(math.ceil(repeat_factor))
|
|
repeat_factors_cumsum = np.cumsum(repeat_factors)
|
|
repeat_factor_dataset = ClassBalancedDataset(dataset, repeat_thr)
|
|
assert repeat_factor_dataset.CLASSES == BaseDataset.CLASSES
|
|
assert len(repeat_factor_dataset) == repeat_factors_cumsum[-1]
|
|
for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
|
|
assert repeat_factor_dataset[idx] == bisect.bisect_right(
|
|
repeat_factors_cumsum, idx)
|
|
|
|
|
|
@pytest.mark.parametrize('construct_dataset', [
|
|
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
|
|
])
|
|
def test_kfold_dataset(construct_dataset):
|
|
construct_toy_dataset = eval(construct_dataset)
|
|
dataset, _ = construct_toy_dataset(10)
|
|
|
|
# test without random seed
|
|
train_datasets = [
|
|
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=False)
|
|
for i in range(5)
|
|
]
|
|
test_datasets = [
|
|
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=True)
|
|
for i in range(5)
|
|
]
|
|
|
|
assert sum([i.indices for i in test_datasets], []) == list(range(10))
|
|
for train_set, test_set in zip(train_datasets, test_datasets):
|
|
train_samples = [train_set[i] for i in range(len(train_set))]
|
|
test_samples = [test_set[i] for i in range(len(test_set))]
|
|
assert set(train_samples + test_samples) == set(range(10))
|
|
|
|
# test with random seed
|
|
train_datasets = [
|
|
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=False, seed=1)
|
|
for i in range(5)
|
|
]
|
|
test_datasets = [
|
|
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=True, seed=1)
|
|
for i in range(5)
|
|
]
|
|
|
|
assert sum([i.indices for i in test_datasets], []) != list(range(10))
|
|
assert set(sum([i.indices for i in test_datasets], [])) == set(range(10))
|
|
for train_set, test_set in zip(train_datasets, test_datasets):
|
|
train_samples = [train_set[i] for i in range(len(train_set))]
|
|
test_samples = [test_set[i] for i in range(len(test_set))]
|
|
assert set(train_samples + test_samples) == set(range(10))
|
|
|
|
# test evaluate
|
|
for test_set in test_datasets:
|
|
eval_inputs = test_set.evaluate(None)
|
|
assert eval_inputs['indices'] == test_set.indices
|