mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Feature] Support CSRA head. (#881)
* Support CSRA head. * Add CSRA config. * Improve training scheduler and Update cfg, ckpt, log * Update metafile * Rename config files and checkpoints Co-authored-by: Ezra-Yu <1105212286@qq.com> Co-authored-by: mzr1996 <mzr1996@163.com>
This commit is contained in:
parent
b5bb86a357
commit
1a3d51acc2
36
configs/csra/README.md
Normal file
36
configs/csra/README.md
Normal file
@ -0,0 +1,36 @@
|
||||
# CSRA
|
||||
|
||||
> [Residual Attention: A Simple but Effective Method for Multi-Label Recognition](https://arxiv.org/abs/2108.02456)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Multi-label image recognition is a challenging computer vision task of practical use. Progresses in this area, however, are often characterized by complicated methods, heavy computations, and lack of intuitive explanations. To effectively capture different spatial regions occupied by objects from different categories, we propose an embarrassingly simple module, named class-specific residual attention (CSRA). CSRA generates class-specific features for every category by proposing a simple spatial attention score, and then combines it with the class-agnostic average pooling feature. CSRA achieves state-of-the-art results on multilabel recognition, and at the same time is much simpler than them. Furthermore, with only 4 lines of code, CSRA also leads to consistent improvement across many diverse pretrained models and datasets without any extra training. CSRA is both easy to implement and light in computations, which also enjoys intuitive explanations and visualizations.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/84259897/176982245-3ffcff56-a4ea-4474-9967-bc2b612bbaa3.png" width="80%"/>
|
||||
</div>
|
||||
|
||||
## Results and models
|
||||
|
||||
### VOC2007
|
||||
|
||||
| Model | Pretrain | Params(M) | Flops(G) | mAP | OF1 (%) | CF1 (%) | Config | Download |
|
||||
| :------------: | :------------------------------------------------: | :-------: | :------: | :---: | :-----: | :-----: | :-----------------------------------------------: | :-------------------------------------------------: |
|
||||
| Resnet101-CSRA | [ImageNet-1k](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth) | 23.55 | 4.12 | 94.98 | 90.80 | 89.16 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/csra/resnet101-csra_1xb16_voc07-448px.py) | [model](https://download.openmmlab.com/mmclassification/v0/csra/resnet101-csra_1xb16_voc07-448px_20220722-29efb40a.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/csra/resnet101-csra_1xb16_voc07-448px_20220722-29efb40a.log.json) |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{https://doi.org/10.48550/arxiv.2108.02456,
|
||||
doi = {10.48550/ARXIV.2108.02456},
|
||||
url = {https://arxiv.org/abs/2108.02456},
|
||||
author = {Zhu, Ke and Wu, Jianxin},
|
||||
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
|
||||
title = {Residual Attention: A Simple but Effective Method for Multi-Label Recognition},
|
||||
publisher = {arXiv},
|
||||
year = {2021},
|
||||
copyright = {arXiv.org perpetual, non-exclusive license}
|
||||
}
|
||||
```
|
29
configs/csra/metafile.yml
Normal file
29
configs/csra/metafile.yml
Normal file
@ -0,0 +1,29 @@
|
||||
Collections:
|
||||
- Name: CSRA
|
||||
Metadata:
|
||||
Training Data: PASCAL VOC 2007
|
||||
Architecture:
|
||||
- Class-specific Residual Attention
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/1911.11929
|
||||
Title: 'Residual Attention: A Simple but Effective Method for Multi-Label Recognition'
|
||||
README: configs/csra/README.md
|
||||
Code:
|
||||
Version: v0.24.0
|
||||
URL: https://github.com/open-mmlab/mmclassification/blob/v0.24.0/mmcls/models/heads/multi_label_csra_head.py
|
||||
|
||||
Models:
|
||||
- Name: resnet101-csra_1xb16_voc07-448px
|
||||
Metadata:
|
||||
FLOPs: 4120000000
|
||||
Parameters: 23550000
|
||||
In Collections: CSRA
|
||||
Results:
|
||||
- Dataset: PASCAL VOC 2007
|
||||
Metrics:
|
||||
mAP: 94.98
|
||||
OF1: 90.80
|
||||
CF1: 89.16
|
||||
Task: Multi-Label Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/csra/resnet101-csra_1xb16_voc07-448px_20220722-29efb40a.pth
|
||||
Config: configs/csra/resnet101-csra_1xb16_voc07-448px.py
|
75
configs/csra/resnet101-csra_1xb16_voc07-448px.py
Normal file
75
configs/csra/resnet101-csra_1xb16_voc07-448px.py
Normal file
@ -0,0 +1,75 @@
|
||||
_base_ = ['../_base_/datasets/voc_bs16.py', '../_base_/default_runtime.py']
|
||||
|
||||
# Pre-trained Checkpoint Path
|
||||
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth' # noqa
|
||||
# If you want to use the pre-trained weight of ResNet101-CutMix from
|
||||
# the originary repo(https://github.com/Kevinz-code/CSRA). Script of
|
||||
# 'tools/convert_models/torchvision_to_mmcls.py' can help you convert weight
|
||||
# into mmcls format. The mAP result would hit 95.5 by using the weight.
|
||||
# checkpoint = 'PATH/TO/PRE-TRAINED_WEIGHT'
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=101,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch',
|
||||
init_cfg=dict(
|
||||
type='Pretrained', checkpoint=checkpoint, prefix='backbone')),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='CSRAClsHead',
|
||||
num_classes=20,
|
||||
in_channels=2048,
|
||||
num_heads=1,
|
||||
lam=0.1,
|
||||
loss=dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)))
|
||||
|
||||
# dataset setting
|
||||
img_norm_cfg = dict(mean=[0, 0, 0], std=[255, 255, 255], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', size=448, scale=(0.7, 1.0)),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', size=448),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
# map the difficult examples as negative ones(0)
|
||||
train=dict(pipeline=train_pipeline, difficult_as_postive=False),
|
||||
val=dict(pipeline=test_pipeline),
|
||||
test=dict(pipeline=test_pipeline))
|
||||
|
||||
# optimizer
|
||||
# the lr of classifier.head is 10 * base_lr, which help convergence.
|
||||
optimizer = dict(
|
||||
type='SGD',
|
||||
lr=0.0002,
|
||||
momentum=0.9,
|
||||
weight_decay=0.0001,
|
||||
paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10)}))
|
||||
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='step',
|
||||
step=6,
|
||||
gamma=0.1,
|
||||
warmup='linear',
|
||||
warmup_iters=1,
|
||||
warmup_ratio=1e-7,
|
||||
warmup_by_epoch=True)
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=20)
|
@ -11,14 +11,29 @@ from .multi_label import MultiLabelDataset
|
||||
|
||||
@DATASETS.register_module()
|
||||
class VOC(MultiLabelDataset):
|
||||
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset."""
|
||||
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset.
|
||||
|
||||
Args:
|
||||
data_prefix (str): the prefix of data path
|
||||
pipeline (list): a list of dict, where each element represents
|
||||
a operation defined in `mmcls.datasets.pipelines`
|
||||
ann_file (str | None): the annotation file. When ann_file is str,
|
||||
the subclass is expected to read from the ann_file. When ann_file
|
||||
is None, the subclass is expected to read according to data_prefix
|
||||
difficult_as_postive (Optional[bool]): Whether to map the difficult
|
||||
labels as positive. If it set to True, map difficult examples to
|
||||
positive ones(1), If it set to False, map difficult examples to
|
||||
negative ones(0). Defaults to None, the difficult labels will be
|
||||
set to '-1'.
|
||||
"""
|
||||
|
||||
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
|
||||
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
|
||||
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
|
||||
'tvmonitor')
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, difficult_as_postive=None, **kwargs):
|
||||
self.difficult_as_postive = difficult_as_postive
|
||||
super(VOC, self).__init__(**kwargs)
|
||||
if 'VOC2007' in self.data_prefix:
|
||||
self.year = 2007
|
||||
@ -55,9 +70,19 @@ class VOC(MultiLabelDataset):
|
||||
labels.append(label)
|
||||
|
||||
gt_label = np.zeros(len(self.CLASSES))
|
||||
# set difficult example first, then set postivate examples.
|
||||
# The order cannot be swapped for the case where multiple objects
|
||||
# of the same kind exist and some are difficult.
|
||||
gt_label[labels_difficult] = -1
|
||||
if self.difficult_as_postive is None:
|
||||
# map difficult examples to -1,
|
||||
# it may be used in evaluation to ignore difficult targets.
|
||||
gt_label[labels_difficult] = -1
|
||||
elif self.difficult_as_postive:
|
||||
# map difficult examples to positive ones(1).
|
||||
gt_label[labels_difficult] = 1
|
||||
else:
|
||||
# map difficult examples to negative ones(0).
|
||||
gt_label[labels_difficult] = 0
|
||||
gt_label[labels] = 1
|
||||
|
||||
info = dict(
|
||||
|
@ -3,6 +3,7 @@ from .cls_head import ClsHead
|
||||
from .conformer_head import ConformerHead
|
||||
from .deit_head import DeiTClsHead
|
||||
from .linear_head import LinearClsHead
|
||||
from .multi_label_csra_head import CSRAClsHead
|
||||
from .multi_label_head import MultiLabelClsHead
|
||||
from .multi_label_linear_head import MultiLabelLinearClsHead
|
||||
from .stacked_head import StackedLinearClsHead
|
||||
@ -11,5 +12,5 @@ from .vision_transformer_head import VisionTransformerClsHead
|
||||
__all__ = [
|
||||
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
|
||||
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead',
|
||||
'ConformerHead'
|
||||
'ConformerHead', 'CSRAClsHead'
|
||||
]
|
||||
|
121
mmcls/models/heads/multi_label_csra_head.py
Executable file
121
mmcls/models/heads/multi_label_csra_head.py
Executable file
@ -0,0 +1,121 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://github.com/Kevinz-code/CSRA
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule, ModuleList
|
||||
|
||||
from ..builder import HEADS
|
||||
from .multi_label_head import MultiLabelClsHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class CSRAClsHead(MultiLabelClsHead):
|
||||
"""Class-specific residual attention classifier head.
|
||||
|
||||
Residual Attention: A Simple but Effective Method for Multi-Label
|
||||
Recognition (ICCV 2021)
|
||||
Please refer to the `paper <https://arxiv.org/abs/2108.02456>`__ for
|
||||
details.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of categories.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
num_heads (int): Number of residual at tensor heads.
|
||||
loss (dict): Config of classification loss.
|
||||
lam (float): Lambda that combines global average and max pooling
|
||||
scores.
|
||||
init_cfg (dict | optional): The extra init config of layers.
|
||||
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
|
||||
"""
|
||||
temperature_settings = { # softmax temperature settings
|
||||
1: [1],
|
||||
2: [1, 99],
|
||||
4: [1, 2, 4, 99],
|
||||
6: [1, 2, 3, 4, 5, 99],
|
||||
8: [1, 2, 3, 4, 5, 6, 7, 99]
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
in_channels,
|
||||
num_heads,
|
||||
lam,
|
||||
loss=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0),
|
||||
init_cfg=dict(type='Normal', layer='Linear', std=0.01),
|
||||
*args,
|
||||
**kwargs):
|
||||
assert num_heads in self.temperature_settings.keys(
|
||||
), 'The num of heads is not in temperature setting.'
|
||||
assert lam > 0, 'Lambda should be between 0 and 1.'
|
||||
super(CSRAClsHead, self).__init__(
|
||||
init_cfg=init_cfg, loss=loss, *args, **kwargs)
|
||||
self.temp_list = self.temperature_settings[num_heads]
|
||||
self.csra_heads = ModuleList([
|
||||
CSRAModule(num_classes, in_channels, self.temp_list[i], lam)
|
||||
for i in range(num_heads)
|
||||
])
|
||||
|
||||
def pre_logits(self, x):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
return x
|
||||
|
||||
def simple_test(self, x, post_process=True, **kwargs):
|
||||
logit = 0.
|
||||
x = self.pre_logits(x)
|
||||
for head in self.csra_heads:
|
||||
logit += head(x)
|
||||
if post_process:
|
||||
return self.post_process(logit)
|
||||
else:
|
||||
return logit
|
||||
|
||||
def forward_train(self, x, gt_label, **kwargs):
|
||||
logit = 0.
|
||||
x = self.pre_logits(x)
|
||||
for head in self.csra_heads:
|
||||
logit += head(x)
|
||||
gt_label = gt_label.type_as(logit)
|
||||
_gt_label = torch.abs(gt_label)
|
||||
losses = self.loss(logit, _gt_label, **kwargs)
|
||||
return losses
|
||||
|
||||
|
||||
class CSRAModule(BaseModule):
|
||||
"""Basic module of CSRA with different temperature.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of categories.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
T (int): Temperature setting.
|
||||
lam (float): Lambda that combines global average and max pooling
|
||||
scores.
|
||||
init_cfg (dict | optional): The extra init config of layers.
|
||||
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, in_channels, T, lam, init_cfg=None):
|
||||
|
||||
super(CSRAModule, self).__init__(init_cfg=init_cfg)
|
||||
self.T = T # temperature
|
||||
self.lam = lam # Lambda
|
||||
self.head = nn.Conv2d(in_channels, num_classes, 1, bias=False)
|
||||
self.softmax = nn.Softmax(dim=2)
|
||||
|
||||
def forward(self, x):
|
||||
score = self.head(x) / torch.norm(
|
||||
self.head.weight, dim=1, keepdim=True).transpose(0, 1)
|
||||
score = score.flatten(2)
|
||||
base_logit = torch.mean(score, dim=2)
|
||||
|
||||
if self.T == 99: # max-pooling
|
||||
att_logit = torch.max(score, dim=2)[0]
|
||||
else:
|
||||
score_soft = self.softmax(score * self.T)
|
||||
att_logit = torch.sum(score * score_soft, dim=2)
|
||||
|
||||
return base_logit + self.lam * att_logit
|
@ -28,4 +28,5 @@ Import:
|
||||
- configs/convmixer/metafile.yml
|
||||
- configs/densenet/metafile.yml
|
||||
- configs/poolformer/metafile.yml
|
||||
- configs/csra/metafile.yml
|
||||
- configs/mvit/metafile.yml
|
||||
|
@ -4,8 +4,8 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.heads import (ClsHead, ConformerHead, DeiTClsHead,
|
||||
LinearClsHead, MultiLabelClsHead,
|
||||
from mmcls.models.heads import (ClsHead, ConformerHead, CSRAClsHead,
|
||||
DeiTClsHead, LinearClsHead, MultiLabelClsHead,
|
||||
MultiLabelLinearClsHead, StackedLinearClsHead,
|
||||
VisionTransformerClsHead)
|
||||
|
||||
@ -317,3 +317,27 @@ def test_deit_head():
|
||||
# test assertion
|
||||
with pytest.raises(ValueError):
|
||||
DeiTClsHead(-1, 100)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'feat', [torch.rand(4, 20, 20, 30), (torch.rand(4, 20, 20, 30), )])
|
||||
def test_csra_head(feat):
|
||||
head = CSRAClsHead(num_classes=10, in_channels=20, num_heads=1, lam=0.1)
|
||||
fake_gt_label = torch.randint(0, 2, (4, 10))
|
||||
|
||||
losses = head.forward_train(feat, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test simple_test with post_process
|
||||
pred = head.simple_test(feat)
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
pred = head.simple_test(feat)
|
||||
assert pred.shape == (4, 10)
|
||||
|
||||
# test pre_logits
|
||||
features = head.pre_logits(feat)
|
||||
if isinstance(feat, tuple):
|
||||
torch.testing.assert_allclose(features, feat[0])
|
||||
else:
|
||||
torch.testing.assert_allclose(features, feat)
|
||||
|
Loading…
x
Reference in New Issue
Block a user