add get_class in base_dataset (#85)

* modify base_dataset

* revise according to the comments
pull/89/head
LXXXXR 2020-11-12 14:22:02 +08:00 committed by GitHub
parent 909a6b9c3f
commit 7636409b3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 1 deletions

View File

@ -1,6 +1,7 @@
import copy
from abc import ABCMeta, abstractmethod
import mmcv
import numpy as np
from torch.utils.data import Dataset
@ -23,7 +24,12 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
CLASSES = None
def __init__(self, data_prefix, pipeline, ann_file=None, test_mode=False):
def __init__(self,
data_prefix,
pipeline,
classes=None,
ann_file=None,
test_mode=False):
super(BaseDataset, self).__init__()
self.ann_file = ann_file
@ -31,6 +37,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
self.test_mode = test_mode
self.pipeline = Compose(pipeline)
self.data_infos = self.load_annotations()
self.CLASSES = self.get_classes(classes)
@abstractmethod
def load_annotations(self):
@ -78,6 +85,32 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
def __getitem__(self, idx):
return self.prepare_data(idx)
@classmethod
def get_classes(cls, classes=None):
"""Get class names of current dataset.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
Returns:
tuple[str] or list[str]: Names of categories of the dataset.
"""
if classes is None:
return cls.CLASSES
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
return class_names
def evaluate(self,
results,
metric='accuracy',

View File

@ -21,12 +21,44 @@ def test_datasets_override_default(dataset_name):
dataset_class = DATASETS.get(dataset_name)
dataset_class.load_annotations = MagicMock()
original_classes = dataset_class.CLASSES
# Test setting classes as a tuple
dataset = dataset_class(
data_prefix='', pipeline=[], classes=('bus', 'car'), test_mode=True)
assert dataset.CLASSES != original_classes
assert dataset.CLASSES == ('bus', 'car')
# Test setting classes as a list
dataset = dataset_class(
data_prefix='', pipeline=[], classes=['bus', 'car'], test_mode=True)
assert dataset.CLASSES != original_classes
assert dataset.CLASSES == ['bus', 'car']
# Test setting classes through a file
tmp_file = tempfile.NamedTemporaryFile()
with open(tmp_file.name, 'w') as f:
f.write('bus\ncar\n')
dataset = dataset_class(
data_prefix='', pipeline=[], classes=tmp_file.name, test_mode=True)
tmp_file.close()
assert dataset.CLASSES != original_classes
assert dataset.CLASSES == ['bus', 'car']
# Test overriding not a subset
dataset = dataset_class(
data_prefix='', pipeline=[], classes=['foo'], test_mode=True)
assert dataset.CLASSES != original_classes
assert dataset.CLASSES == ['foo']
# Test default behavior
dataset = dataset_class(data_prefix='', pipeline=[])
assert dataset.data_prefix == ''
assert not dataset.test_mode
assert dataset.ann_file is None
assert dataset.CLASSES == original_classes
@patch.multiple(BaseDataset, __abstractmethods__=set())