mmclassification/mmcls/datasets/base_dataset.py
Lei Yang 9547e7b7a5
Add model inference (#16)
* add model inference on single image

* rm --eval

* revise doc

* add inference tool and demo

* fix linting

* rename inference_image to inference_model

* infer pred_label and pred_score

* fix linting

* add docstr for inference

* add remove_keys

* add doc for inference

* dump results rather than outputs

* add class_names

* add related infer scripts

* add demo image and the first part of colab tutorial

* conduct evaluation in dataset

* return lst in simple_test

* compuate topk accuracy with numpy

* return outputs in test api

* merge inference and evaluation tool

* fix typo

* rm gt_labels in test conifg

* get gt_labels during evaluation

* sperate the ipython notebook to another PR

* return tensor for onnx_export

* detach var in simple_test

* rm inference script

* rm inference script

* construct data dict to replace LoadImage

* print first predicted result if args.out is None

* modify test_pipeline in inference

* refactor class_names of imagenet

* set class_to_idx as a property in base dataset

* output pred_class during inference

* remove unused docstr
2020-09-30 19:00:20 +08:00

90 lines
2.9 KiB
Python

import copy
from abc import ABCMeta, abstractmethod
import numpy as np
from torch.utils.data import Dataset
from mmcls.models.losses import accuracy
from .pipelines import Compose
class BaseDataset(Dataset, metaclass=ABCMeta):
"""Base dataset.
Args:
data_prefix (str): the prefix of data path
pipeline (list): a list of dict, where each element represents
a operation defined in `mmcls.datasets.pipelines`
ann_file (str | None): the annotation file. When ann_file is str,
the subclass is expected to read from the ann_file. When ann_file
is None, the subclass is expected to read according to data_prefix
test_mode (bool): in train mode or test mode
"""
CLASSES = None
def __init__(self, data_prefix, pipeline, ann_file=None, test_mode=False):
super(BaseDataset, self).__init__()
self.ann_file = ann_file
self.data_prefix = data_prefix
self.test_mode = test_mode
self.pipeline = Compose(pipeline)
self.data_infos = self.load_annotations()
@abstractmethod
def load_annotations(self):
pass
@property
def class_to_idx(self):
return {_class: i for i, _class in enumerate(self.CLASSES)}
def get_gt_labels(self):
gt_labels = np.array([data['gt_label'] for data in self.data_infos])
return gt_labels
def prepare_data(self, idx):
results = copy.deepcopy(self.data_infos[idx])
return self.pipeline(results)
def __len__(self):
return len(self.data_infos)
def __getitem__(self, idx):
return self.prepare_data(idx)
def evaluate(self,
results,
metric='accuracy',
metric_options={'topk': (1, 5)},
logger=None):
"""Evaluate the dataset.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
Default value is `accuracy`.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict: evaluation results
"""
if not isinstance(metric, str):
assert len(metric) == 1
metric = metric[0]
allowed_metrics = ['accuracy']
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')
eval_results = {}
if metric == 'accuracy':
topk = metric_options.get('topk')
results = np.vstack(results)
gt_labels = self.get_gt_labels()
num_imgs = len(results)
assert len(gt_labels) == num_imgs
acc = accuracy(results, gt_labels, topk)
eval_results = {f'top-{k}': a.item() for k, a in zip(topk, acc)}
return eval_results