Add missing CLASSES argument to dataset wrappers (#66)

* Add missing classes in dataset wrappers

* Update tests
pull/69/head^2
David de la Iglesia Castro 2020-10-15 15:25:53 +02:00 committed by GitHub
parent 4f4f2957ef
commit df2189c6f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 0 deletions

View File

@ -21,6 +21,7 @@ class ConcatDataset(_ConcatDataset):
def __init__(self, datasets):
super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
def get_cat_ids(self, idx):
if idx < 0:
@ -53,6 +54,7 @@ class RepeatDataset(object):
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self.CLASSES = dataset.CLASSES
self._ori_len = len(self.dataset)
@ -104,6 +106,7 @@ class ClassBalancedDataset(object):
def __init__(self, dataset, oversample_thr):
self.dataset = dataset
self.oversample_thr = oversample_thr
self.CLASSES = dataset.CLASSES
repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
repeat_indices = []

View File

@ -31,6 +31,7 @@ def test_datasets_override_default(dataset_name):
@patch.multiple(BaseDataset, __abstractmethods__=set())
def test_dataset_wrapper():
BaseDataset.CLASSES = ('foo', 'bar')
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
dataset_a = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
len_a = 10
@ -59,6 +60,7 @@ def test_dataset_wrapper():
assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15]
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
assert concat_dataset.CLASSES == BaseDataset.CLASSES
repeat_dataset = RepeatDataset(dataset_a, 10)
assert repeat_dataset[5] == 5
@ -68,6 +70,7 @@ def test_dataset_wrapper():
assert repeat_dataset.get_cat_ids(15) == cat_ids_list_a[5]
assert repeat_dataset.get_cat_ids(27) == cat_ids_list_a[7]
assert len(repeat_dataset) == 10 * len(dataset_a)
assert repeat_dataset.CLASSES == BaseDataset.CLASSES
category_freq = defaultdict(int)
for cat_ids in cat_ids_list_a:
@ -92,6 +95,7 @@ def test_dataset_wrapper():
repeat_factors.append(math.ceil(repeat_factor))
repeat_factors_cumsum = np.cumsum(repeat_factors)
repeat_factor_dataset = ClassBalancedDataset(dataset_a, repeat_thr)
assert repeat_factor_dataset.CLASSES == BaseDataset.CLASSES
assert len(repeat_factor_dataset) == repeat_factors_cumsum[-1]
for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
assert repeat_factor_dataset[idx] == bisect.bisect_right(