[Feature] Migrate CSRA head to 1.x. (#1177)
* [Feat] add csra to 1x * minor fix * add voc metrics * refine * add unittest * minor fix * add more comments * Fix docs and metafile. * Fix docs. Co-authored-by: mzr1996 <mzr1996@163.com>pull/1143/head
parent
0e8cfa6286
commit
629f6447ef
|
@ -62,7 +62,11 @@ test_dataloader = dict(
|
|||
)
|
||||
|
||||
# calculate precision_recall_f1 and mAP
|
||||
val_evaluator = [dict(type='MultiLabelMetric'), dict(type='AveragePrecision')]
|
||||
val_evaluator = [
|
||||
dict(type='VOCMultiLabelMetric'),
|
||||
dict(type='VOCMultiLabelMetric', average='micro'),
|
||||
dict(type='VOCAveragePrecision')
|
||||
]
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# CSRA
|
||||
|
||||
> [Residual Attention: A Simple but Effective Method for Multi-Label Recognition](https://arxiv.org/abs/2108.02456)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Multi-label image recognition is a challenging computer vision task of practical use. Progresses in this area, however, are often characterized by complicated methods, heavy computations, and lack of intuitive explanations. To effectively capture different spatial regions occupied by objects from different categories, we propose an embarrassingly simple module, named class-specific residual attention (CSRA). CSRA generates class-specific features for every category by proposing a simple spatial attention score, and then combines it with the class-agnostic average pooling feature. CSRA achieves state-of-the-art results on multilabel recognition, and at the same time is much simpler than them. Furthermore, with only 4 lines of code, CSRA also leads to consistent improvement across many diverse pretrained models and datasets without any extra training. CSRA is both easy to implement and light in computations, which also enjoys intuitive explanations and visualizations.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/84259897/176982245-3ffcff56-a4ea-4474-9967-bc2b612bbaa3.png" width="80%"/>
|
||||
</div>
|
||||
|
||||
## Results and models
|
||||
|
||||
### VOC2007
|
||||
|
||||
| Model | Pretrain | Params(M) | Flops(G) | mAP | OF1 (%) | CF1 (%) | Config | Download |
|
||||
| :------------: | :------------------------------------------------: | :-------: | :------: | :---: | :-----: | :-----: | :-----------------------------------------------: | :-------------------------------------------------: |
|
||||
| Resnet101-CSRA | [ImageNet-1k](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth) | 23.55 | 4.12 | 94.98 | 90.80 | 89.16 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/csra/resnet101-csra_1xb16_voc07-448px.py) | [model](https://download.openmmlab.com/mmclassification/v0/csra/resnet101-csra_1xb16_voc07-448px_20220722-29efb40a.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/csra/resnet101-csra_1xb16_voc07-448px_20220722-29efb40a.log.json) |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{https://doi.org/10.48550/arxiv.2108.02456,
|
||||
doi = {10.48550/ARXIV.2108.02456},
|
||||
url = {https://arxiv.org/abs/2108.02456},
|
||||
author = {Zhu, Ke and Wu, Jianxin},
|
||||
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
|
||||
title = {Residual Attention: A Simple but Effective Method for Multi-Label Recognition},
|
||||
publisher = {arXiv},
|
||||
year = {2021},
|
||||
copyright = {arXiv.org perpetual, non-exclusive license}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,29 @@
|
|||
Collections:
|
||||
- Name: CSRA
|
||||
Metadata:
|
||||
Training Data: PASCAL VOC 2007
|
||||
Architecture:
|
||||
- Class-specific Residual Attention
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/1911.11929
|
||||
Title: 'Residual Attention: A Simple but Effective Method for Multi-Label Recognition'
|
||||
README: configs/csra/README.md
|
||||
Code:
|
||||
Version: v0.24.0
|
||||
URL: https://github.com/open-mmlab/mmclassification/blob/v0.24.0/mmcls/models/heads/multi_label_csra_head.py
|
||||
|
||||
Models:
|
||||
- Name: resnet101-csra_1xb16_voc07-448px
|
||||
Metadata:
|
||||
FLOPs: 4120000000
|
||||
Parameters: 23550000
|
||||
In Collection: CSRA
|
||||
Results:
|
||||
- Dataset: PASCAL VOC 2007
|
||||
Metrics:
|
||||
mAP: 94.98
|
||||
OF1: 90.80
|
||||
CF1: 89.16
|
||||
Task: Multi-Label Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/csra/resnet101-csra_1xb16_voc07-448px_20220722-29efb40a.pth
|
||||
Config: configs/csra/resnet101-csra_1xb16_voc07-448px.py
|
|
@ -0,0 +1,78 @@
|
|||
_base_ = ['../_base_/datasets/voc_bs16.py', '../_base_/default_runtime.py']
|
||||
|
||||
# Pre-trained Checkpoint Path
|
||||
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth' # noqa
|
||||
# If you want to use the pre-trained weight of ResNet101-CutMix from
|
||||
# the originary repo(https://github.com/Kevinz-code/CSRA). Script of
|
||||
# 'tools/convert_models/torchvision_to_mmcls.py' can help you convert weight
|
||||
# into mmcls format. The mAP result would hit 95.5 by using the weight.
|
||||
# checkpoint = 'PATH/TO/PRE-TRAINED_WEIGHT'
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=101,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch',
|
||||
init_cfg=dict(
|
||||
type='Pretrained', checkpoint=checkpoint, prefix='backbone')),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='CSRAClsHead',
|
||||
num_classes=20,
|
||||
in_channels=2048,
|
||||
num_heads=1,
|
||||
lam=0.1,
|
||||
loss=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)))
|
||||
|
||||
# dataset setting
|
||||
data_preprocessor = dict(
|
||||
# RGB format normalization parameters
|
||||
mean=[0, 0, 0],
|
||||
std=[255, 255, 255])
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', scale=448, crop_ratio_range=(0.7, 1.0)),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='PackClsInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=448),
|
||||
dict(
|
||||
type='PackClsInputs',
|
||||
# `gt_label_difficult` is needed for VOC evaluation
|
||||
meta_keys=('sample_idx', 'img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor', 'flip', 'flip_direction',
|
||||
'gt_label_difficult')),
|
||||
]
|
||||
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
# optimizer
|
||||
# the lr of classifier.head is 10 * base_lr, which help convergence.
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(type='SGD', lr=0.0002, momentum=0.9, weight_decay=0.0001),
|
||||
paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10)}))
|
||||
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-7,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=1,
|
||||
convert_to_iter_based=True),
|
||||
dict(type='StepLR', by_epoch=True, step_size=6, gamma=0.1)
|
||||
]
|
||||
|
||||
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
|
@ -30,5 +30,7 @@ Multi Label Metric
|
|||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
MultiLabelMetric
|
||||
AveragePrecision
|
||||
MultiLabelMetric
|
||||
VOCAveragePrecision
|
||||
VOCMultiLabelMetric
|
||||
|
|
|
@ -142,6 +142,7 @@ Heads
|
|||
ConformerHead
|
||||
MultiLabelClsHead
|
||||
MultiLabelLinearClsHead
|
||||
CSRAClsHead
|
||||
|
||||
.. module:: mmcls.models.losses
|
||||
|
||||
|
|
|
@ -29,6 +29,12 @@ class VOC(MultiLabelDataset):
|
|||
│ └── ...
|
||||
└── ImageSets (directory contains various imageset file)
|
||||
|
||||
Extra difficult label is in VOC annotations, we will use
|
||||
`gt_label_difficult` to record the difficult labels in each sample
|
||||
and corresponding evaluation should take care of this field
|
||||
to calculate metrics. Usually, difficult labels are reckoned as
|
||||
negative in defaults.
|
||||
|
||||
Args:
|
||||
data_root (str): The root directory for VOC dataset.
|
||||
image_set_path (str): The path of image set, The file which
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .multi_label import AveragePrecision, MultiLabelMetric
|
||||
from .single_label import Accuracy, SingleLabelMetric
|
||||
from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric
|
||||
|
||||
__all__ = [
|
||||
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision'
|
||||
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
|
||||
'VOCAveragePrecision', 'VOCMultiLabelMetric'
|
||||
]
|
||||
|
|
|
@ -400,6 +400,12 @@ def _average_precision(pred: torch.Tensor,
|
|||
# a small value for division by zero errors
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
|
||||
# get rid of -1 target such as difficult sample
|
||||
# that is not wanted in evaluation results.
|
||||
valid_index = target > -1
|
||||
pred = pred[valid_index]
|
||||
target = target[valid_index]
|
||||
|
||||
# sort examples
|
||||
sorted_pred_inds = torch.argsort(pred, dim=0, descending=True)
|
||||
sorted_target = target[sorted_pred_inds]
|
||||
|
|
|
@ -28,6 +28,13 @@ def _precision_recall_f1_support(pred_positive, gt_positive, average):
|
|||
assert average in average_options, 'Invalid `average` argument, ' \
|
||||
f'please specicy from {average_options}.'
|
||||
|
||||
# ignore -1 target such as difficult sample that is not wanted
|
||||
# in evaluation results.
|
||||
# only for calculate multi-label without affecting single-label behavior
|
||||
ignored_index = gt_positive == -1
|
||||
pred_positive[ignored_index] = 0
|
||||
gt_positive[ignored_index] = 0
|
||||
|
||||
class_correct = (pred_positive & gt_positive)
|
||||
if average == 'micro':
|
||||
tp_sum = class_correct.sum()
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from mmengine.structures import LabelData
|
||||
|
||||
from mmcls.registry import METRICS
|
||||
from .multi_label import AveragePrecision, MultiLabelMetric
|
||||
|
||||
|
||||
class VOCMetricMixin:
|
||||
"""A mixin class for VOC dataset metrics, VOC annotations have extra
|
||||
`difficult` attribute for each object, therefore, extra option is needed
|
||||
for calculating VOC metrics.
|
||||
|
||||
Args:
|
||||
difficult_as_postive (Optional[bool]): Whether to map the difficult
|
||||
labels as positive in one-hot ground truth for evaluation. If it
|
||||
set to True, map difficult gt labels to positive ones(1), If it
|
||||
set to False, map difficult gt labels to negative ones(0).
|
||||
Defaults to None, the difficult labels will be set to '-1'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*arg,
|
||||
difficult_as_positive: Optional[bool] = None,
|
||||
**kwarg):
|
||||
self.difficult_as_positive = difficult_as_positive
|
||||
super().__init__(*arg, **kwarg)
|
||||
|
||||
def process(self, data_batch, data_samples: Sequence[dict]):
|
||||
"""Process one batch of data samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch: A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for data_sample in data_samples:
|
||||
result = dict()
|
||||
pred_label = data_sample['pred_label']
|
||||
gt_label = data_sample['gt_label']
|
||||
gt_label_difficult = data_sample['gt_label_difficult']
|
||||
|
||||
result['pred_score'] = pred_label['score'].clone()
|
||||
num_classes = result['pred_score'].size()[-1]
|
||||
|
||||
if 'score' in gt_label:
|
||||
result['gt_score'] = gt_label['score'].clone()
|
||||
else:
|
||||
result['gt_score'] = LabelData.label_to_onehot(
|
||||
gt_label['label'], num_classes)
|
||||
|
||||
# VOC annotation labels all the objects in a single image
|
||||
# therefore, some categories are appeared both in
|
||||
# difficult objects and non-difficult objects.
|
||||
# Here we reckon those labels which are only exists in difficult
|
||||
# objects as difficult labels.
|
||||
difficult_label = set(gt_label_difficult) - (
|
||||
set(gt_label_difficult) & set(gt_label['label'].tolist()))
|
||||
|
||||
# set difficult label for better eval
|
||||
if self.difficult_as_positive is None:
|
||||
result['gt_score'][[*difficult_label]] = -1
|
||||
elif self.difficult_as_positive:
|
||||
result['gt_score'][[*difficult_label]] = 1
|
||||
|
||||
# Save the result to `self.results`.
|
||||
self.results.append(result)
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class VOCMultiLabelMetric(VOCMetricMixin, MultiLabelMetric):
|
||||
"""A collection of metrics for multi-label multi-class classification task
|
||||
based on confusion matrix for VOC dataset.
|
||||
|
||||
It includes precision, recall, f1-score and support.
|
||||
|
||||
Args:
|
||||
difficult_as_postive (Optional[bool]): Whether to map the difficult
|
||||
labels as positive in one-hot ground truth for evaluation. If it
|
||||
set to True, map difficult gt labels to positive ones(1), If it
|
||||
set to False, map difficult gt labels to negative ones(0).
|
||||
Defaults to None, the difficult labels will be set to '-1'.
|
||||
**kwarg: Refers to `MultiLabelMetric` for detailed docstrings.
|
||||
"""
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class VOCAveragePrecision(VOCMetricMixin, AveragePrecision):
|
||||
"""Calculate the average precision with respect of classes for VOC dataset.
|
||||
|
||||
Args:
|
||||
difficult_as_postive (Optional[bool]): Whether to map the difficult
|
||||
labels as positive in one-hot ground truth for evaluation. If it
|
||||
set to True, map difficult gt labels to positive ones(1), If it
|
||||
set to False, map difficult gt labels to negative ones(0).
|
||||
Defaults to None, the difficult labels will be set to '-1'.
|
||||
**kwarg: Refers to `AveragePrecision` for detailed docstrings.
|
||||
"""
|
|
@ -6,6 +6,7 @@ from .deit_head import DeiTClsHead
|
|||
from .efficientformer_head import EfficientFormerClsHead
|
||||
from .linear_head import LinearClsHead
|
||||
from .multi_label_cls_head import MultiLabelClsHead
|
||||
from .multi_label_csra_head import CSRAClsHead
|
||||
from .multi_label_linear_head import MultiLabelLinearClsHead
|
||||
from .stacked_head import StackedLinearClsHead
|
||||
from .vision_transformer_head import VisionTransformerClsHead
|
||||
|
@ -13,5 +14,5 @@ from .vision_transformer_head import VisionTransformerClsHead
|
|||
__all__ = [
|
||||
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
|
||||
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead',
|
||||
'ConformerHead', 'EfficientFormerClsHead', 'ArcFaceClsHead'
|
||||
'ConformerHead', 'EfficientFormerClsHead', 'ArcFaceClsHead', 'CSRAClsHead'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://github.com/Kevinz-code/CSRA
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from .multi_label_cls_head import MultiLabelClsHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CSRAClsHead(MultiLabelClsHead):
|
||||
"""Class-specific residual attention classifier head.
|
||||
|
||||
Please refer to the `Residual Attention: A Simple but Effective Method for
|
||||
Multi-Label Recognition (ICCV 2021) <https://arxiv.org/abs/2108.02456>`_
|
||||
for details.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of categories.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
num_heads (int): Number of residual at tensor heads.
|
||||
loss (dict): Config of classification loss.
|
||||
lam (float): Lambda that combines global average and max pooling
|
||||
scores.
|
||||
init_cfg (dict, optional): The extra init config of layers.
|
||||
Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``.
|
||||
"""
|
||||
temperature_settings = { # softmax temperature settings
|
||||
1: [1],
|
||||
2: [1, 99],
|
||||
4: [1, 2, 4, 99],
|
||||
6: [1, 2, 3, 4, 5, 99],
|
||||
8: [1, 2, 3, 4, 5, 6, 7, 99]
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int,
|
||||
in_channels: int,
|
||||
num_heads: int,
|
||||
lam: float,
|
||||
init_cfg=dict(type='Normal', layer='Linear', std=0.01),
|
||||
**kwargs):
|
||||
assert num_heads in self.temperature_settings.keys(
|
||||
), 'The num of heads is not in temperature setting.'
|
||||
assert lam > 0, 'Lambda should be between 0 and 1.'
|
||||
super(CSRAClsHead, self).__init__(init_cfg=init_cfg, **kwargs)
|
||||
self.temp_list = self.temperature_settings[num_heads]
|
||||
self.csra_heads = ModuleList([
|
||||
CSRAModule(num_classes, in_channels, self.temp_list[i], lam)
|
||||
for i in range(num_heads)
|
||||
])
|
||||
|
||||
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""The process before the final classification head.
|
||||
|
||||
The input ``feats`` is a tuple of tensor, and each tensor is the
|
||||
feature of a backbone stage. In ``CSRAClsHead``, we just obtain the
|
||||
feature of the last stage.
|
||||
"""
|
||||
# The CSRAClsHead doesn't have other module, just return after
|
||||
# unpacking.
|
||||
return feats[-1]
|
||||
|
||||
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""The forward process."""
|
||||
pre_logits = self.pre_logits(feats)
|
||||
logit = sum([head(pre_logits) for head in self.csra_heads])
|
||||
return logit
|
||||
|
||||
|
||||
class CSRAModule(BaseModule):
|
||||
"""Basic module of CSRA with different temperature.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of categories.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
T (int): Temperature setting.
|
||||
lam (float): Lambda that combines global average and max pooling
|
||||
scores.
|
||||
init_cfg (dict | optional): The extra init config of layers.
|
||||
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int,
|
||||
in_channels: int,
|
||||
T: int,
|
||||
lam: float,
|
||||
init_cfg=None):
|
||||
|
||||
super(CSRAModule, self).__init__(init_cfg=init_cfg)
|
||||
self.T = T # temperature
|
||||
self.lam = lam # Lambda
|
||||
self.head = nn.Conv2d(in_channels, num_classes, 1, bias=False)
|
||||
self.softmax = nn.Softmax(dim=2)
|
||||
|
||||
def forward(self, x):
|
||||
score = self.head(x) / torch.norm(
|
||||
self.head.weight, dim=1, keepdim=True).transpose(0, 1)
|
||||
score = score.flatten(2)
|
||||
base_logit = torch.mean(score, dim=2)
|
||||
|
||||
if self.T == 99: # max-pooling
|
||||
att_logit = torch.max(score, dim=2)[0]
|
||||
else:
|
||||
score_soft = self.softmax(score * self.T)
|
||||
att_logit = torch.sum(score * score_soft, dim=2)
|
||||
|
||||
return base_logit + self.lam * att_logit
|
|
@ -39,3 +39,4 @@ Import:
|
|||
- configs/mobilevit/metafile.yml
|
||||
- configs/davit/metafile.yml
|
||||
- configs/replknet/metafile.yml
|
||||
- configs/csra/metafile.yml
|
||||
|
|
|
@ -0,0 +1,228 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import sklearn.metrics
|
||||
import torch
|
||||
from mmengine.evaluator import Evaluator
|
||||
|
||||
from mmcls.structures import ClsDataSample
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestVOCMultiLabel(TestCase):
|
||||
|
||||
def test_evaluate(self):
|
||||
# prepare input data
|
||||
y_true_label = [[0], [1, 3], [0, 1, 2], [3]]
|
||||
y_true_difficult = [[0], [2], [1], []]
|
||||
y_pred_score = torch.tensor([
|
||||
[0.8, 0, 0, 0.6],
|
||||
[0.2, 0, 0.6, 0],
|
||||
[0, 0.9, 0.6, 0],
|
||||
[0, 0, 0.2, 0.3],
|
||||
])
|
||||
|
||||
# generate data samples
|
||||
pred = [
|
||||
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_label(j)
|
||||
for i, j in zip(y_pred_score, y_true_label)
|
||||
]
|
||||
for sample, difficult_label in zip(pred, y_true_difficult):
|
||||
sample.set_metainfo({'gt_label_difficult': difficult_label})
|
||||
|
||||
# 1. Test with default argument
|
||||
evaluator = Evaluator(dict(type='VOCMultiLabelMetric'))
|
||||
evaluator.process(pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
|
||||
# generate sklearn input
|
||||
y_true = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, -1, 1],
|
||||
[1, 1, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
])
|
||||
ignored_index = y_true == -1
|
||||
y_true[ignored_index] = 0
|
||||
thr05_y_pred = np.array([
|
||||
[1, 0, 0, 1],
|
||||
[0, 0, 1, 0],
|
||||
[0, 1, 1, 0],
|
||||
[0, 0, 0, 0],
|
||||
])
|
||||
thr05_y_pred[ignored_index] = 0
|
||||
|
||||
expect_precision = sklearn.metrics.precision_score(
|
||||
y_true, thr05_y_pred, average='macro') * 100
|
||||
expect_recall = sklearn.metrics.recall_score(
|
||||
y_true, thr05_y_pred, average='macro') * 100
|
||||
expect_f1 = sklearn.metrics.f1_score(
|
||||
y_true, thr05_y_pred, average='macro') * 100
|
||||
self.assertEqual(res['multi-label/precision'], expect_precision)
|
||||
self.assertEqual(res['multi-label/recall'], expect_recall)
|
||||
# precision is different between torch and sklearn
|
||||
self.assertAlmostEqual(res['multi-label/f1-score'], expect_f1, 5)
|
||||
|
||||
# 2. Test with `difficult_as_positive`=False argument
|
||||
evaluator = Evaluator(
|
||||
dict(type='VOCMultiLabelMetric', difficult_as_positive=False))
|
||||
evaluator.process(pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
|
||||
# generate sklearn input
|
||||
y_true = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, 1],
|
||||
[1, 1, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
])
|
||||
thr05_y_pred = np.array([
|
||||
[1, 0, 0, 1],
|
||||
[0, 0, 1, 0],
|
||||
[0, 1, 1, 0],
|
||||
[0, 0, 0, 0],
|
||||
])
|
||||
|
||||
expect_precision = sklearn.metrics.precision_score(
|
||||
y_true, thr05_y_pred, average='macro') * 100
|
||||
expect_recall = sklearn.metrics.recall_score(
|
||||
y_true, thr05_y_pred, average='macro') * 100
|
||||
expect_f1 = sklearn.metrics.f1_score(
|
||||
y_true, thr05_y_pred, average='macro') * 100
|
||||
self.assertEqual(res['multi-label/precision'], expect_precision)
|
||||
self.assertEqual(res['multi-label/recall'], expect_recall)
|
||||
# precision is different between torch and sklearn
|
||||
self.assertAlmostEqual(res['multi-label/f1-score'], expect_f1, 5)
|
||||
|
||||
# 3. Test with `difficult_as_positive`=True argument
|
||||
evaluator = Evaluator(
|
||||
dict(type='VOCMultiLabelMetric', difficult_as_positive=True))
|
||||
evaluator.process(pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
|
||||
# generate sklearn input
|
||||
y_true = np.array([
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 1, 1],
|
||||
[1, 1, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
])
|
||||
thr05_y_pred = np.array([
|
||||
[1, 0, 0, 1],
|
||||
[0, 0, 1, 0],
|
||||
[0, 1, 1, 0],
|
||||
[0, 0, 0, 0],
|
||||
])
|
||||
|
||||
expect_precision = sklearn.metrics.precision_score(
|
||||
y_true, thr05_y_pred, average='macro') * 100
|
||||
expect_recall = sklearn.metrics.recall_score(
|
||||
y_true, thr05_y_pred, average='macro') * 100
|
||||
expect_f1 = sklearn.metrics.f1_score(
|
||||
y_true, thr05_y_pred, average='macro') * 100
|
||||
self.assertEqual(res['multi-label/precision'], expect_precision)
|
||||
self.assertEqual(res['multi-label/recall'], expect_recall)
|
||||
# precision is different between torch and sklearn
|
||||
self.assertAlmostEqual(res['multi-label/f1-score'], expect_f1, 5)
|
||||
|
||||
|
||||
class TestVOCAveragePrecision(TestCase):
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test using the metric in the same way as Evalutor."""
|
||||
# prepare input data
|
||||
y_true_difficult = [[0], [2], [1], []]
|
||||
y_pred_score = torch.tensor([
|
||||
[0.8, 0.1, 0, 0.6],
|
||||
[0.2, 0.2, 0.7, 0],
|
||||
[0.1, 0.9, 0.6, 0.1],
|
||||
[0, 0, 0.2, 0.3],
|
||||
])
|
||||
y_true_label = [[0], [1, 3], [0, 1, 2], [3]]
|
||||
y_true = torch.tensor([
|
||||
[1, 0, 0, 0],
|
||||
[0, 1, 0, 1],
|
||||
[1, 1, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
])
|
||||
y_true_difficult = [[0], [2], [1], []]
|
||||
|
||||
# generate data samples
|
||||
pred = [
|
||||
ClsDataSample(num_classes=4).set_pred_score(i).set_gt_score(
|
||||
j).set_gt_label(k)
|
||||
for i, j, k in zip(y_pred_score, y_true, y_true_label)
|
||||
]
|
||||
for sample, difficult_label in zip(pred, y_true_difficult):
|
||||
sample.set_metainfo({'gt_label_difficult': difficult_label})
|
||||
|
||||
# 1. Test with default
|
||||
evaluator = Evaluator(dict(type='VOCAveragePrecision'))
|
||||
evaluator.process(pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
|
||||
# prepare inputs for sklearn for this case
|
||||
y_pred_score = [[0.8, 0.2, 0.1, 0], [0.1, 0.2, 0.9, 0], [0, 0.6, 0.2],
|
||||
[0.6, 0, 0.1, 0.3]]
|
||||
y_true = [[1, 0, 1, 0], [0, 1, 1, 0], [0, 1, 0], [0, 1, 0, 1]]
|
||||
expected_res = []
|
||||
for pred_per_class, gt_per_class in zip(y_pred_score, y_true):
|
||||
expected_res.append(
|
||||
sklearn.metrics.average_precision_score(
|
||||
gt_per_class, pred_per_class))
|
||||
|
||||
self.assertAlmostEqual(
|
||||
res['multi-label/mAP'],
|
||||
sum(expected_res) * 100 / len(expected_res),
|
||||
places=4)
|
||||
|
||||
# 2. Test with `difficult_as_positive`=False argument
|
||||
evaluator = Evaluator(
|
||||
dict(type='VOCAveragePrecision', difficult_as_positive=False))
|
||||
evaluator.process(pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
|
||||
# prepare inputs for sklearn for this case
|
||||
y_pred_score = [[0.8, 0.2, 0.1, 0], [0.1, 0.2, 0.9, 0],
|
||||
[0, 0.7, 0.6, 0.2], [0.6, 0, 0.1, 0.3]]
|
||||
y_true = [[1, 0, 1, 0], [0, 1, 1, 0], [0, 0, 1, 0], [0, 1, 0, 1]]
|
||||
expected_res = []
|
||||
for pred_per_class, gt_per_class in zip(y_pred_score, y_true):
|
||||
expected_res.append(
|
||||
sklearn.metrics.average_precision_score(
|
||||
gt_per_class, pred_per_class))
|
||||
|
||||
self.assertAlmostEqual(
|
||||
res['multi-label/mAP'],
|
||||
sum(expected_res) * 100 / len(expected_res),
|
||||
places=4)
|
||||
|
||||
# 3. Test with `difficult_as_positive`=True argument
|
||||
evaluator = Evaluator(
|
||||
dict(type='VOCAveragePrecision', difficult_as_positive=True))
|
||||
evaluator.process(pred)
|
||||
res = evaluator.evaluate(4)
|
||||
self.assertIsInstance(res, dict)
|
||||
|
||||
# prepare inputs for sklearn for this case
|
||||
y_pred_score = [[0.8, 0.2, 0.1, 0], [0.1, 0.2, 0.9, 0],
|
||||
[0, 0.7, 0.6, 0.2], [0.6, 0, 0.1, 0.3]]
|
||||
y_true = [[1, 0, 1, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 1, 0, 1]]
|
||||
expected_res = []
|
||||
for pred_per_class, gt_per_class in zip(y_pred_score, y_true):
|
||||
expected_res.append(
|
||||
sklearn.metrics.average_precision_score(
|
||||
gt_per_class, pred_per_class))
|
||||
|
||||
self.assertAlmostEqual(
|
||||
res['multi-label/mAP'],
|
||||
sum(expected_res) * 100 / len(expected_res),
|
||||
places=4)
|
Loading…
Reference in New Issue