2020-05-21 21:21:43 +08:00
|
|
|
import copy
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
|
|
|
2020-07-01 16:09:06 +08:00
|
|
|
import numpy as np
|
2020-05-21 21:21:43 +08:00
|
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
|
|
from .pipelines import Compose
|
|
|
|
|
|
|
|
|
|
|
|
class BaseDataset(Dataset, metaclass=ABCMeta):
|
2020-07-01 16:09:06 +08:00
|
|
|
"""Base dataset.
|
2020-05-21 21:21:43 +08:00
|
|
|
|
2020-07-01 16:09:06 +08:00
|
|
|
Args:
|
|
|
|
data_prefix (str): the prefix of data path
|
|
|
|
pipeline (list): a list of dict, where each element represents
|
|
|
|
a operation defined in `mmcls.datasets.pipelines`
|
|
|
|
ann_file (str | None): the annotation file. When ann_file is str,
|
|
|
|
the subclass is expected to read from the ann_file. When ann_file
|
|
|
|
is None, the subclass is expected to read according to data_prefix
|
|
|
|
test_mode (bool): in train mode or test mode
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, data_prefix, pipeline, ann_file=None, test_mode=False):
|
2020-05-21 21:21:43 +08:00
|
|
|
super(BaseDataset, self).__init__()
|
|
|
|
|
|
|
|
self.ann_file = ann_file
|
|
|
|
self.data_prefix = data_prefix
|
|
|
|
self.test_mode = test_mode
|
|
|
|
self.pipeline = Compose(pipeline)
|
|
|
|
self.data_infos = self.load_annotations()
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def load_annotations(self):
|
|
|
|
pass
|
|
|
|
|
2020-07-01 16:09:06 +08:00
|
|
|
def prepare_data(self, idx):
|
2020-05-21 21:21:43 +08:00
|
|
|
results = copy.deepcopy(self.data_infos[idx])
|
|
|
|
return self.pipeline(results)
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.data_infos)
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
2020-07-01 16:09:06 +08:00
|
|
|
return self.prepare_data(idx)
|
|
|
|
|
|
|
|
def evaluate(self, results, metric='accuracy', logger=None):
|
|
|
|
"""Evaluate the dataset.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
results (list): Testing results of the dataset.
|
|
|
|
metric (str | list[str]): Metrics to be evaluated.
|
|
|
|
Default value is `accuracy`.
|
|
|
|
logger (logging.Logger | None | str): Logger used for printing
|
|
|
|
related information during evaluation. Default: None.
|
|
|
|
Returns:
|
|
|
|
dict: evaluation results
|
|
|
|
"""
|
|
|
|
if not isinstance(metric, str):
|
|
|
|
assert len(metric) == 1
|
|
|
|
metric = metric[0]
|
|
|
|
allowed_metrics = ['accuracy']
|
|
|
|
if metric not in allowed_metrics:
|
|
|
|
raise KeyError(f'metric {metric} is not supported')
|
|
|
|
eval_results = {}
|
|
|
|
if metric == 'accuracy':
|
|
|
|
nums = []
|
|
|
|
for result in results:
|
|
|
|
nums.append(result['num_samples'].item())
|
|
|
|
for topk, v in result['accuracy'].items():
|
|
|
|
if topk not in eval_results:
|
|
|
|
eval_results[topk] = []
|
|
|
|
eval_results[topk].append(v.item())
|
|
|
|
for topk, accs in eval_results.items():
|
|
|
|
eval_results[topk] = np.average(accs, weights=nums)
|
|
|
|
return eval_results
|