add get_class in base_dataset (#85)

* modify base_dataset

* revise according to the comments
pull/89/head
LXXXXR 2020-11-12 14:22:02 +08:00 committed by GitHub
parent 909a6b9c3f
commit 7636409b3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 1 deletions

View File

@ -1,6 +1,7 @@
import copy import copy
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import mmcv
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -23,7 +24,12 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
CLASSES = None CLASSES = None
def __init__(self, data_prefix, pipeline, ann_file=None, test_mode=False): def __init__(self,
data_prefix,
pipeline,
classes=None,
ann_file=None,
test_mode=False):
super(BaseDataset, self).__init__() super(BaseDataset, self).__init__()
self.ann_file = ann_file self.ann_file = ann_file
@ -31,6 +37,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
self.test_mode = test_mode self.test_mode = test_mode
self.pipeline = Compose(pipeline) self.pipeline = Compose(pipeline)
self.data_infos = self.load_annotations() self.data_infos = self.load_annotations()
self.CLASSES = self.get_classes(classes)
@abstractmethod @abstractmethod
def load_annotations(self): def load_annotations(self):
@ -78,6 +85,32 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
def __getitem__(self, idx): def __getitem__(self, idx):
return self.prepare_data(idx) return self.prepare_data(idx)
@classmethod
def get_classes(cls, classes=None):
"""Get class names of current dataset.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
Returns:
tuple[str] or list[str]: Names of categories of the dataset.
"""
if classes is None:
return cls.CLASSES
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
return class_names
def evaluate(self, def evaluate(self,
results, results,
metric='accuracy', metric='accuracy',

View File

@ -21,12 +21,44 @@ def test_datasets_override_default(dataset_name):
dataset_class = DATASETS.get(dataset_name) dataset_class = DATASETS.get(dataset_name)
dataset_class.load_annotations = MagicMock() dataset_class.load_annotations = MagicMock()
original_classes = dataset_class.CLASSES
# Test setting classes as a tuple
dataset = dataset_class(
data_prefix='', pipeline=[], classes=('bus', 'car'), test_mode=True)
assert dataset.CLASSES != original_classes
assert dataset.CLASSES == ('bus', 'car')
# Test setting classes as a list
dataset = dataset_class(
data_prefix='', pipeline=[], classes=['bus', 'car'], test_mode=True)
assert dataset.CLASSES != original_classes
assert dataset.CLASSES == ['bus', 'car']
# Test setting classes through a file
tmp_file = tempfile.NamedTemporaryFile()
with open(tmp_file.name, 'w') as f:
f.write('bus\ncar\n')
dataset = dataset_class(
data_prefix='', pipeline=[], classes=tmp_file.name, test_mode=True)
tmp_file.close()
assert dataset.CLASSES != original_classes
assert dataset.CLASSES == ['bus', 'car']
# Test overriding not a subset
dataset = dataset_class(
data_prefix='', pipeline=[], classes=['foo'], test_mode=True)
assert dataset.CLASSES != original_classes
assert dataset.CLASSES == ['foo']
# Test default behavior # Test default behavior
dataset = dataset_class(data_prefix='', pipeline=[]) dataset = dataset_class(data_prefix='', pipeline=[])
assert dataset.data_prefix == '' assert dataset.data_prefix == ''
assert not dataset.test_mode assert not dataset.test_mode
assert dataset.ann_file is None assert dataset.ann_file is None
assert dataset.CLASSES == original_classes
@patch.multiple(BaseDataset, __abstractmethods__=set()) @patch.multiple(BaseDataset, __abstractmethods__=set())