80 lines
2.7 KiB
Python
80 lines
2.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
|
|
from mmcls.core import average_performance, mAP
|
|
from .base_dataset import BaseDataset
|
|
|
|
|
|
class MultiLabelDataset(BaseDataset):
|
|
"""Multi-label Dataset."""
|
|
|
|
def get_cat_ids(self, idx: int) -> List[int]:
|
|
"""Get category ids by index.
|
|
|
|
Args:
|
|
idx (int): Index of data.
|
|
|
|
Returns:
|
|
cat_ids (List[int]): Image categories of specified index.
|
|
"""
|
|
gt_labels = self.data_infos[idx]['gt_label']
|
|
cat_ids = np.where(gt_labels == 1)[0].tolist()
|
|
return cat_ids
|
|
|
|
def evaluate(self,
|
|
results,
|
|
metric='mAP',
|
|
metric_options=None,
|
|
indices=None,
|
|
logger=None):
|
|
"""Evaluate the dataset.
|
|
|
|
Args:
|
|
results (list): Testing results of the dataset.
|
|
metric (str | list[str]): Metrics to be evaluated.
|
|
Default value is 'mAP'. Options are 'mAP', 'CP', 'CR', 'CF1',
|
|
'OP', 'OR' and 'OF1'.
|
|
metric_options (dict, optional): Options for calculating metrics.
|
|
Allowed keys are 'k' and 'thr'. Defaults to None
|
|
logger (logging.Logger | str, optional): Logger used for printing
|
|
related information during evaluation. Defaults to None.
|
|
|
|
Returns:
|
|
dict: evaluation results
|
|
"""
|
|
if metric_options is None or metric_options == {}:
|
|
metric_options = {'thr': 0.5}
|
|
|
|
if isinstance(metric, str):
|
|
metrics = [metric]
|
|
else:
|
|
metrics = metric
|
|
allowed_metrics = ['mAP', 'CP', 'CR', 'CF1', 'OP', 'OR', 'OF1']
|
|
eval_results = {}
|
|
results = np.vstack(results)
|
|
gt_labels = self.get_gt_labels()
|
|
if indices is not None:
|
|
gt_labels = gt_labels[indices]
|
|
num_imgs = len(results)
|
|
assert len(gt_labels) == num_imgs, 'dataset testing results should '\
|
|
'be of the same length as gt_labels.'
|
|
|
|
invalid_metrics = set(metrics) - set(allowed_metrics)
|
|
if len(invalid_metrics) != 0:
|
|
raise ValueError(f'metric {invalid_metrics} is not supported.')
|
|
|
|
if 'mAP' in metrics:
|
|
mAP_value = mAP(results, gt_labels)
|
|
eval_results['mAP'] = mAP_value
|
|
if len(set(metrics) - {'mAP'}) != 0:
|
|
performance_keys = ['CP', 'CR', 'CF1', 'OP', 'OR', 'OF1']
|
|
performance_values = average_performance(results, gt_labels,
|
|
**metric_options)
|
|
for k, v in zip(performance_keys, performance_values):
|
|
if k in metrics:
|
|
eval_results[k] = v
|
|
|
|
return eval_results
|