add get_class in base_dataset (#85)
* modify base_dataset * revise according to the commentspull/89/head
parent
909a6b9c3f
commit
7636409b3b
|
@ -1,6 +1,7 @@
|
|||
import copy
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
@ -23,7 +24,12 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
|
||||
CLASSES = None
|
||||
|
||||
def __init__(self, data_prefix, pipeline, ann_file=None, test_mode=False):
|
||||
def __init__(self,
|
||||
data_prefix,
|
||||
pipeline,
|
||||
classes=None,
|
||||
ann_file=None,
|
||||
test_mode=False):
|
||||
super(BaseDataset, self).__init__()
|
||||
|
||||
self.ann_file = ann_file
|
||||
|
@ -31,6 +37,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
self.test_mode = test_mode
|
||||
self.pipeline = Compose(pipeline)
|
||||
self.data_infos = self.load_annotations()
|
||||
self.CLASSES = self.get_classes(classes)
|
||||
|
||||
@abstractmethod
|
||||
def load_annotations(self):
|
||||
|
@ -78,6 +85,32 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
def __getitem__(self, idx):
|
||||
return self.prepare_data(idx)
|
||||
|
||||
@classmethod
|
||||
def get_classes(cls, classes=None):
|
||||
"""Get class names of current dataset.
|
||||
Args:
|
||||
classes (Sequence[str] | str | None): If classes is None, use
|
||||
default CLASSES defined by builtin dataset. If classes is a
|
||||
string, take it as a file name. The file contains the name of
|
||||
classes where each line contains one class name. If classes is
|
||||
a tuple or list, override the CLASSES defined by the dataset.
|
||||
|
||||
Returns:
|
||||
tuple[str] or list[str]: Names of categories of the dataset.
|
||||
"""
|
||||
if classes is None:
|
||||
return cls.CLASSES
|
||||
|
||||
if isinstance(classes, str):
|
||||
# take it as a file path
|
||||
class_names = mmcv.list_from_file(classes)
|
||||
elif isinstance(classes, (tuple, list)):
|
||||
class_names = classes
|
||||
else:
|
||||
raise ValueError(f'Unsupported type {type(classes)} of classes.')
|
||||
|
||||
return class_names
|
||||
|
||||
def evaluate(self,
|
||||
results,
|
||||
metric='accuracy',
|
||||
|
|
|
@ -21,12 +21,44 @@ def test_datasets_override_default(dataset_name):
|
|||
dataset_class = DATASETS.get(dataset_name)
|
||||
dataset_class.load_annotations = MagicMock()
|
||||
|
||||
original_classes = dataset_class.CLASSES
|
||||
|
||||
# Test setting classes as a tuple
|
||||
dataset = dataset_class(
|
||||
data_prefix='', pipeline=[], classes=('bus', 'car'), test_mode=True)
|
||||
assert dataset.CLASSES != original_classes
|
||||
assert dataset.CLASSES == ('bus', 'car')
|
||||
|
||||
# Test setting classes as a list
|
||||
dataset = dataset_class(
|
||||
data_prefix='', pipeline=[], classes=['bus', 'car'], test_mode=True)
|
||||
assert dataset.CLASSES != original_classes
|
||||
assert dataset.CLASSES == ['bus', 'car']
|
||||
|
||||
# Test setting classes through a file
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
with open(tmp_file.name, 'w') as f:
|
||||
f.write('bus\ncar\n')
|
||||
dataset = dataset_class(
|
||||
data_prefix='', pipeline=[], classes=tmp_file.name, test_mode=True)
|
||||
tmp_file.close()
|
||||
|
||||
assert dataset.CLASSES != original_classes
|
||||
assert dataset.CLASSES == ['bus', 'car']
|
||||
|
||||
# Test overriding not a subset
|
||||
dataset = dataset_class(
|
||||
data_prefix='', pipeline=[], classes=['foo'], test_mode=True)
|
||||
assert dataset.CLASSES != original_classes
|
||||
assert dataset.CLASSES == ['foo']
|
||||
|
||||
# Test default behavior
|
||||
dataset = dataset_class(data_prefix='', pipeline=[])
|
||||
|
||||
assert dataset.data_prefix == ''
|
||||
assert not dataset.test_mode
|
||||
assert dataset.ann_file is None
|
||||
assert dataset.CLASSES == original_classes
|
||||
|
||||
|
||||
@patch.multiple(BaseDataset, __abstractmethods__=set())
|
||||
|
|
Loading…
Reference in New Issue