From a19c28fe9538fd1c06dbc08c4dc7ad59d67d4217 Mon Sep 17 00:00:00 2001 From: huyu Date: Wed, 23 Mar 2022 16:57:36 +0900 Subject: [PATCH] [Enhance] Add `get_cat_ids` and `get_gt_labels` to KFoldDataset. (#721) --- mmcls/datasets/dataset_wrappers.py | 8 ++++++ .../test_datasets/test_dataset_wrapper.py | 25 +++++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/mmcls/datasets/dataset_wrappers.py b/mmcls/datasets/dataset_wrappers.py index 6aef65638..4b471963b 100644 --- a/mmcls/datasets/dataset_wrappers.py +++ b/mmcls/datasets/dataset_wrappers.py @@ -311,6 +311,14 @@ class KFoldDataset: else: self.indices = indices[:test_start] + indices[test_end:] + def get_cat_ids(self, idx): + return self.dataset.get_cat_ids(self.indices[idx]) + + def get_gt_labels(self): + dataset_gt_labels = self.dataset.get_gt_labels() + gt_labels = np.array([dataset_gt_labels[idx] for idx in self.indices]) + return gt_labels + def __getitem__(self, idx): return self.dataset[self.indices[idx]] diff --git a/tests/test_data/test_datasets/test_dataset_wrapper.py b/tests/test_data/test_datasets/test_dataset_wrapper.py index 2798e1fbb..b6430f41a 100644 --- a/tests/test_data/test_datasets/test_dataset_wrapper.py +++ b/tests/test_data/test_datasets/test_dataset_wrapper.py @@ -36,7 +36,8 @@ def construct_toy_multi_label_dataset(length): dataset.data_infos = MagicMock() dataset.data_infos.__len__.return_value = length dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx]) - + dataset.get_gt_labels = \ + MagicMock(side_effect=lambda: np.array(cat_ids_list)) dataset.evaluate = MagicMock(side_effect=mock_evaluate) return dataset, cat_ids_list @@ -50,6 +51,8 @@ def construct_toy_single_label_dataset(length): dataset.data_infos = MagicMock() dataset.data_infos.__len__.return_value = length dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx]) + dataset.get_gt_labels = \ + MagicMock(side_effect=lambda: np.array(cat_ids_list)) dataset.evaluate = MagicMock(side_effect=mock_evaluate) return dataset, cat_ids_list @@ -130,7 +133,7 @@ def test_class_balanced_dataset(construct_dataset): ]) def test_kfold_dataset(construct_dataset): construct_toy_dataset = eval(construct_dataset) - dataset, _ = construct_toy_dataset(10) + dataset, cat_ids_list = construct_toy_dataset(10) # test without random seed train_datasets = [ @@ -165,6 +168,24 @@ def test_kfold_dataset(construct_dataset): test_samples = [test_set[i] for i in range(len(test_set))] assert set(train_samples + test_samples) == set(range(10)) + # test behavior of get_cat_ids method + for train_set, test_set in zip(train_datasets, test_datasets): + for i in range(len(train_set)): + cat_ids = train_set.get_cat_ids(i) + assert cat_ids == cat_ids_list[train_set.indices[i]] + for i in range(len(test_set)): + cat_ids = test_set.get_cat_ids(i) + assert cat_ids == cat_ids_list[test_set.indices[i]] + + # test behavior of get_gt_labels method + for train_set, test_set in zip(train_datasets, test_datasets): + for i in range(len(train_set)): + gt_label = train_set.get_gt_labels()[i] + assert gt_label == cat_ids_list[train_set.indices[i]] + for i in range(len(test_set)): + gt_label = test_set.get_gt_labels()[i] + assert gt_label == cat_ids_list[test_set.indices[i]] + # test evaluate for test_set in test_datasets: eval_inputs = test_set.evaluate(None)