110 lines
4.3 KiB
Python
110 lines
4.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import bisect
|
|
import math
|
|
from collections import defaultdict
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from mmcls.datasets import (BaseDataset, ClassBalancedDataset, ConcatDataset,
|
|
RepeatDataset)
|
|
|
|
|
|
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
|
def construct_toy_multi_label_dataset(length):
|
|
BaseDataset.CLASSES = ('foo', 'bar')
|
|
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
|
|
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
|
|
cat_ids_list = [
|
|
np.random.randint(0, 80, num).tolist()
|
|
for num in np.random.randint(1, 20, length)
|
|
]
|
|
dataset.data_infos = MagicMock()
|
|
dataset.data_infos.__len__.return_value = length
|
|
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
|
|
return dataset, cat_ids_list
|
|
|
|
|
|
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
|
def construct_toy_single_label_dataset(length):
|
|
BaseDataset.CLASSES = ('foo', 'bar')
|
|
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
|
|
dataset = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
|
|
cat_ids_list = [[np.random.randint(0, 80)] for _ in range(length)]
|
|
dataset.data_infos = MagicMock()
|
|
dataset.data_infos.__len__.return_value = length
|
|
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
|
|
return dataset, cat_ids_list
|
|
|
|
|
|
@pytest.mark.parametrize('construct_dataset', [
|
|
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
|
|
])
|
|
def test_concat_dataset(construct_dataset):
|
|
construct_toy_dataset = eval(construct_dataset)
|
|
dataset_a, cat_ids_list_a = construct_toy_dataset(10)
|
|
dataset_b, cat_ids_list_b = construct_toy_dataset(20)
|
|
|
|
concat_dataset = ConcatDataset([dataset_a, dataset_b])
|
|
assert concat_dataset[5] == 5
|
|
assert concat_dataset[25] == 15
|
|
assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
|
|
assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15]
|
|
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
|
|
assert concat_dataset.CLASSES == BaseDataset.CLASSES
|
|
|
|
|
|
@pytest.mark.parametrize('construct_dataset', [
|
|
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
|
|
])
|
|
def test_repeat_dataset(construct_dataset):
|
|
construct_toy_dataset = eval(construct_dataset)
|
|
dataset, cat_ids_list = construct_toy_dataset(10)
|
|
repeat_dataset = RepeatDataset(dataset, 10)
|
|
assert repeat_dataset[5] == 5
|
|
assert repeat_dataset[15] == 5
|
|
assert repeat_dataset[27] == 7
|
|
assert repeat_dataset.get_cat_ids(5) == cat_ids_list[5]
|
|
assert repeat_dataset.get_cat_ids(15) == cat_ids_list[5]
|
|
assert repeat_dataset.get_cat_ids(27) == cat_ids_list[7]
|
|
assert len(repeat_dataset) == 10 * len(dataset)
|
|
assert repeat_dataset.CLASSES == BaseDataset.CLASSES
|
|
|
|
|
|
@pytest.mark.parametrize('construct_dataset', [
|
|
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
|
|
])
|
|
def test_class_balanced_dataset(construct_dataset):
|
|
construct_toy_dataset = eval(construct_dataset)
|
|
dataset, cat_ids_list = construct_toy_dataset(10)
|
|
|
|
category_freq = defaultdict(int)
|
|
for cat_ids in cat_ids_list:
|
|
cat_ids = set(cat_ids)
|
|
for cat_id in cat_ids:
|
|
category_freq[cat_id] += 1
|
|
for k, v in category_freq.items():
|
|
category_freq[k] = v / len(cat_ids_list)
|
|
|
|
mean_freq = np.mean(list(category_freq.values()))
|
|
repeat_thr = mean_freq
|
|
|
|
category_repeat = {
|
|
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
|
|
for cat_id, cat_freq in category_freq.items()
|
|
}
|
|
|
|
repeat_factors = []
|
|
for cat_ids in cat_ids_list:
|
|
cat_ids = set(cat_ids)
|
|
repeat_factor = max({category_repeat[cat_id] for cat_id in cat_ids})
|
|
repeat_factors.append(math.ceil(repeat_factor))
|
|
repeat_factors_cumsum = np.cumsum(repeat_factors)
|
|
repeat_factor_dataset = ClassBalancedDataset(dataset, repeat_thr)
|
|
assert repeat_factor_dataset.CLASSES == BaseDataset.CLASSES
|
|
assert len(repeat_factor_dataset) == repeat_factors_cumsum[-1]
|
|
for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
|
|
assert repeat_factor_dataset[idx] == bisect.bisect_right(
|
|
repeat_factors_cumsum, idx)
|