[Enhance] Add get_cat_ids and get_gt_labels to KFoldDataset. (#721)

This commit is contained in:
huyu 2022-03-23 16:57:36 +09:00 committed by GitHub
parent 04cb42a768
commit a19c28fe95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 2 deletions

View File

@ -311,6 +311,14 @@ class KFoldDataset:
else:
self.indices = indices[:test_start] + indices[test_end:]
def get_cat_ids(self, idx):
return self.dataset.get_cat_ids(self.indices[idx])
def get_gt_labels(self):
dataset_gt_labels = self.dataset.get_gt_labels()
gt_labels = np.array([dataset_gt_labels[idx] for idx in self.indices])
return gt_labels
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]

View File

@ -36,7 +36,8 @@ def construct_toy_multi_label_dataset(length):
dataset.data_infos = MagicMock()
dataset.data_infos.__len__.return_value = length
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
dataset.get_gt_labels = \
MagicMock(side_effect=lambda: np.array(cat_ids_list))
dataset.evaluate = MagicMock(side_effect=mock_evaluate)
return dataset, cat_ids_list
@ -50,6 +51,8 @@ def construct_toy_single_label_dataset(length):
dataset.data_infos = MagicMock()
dataset.data_infos.__len__.return_value = length
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
dataset.get_gt_labels = \
MagicMock(side_effect=lambda: np.array(cat_ids_list))
dataset.evaluate = MagicMock(side_effect=mock_evaluate)
return dataset, cat_ids_list
@ -130,7 +133,7 @@ def test_class_balanced_dataset(construct_dataset):
])
def test_kfold_dataset(construct_dataset):
construct_toy_dataset = eval(construct_dataset)
dataset, _ = construct_toy_dataset(10)
dataset, cat_ids_list = construct_toy_dataset(10)
# test without random seed
train_datasets = [
@ -165,6 +168,24 @@ def test_kfold_dataset(construct_dataset):
test_samples = [test_set[i] for i in range(len(test_set))]
assert set(train_samples + test_samples) == set(range(10))
# test behavior of get_cat_ids method
for train_set, test_set in zip(train_datasets, test_datasets):
for i in range(len(train_set)):
cat_ids = train_set.get_cat_ids(i)
assert cat_ids == cat_ids_list[train_set.indices[i]]
for i in range(len(test_set)):
cat_ids = test_set.get_cat_ids(i)
assert cat_ids == cat_ids_list[test_set.indices[i]]
# test behavior of get_gt_labels method
for train_set, test_set in zip(train_datasets, test_datasets):
for i in range(len(train_set)):
gt_label = train_set.get_gt_labels()[i]
assert gt_label == cat_ids_list[train_set.indices[i]]
for i in range(len(test_set)):
gt_label = test_set.get_gt_labels()[i]
assert gt_label == cat_ids_list[test_set.indices[i]]
# test evaluate
for test_set in test_datasets:
eval_inputs = test_set.evaluate(None)