diff --git a/mmcls/datasets/base_dataset.py b/mmcls/datasets/base_dataset.py index 3d4212d2e..9fe3bf03f 100644 --- a/mmcls/datasets/base_dataset.py +++ b/mmcls/datasets/base_dataset.py @@ -1,6 +1,7 @@ import copy from abc import ABCMeta, abstractmethod +import mmcv import numpy as np from torch.utils.data import Dataset @@ -23,7 +24,12 @@ class BaseDataset(Dataset, metaclass=ABCMeta): CLASSES = None - def __init__(self, data_prefix, pipeline, ann_file=None, test_mode=False): + def __init__(self, + data_prefix, + pipeline, + classes=None, + ann_file=None, + test_mode=False): super(BaseDataset, self).__init__() self.ann_file = ann_file @@ -31,6 +37,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta): self.test_mode = test_mode self.pipeline = Compose(pipeline) self.data_infos = self.load_annotations() + self.CLASSES = self.get_classes(classes) @abstractmethod def load_annotations(self): @@ -78,6 +85,32 @@ class BaseDataset(Dataset, metaclass=ABCMeta): def __getitem__(self, idx): return self.prepare_data(idx) + @classmethod + def get_classes(cls, classes=None): + """Get class names of current dataset. + Args: + classes (Sequence[str] | str | None): If classes is None, use + default CLASSES defined by builtin dataset. If classes is a + string, take it as a file name. The file contains the name of + classes where each line contains one class name. If classes is + a tuple or list, override the CLASSES defined by the dataset. + + Returns: + tuple[str] or list[str]: Names of categories of the dataset. + """ + if classes is None: + return cls.CLASSES + + if isinstance(classes, str): + # take it as a file path + class_names = mmcv.list_from_file(classes) + elif isinstance(classes, (tuple, list)): + class_names = classes + else: + raise ValueError(f'Unsupported type {type(classes)} of classes.') + + return class_names + def evaluate(self, results, metric='accuracy', diff --git a/tests/test_dataset.py b/tests/test_dataset.py index d4d930d23..17fed0eb3 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -21,12 +21,44 @@ def test_datasets_override_default(dataset_name): dataset_class = DATASETS.get(dataset_name) dataset_class.load_annotations = MagicMock() + original_classes = dataset_class.CLASSES + + # Test setting classes as a tuple + dataset = dataset_class( + data_prefix='', pipeline=[], classes=('bus', 'car'), test_mode=True) + assert dataset.CLASSES != original_classes + assert dataset.CLASSES == ('bus', 'car') + + # Test setting classes as a list + dataset = dataset_class( + data_prefix='', pipeline=[], classes=['bus', 'car'], test_mode=True) + assert dataset.CLASSES != original_classes + assert dataset.CLASSES == ['bus', 'car'] + + # Test setting classes through a file + tmp_file = tempfile.NamedTemporaryFile() + with open(tmp_file.name, 'w') as f: + f.write('bus\ncar\n') + dataset = dataset_class( + data_prefix='', pipeline=[], classes=tmp_file.name, test_mode=True) + tmp_file.close() + + assert dataset.CLASSES != original_classes + assert dataset.CLASSES == ['bus', 'car'] + + # Test overriding not a subset + dataset = dataset_class( + data_prefix='', pipeline=[], classes=['foo'], test_mode=True) + assert dataset.CLASSES != original_classes + assert dataset.CLASSES == ['foo'] + # Test default behavior dataset = dataset_class(data_prefix='', pipeline=[]) assert dataset.data_prefix == '' assert not dataset.test_mode assert dataset.ann_file is None + assert dataset.CLASSES == original_classes @patch.multiple(BaseDataset, __abstractmethods__=set())