2020-05-21 21:21:43 +08:00
|
|
|
import copy
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
|
|
|
2020-11-12 14:22:02 +08:00
|
|
|
import mmcv
|
2020-07-01 16:09:06 +08:00
|
|
|
import numpy as np
|
2020-05-21 21:21:43 +08:00
|
|
|
from torch.utils.data import Dataset
|
|
|
|
|
2021-01-19 16:42:16 +08:00
|
|
|
from mmcls.core.evaluation import f1_score, precision, recall, support
|
|
|
|
from mmcls.models.losses import accuracy
|
2020-05-21 21:21:43 +08:00
|
|
|
from .pipelines import Compose
|
|
|
|
|
|
|
|
|
|
|
|
class BaseDataset(Dataset, metaclass=ABCMeta):
|
2020-07-01 16:09:06 +08:00
|
|
|
"""Base dataset.
|
2020-05-21 21:21:43 +08:00
|
|
|
|
2020-07-01 16:09:06 +08:00
|
|
|
Args:
|
|
|
|
data_prefix (str): the prefix of data path
|
|
|
|
pipeline (list): a list of dict, where each element represents
|
|
|
|
a operation defined in `mmcls.datasets.pipelines`
|
|
|
|
ann_file (str | None): the annotation file. When ann_file is str,
|
|
|
|
the subclass is expected to read from the ann_file. When ann_file
|
|
|
|
is None, the subclass is expected to read according to data_prefix
|
|
|
|
test_mode (bool): in train mode or test mode
|
|
|
|
"""
|
|
|
|
|
2020-07-07 19:32:06 +08:00
|
|
|
CLASSES = None
|
|
|
|
|
2020-11-12 14:22:02 +08:00
|
|
|
def __init__(self,
|
|
|
|
data_prefix,
|
|
|
|
pipeline,
|
|
|
|
classes=None,
|
|
|
|
ann_file=None,
|
|
|
|
test_mode=False):
|
2020-05-21 21:21:43 +08:00
|
|
|
super(BaseDataset, self).__init__()
|
|
|
|
|
|
|
|
self.ann_file = ann_file
|
|
|
|
self.data_prefix = data_prefix
|
|
|
|
self.test_mode = test_mode
|
|
|
|
self.pipeline = Compose(pipeline)
|
2020-11-12 14:22:02 +08:00
|
|
|
self.CLASSES = self.get_classes(classes)
|
2021-01-11 11:08:59 +08:00
|
|
|
self.data_infos = self.load_annotations()
|
2020-05-21 21:21:43 +08:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def load_annotations(self):
|
|
|
|
pass
|
|
|
|
|
2020-09-30 19:00:20 +08:00
|
|
|
@property
|
|
|
|
def class_to_idx(self):
|
2020-10-26 14:04:10 +08:00
|
|
|
"""Map mapping class name to class index.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: mapping from class name to class index.
|
|
|
|
"""
|
|
|
|
|
2020-09-30 19:00:20 +08:00
|
|
|
return {_class: i for i, _class in enumerate(self.CLASSES)}
|
|
|
|
|
|
|
|
def get_gt_labels(self):
|
2020-10-26 14:04:10 +08:00
|
|
|
"""Get all ground-truth labels (categories).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list[int]: categories for all images.
|
|
|
|
"""
|
|
|
|
|
2020-09-30 19:00:20 +08:00
|
|
|
gt_labels = np.array([data['gt_label'] for data in self.data_infos])
|
|
|
|
return gt_labels
|
|
|
|
|
2020-10-26 14:04:10 +08:00
|
|
|
def get_cat_ids(self, idx):
|
|
|
|
"""Get category id by index.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
idx (int): Index of data.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: Image category of specified index.
|
|
|
|
"""
|
|
|
|
|
|
|
|
return self.data_infos[idx]['gt_label'].astype(np.int)
|
|
|
|
|
2020-07-01 16:09:06 +08:00
|
|
|
def prepare_data(self, idx):
|
2020-05-21 21:21:43 +08:00
|
|
|
results = copy.deepcopy(self.data_infos[idx])
|
|
|
|
return self.pipeline(results)
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.data_infos)
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
2020-07-01 16:09:06 +08:00
|
|
|
return self.prepare_data(idx)
|
|
|
|
|
2020-11-12 14:22:02 +08:00
|
|
|
@classmethod
|
|
|
|
def get_classes(cls, classes=None):
|
|
|
|
"""Get class names of current dataset.
|
|
|
|
Args:
|
|
|
|
classes (Sequence[str] | str | None): If classes is None, use
|
|
|
|
default CLASSES defined by builtin dataset. If classes is a
|
|
|
|
string, take it as a file name. The file contains the name of
|
|
|
|
classes where each line contains one class name. If classes is
|
|
|
|
a tuple or list, override the CLASSES defined by the dataset.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
tuple[str] or list[str]: Names of categories of the dataset.
|
|
|
|
"""
|
|
|
|
if classes is None:
|
|
|
|
return cls.CLASSES
|
|
|
|
|
|
|
|
if isinstance(classes, str):
|
|
|
|
# take it as a file path
|
|
|
|
class_names = mmcv.list_from_file(classes)
|
|
|
|
elif isinstance(classes, (tuple, list)):
|
|
|
|
class_names = classes
|
|
|
|
else:
|
|
|
|
raise ValueError(f'Unsupported type {type(classes)} of classes.')
|
|
|
|
|
|
|
|
return class_names
|
|
|
|
|
2020-09-30 19:00:20 +08:00
|
|
|
def evaluate(self,
|
|
|
|
results,
|
|
|
|
metric='accuracy',
|
|
|
|
metric_options={'topk': (1, 5)},
|
|
|
|
logger=None):
|
2020-07-01 16:09:06 +08:00
|
|
|
"""Evaluate the dataset.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
results (list): Testing results of the dataset.
|
|
|
|
metric (str | list[str]): Metrics to be evaluated.
|
|
|
|
Default value is `accuracy`.
|
2021-01-19 16:42:16 +08:00
|
|
|
metric_options (dict): Options for calculating metrics. Allowed
|
|
|
|
keys are 'topk' and 'average'.
|
2020-07-01 16:09:06 +08:00
|
|
|
logger (logging.Logger | None | str): Logger used for printing
|
|
|
|
related information during evaluation. Default: None.
|
|
|
|
Returns:
|
|
|
|
dict: evaluation results
|
|
|
|
"""
|
2020-11-25 16:13:54 +08:00
|
|
|
if isinstance(metric, str):
|
|
|
|
metrics = [metric]
|
|
|
|
else:
|
|
|
|
metrics = metric
|
2021-01-19 16:42:16 +08:00
|
|
|
allowed_metrics = [
|
|
|
|
'accuracy', 'precision', 'recall', 'f1_score', 'support'
|
|
|
|
]
|
2020-07-01 16:09:06 +08:00
|
|
|
eval_results = {}
|
2021-01-11 11:08:59 +08:00
|
|
|
results = np.vstack(results)
|
|
|
|
gt_labels = self.get_gt_labels()
|
|
|
|
num_imgs = len(results)
|
|
|
|
assert len(gt_labels) == num_imgs
|
2020-11-25 16:13:54 +08:00
|
|
|
for metric in metrics:
|
|
|
|
if metric not in allowed_metrics:
|
|
|
|
raise KeyError(f'metric {metric} is not supported.')
|
|
|
|
if metric == 'accuracy':
|
|
|
|
topk = metric_options.get('topk')
|
|
|
|
acc = accuracy(results, gt_labels, topk)
|
|
|
|
eval_result = {f'top-{k}': a.item() for k, a in zip(topk, acc)}
|
|
|
|
elif metric == 'precision':
|
2021-01-19 16:42:16 +08:00
|
|
|
precision_value = precision(
|
|
|
|
results,
|
|
|
|
gt_labels,
|
|
|
|
average=metric_options.get('average', 'macro'))
|
2020-11-25 16:13:54 +08:00
|
|
|
eval_result = {'precision': precision_value}
|
|
|
|
elif metric == 'recall':
|
2021-01-19 16:42:16 +08:00
|
|
|
recall_value = recall(
|
|
|
|
results,
|
|
|
|
gt_labels,
|
|
|
|
average=metric_options.get('average', 'macro'))
|
2020-11-25 16:13:54 +08:00
|
|
|
eval_result = {'recall': recall_value}
|
|
|
|
elif metric == 'f1_score':
|
2021-01-19 16:42:16 +08:00
|
|
|
f1_score_value = f1_score(
|
|
|
|
results,
|
|
|
|
gt_labels,
|
|
|
|
average=metric_options.get('average', 'macro'))
|
2020-11-25 16:13:54 +08:00
|
|
|
eval_result = {'f1_score': f1_score_value}
|
2021-01-19 16:42:16 +08:00
|
|
|
elif metric == 'support':
|
|
|
|
support_value = support(
|
|
|
|
results,
|
|
|
|
gt_labels,
|
|
|
|
average=metric_options.get('average', 'macro'))
|
|
|
|
eval_result = {'support': support_value}
|
2020-11-25 16:13:54 +08:00
|
|
|
eval_results.update(eval_result)
|
2020-07-01 16:09:06 +08:00
|
|
|
return eval_results
|