mmfewshot/tests/test_classification_data/test_classification_dataset...

75 lines
2.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock, patch
import numpy as np
from mmfewshot.classification.datasets import (BaseFewShotDataset,
EpisodicDataset,
MetaTestDataset)
@patch.multiple(BaseFewShotDataset, __abstractmethods__=set())
def construct_toy_dataset():
BaseFewShotDataset.CLASSES = ('a', 'b', 'c', 'd', 'e', 'f', 'g')
cat_ids_list = [i for i in range(7)] * 20
data_infos = [dict(gt_label=np.array(i)) for i in cat_ids_list]
BaseFewShotDataset.load_annotations = MagicMock(return_value=data_infos)
dataset = BaseFewShotDataset(data_prefix='', pipeline=[])
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
return dataset, cat_ids_list
def test_episodic_dataset():
toy_dataset, cat_ids_list = construct_toy_dataset()
episodic_dataset_a = EpisodicDataset(
toy_dataset,
num_episodes=10,
num_ways=5,
num_shots=2,
num_queries=3,
episodes_seed=0)
episodic_dataset_b = EpisodicDataset(
toy_dataset,
num_episodes=10,
num_ways=5,
num_shots=2,
num_queries=3,
episodes_seed=1)
assert len(episodic_dataset_a) == 10 and len(episodic_dataset_a) == 10
assert len(episodic_dataset_a[5]['support_data']) == 5 * 2
assert len(episodic_dataset_a[5]['query_data']) == 5 * 3
assert episodic_dataset_a[5]['query_data'] != episodic_dataset_b[5][
'query_data']
def test_meta_test_dataset():
toy_dataset, cat_ids_list = construct_toy_dataset()
meta_dataset = MetaTestDataset(
dataset=toy_dataset,
num_episodes=10,
num_ways=5,
num_shots=2,
num_queries=3,
episodes_seed=0)
test_set = meta_dataset.test_set()
assert test_set._mode == 'test_set'
assert len(test_set) == 140
test_set.set_task_id(9)
task9_class_id = test_set.get_task_class_ids()
assert len(task9_class_id) == 5
test_set.set_task_id(5)
assert test_set.get_task_class_ids() != task9_class_id
support_set = meta_dataset.support()
assert support_set._mode == 'support'
assert len(support_set) == 10
query_set = meta_dataset.query()
assert query_set._mode == 'query'
assert len(query_set) == 15
assert query_set.with_cache_feats() is False