Add missing CLASSES argument to dataset wrappers (#66)
* Add missing classes in dataset wrappers * Update testspull/69/head^2
parent
4f4f2957ef
commit
df2189c6f6
|
@ -21,6 +21,7 @@ class ConcatDataset(_ConcatDataset):
|
|||
|
||||
def __init__(self, datasets):
|
||||
super(ConcatDataset, self).__init__(datasets)
|
||||
self.CLASSES = datasets[0].CLASSES
|
||||
|
||||
def get_cat_ids(self, idx):
|
||||
if idx < 0:
|
||||
|
@ -53,6 +54,7 @@ class RepeatDataset(object):
|
|||
def __init__(self, dataset, times):
|
||||
self.dataset = dataset
|
||||
self.times = times
|
||||
self.CLASSES = dataset.CLASSES
|
||||
|
||||
self._ori_len = len(self.dataset)
|
||||
|
||||
|
@ -104,6 +106,7 @@ class ClassBalancedDataset(object):
|
|||
def __init__(self, dataset, oversample_thr):
|
||||
self.dataset = dataset
|
||||
self.oversample_thr = oversample_thr
|
||||
self.CLASSES = dataset.CLASSES
|
||||
|
||||
repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
|
||||
repeat_indices = []
|
||||
|
|
|
@ -31,6 +31,7 @@ def test_datasets_override_default(dataset_name):
|
|||
|
||||
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
||||
def test_dataset_wrapper():
|
||||
BaseDataset.CLASSES = ('foo', 'bar')
|
||||
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
|
||||
dataset_a = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
|
||||
len_a = 10
|
||||
|
@ -59,6 +60,7 @@ def test_dataset_wrapper():
|
|||
assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
|
||||
assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15]
|
||||
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
|
||||
assert concat_dataset.CLASSES == BaseDataset.CLASSES
|
||||
|
||||
repeat_dataset = RepeatDataset(dataset_a, 10)
|
||||
assert repeat_dataset[5] == 5
|
||||
|
@ -68,6 +70,7 @@ def test_dataset_wrapper():
|
|||
assert repeat_dataset.get_cat_ids(15) == cat_ids_list_a[5]
|
||||
assert repeat_dataset.get_cat_ids(27) == cat_ids_list_a[7]
|
||||
assert len(repeat_dataset) == 10 * len(dataset_a)
|
||||
assert repeat_dataset.CLASSES == BaseDataset.CLASSES
|
||||
|
||||
category_freq = defaultdict(int)
|
||||
for cat_ids in cat_ids_list_a:
|
||||
|
@ -92,6 +95,7 @@ def test_dataset_wrapper():
|
|||
repeat_factors.append(math.ceil(repeat_factor))
|
||||
repeat_factors_cumsum = np.cumsum(repeat_factors)
|
||||
repeat_factor_dataset = ClassBalancedDataset(dataset_a, repeat_thr)
|
||||
assert repeat_factor_dataset.CLASSES == BaseDataset.CLASSES
|
||||
assert len(repeat_factor_dataset) == repeat_factors_cumsum[-1]
|
||||
for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
|
||||
assert repeat_factor_dataset[idx] == bisect.bisect_right(
|
||||
|
|
Loading…
Reference in New Issue