mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Enhance] Add get_cat_ids
and get_gt_labels
to KFoldDataset. (#721)
This commit is contained in:
parent
04cb42a768
commit
a19c28fe95
@ -311,6 +311,14 @@ class KFoldDataset:
|
||||
else:
|
||||
self.indices = indices[:test_start] + indices[test_end:]
|
||||
|
||||
def get_cat_ids(self, idx):
|
||||
return self.dataset.get_cat_ids(self.indices[idx])
|
||||
|
||||
def get_gt_labels(self):
|
||||
dataset_gt_labels = self.dataset.get_gt_labels()
|
||||
gt_labels = np.array([dataset_gt_labels[idx] for idx in self.indices])
|
||||
return gt_labels
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset[self.indices[idx]]
|
||||
|
||||
|
@ -36,7 +36,8 @@ def construct_toy_multi_label_dataset(length):
|
||||
dataset.data_infos = MagicMock()
|
||||
dataset.data_infos.__len__.return_value = length
|
||||
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
|
||||
|
||||
dataset.get_gt_labels = \
|
||||
MagicMock(side_effect=lambda: np.array(cat_ids_list))
|
||||
dataset.evaluate = MagicMock(side_effect=mock_evaluate)
|
||||
return dataset, cat_ids_list
|
||||
|
||||
@ -50,6 +51,8 @@ def construct_toy_single_label_dataset(length):
|
||||
dataset.data_infos = MagicMock()
|
||||
dataset.data_infos.__len__.return_value = length
|
||||
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
|
||||
dataset.get_gt_labels = \
|
||||
MagicMock(side_effect=lambda: np.array(cat_ids_list))
|
||||
dataset.evaluate = MagicMock(side_effect=mock_evaluate)
|
||||
return dataset, cat_ids_list
|
||||
|
||||
@ -130,7 +133,7 @@ def test_class_balanced_dataset(construct_dataset):
|
||||
])
|
||||
def test_kfold_dataset(construct_dataset):
|
||||
construct_toy_dataset = eval(construct_dataset)
|
||||
dataset, _ = construct_toy_dataset(10)
|
||||
dataset, cat_ids_list = construct_toy_dataset(10)
|
||||
|
||||
# test without random seed
|
||||
train_datasets = [
|
||||
@ -165,6 +168,24 @@ def test_kfold_dataset(construct_dataset):
|
||||
test_samples = [test_set[i] for i in range(len(test_set))]
|
||||
assert set(train_samples + test_samples) == set(range(10))
|
||||
|
||||
# test behavior of get_cat_ids method
|
||||
for train_set, test_set in zip(train_datasets, test_datasets):
|
||||
for i in range(len(train_set)):
|
||||
cat_ids = train_set.get_cat_ids(i)
|
||||
assert cat_ids == cat_ids_list[train_set.indices[i]]
|
||||
for i in range(len(test_set)):
|
||||
cat_ids = test_set.get_cat_ids(i)
|
||||
assert cat_ids == cat_ids_list[test_set.indices[i]]
|
||||
|
||||
# test behavior of get_gt_labels method
|
||||
for train_set, test_set in zip(train_datasets, test_datasets):
|
||||
for i in range(len(train_set)):
|
||||
gt_label = train_set.get_gt_labels()[i]
|
||||
assert gt_label == cat_ids_list[train_set.indices[i]]
|
||||
for i in range(len(test_set)):
|
||||
gt_label = test_set.get_gt_labels()[i]
|
||||
assert gt_label == cat_ids_list[test_set.indices[i]]
|
||||
|
||||
# test evaluate
|
||||
for test_set in test_datasets:
|
||||
eval_inputs = test_set.evaluate(None)
|
||||
|
Loading…
x
Reference in New Issue
Block a user