[Enhance] Enhance feature extraction function. (#593)
* Fix MobileNet V3 configs * Refactor to support more powerful feature extraction. * Add unit tests * Fix unit test * Imporve according to comments * Update checkpoints path * Fix unit tests * Add docstring of `simple_test` * Add docstring of `extract_feat` * Update model zoopull/609/head
parent
f9a2b04cee
commit
643fb192cd
|
@ -11,4 +11,6 @@ model = dict(
|
|||
dropout_rate=0.2,
|
||||
act_cfg=dict(type='HSwish'),
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
init_cfg=dict(
|
||||
type='Normal', layer='Linear', mean=0., std=0.01, bias=0.),
|
||||
topk=(1, 5)))
|
||||
|
|
|
@ -11,4 +11,6 @@ model = dict(
|
|||
dropout_rate=0.2,
|
||||
act_cfg=dict(type='HSwish'),
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
init_cfg=dict(
|
||||
type='Normal', layer='Linear', mean=0., std=0.01, bias=0.),
|
||||
topk=(1, 5)))
|
||||
|
|
|
@ -34,11 +34,11 @@ The pre-trained models are converted from the [official repo](https://github.com
|
|||
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
|
||||
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) |
|
||||
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) |
|
||||
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth) |
|
||||
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) |
|
||||
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) |
|
||||
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth) |
|
||||
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) |
|
||||
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) |
|
||||
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth) |
|
||||
|
||||
*Models with \* are converted from other repos.*
|
||||
|
||||
|
@ -51,7 +51,7 @@ The fine-tuned models are converted from the [official repo](https://github.com/
|
|||
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
|
||||
| DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) |
|
||||
| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth) |
|
||||
| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth) |
|
||||
|
||||
*Models with \* are converted from other repos.*
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ Models:
|
|||
Top 1 Accuracy: 74.51
|
||||
Top 5 Accuracy: 91.90
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L108
|
||||
|
@ -72,7 +72,7 @@ Models:
|
|||
Top 1 Accuracy: 81.17
|
||||
Top 5 Accuracy: 95.40
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L123
|
||||
|
@ -104,7 +104,7 @@ Models:
|
|||
Top 1 Accuracy: 83.33
|
||||
Top 5 Accuracy: 96.49
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L138
|
||||
|
@ -136,7 +136,7 @@ Models:
|
|||
Top 1 Accuracy: 85.55
|
||||
Top 5 Accuracy: 97.35
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L168
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
# - modify: RandomErasing use RE-M instead of RE-0
|
||||
|
||||
_base_ = [
|
||||
'../_base_/models/mobilenet-v3-small_8xb32_in1k.py',
|
||||
'../_base_/models/mobilenet_v3_small_imagenet.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
|
|
@ -63,16 +63,18 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
|||
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) | [log]()|
|
||||
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) | [log]()|
|
||||
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) | [log]()|
|
||||
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) | [log]()|
|
||||
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) | [log]()|
|
||||
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) | [log]()|
|
||||
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) | [log]()|
|
||||
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) | [log]()|
|
||||
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) | [log]()|
|
||||
| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) | [log]()|
|
||||
| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) | [log]()|
|
||||
| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) | [log]()|
|
||||
| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) | [log]()|
|
||||
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) | [log]()|
|
||||
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth) | [log]()|
|
||||
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) | [log]()|
|
||||
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth) | [log]()|
|
||||
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) | [log]()|
|
||||
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth) | [log]()|
|
||||
| DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) | [log]()|
|
||||
| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth) | [log]()|
|
||||
| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) | [log]()|
|
||||
| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) | [log]()|
|
||||
| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) | [log]()|
|
||||
| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) | [log]()|
|
||||
|
||||
Models with * are converted from other repos, others are trained by ourselves.
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Sequence
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
|
@ -35,13 +36,14 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
|
|||
return hasattr(self, 'head') and self.head is not None
|
||||
|
||||
@abstractmethod
|
||||
def extract_feat(self, imgs):
|
||||
def extract_feat(self, imgs, stage=None):
|
||||
pass
|
||||
|
||||
def extract_feats(self, imgs):
|
||||
assert isinstance(imgs, list)
|
||||
def extract_feats(self, imgs, stage=None):
|
||||
assert isinstance(imgs, Sequence)
|
||||
kwargs = {} if stage is None else {'stage': stage}
|
||||
for img in imgs:
|
||||
yield self.extract_feat(img)
|
||||
yield self.extract_feat(img, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def forward_train(self, imgs, **kwargs):
|
||||
|
|
|
@ -3,6 +3,7 @@ import copy
|
|||
import warnings
|
||||
|
||||
from ..builder import CLASSIFIERS, build_backbone, build_head, build_neck
|
||||
from ..heads import MultiLabelClsHead
|
||||
from ..utils.augment import Augments
|
||||
from .base import BaseClassifier
|
||||
|
||||
|
@ -70,8 +71,74 @@ class ImageClassifier(BaseClassifier):
|
|||
cfg['prob'] = cutmix_prob
|
||||
self.augments = Augments(cfg)
|
||||
|
||||
def extract_feat(self, img):
|
||||
"""Directly extract features from the backbone + neck."""
|
||||
def extract_feat(self, img, stage='neck'):
|
||||
"""Directly extract features from the specified stage.
|
||||
|
||||
Args:
|
||||
img (Tensor): The input images. The shape of it should be
|
||||
``(num_samples, num_channels, *img_shape)``.
|
||||
stage (str): Which stage to output the feature. Choose from
|
||||
"backbone", "neck" and "pre_logits". Defaults to "neck".
|
||||
|
||||
Returns:
|
||||
tuple | Tensor: The output of specified stage.
|
||||
The output depends on detailed implementation. In general, the
|
||||
output of backbone and neck is a tuple and the output of
|
||||
pre_logits is a tensor.
|
||||
|
||||
Examples:
|
||||
1. Backbone output
|
||||
|
||||
>>> import torch
|
||||
>>> from mmcv import Config
|
||||
>>> from mmcls.models import build_classifier
|
||||
>>>
|
||||
>>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model
|
||||
>>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps
|
||||
>>> model = build_classifier(cfg)
|
||||
>>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone')
|
||||
>>> for out in outs:
|
||||
... print(out.shape)
|
||||
torch.Size([1, 64, 56, 56])
|
||||
torch.Size([1, 128, 28, 28])
|
||||
torch.Size([1, 256, 14, 14])
|
||||
torch.Size([1, 512, 7, 7])
|
||||
|
||||
2. Neck output
|
||||
|
||||
>>> import torch
|
||||
>>> from mmcv import Config
|
||||
>>> from mmcls.models import build_classifier
|
||||
>>>
|
||||
>>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model
|
||||
>>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps
|
||||
>>> model = build_classifier(cfg)
|
||||
>>>
|
||||
>>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck')
|
||||
>>> for out in outs:
|
||||
... print(out.shape)
|
||||
torch.Size([1, 64])
|
||||
torch.Size([1, 128])
|
||||
torch.Size([1, 256])
|
||||
torch.Size([1, 512])
|
||||
|
||||
3. Pre-logits output (without the final linear classifier head)
|
||||
|
||||
>>> import torch
|
||||
>>> from mmcv import Config
|
||||
>>> from mmcls.models import build_classifier
|
||||
>>>
|
||||
>>> cfg = Config.fromfile('configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py').model
|
||||
>>> model = build_classifier(cfg)
|
||||
>>>
|
||||
>>> out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits')
|
||||
>>> print(out.shape) # The hidden dims in head is 3072
|
||||
torch.Size([1, 3072])
|
||||
""" # noqa: E501
|
||||
assert stage in ['backbone', 'neck', 'pre_logits'], \
|
||||
(f'Invalid output stage "{stage}", please choose from "backbone", '
|
||||
'"neck" and "pre_logits"')
|
||||
|
||||
x = self.backbone(img)
|
||||
if self.return_tuple:
|
||||
if not isinstance(x, tuple):
|
||||
|
@ -83,8 +150,16 @@ class ImageClassifier(BaseClassifier):
|
|||
else:
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
if stage == 'backbone':
|
||||
return x
|
||||
|
||||
if self.with_neck:
|
||||
x = self.neck(x)
|
||||
if stage == 'neck':
|
||||
return x
|
||||
|
||||
if self.with_head and hasattr(self.head, 'pre_logits'):
|
||||
x = self.head.pre_logits(x)
|
||||
return x
|
||||
|
||||
def forward_train(self, img, gt_label, **kwargs):
|
||||
|
@ -122,12 +197,16 @@ class ImageClassifier(BaseClassifier):
|
|||
|
||||
return losses
|
||||
|
||||
def simple_test(self, img, img_metas):
|
||||
def simple_test(self, img, img_metas=None, **kwargs):
|
||||
"""Test without augmentation."""
|
||||
x = self.extract_feat(img)
|
||||
|
||||
try:
|
||||
res = self.head.simple_test(x)
|
||||
if isinstance(self.head, MultiLabelClsHead):
|
||||
assert 'softmax' not in kwargs, (
|
||||
'Please use `sigmoid` instead of `softmax` '
|
||||
'in multi-label tasks.')
|
||||
res = self.head.simple_test(x, **kwargs)
|
||||
except TypeError as e:
|
||||
if 'not tuple' in str(e) and self.return_tuple:
|
||||
return TypeError(
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
@ -62,14 +64,49 @@ class ClsHead(BaseHead):
|
|||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
|
||||
def simple_test(self, cls_score):
|
||||
"""Test without augmentation."""
|
||||
def pre_logits(self, x):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
|
||||
warnings.warn(
|
||||
'The input of ClsHead should be already logits. '
|
||||
'Please modify the backbone if you want to get pre-logits feature.'
|
||||
)
|
||||
return x
|
||||
|
||||
def simple_test(self, cls_score, softmax=True, post_process=True):
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
cls_score (tuple[Tensor]): The input classification score logits.
|
||||
Multi-stage inputs are acceptable but only the last stage will
|
||||
be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
softmax (bool): Whether to softmax the classification score.
|
||||
post_process (bool): Whether to do post processing the
|
||||
inference results. It will convert the output to a list.
|
||||
|
||||
Returns:
|
||||
Tensor | list: The inference results.
|
||||
|
||||
- If no post processing, the output is a tensor with shape
|
||||
``(num_samples, num_classes)``.
|
||||
- If post processing, the output is a multi-dimentional list of
|
||||
float and the dimensions are ``(num_samples, num_classes)``.
|
||||
"""
|
||||
if isinstance(cls_score, tuple):
|
||||
cls_score = cls_score[-1]
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
return self.post_process(pred)
|
||||
|
||||
if softmax:
|
||||
pred = (
|
||||
F.softmax(cls_score, dim=1) if cls_score is not None else None)
|
||||
else:
|
||||
pred = cls_score
|
||||
|
||||
if post_process:
|
||||
return self.post_process(pred)
|
||||
else:
|
||||
return pred
|
||||
|
||||
def post_process(self, pred):
|
||||
on_trace = is_tracing()
|
||||
|
|
|
@ -16,7 +16,7 @@ class ConformerHead(ClsHead):
|
|||
category.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
init_cfg (dict | optional): The extra init config of layers.
|
||||
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
|
||||
Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -55,25 +55,54 @@ class ConformerHead(ClsHead):
|
|||
else:
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def simple_test(self, x):
|
||||
"""Test without augmentation."""
|
||||
def pre_logits(self, x):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
assert isinstance(x,
|
||||
list) # There are two outputs in the Conformer model
|
||||
return x
|
||||
|
||||
def simple_test(self, x, softmax=True, post_process=True):
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
x (tuple[tuple[tensor, tensor]]): The input features.
|
||||
Multi-stage inputs are acceptable but only the last stage will
|
||||
be used to classify. Every item should be a tuple which
|
||||
includes convluation features and transformer features. The
|
||||
shape of them should be ``(num_samples, in_channels[0])`` and
|
||||
``(num_samples, in_channels[1])``.
|
||||
softmax (bool): Whether to softmax the classification score.
|
||||
post_process (bool): Whether to do post processing the
|
||||
inference results. It will convert the output to a list.
|
||||
|
||||
Returns:
|
||||
Tensor | list: The inference results.
|
||||
|
||||
- If no post processing, the output is a tensor with shape
|
||||
``(num_samples, num_classes)``.
|
||||
- If post processing, the output is a multi-dimentional list of
|
||||
float and the dimensions are ``(num_samples, num_classes)``.
|
||||
"""
|
||||
x = self.pre_logits(x)
|
||||
# There are two outputs in the Conformer model
|
||||
assert len(x) == 2
|
||||
|
||||
conv_cls_score = self.conv_cls_head(x[0])
|
||||
tran_cls_score = self.trans_cls_head(x[1])
|
||||
|
||||
cls_score = conv_cls_score + tran_cls_score
|
||||
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
|
||||
return self.post_process(pred)
|
||||
if softmax:
|
||||
cls_score = conv_cls_score + tran_cls_score
|
||||
pred = (
|
||||
F.softmax(cls_score, dim=1) if cls_score is not None else None)
|
||||
if post_process:
|
||||
pred = self.post_process(pred)
|
||||
else:
|
||||
pred = [conv_cls_score, tran_cls_score]
|
||||
if post_process:
|
||||
pred = list(map(self.post_process, pred))
|
||||
return pred
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
x = self.pre_logits(x)
|
||||
assert isinstance(x, list) and len(x) == 2, \
|
||||
'There should be two outputs in the Conformer model'
|
||||
|
||||
|
|
|
@ -12,25 +12,67 @@ class DeiTClsHead(VisionTransformerClsHead):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DeiTClsHead, self).__init__(*args, **kwargs)
|
||||
self.head_dist = nn.Linear(self.in_channels, self.num_classes)
|
||||
if self.hidden_dim is None:
|
||||
head_dist = nn.Linear(self.in_channels, self.num_classes)
|
||||
else:
|
||||
head_dist = nn.Linear(self.hidden_dim, self.num_classes)
|
||||
self.layers.add_module('head_dist', head_dist)
|
||||
|
||||
def simple_test(self, x):
|
||||
"""Test without augmentation."""
|
||||
x = x[-1]
|
||||
assert isinstance(x, list) and len(x) == 3
|
||||
def pre_logits(self, x):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
_, cls_token, dist_token = x
|
||||
cls_score = (self.layers(cls_token) + self.head_dist(dist_token)) / 2
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
|
||||
return self.post_process(pred)
|
||||
if self.hidden_dim is None:
|
||||
return cls_token, dist_token
|
||||
else:
|
||||
cls_token = self.layers.act(self.layers.pre_logits(cls_token))
|
||||
dist_token = self.layers.act(self.layers.pre_logits(dist_token))
|
||||
return cls_token, dist_token
|
||||
|
||||
def simple_test(self, x, softmax=True, post_process=True):
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
x (tuple[tuple[tensor, tensor, tensor]]): The input features.
|
||||
Multi-stage inputs are acceptable but only the last stage will
|
||||
be used to classify. Every item should be a tuple which
|
||||
includes patch token, cls token and dist token. The cls token
|
||||
and dist token will be used to classify and the shape of them
|
||||
should be ``(num_samples, in_channels)``.
|
||||
softmax (bool): Whether to softmax the classification score.
|
||||
post_process (bool): Whether to do post processing the
|
||||
inference results. It will convert the output to a list.
|
||||
|
||||
Returns:
|
||||
Tensor | list: The inference results.
|
||||
|
||||
- If no post processing, the output is a tensor with shape
|
||||
``(num_samples, num_classes)``.
|
||||
- If post processing, the output is a multi-dimentional list of
|
||||
float and the dimensions are ``(num_samples, num_classes)``.
|
||||
"""
|
||||
cls_token, dist_token = self.pre_logits(x)
|
||||
cls_score = (self.layers.head(cls_token) +
|
||||
self.layers.head_dist(dist_token)) / 2
|
||||
|
||||
if softmax:
|
||||
pred = F.softmax(
|
||||
cls_score, dim=1) if cls_score is not None else None
|
||||
else:
|
||||
pred = cls_score
|
||||
|
||||
if post_process:
|
||||
return self.post_process(pred)
|
||||
else:
|
||||
return pred
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
logger = get_root_logger()
|
||||
logger.warning("MMClassification doesn't support to train the "
|
||||
'distilled version DeiT.')
|
||||
x = x[-1]
|
||||
assert isinstance(x, list) and len(x) == 3
|
||||
_, cls_token, dist_token = x
|
||||
cls_score = (self.layers(cls_token) + self.head_dist(dist_token)) / 2
|
||||
cls_token, dist_token = self.pre_logits(x)
|
||||
cls_score = (self.layers.head(cls_token) +
|
||||
self.layers.head_dist(dist_token)) / 2
|
||||
losses = self.loss(cls_score, gt_label)
|
||||
return losses
|
||||
|
|
|
@ -35,20 +35,47 @@ class LinearClsHead(ClsHead):
|
|||
|
||||
self.fc = nn.Linear(self.in_channels, self.num_classes)
|
||||
|
||||
def simple_test(self, x):
|
||||
"""Test without augmentation."""
|
||||
def pre_logits(self, x):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
cls_score = self.fc(x)
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
return x
|
||||
|
||||
return self.post_process(pred)
|
||||
def simple_test(self, x, softmax=True, post_process=True):
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): The input features.
|
||||
Multi-stage inputs are acceptable but only the last stage will
|
||||
be used to classify. The shape of every item should be
|
||||
``(num_samples, in_channels)``.
|
||||
softmax (bool): Whether to softmax the classification score.
|
||||
post_process (bool): Whether to do post processing the
|
||||
inference results. It will convert the output to a list.
|
||||
|
||||
Returns:
|
||||
Tensor | list: The inference results.
|
||||
|
||||
- If no post processing, the output is a tensor with shape
|
||||
``(num_samples, num_classes)``.
|
||||
- If post processing, the output is a multi-dimentional list of
|
||||
float and the dimensions are ``(num_samples, num_classes)``.
|
||||
"""
|
||||
x = self.pre_logits(x)
|
||||
cls_score = self.fc(x)
|
||||
|
||||
if softmax:
|
||||
pred = (
|
||||
F.softmax(cls_score, dim=1) if cls_score is not None else None)
|
||||
else:
|
||||
pred = cls_score
|
||||
|
||||
if post_process:
|
||||
return self.post_process(pred)
|
||||
else:
|
||||
return pred
|
||||
|
||||
def forward_train(self, x, gt_label, **kwargs):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
x = self.pre_logits(x)
|
||||
cls_score = self.fc(x)
|
||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import HEADS, build_loss
|
||||
from ..utils import is_tracing
|
||||
|
@ -47,14 +46,50 @@ class MultiLabelClsHead(BaseHead):
|
|||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
|
||||
def simple_test(self, x):
|
||||
def pre_logits(self, x):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
if isinstance(x, list):
|
||||
x = sum(x) / float(len(x))
|
||||
pred = F.sigmoid(x) if x is not None else None
|
||||
|
||||
return self.post_process(pred)
|
||||
from mmcls.utils import get_root_logger
|
||||
logger = get_root_logger()
|
||||
logger.warning(
|
||||
'The input of MultiLabelClsHead should be already logits. '
|
||||
'Please modify the backbone if you want to get pre-logits feature.'
|
||||
)
|
||||
return x
|
||||
|
||||
def simple_test(self, x, sigmoid=True, post_process=True):
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
cls_score (tuple[Tensor]): The input classification score logits.
|
||||
Multi-stage inputs are acceptable but only the last stage will
|
||||
be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
sigmoid (bool): Whether to sigmoid the classification score.
|
||||
post_process (bool): Whether to do post processing the
|
||||
inference results. It will convert the output to a list.
|
||||
|
||||
Returns:
|
||||
Tensor | list: The inference results.
|
||||
|
||||
- If no post processing, the output is a tensor with shape
|
||||
``(num_samples, num_classes)``.
|
||||
- If post processing, the output is a multi-dimentional list of
|
||||
float and the dimensions are ``(num_samples, num_classes)``.
|
||||
"""
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
|
||||
if sigmoid:
|
||||
pred = torch.sigmoid(x) if x is not None else None
|
||||
else:
|
||||
pred = x
|
||||
|
||||
if post_process:
|
||||
return self.post_process(pred)
|
||||
else:
|
||||
return pred
|
||||
|
||||
def post_process(self, pred):
|
||||
on_trace = is_tracing()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import HEADS
|
||||
from .multi_label_head import MultiLabelClsHead
|
||||
|
@ -39,21 +39,47 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
|
|||
|
||||
self.fc = nn.Linear(self.in_channels, self.num_classes)
|
||||
|
||||
def forward_train(self, x, gt_label, **kwargs):
|
||||
def pre_logits(self, x):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
return x
|
||||
|
||||
def forward_train(self, x, gt_label, **kwargs):
|
||||
x = self.pre_logits(x)
|
||||
gt_label = gt_label.type_as(x)
|
||||
cls_score = self.fc(x)
|
||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
|
||||
def simple_test(self, x):
|
||||
"""Test without augmentation."""
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
cls_score = self.fc(x)
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.sigmoid(cls_score) if cls_score is not None else None
|
||||
def simple_test(self, x, sigmoid=True, post_process=True):
|
||||
"""Inference without augmentation.
|
||||
|
||||
return self.post_process(pred)
|
||||
Args:
|
||||
x (tuple[Tensor]): The input features.
|
||||
Multi-stage inputs are acceptable but only the last stage will
|
||||
be used to classify. The shape of every item should be
|
||||
``(num_samples, in_channels)``.
|
||||
sigmoid (bool): Whether to sigmoid the classification score.
|
||||
post_process (bool): Whether to do post processing the
|
||||
inference results. It will convert the output to a list.
|
||||
|
||||
Returns:
|
||||
Tensor | list: The inference results.
|
||||
|
||||
- If no post processing, the output is a tensor with shape
|
||||
``(num_samples, num_classes)``.
|
||||
- If post processing, the output is a multi-dimentional list of
|
||||
float and the dimensions are ``(num_samples, num_classes)``.
|
||||
"""
|
||||
x = self.pre_logits(x)
|
||||
cls_score = self.fc(x)
|
||||
|
||||
if sigmoid:
|
||||
pred = torch.sigmoid(cls_score) if cls_score is not None else None
|
||||
else:
|
||||
pred = cls_score
|
||||
|
||||
if post_process:
|
||||
return self.post_process(pred)
|
||||
else:
|
||||
return pred
|
||||
|
|
|
@ -49,8 +49,7 @@ class StackedLinearClsHead(ClsHead):
|
|||
"""Classifier head with several hidden fc layer and a output fc layer.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of categories excluding the background
|
||||
category.
|
||||
num_classes (int): Number of categories.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
mid_channels (Sequence): Number of channels in the hidden fc layers.
|
||||
dropout_rate (float): Dropout rate after each hidden fc layer,
|
||||
|
@ -89,9 +88,7 @@ class StackedLinearClsHead(ClsHead):
|
|||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
self.layers = ModuleList(
|
||||
init_cfg=dict(
|
||||
type='Normal', layer='Linear', mean=0., std=0.01, bias=0.))
|
||||
self.layers = ModuleList()
|
||||
in_channels = self.in_channels
|
||||
for hidden_channels in self.mid_channels:
|
||||
self.layers.append(
|
||||
|
@ -114,24 +111,53 @@ class StackedLinearClsHead(ClsHead):
|
|||
def init_weights(self):
|
||||
self.layers.init_weights()
|
||||
|
||||
def simple_test(self, x):
|
||||
"""Test without augmentation."""
|
||||
def pre_logits(self, x):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
cls_score = x
|
||||
for layer in self.layers:
|
||||
cls_score = layer(cls_score)
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
for layer in self.layers[:-1]:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
return self.post_process(pred)
|
||||
@property
|
||||
def fc(self):
|
||||
return self.layers[-1]
|
||||
|
||||
def simple_test(self, x, softmax=True, post_process=True):
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): The input features.
|
||||
Multi-stage inputs are acceptable but only the last stage will
|
||||
be used to classify. The shape of every item should be
|
||||
``(num_samples, in_channels)``.
|
||||
softmax (bool): Whether to softmax the classification score.
|
||||
post_process (bool): Whether to do post processing the
|
||||
inference results. It will convert the output to a list.
|
||||
|
||||
Returns:
|
||||
Tensor | list: The inference results.
|
||||
|
||||
- If no post processing, the output is a tensor with shape
|
||||
``(num_samples, num_classes)``.
|
||||
- If post processing, the output is a multi-dimentional list of
|
||||
float and the dimensions are ``(num_samples, num_classes)``.
|
||||
"""
|
||||
x = self.pre_logits(x)
|
||||
cls_score = self.fc(x)
|
||||
|
||||
if softmax:
|
||||
pred = (
|
||||
F.softmax(cls_score, dim=1) if cls_score is not None else None)
|
||||
else:
|
||||
pred = cls_score
|
||||
|
||||
if post_process:
|
||||
return self.post_process(pred)
|
||||
else:
|
||||
return pred
|
||||
|
||||
def forward_train(self, x, gt_label, **kwargs):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
cls_score = x
|
||||
for layer in self.layers:
|
||||
cls_score = layer(cls_score)
|
||||
x = self.pre_logits(x)
|
||||
cls_score = self.fc(x)
|
||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
|
|
|
@ -68,20 +68,54 @@ class VisionTransformerClsHead(ClsHead):
|
|||
std=math.sqrt(1 / self.layers.pre_logits.in_features))
|
||||
nn.init.zeros_(self.layers.pre_logits.bias)
|
||||
|
||||
def simple_test(self, x):
|
||||
"""Test without augmentation."""
|
||||
x = x[-1]
|
||||
def pre_logits(self, x):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
_, cls_token = x
|
||||
cls_score = self.layers(cls_token)
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
if self.hidden_dim is None:
|
||||
return cls_token
|
||||
else:
|
||||
x = self.layers.pre_logits(cls_token)
|
||||
return self.layers.act(x)
|
||||
|
||||
return self.post_process(pred)
|
||||
def simple_test(self, x, softmax=True, post_process=True):
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
x (tuple[tuple[tensor, tensor]]): The input features.
|
||||
Multi-stage inputs are acceptable but only the last stage will
|
||||
be used to classify. Every item should be a tuple which
|
||||
includes patch token and cls token. The cls token will be used
|
||||
to classify and the shape of it should be
|
||||
``(num_samples, in_channels)``.
|
||||
softmax (bool): Whether to softmax the classification score.
|
||||
post_process (bool): Whether to do post processing the
|
||||
inference results. It will convert the output to a list.
|
||||
|
||||
Returns:
|
||||
Tensor | list: The inference results.
|
||||
|
||||
- If no post processing, the output is a tensor with shape
|
||||
``(num_samples, num_classes)``.
|
||||
- If post processing, the output is a multi-dimentional list of
|
||||
float and the dimensions are ``(num_samples, num_classes)``.
|
||||
"""
|
||||
x = self.pre_logits(x)
|
||||
cls_score = self.layers.head(x)
|
||||
|
||||
if softmax:
|
||||
pred = (
|
||||
F.softmax(cls_score, dim=1) if cls_score is not None else None)
|
||||
else:
|
||||
pred = cls_score
|
||||
|
||||
if post_process:
|
||||
return self.post_process(pred)
|
||||
else:
|
||||
return pred
|
||||
|
||||
def forward_train(self, x, gt_label, **kwargs):
|
||||
x = x[-1]
|
||||
_, cls_token = x
|
||||
cls_score = self.layers(cls_token)
|
||||
x = self.pre_logits(x)
|
||||
cls_score = self.layers.head(x)
|
||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
|
|
|
@ -73,6 +73,19 @@ def test_image_classifier():
|
|||
pred = model(single_img, return_loss=False, img_metas=None)
|
||||
assert isinstance(pred, list) and len(pred) == 1
|
||||
|
||||
pred = model.simple_test(imgs, softmax=False)
|
||||
assert isinstance(pred, list) and len(pred) == 16
|
||||
assert len(pred[0] == 10)
|
||||
|
||||
pred = model.simple_test(imgs, softmax=False, post_process=False)
|
||||
assert isinstance(pred, torch.Tensor)
|
||||
assert pred.shape == (16, 10)
|
||||
|
||||
soft_pred = model.simple_test(imgs, softmax=True, post_process=False)
|
||||
assert isinstance(soft_pred, torch.Tensor)
|
||||
assert soft_pred.shape == (16, 10)
|
||||
torch.testing.assert_allclose(soft_pred, torch.softmax(pred, dim=1))
|
||||
|
||||
# test pretrained
|
||||
# TODO remove deprecated pretrained
|
||||
with pytest.warns(UserWarning):
|
||||
|
@ -83,7 +96,7 @@ def test_image_classifier():
|
|||
type='Pretrained', checkpoint='checkpoint')
|
||||
|
||||
# test show_result
|
||||
img = np.random.random_integers(0, 255, (224, 224, 3)).astype(np.uint8)
|
||||
img = np.random.randint(0, 256, (224, 224, 3)).astype(np.uint8)
|
||||
result = dict(pred_class='cat', pred_label=0, pred_score=0.9)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
|
@ -304,3 +317,88 @@ def test_image_classifier_return_tuple():
|
|||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
model.extract_feat(imgs)
|
||||
|
||||
|
||||
def test_classifier_extract_feat():
|
||||
model_cfg = ConfigDict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=18,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
loss=dict(type='CrossEntropyLoss'),
|
||||
topk=(1, 5),
|
||||
))
|
||||
|
||||
model = CLASSIFIERS.build(model_cfg)
|
||||
|
||||
# test backbone output
|
||||
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone')
|
||||
assert outs[0].shape == (1, 64, 56, 56)
|
||||
assert outs[1].shape == (1, 128, 28, 28)
|
||||
assert outs[2].shape == (1, 256, 14, 14)
|
||||
assert outs[3].shape == (1, 512, 7, 7)
|
||||
|
||||
# test neck output
|
||||
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck')
|
||||
assert outs[0].shape == (1, 64)
|
||||
assert outs[1].shape == (1, 128)
|
||||
assert outs[2].shape == (1, 256)
|
||||
assert outs[3].shape == (1, 512)
|
||||
|
||||
# test pre_logits output
|
||||
out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits')
|
||||
assert out.shape == (1, 512)
|
||||
|
||||
# test transformer style feature extraction
|
||||
model_cfg = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='VisionTransformer', arch='b', out_indices=[-3, -2, -1]),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='VisionTransformerClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
hidden_dim=1024,
|
||||
loss=dict(type='CrossEntropyLoss'),
|
||||
))
|
||||
model = CLASSIFIERS.build(model_cfg)
|
||||
|
||||
# test backbone output
|
||||
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone')
|
||||
for out in outs:
|
||||
patch_token, cls_token = out
|
||||
assert patch_token.shape == (1, 768, 14, 14)
|
||||
assert cls_token.shape == (1, 768)
|
||||
|
||||
# test neck output (the same with backbone)
|
||||
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck')
|
||||
for out in outs:
|
||||
patch_token, cls_token = out
|
||||
assert patch_token.shape == (1, 768, 14, 14)
|
||||
assert cls_token.shape == (1, 768)
|
||||
|
||||
# test pre_logits output
|
||||
out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits')
|
||||
assert out.shape == (1, 1024)
|
||||
|
||||
# test extract_feats
|
||||
multi_imgs = [torch.rand(1, 3, 224, 224) for _ in range(3)]
|
||||
outs = model.extract_feats(multi_imgs)
|
||||
for outs_per_img in outs:
|
||||
for out in outs_per_img:
|
||||
patch_token, cls_token = out
|
||||
assert patch_token.shape == (1, 768, 14, 14)
|
||||
assert cls_token.shape == (1, 768)
|
||||
|
||||
outs = model.extract_feats(multi_imgs, stage='pre_logits')
|
||||
for out_per_img in outs:
|
||||
assert out_per_img.shape == (1, 1024)
|
||||
|
|
|
@ -4,36 +4,53 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.heads import (ClsHead, DeiTClsHead, LinearClsHead,
|
||||
MultiLabelClsHead, MultiLabelLinearClsHead,
|
||||
StackedLinearClsHead, VisionTransformerClsHead)
|
||||
from mmcls.models.heads import (ClsHead, ConformerHead, DeiTClsHead,
|
||||
LinearClsHead, MultiLabelClsHead,
|
||||
MultiLabelLinearClsHead, StackedLinearClsHead,
|
||||
VisionTransformerClsHead)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
|
||||
@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )])
|
||||
def test_cls_head(feat):
|
||||
fake_gt_label = torch.randint(0, 10, (4, ))
|
||||
|
||||
# test ClsHead with cal_acc=False
|
||||
head = ClsHead()
|
||||
fake_gt_label = torch.randint(0, 2, (4, ))
|
||||
|
||||
losses = head.forward_train(feat, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test ClsHead with cal_acc=True
|
||||
# test forward_train with cal_acc=True
|
||||
head = ClsHead(cal_acc=True)
|
||||
feat = torch.rand(4, 3)
|
||||
fake_gt_label = torch.randint(0, 2, (4, ))
|
||||
losses = head.forward_train(feat, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
assert 'accuracy' in losses
|
||||
|
||||
# test forward_train with cal_acc=False
|
||||
head = ClsHead()
|
||||
losses = head.forward_train(feat, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test ClsHead with weight
|
||||
# test forward_train with weight
|
||||
weight = torch.tensor([0.5, 0.5, 0.5, 0.5])
|
||||
|
||||
losses_ = head.forward_train(feat, fake_gt_label)
|
||||
losses = head.forward_train(feat, fake_gt_label, weight=weight)
|
||||
assert losses['loss'].item() == losses_['loss'].item() * 0.5
|
||||
|
||||
# test simple_test with post_process
|
||||
pred = head.simple_test(feat)
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
pred = head.simple_test(feat)
|
||||
assert pred.shape == (4, 10)
|
||||
|
||||
# test simple_test without post_process
|
||||
pred = head.simple_test(feat, post_process=False)
|
||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
||||
logits = head.simple_test(feat, softmax=False, post_process=False)
|
||||
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
|
||||
|
||||
# test pre_logits
|
||||
features = head.pre_logits(feat)
|
||||
if isinstance(feat, tuple):
|
||||
torch.testing.assert_allclose(features, feat[0])
|
||||
else:
|
||||
torch.testing.assert_allclose(features, feat)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
|
||||
def test_linear_head(feat):
|
||||
|
@ -50,35 +67,85 @@ def test_linear_head(feat):
|
|||
head.init_weights()
|
||||
assert abs(head.fc.weight).sum() > 0
|
||||
|
||||
# test simple_test
|
||||
head = LinearClsHead(10, 3)
|
||||
# test simple_test with post_process
|
||||
pred = head.simple_test(feat)
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
head = LinearClsHead(10, 3)
|
||||
pred = head.simple_test(feat)
|
||||
assert pred.shape == (4, 10)
|
||||
|
||||
# test simple_test without post_process
|
||||
pred = head.simple_test(feat, post_process=False)
|
||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
||||
logits = head.simple_test(feat, softmax=False, post_process=False)
|
||||
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
|
||||
|
||||
@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
|
||||
# test pre_logits
|
||||
features = head.pre_logits(feat)
|
||||
if isinstance(feat, tuple):
|
||||
torch.testing.assert_allclose(features, feat[0])
|
||||
else:
|
||||
torch.testing.assert_allclose(features, feat)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )])
|
||||
def test_multilabel_head(feat):
|
||||
head = MultiLabelClsHead()
|
||||
fake_gt_label = torch.randint(0, 2, (4, 3))
|
||||
fake_gt_label = torch.randint(0, 2, (4, 10))
|
||||
|
||||
losses = head.forward_train(feat, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test simple_test with post_process
|
||||
pred = head.simple_test(feat)
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
pred = head.simple_test(feat)
|
||||
assert pred.shape == (4, 10)
|
||||
|
||||
# test simple_test without post_process
|
||||
pred = head.simple_test(feat, post_process=False)
|
||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
||||
logits = head.simple_test(feat, sigmoid=False, post_process=False)
|
||||
torch.testing.assert_allclose(pred, torch.sigmoid(logits))
|
||||
|
||||
# test pre_logits
|
||||
features = head.pre_logits(feat)
|
||||
if isinstance(feat, tuple):
|
||||
torch.testing.assert_allclose(features, feat[0])
|
||||
else:
|
||||
torch.testing.assert_allclose(features, feat)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )])
|
||||
def test_multilabel_linear_head(feat):
|
||||
head = MultiLabelLinearClsHead(3, 5)
|
||||
fake_gt_label = torch.randint(0, 2, (4, 3))
|
||||
head = MultiLabelLinearClsHead(10, 5)
|
||||
fake_gt_label = torch.randint(0, 2, (4, 10))
|
||||
|
||||
head.init_weights()
|
||||
losses = head.forward_train(feat, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test simple_test with post_process
|
||||
pred = head.simple_test(feat)
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
pred = head.simple_test(feat)
|
||||
assert pred.shape == (4, 10)
|
||||
|
||||
# test simple_test without post_process
|
||||
pred = head.simple_test(feat, post_process=False)
|
||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
||||
logits = head.simple_test(feat, sigmoid=False, post_process=False)
|
||||
torch.testing.assert_allclose(pred, torch.sigmoid(logits))
|
||||
|
||||
# test pre_logits
|
||||
features = head.pre_logits(feat)
|
||||
if isinstance(feat, tuple):
|
||||
torch.testing.assert_allclose(features, feat[0])
|
||||
else:
|
||||
torch.testing.assert_allclose(features, feat)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )])
|
||||
def test_stacked_linear_cls_head(feat):
|
||||
|
@ -93,20 +160,28 @@ def test_stacked_linear_cls_head(feat):
|
|||
|
||||
# test forward with default setting
|
||||
head = StackedLinearClsHead(
|
||||
num_classes=3, in_channels=5, mid_channels=[10])
|
||||
num_classes=10, in_channels=5, mid_channels=[20])
|
||||
head.init_weights()
|
||||
|
||||
losses = head.forward_train(feat, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test simple test
|
||||
# test simple_test with post_process
|
||||
pred = head.simple_test(feat)
|
||||
assert len(pred) == 4
|
||||
|
||||
# test simple test in tracing
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
pred = head.simple_test(feat)
|
||||
assert pred.shape == torch.Size((4, 3))
|
||||
assert pred.shape == (4, 10)
|
||||
|
||||
# test simple_test without post_process
|
||||
pred = head.simple_test(feat, post_process=False)
|
||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
||||
logits = head.simple_test(feat, softmax=False, post_process=False)
|
||||
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
|
||||
|
||||
# test pre_logits
|
||||
features = head.pre_logits(feat)
|
||||
assert features.shape == (4, 20)
|
||||
|
||||
# test forward with full function
|
||||
head = StackedLinearClsHead(
|
||||
|
@ -144,21 +219,56 @@ def test_vit_head():
|
|||
head.init_weights()
|
||||
assert abs(head.layers.pre_logits.weight).sum() > 0
|
||||
|
||||
# test simple_test
|
||||
head = VisionTransformerClsHead(10, 100, hidden_dim=20)
|
||||
# test simple_test with post_process
|
||||
pred = head.simple_test(fake_features)
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
head = VisionTransformerClsHead(10, 100, hidden_dim=20)
|
||||
pred = head.simple_test(fake_features)
|
||||
assert pred.shape == (4, 10)
|
||||
|
||||
# test simple_test without post_process
|
||||
pred = head.simple_test(fake_features, post_process=False)
|
||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
||||
logits = head.simple_test(fake_features, softmax=False, post_process=False)
|
||||
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
|
||||
|
||||
# test pre_logits
|
||||
features = head.pre_logits(fake_features)
|
||||
assert features.shape == (4, 20)
|
||||
|
||||
# test assertion
|
||||
with pytest.raises(ValueError):
|
||||
VisionTransformerClsHead(-1, 100)
|
||||
|
||||
|
||||
def test_conformer_head():
|
||||
fake_features = ([torch.rand(4, 64), torch.rand(4, 96)], )
|
||||
fake_gt_label = torch.randint(0, 10, (4, ))
|
||||
|
||||
# test conformer head forward
|
||||
head = ConformerHead(num_classes=10, in_channels=[64, 96])
|
||||
losses = head.forward_train(fake_features, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test simple_test with post_process
|
||||
pred = head.simple_test(fake_features)
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
pred = head.simple_test(fake_features)
|
||||
assert pred.shape == (4, 10)
|
||||
|
||||
# test simple_test without post_process
|
||||
pred = head.simple_test(fake_features, post_process=False)
|
||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
||||
logits = head.simple_test(fake_features, softmax=False, post_process=False)
|
||||
torch.testing.assert_allclose(pred, torch.softmax(sum(logits), dim=1))
|
||||
|
||||
# test pre_logits
|
||||
features = head.pre_logits(fake_features)
|
||||
assert features is fake_features[0]
|
||||
|
||||
|
||||
def test_deit_head():
|
||||
fake_features = ([
|
||||
torch.rand(4, 7, 7, 16),
|
||||
|
@ -185,16 +295,25 @@ def test_deit_head():
|
|||
head.init_weights()
|
||||
assert abs(head.layers.pre_logits.weight).sum() > 0
|
||||
|
||||
# test simple_test
|
||||
head = DeiTClsHead(10, 100, hidden_dim=20)
|
||||
# test simple_test with post_process
|
||||
pred = head.simple_test(fake_features)
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
head = DeiTClsHead(10, 100, hidden_dim=20)
|
||||
pred = head.simple_test(fake_features)
|
||||
assert pred.shape == (4, 10)
|
||||
|
||||
# test simple_test without post_process
|
||||
pred = head.simple_test(fake_features, post_process=False)
|
||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
||||
logits = head.simple_test(fake_features, softmax=False, post_process=False)
|
||||
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
|
||||
|
||||
# test pre_logits
|
||||
cls_token, dist_token = head.pre_logits(fake_features)
|
||||
assert cls_token.shape == (4, 20)
|
||||
assert dist_token.shape == (4, 20)
|
||||
|
||||
# test assertion
|
||||
with pytest.raises(ValueError):
|
||||
DeiTClsHead(-1, 100)
|
||||
|
|
Loading…
Reference in New Issue