[Feature] Support CSRA head. (#881)

* Support CSRA head.

* Add CSRA config.

* Improve training scheduler and Update cfg, ckpt, log

* Update metafile

* Rename config files and checkpoints

Co-authored-by: Ezra-Yu <1105212286@qq.com>
Co-authored-by: mzr1996 <mzr1996@163.com>
This commit is contained in:
JiayuXu 2022-08-04 18:15:51 +08:00 committed by GitHub
parent b5bb86a357
commit 1a3d51acc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 318 additions and 6 deletions

36
configs/csra/README.md Normal file
View File

@ -0,0 +1,36 @@
# CSRA
> [Residual Attention: A Simple but Effective Method for Multi-Label Recognition](https://arxiv.org/abs/2108.02456)
<!-- [ALGORITHM] -->
## Abstract
Multi-label image recognition is a challenging computer vision task of practical use. Progresses in this area, however, are often characterized by complicated methods, heavy computations, and lack of intuitive explanations. To effectively capture different spatial regions occupied by objects from different categories, we propose an embarrassingly simple module, named class-specific residual attention (CSRA). CSRA generates class-specific features for every category by proposing a simple spatial attention score, and then combines it with the class-agnostic average pooling feature. CSRA achieves state-of-the-art results on multilabel recognition, and at the same time is much simpler than them. Furthermore, with only 4 lines of code, CSRA also leads to consistent improvement across many diverse pretrained models and datasets without any extra training. CSRA is both easy to implement and light in computations, which also enjoys intuitive explanations and visualizations.
<div align=center>
<img src="https://user-images.githubusercontent.com/84259897/176982245-3ffcff56-a4ea-4474-9967-bc2b612bbaa3.png" width="80%"/>
</div>
## Results and models
### VOC2007
| Model | Pretrain | Params(M) | Flops(G) | mAP | OF1 (%) | CF1 (%) | Config | Download |
| :------------: | :------------------------------------------------: | :-------: | :------: | :---: | :-----: | :-----: | :-----------------------------------------------: | :-------------------------------------------------: |
| Resnet101-CSRA | [ImageNet-1k](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth) | 23.55 | 4.12 | 94.98 | 90.80 | 89.16 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/csra/resnet101-csra_1xb16_voc07-448px.py) | [model](https://download.openmmlab.com/mmclassification/v0/csra/resnet101-csra_1xb16_voc07-448px_20220722-29efb40a.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/csra/resnet101-csra_1xb16_voc07-448px_20220722-29efb40a.log.json) |
## Citation
```bibtex
@misc{https://doi.org/10.48550/arxiv.2108.02456,
doi = {10.48550/ARXIV.2108.02456},
url = {https://arxiv.org/abs/2108.02456},
author = {Zhu, Ke and Wu, Jianxin},
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Residual Attention: A Simple but Effective Method for Multi-Label Recognition},
publisher = {arXiv},
year = {2021},
copyright = {arXiv.org perpetual, non-exclusive license}
}
```

29
configs/csra/metafile.yml Normal file
View File

@ -0,0 +1,29 @@
Collections:
- Name: CSRA
Metadata:
Training Data: PASCAL VOC 2007
Architecture:
- Class-specific Residual Attention
Paper:
URL: https://arxiv.org/abs/1911.11929
Title: 'Residual Attention: A Simple but Effective Method for Multi-Label Recognition'
README: configs/csra/README.md
Code:
Version: v0.24.0
URL: https://github.com/open-mmlab/mmclassification/blob/v0.24.0/mmcls/models/heads/multi_label_csra_head.py
Models:
- Name: resnet101-csra_1xb16_voc07-448px
Metadata:
FLOPs: 4120000000
Parameters: 23550000
In Collections: CSRA
Results:
- Dataset: PASCAL VOC 2007
Metrics:
mAP: 94.98
OF1: 90.80
CF1: 89.16
Task: Multi-Label Classification
Weights: https://download.openmmlab.com/mmclassification/v0/csra/resnet101-csra_1xb16_voc07-448px_20220722-29efb40a.pth
Config: configs/csra/resnet101-csra_1xb16_voc07-448px.py

View File

@ -0,0 +1,75 @@
_base_ = ['../_base_/datasets/voc_bs16.py', '../_base_/default_runtime.py']
# Pre-trained Checkpoint Path
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth' # noqa
# If you want to use the pre-trained weight of ResNet101-CutMix from
# the originary repo(https://github.com/Kevinz-code/CSRA). Script of
# 'tools/convert_models/torchvision_to_mmcls.py' can help you convert weight
# into mmcls format. The mAP result would hit 95.5 by using the weight.
# checkpoint = 'PATH/TO/PRE-TRAINED_WEIGHT'
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(3, ),
style='pytorch',
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint, prefix='backbone')),
neck=None,
head=dict(
type='CSRAClsHead',
num_classes=20,
in_channels=2048,
num_heads=1,
lam=0.1,
loss=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)))
# dataset setting
img_norm_cfg = dict(mean=[0, 0, 0], std=[255, 255, 255], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=448, scale=(0.7, 1.0)),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=448),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
# map the difficult examples as negative ones(0)
train=dict(pipeline=train_pipeline, difficult_as_postive=False),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
# optimizer
# the lr of classifier.head is 10 * base_lr, which help convergence.
optimizer = dict(
type='SGD',
lr=0.0002,
momentum=0.9,
weight_decay=0.0001,
paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10)}))
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='step',
step=6,
gamma=0.1,
warmup='linear',
warmup_iters=1,
warmup_ratio=1e-7,
warmup_by_epoch=True)
runner = dict(type='EpochBasedRunner', max_epochs=20)

View File

@ -11,14 +11,29 @@ from .multi_label import MultiLabelDataset
@DATASETS.register_module()
class VOC(MultiLabelDataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset."""
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset.
Args:
data_prefix (str): the prefix of data path
pipeline (list): a list of dict, where each element represents
a operation defined in `mmcls.datasets.pipelines`
ann_file (str | None): the annotation file. When ann_file is str,
the subclass is expected to read from the ann_file. When ann_file
is None, the subclass is expected to read according to data_prefix
difficult_as_postive (Optional[bool]): Whether to map the difficult
labels as positive. If it set to True, map difficult examples to
positive ones(1), If it set to False, map difficult examples to
negative ones(0). Defaults to None, the difficult labels will be
set to '-1'.
"""
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
'tvmonitor')
def __init__(self, **kwargs):
def __init__(self, difficult_as_postive=None, **kwargs):
self.difficult_as_postive = difficult_as_postive
super(VOC, self).__init__(**kwargs)
if 'VOC2007' in self.data_prefix:
self.year = 2007
@ -55,9 +70,19 @@ class VOC(MultiLabelDataset):
labels.append(label)
gt_label = np.zeros(len(self.CLASSES))
# set difficult example first, then set postivate examples.
# The order cannot be swapped for the case where multiple objects
# of the same kind exist and some are difficult.
gt_label[labels_difficult] = -1
if self.difficult_as_postive is None:
# map difficult examples to -1,
# it may be used in evaluation to ignore difficult targets.
gt_label[labels_difficult] = -1
elif self.difficult_as_postive:
# map difficult examples to positive ones(1).
gt_label[labels_difficult] = 1
else:
# map difficult examples to negative ones(0).
gt_label[labels_difficult] = 0
gt_label[labels] = 1
info = dict(

View File

@ -3,6 +3,7 @@ from .cls_head import ClsHead
from .conformer_head import ConformerHead
from .deit_head import DeiTClsHead
from .linear_head import LinearClsHead
from .multi_label_csra_head import CSRAClsHead
from .multi_label_head import MultiLabelClsHead
from .multi_label_linear_head import MultiLabelLinearClsHead
from .stacked_head import StackedLinearClsHead
@ -11,5 +12,5 @@ from .vision_transformer_head import VisionTransformerClsHead
__all__ = [
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead',
'ConformerHead'
'ConformerHead', 'CSRAClsHead'
]

View File

@ -0,0 +1,121 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/Kevinz-code/CSRA
import torch
import torch.nn as nn
from mmcv.runner import BaseModule, ModuleList
from ..builder import HEADS
from .multi_label_head import MultiLabelClsHead
@HEADS.register_module()
class CSRAClsHead(MultiLabelClsHead):
"""Class-specific residual attention classifier head.
Residual Attention: A Simple but Effective Method for Multi-Label
Recognition (ICCV 2021)
Please refer to the `paper <https://arxiv.org/abs/2108.02456>`__ for
details.
Args:
num_classes (int): Number of categories.
in_channels (int): Number of channels in the input feature map.
num_heads (int): Number of residual at tensor heads.
loss (dict): Config of classification loss.
lam (float): Lambda that combines global average and max pooling
scores.
init_cfg (dict | optional): The extra init config of layers.
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
"""
temperature_settings = { # softmax temperature settings
1: [1],
2: [1, 99],
4: [1, 2, 4, 99],
6: [1, 2, 3, 4, 5, 99],
8: [1, 2, 3, 4, 5, 6, 7, 99]
}
def __init__(self,
num_classes,
in_channels,
num_heads,
lam,
loss=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
reduction='mean',
loss_weight=1.0),
init_cfg=dict(type='Normal', layer='Linear', std=0.01),
*args,
**kwargs):
assert num_heads in self.temperature_settings.keys(
), 'The num of heads is not in temperature setting.'
assert lam > 0, 'Lambda should be between 0 and 1.'
super(CSRAClsHead, self).__init__(
init_cfg=init_cfg, loss=loss, *args, **kwargs)
self.temp_list = self.temperature_settings[num_heads]
self.csra_heads = ModuleList([
CSRAModule(num_classes, in_channels, self.temp_list[i], lam)
for i in range(num_heads)
])
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
return x
def simple_test(self, x, post_process=True, **kwargs):
logit = 0.
x = self.pre_logits(x)
for head in self.csra_heads:
logit += head(x)
if post_process:
return self.post_process(logit)
else:
return logit
def forward_train(self, x, gt_label, **kwargs):
logit = 0.
x = self.pre_logits(x)
for head in self.csra_heads:
logit += head(x)
gt_label = gt_label.type_as(logit)
_gt_label = torch.abs(gt_label)
losses = self.loss(logit, _gt_label, **kwargs)
return losses
class CSRAModule(BaseModule):
"""Basic module of CSRA with different temperature.
Args:
num_classes (int): Number of categories.
in_channels (int): Number of channels in the input feature map.
T (int): Temperature setting.
lam (float): Lambda that combines global average and max pooling
scores.
init_cfg (dict | optional): The extra init config of layers.
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
"""
def __init__(self, num_classes, in_channels, T, lam, init_cfg=None):
super(CSRAModule, self).__init__(init_cfg=init_cfg)
self.T = T # temperature
self.lam = lam # Lambda
self.head = nn.Conv2d(in_channels, num_classes, 1, bias=False)
self.softmax = nn.Softmax(dim=2)
def forward(self, x):
score = self.head(x) / torch.norm(
self.head.weight, dim=1, keepdim=True).transpose(0, 1)
score = score.flatten(2)
base_logit = torch.mean(score, dim=2)
if self.T == 99: # max-pooling
att_logit = torch.max(score, dim=2)[0]
else:
score_soft = self.softmax(score * self.T)
att_logit = torch.sum(score * score_soft, dim=2)
return base_logit + self.lam * att_logit

View File

@ -28,4 +28,5 @@ Import:
- configs/convmixer/metafile.yml
- configs/densenet/metafile.yml
- configs/poolformer/metafile.yml
- configs/csra/metafile.yml
- configs/mvit/metafile.yml

View File

@ -4,8 +4,8 @@ from unittest.mock import patch
import pytest
import torch
from mmcls.models.heads import (ClsHead, ConformerHead, DeiTClsHead,
LinearClsHead, MultiLabelClsHead,
from mmcls.models.heads import (ClsHead, ConformerHead, CSRAClsHead,
DeiTClsHead, LinearClsHead, MultiLabelClsHead,
MultiLabelLinearClsHead, StackedLinearClsHead,
VisionTransformerClsHead)
@ -317,3 +317,27 @@ def test_deit_head():
# test assertion
with pytest.raises(ValueError):
DeiTClsHead(-1, 100)
@pytest.mark.parametrize(
'feat', [torch.rand(4, 20, 20, 30), (torch.rand(4, 20, 20, 30), )])
def test_csra_head(feat):
head = CSRAClsHead(num_classes=10, in_channels=20, num_heads=1, lam=0.1)
fake_gt_label = torch.randint(0, 2, (4, 10))
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
# test simple_test with post_process
pred = head.simple_test(feat)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred.shape == (4, 10)
# test pre_logits
features = head.pre_logits(feat)
if isinstance(feat, tuple):
torch.testing.assert_allclose(features, feat[0])
else:
torch.testing.assert_allclose(features, feat)