[Enhance] Enhance feature extraction function. (#593)

* Fix MobileNet V3 configs

* Refactor to support more powerful feature extraction.

* Add unit tests

* Fix unit test

* Imporve according to comments

* Update checkpoints path

* Fix unit tests

* Add docstring of `simple_test`

* Add docstring of `extract_feat`

* Update model zoo
pull/609/head
Ma Zerun 2021-12-17 15:55:02 +08:00 committed by GitHub
parent f9a2b04cee
commit 643fb192cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 710 additions and 150 deletions

View File

@ -11,4 +11,6 @@ model = dict(
dropout_rate=0.2,
act_cfg=dict(type='HSwish'),
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
init_cfg=dict(
type='Normal', layer='Linear', mean=0., std=0.01, bias=0.),
topk=(1, 5)))

View File

@ -11,4 +11,6 @@ model = dict(
dropout_rate=0.2,
act_cfg=dict(type='HSwish'),
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
init_cfg=dict(
type='Normal', layer='Linear', mean=0., std=0.01, bias=0.),
topk=(1, 5)))

View File

@ -34,11 +34,11 @@ The pre-trained models are converted from the [official repo](https://github.com
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) |
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) |
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth) |
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) |
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) |
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth) |
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) |
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) |
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth) |
*Models with \* are converted from other repos.*
@ -51,7 +51,7 @@ The fine-tuned models are converted from the [official repo](https://github.com/
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
| DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) |
| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth) |
| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth) |
*Models with \* are converted from other repos.*

View File

@ -40,7 +40,7 @@ Models:
Top 1 Accuracy: 74.51
Top 5 Accuracy: 91.90
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L108
@ -72,7 +72,7 @@ Models:
Top 1 Accuracy: 81.17
Top 5 Accuracy: 95.40
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L123
@ -104,7 +104,7 @@ Models:
Top 1 Accuracy: 83.33
Top 5 Accuracy: 96.49
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L138
@ -136,7 +136,7 @@ Models:
Top 1 Accuracy: 85.55
Top 5 Accuracy: 97.35
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L168

View File

@ -17,7 +17,7 @@
# - modify: RandomErasing use RE-M instead of RE-0
_base_ = [
'../_base_/models/mobilenet-v3-small_8xb32_in1k.py',
'../_base_/models/mobilenet_v3_small_imagenet.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/default_runtime.py'
]

View File

@ -63,16 +63,18 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) | [log]()|
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) | [log]()|
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) | [log]()|
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) | [log]()|
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) | [log]()|
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) | [log]()|
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) | [log]()|
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) | [log]()|
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) | [log]()|
| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) | [log]()|
| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) | [log]()|
| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) | [log]()|
| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) | [log]()|
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) | [log]()|
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth) | [log]()|
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) | [log]()|
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth) | [log]()|
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) | [log]()|
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth) | [log]()|
| DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) | [log]()|
| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth) | [log]()|
| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) | [log]()|
| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) | [log]()|
| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) | [log]()|
| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) | [log]()|
Models with * are converted from other repos, others are trained by ourselves.

View File

@ -2,6 +2,7 @@
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import Sequence
import mmcv
import torch
@ -35,13 +36,14 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
return hasattr(self, 'head') and self.head is not None
@abstractmethod
def extract_feat(self, imgs):
def extract_feat(self, imgs, stage=None):
pass
def extract_feats(self, imgs):
assert isinstance(imgs, list)
def extract_feats(self, imgs, stage=None):
assert isinstance(imgs, Sequence)
kwargs = {} if stage is None else {'stage': stage}
for img in imgs:
yield self.extract_feat(img)
yield self.extract_feat(img, **kwargs)
@abstractmethod
def forward_train(self, imgs, **kwargs):

View File

@ -3,6 +3,7 @@ import copy
import warnings
from ..builder import CLASSIFIERS, build_backbone, build_head, build_neck
from ..heads import MultiLabelClsHead
from ..utils.augment import Augments
from .base import BaseClassifier
@ -70,8 +71,74 @@ class ImageClassifier(BaseClassifier):
cfg['prob'] = cutmix_prob
self.augments = Augments(cfg)
def extract_feat(self, img):
"""Directly extract features from the backbone + neck."""
def extract_feat(self, img, stage='neck'):
"""Directly extract features from the specified stage.
Args:
img (Tensor): The input images. The shape of it should be
``(num_samples, num_channels, *img_shape)``.
stage (str): Which stage to output the feature. Choose from
"backbone", "neck" and "pre_logits". Defaults to "neck".
Returns:
tuple | Tensor: The output of specified stage.
The output depends on detailed implementation. In general, the
output of backbone and neck is a tuple and the output of
pre_logits is a tensor.
Examples:
1. Backbone output
>>> import torch
>>> from mmcv import Config
>>> from mmcls.models import build_classifier
>>>
>>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model
>>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps
>>> model = build_classifier(cfg)
>>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone')
>>> for out in outs:
... print(out.shape)
torch.Size([1, 64, 56, 56])
torch.Size([1, 128, 28, 28])
torch.Size([1, 256, 14, 14])
torch.Size([1, 512, 7, 7])
2. Neck output
>>> import torch
>>> from mmcv import Config
>>> from mmcls.models import build_classifier
>>>
>>> cfg = Config.fromfile('configs/resnet/resnet18_8xb32_in1k.py').model
>>> cfg.backbone.out_indices = (0, 1, 2, 3) # Output multi-scale feature maps
>>> model = build_classifier(cfg)
>>>
>>> outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck')
>>> for out in outs:
... print(out.shape)
torch.Size([1, 64])
torch.Size([1, 128])
torch.Size([1, 256])
torch.Size([1, 512])
3. Pre-logits output (without the final linear classifier head)
>>> import torch
>>> from mmcv import Config
>>> from mmcls.models import build_classifier
>>>
>>> cfg = Config.fromfile('configs/vision_transformer/vit-base-p16_pt-64xb64_in1k-224.py').model
>>> model = build_classifier(cfg)
>>>
>>> out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits')
>>> print(out.shape) # The hidden dims in head is 3072
torch.Size([1, 3072])
""" # noqa: E501
assert stage in ['backbone', 'neck', 'pre_logits'], \
(f'Invalid output stage "{stage}", please choose from "backbone", '
'"neck" and "pre_logits"')
x = self.backbone(img)
if self.return_tuple:
if not isinstance(x, tuple):
@ -83,8 +150,16 @@ class ImageClassifier(BaseClassifier):
else:
if isinstance(x, tuple):
x = x[-1]
if stage == 'backbone':
return x
if self.with_neck:
x = self.neck(x)
if stage == 'neck':
return x
if self.with_head and hasattr(self.head, 'pre_logits'):
x = self.head.pre_logits(x)
return x
def forward_train(self, img, gt_label, **kwargs):
@ -122,12 +197,16 @@ class ImageClassifier(BaseClassifier):
return losses
def simple_test(self, img, img_metas):
def simple_test(self, img, img_metas=None, **kwargs):
"""Test without augmentation."""
x = self.extract_feat(img)
try:
res = self.head.simple_test(x)
if isinstance(self.head, MultiLabelClsHead):
assert 'softmax' not in kwargs, (
'Please use `sigmoid` instead of `softmax` '
'in multi-label tasks.')
res = self.head.simple_test(x, **kwargs)
except TypeError as e:
if 'not tuple' in str(e) and self.return_tuple:
return TypeError(

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn.functional as F
@ -62,14 +64,49 @@ class ClsHead(BaseHead):
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
def simple_test(self, cls_score):
"""Test without augmentation."""
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
warnings.warn(
'The input of ClsHead should be already logits. '
'Please modify the backbone if you want to get pre-logits feature.'
)
return x
def simple_test(self, cls_score, softmax=True, post_process=True):
"""Inference without augmentation.
Args:
cls_score (tuple[Tensor]): The input classification score logits.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
softmax (bool): Whether to softmax the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
"""
if isinstance(cls_score, tuple):
cls_score = cls_score[-1]
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
return self.post_process(pred)
if softmax:
pred = (
F.softmax(cls_score, dim=1) if cls_score is not None else None)
else:
pred = cls_score
if post_process:
return self.post_process(pred)
else:
return pred
def post_process(self, pred):
on_trace = is_tracing()

View File

@ -16,7 +16,7 @@ class ConformerHead(ClsHead):
category.
in_channels (int): Number of channels in the input feature map.
init_cfg (dict | optional): The extra init config of layers.
Defaults to use dict(type='Normal', layer='Linear', std=0.01).
Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``.
"""
def __init__(
@ -55,25 +55,54 @@ class ConformerHead(ClsHead):
else:
self.apply(self._init_weights)
def simple_test(self, x):
"""Test without augmentation."""
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
assert isinstance(x,
list) # There are two outputs in the Conformer model
return x
def simple_test(self, x, softmax=True, post_process=True):
"""Inference without augmentation.
Args:
x (tuple[tuple[tensor, tensor]]): The input features.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. Every item should be a tuple which
includes convluation features and transformer features. The
shape of them should be ``(num_samples, in_channels[0])`` and
``(num_samples, in_channels[1])``.
softmax (bool): Whether to softmax the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
"""
x = self.pre_logits(x)
# There are two outputs in the Conformer model
assert len(x) == 2
conv_cls_score = self.conv_cls_head(x[0])
tran_cls_score = self.trans_cls_head(x[1])
cls_score = conv_cls_score + tran_cls_score
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
return self.post_process(pred)
if softmax:
cls_score = conv_cls_score + tran_cls_score
pred = (
F.softmax(cls_score, dim=1) if cls_score is not None else None)
if post_process:
pred = self.post_process(pred)
else:
pred = [conv_cls_score, tran_cls_score]
if post_process:
pred = list(map(self.post_process, pred))
return pred
def forward_train(self, x, gt_label):
if isinstance(x, tuple):
x = x[-1]
x = self.pre_logits(x)
assert isinstance(x, list) and len(x) == 2, \
'There should be two outputs in the Conformer model'

View File

@ -12,25 +12,67 @@ class DeiTClsHead(VisionTransformerClsHead):
def __init__(self, *args, **kwargs):
super(DeiTClsHead, self).__init__(*args, **kwargs)
self.head_dist = nn.Linear(self.in_channels, self.num_classes)
if self.hidden_dim is None:
head_dist = nn.Linear(self.in_channels, self.num_classes)
else:
head_dist = nn.Linear(self.hidden_dim, self.num_classes)
self.layers.add_module('head_dist', head_dist)
def simple_test(self, x):
"""Test without augmentation."""
x = x[-1]
assert isinstance(x, list) and len(x) == 3
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
_, cls_token, dist_token = x
cls_score = (self.layers(cls_token) + self.head_dist(dist_token)) / 2
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
return self.post_process(pred)
if self.hidden_dim is None:
return cls_token, dist_token
else:
cls_token = self.layers.act(self.layers.pre_logits(cls_token))
dist_token = self.layers.act(self.layers.pre_logits(dist_token))
return cls_token, dist_token
def simple_test(self, x, softmax=True, post_process=True):
"""Inference without augmentation.
Args:
x (tuple[tuple[tensor, tensor, tensor]]): The input features.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. Every item should be a tuple which
includes patch token, cls token and dist token. The cls token
and dist token will be used to classify and the shape of them
should be ``(num_samples, in_channels)``.
softmax (bool): Whether to softmax the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
"""
cls_token, dist_token = self.pre_logits(x)
cls_score = (self.layers.head(cls_token) +
self.layers.head_dist(dist_token)) / 2
if softmax:
pred = F.softmax(
cls_score, dim=1) if cls_score is not None else None
else:
pred = cls_score
if post_process:
return self.post_process(pred)
else:
return pred
def forward_train(self, x, gt_label):
logger = get_root_logger()
logger.warning("MMClassification doesn't support to train the "
'distilled version DeiT.')
x = x[-1]
assert isinstance(x, list) and len(x) == 3
_, cls_token, dist_token = x
cls_score = (self.layers(cls_token) + self.head_dist(dist_token)) / 2
cls_token, dist_token = self.pre_logits(x)
cls_score = (self.layers.head(cls_token) +
self.layers.head_dist(dist_token)) / 2
losses = self.loss(cls_score, gt_label)
return losses

View File

@ -35,20 +35,47 @@ class LinearClsHead(ClsHead):
self.fc = nn.Linear(self.in_channels, self.num_classes)
def simple_test(self, x):
"""Test without augmentation."""
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
cls_score = self.fc(x)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
return x
return self.post_process(pred)
def simple_test(self, x, softmax=True, post_process=True):
"""Inference without augmentation.
Args:
x (tuple[Tensor]): The input features.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. The shape of every item should be
``(num_samples, in_channels)``.
softmax (bool): Whether to softmax the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
"""
x = self.pre_logits(x)
cls_score = self.fc(x)
if softmax:
pred = (
F.softmax(cls_score, dim=1) if cls_score is not None else None)
else:
pred = cls_score
if post_process:
return self.post_process(pred)
else:
return pred
def forward_train(self, x, gt_label, **kwargs):
if isinstance(x, tuple):
x = x[-1]
x = self.pre_logits(x)
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from ..builder import HEADS, build_loss
from ..utils import is_tracing
@ -47,14 +46,50 @@ class MultiLabelClsHead(BaseHead):
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
def simple_test(self, x):
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
if isinstance(x, list):
x = sum(x) / float(len(x))
pred = F.sigmoid(x) if x is not None else None
return self.post_process(pred)
from mmcls.utils import get_root_logger
logger = get_root_logger()
logger.warning(
'The input of MultiLabelClsHead should be already logits. '
'Please modify the backbone if you want to get pre-logits feature.'
)
return x
def simple_test(self, x, sigmoid=True, post_process=True):
"""Inference without augmentation.
Args:
cls_score (tuple[Tensor]): The input classification score logits.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
sigmoid (bool): Whether to sigmoid the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
"""
if isinstance(x, tuple):
x = x[-1]
if sigmoid:
pred = torch.sigmoid(x) if x is not None else None
else:
pred = x
if post_process:
return self.post_process(pred)
else:
return pred
def post_process(self, pred):
on_trace = is_tracing()

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import HEADS
from .multi_label_head import MultiLabelClsHead
@ -39,21 +39,47 @@ class MultiLabelLinearClsHead(MultiLabelClsHead):
self.fc = nn.Linear(self.in_channels, self.num_classes)
def forward_train(self, x, gt_label, **kwargs):
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
return x
def forward_train(self, x, gt_label, **kwargs):
x = self.pre_logits(x)
gt_label = gt_label.type_as(x)
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
def simple_test(self, x):
"""Test without augmentation."""
if isinstance(x, tuple):
x = x[-1]
cls_score = self.fc(x)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.sigmoid(cls_score) if cls_score is not None else None
def simple_test(self, x, sigmoid=True, post_process=True):
"""Inference without augmentation.
return self.post_process(pred)
Args:
x (tuple[Tensor]): The input features.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. The shape of every item should be
``(num_samples, in_channels)``.
sigmoid (bool): Whether to sigmoid the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
"""
x = self.pre_logits(x)
cls_score = self.fc(x)
if sigmoid:
pred = torch.sigmoid(cls_score) if cls_score is not None else None
else:
pred = cls_score
if post_process:
return self.post_process(pred)
else:
return pred

View File

@ -49,8 +49,7 @@ class StackedLinearClsHead(ClsHead):
"""Classifier head with several hidden fc layer and a output fc layer.
Args:
num_classes (int): Number of categories excluding the background
category.
num_classes (int): Number of categories.
in_channels (int): Number of channels in the input feature map.
mid_channels (Sequence): Number of channels in the hidden fc layers.
dropout_rate (float): Dropout rate after each hidden fc layer,
@ -89,9 +88,7 @@ class StackedLinearClsHead(ClsHead):
self._init_layers()
def _init_layers(self):
self.layers = ModuleList(
init_cfg=dict(
type='Normal', layer='Linear', mean=0., std=0.01, bias=0.))
self.layers = ModuleList()
in_channels = self.in_channels
for hidden_channels in self.mid_channels:
self.layers.append(
@ -114,24 +111,53 @@ class StackedLinearClsHead(ClsHead):
def init_weights(self):
self.layers.init_weights()
def simple_test(self, x):
"""Test without augmentation."""
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
cls_score = x
for layer in self.layers:
cls_score = layer(cls_score)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
for layer in self.layers[:-1]:
x = layer(x)
return x
return self.post_process(pred)
@property
def fc(self):
return self.layers[-1]
def simple_test(self, x, softmax=True, post_process=True):
"""Inference without augmentation.
Args:
x (tuple[Tensor]): The input features.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. The shape of every item should be
``(num_samples, in_channels)``.
softmax (bool): Whether to softmax the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
"""
x = self.pre_logits(x)
cls_score = self.fc(x)
if softmax:
pred = (
F.softmax(cls_score, dim=1) if cls_score is not None else None)
else:
pred = cls_score
if post_process:
return self.post_process(pred)
else:
return pred
def forward_train(self, x, gt_label, **kwargs):
if isinstance(x, tuple):
x = x[-1]
cls_score = x
for layer in self.layers:
cls_score = layer(cls_score)
x = self.pre_logits(x)
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses

View File

@ -68,20 +68,54 @@ class VisionTransformerClsHead(ClsHead):
std=math.sqrt(1 / self.layers.pre_logits.in_features))
nn.init.zeros_(self.layers.pre_logits.bias)
def simple_test(self, x):
"""Test without augmentation."""
x = x[-1]
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
_, cls_token = x
cls_score = self.layers(cls_token)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
if self.hidden_dim is None:
return cls_token
else:
x = self.layers.pre_logits(cls_token)
return self.layers.act(x)
return self.post_process(pred)
def simple_test(self, x, softmax=True, post_process=True):
"""Inference without augmentation.
Args:
x (tuple[tuple[tensor, tensor]]): The input features.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. Every item should be a tuple which
includes patch token and cls token. The cls token will be used
to classify and the shape of it should be
``(num_samples, in_channels)``.
softmax (bool): Whether to softmax the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
"""
x = self.pre_logits(x)
cls_score = self.layers.head(x)
if softmax:
pred = (
F.softmax(cls_score, dim=1) if cls_score is not None else None)
else:
pred = cls_score
if post_process:
return self.post_process(pred)
else:
return pred
def forward_train(self, x, gt_label, **kwargs):
x = x[-1]
_, cls_token = x
cls_score = self.layers(cls_token)
x = self.pre_logits(x)
cls_score = self.layers.head(x)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses

View File

@ -73,6 +73,19 @@ def test_image_classifier():
pred = model(single_img, return_loss=False, img_metas=None)
assert isinstance(pred, list) and len(pred) == 1
pred = model.simple_test(imgs, softmax=False)
assert isinstance(pred, list) and len(pred) == 16
assert len(pred[0] == 10)
pred = model.simple_test(imgs, softmax=False, post_process=False)
assert isinstance(pred, torch.Tensor)
assert pred.shape == (16, 10)
soft_pred = model.simple_test(imgs, softmax=True, post_process=False)
assert isinstance(soft_pred, torch.Tensor)
assert soft_pred.shape == (16, 10)
torch.testing.assert_allclose(soft_pred, torch.softmax(pred, dim=1))
# test pretrained
# TODO remove deprecated pretrained
with pytest.warns(UserWarning):
@ -83,7 +96,7 @@ def test_image_classifier():
type='Pretrained', checkpoint='checkpoint')
# test show_result
img = np.random.random_integers(0, 255, (224, 224, 3)).astype(np.uint8)
img = np.random.randint(0, 256, (224, 224, 3)).astype(np.uint8)
result = dict(pred_class='cat', pred_label=0, pred_score=0.9)
with tempfile.TemporaryDirectory() as tmpdir:
@ -304,3 +317,88 @@ def test_image_classifier_return_tuple():
with pytest.warns(DeprecationWarning):
model.extract_feat(imgs)
def test_classifier_extract_feat():
model_cfg = ConfigDict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=18,
num_stages=4,
out_indices=(0, 1, 2, 3),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss'),
topk=(1, 5),
))
model = CLASSIFIERS.build(model_cfg)
# test backbone output
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone')
assert outs[0].shape == (1, 64, 56, 56)
assert outs[1].shape == (1, 128, 28, 28)
assert outs[2].shape == (1, 256, 14, 14)
assert outs[3].shape == (1, 512, 7, 7)
# test neck output
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck')
assert outs[0].shape == (1, 64)
assert outs[1].shape == (1, 128)
assert outs[2].shape == (1, 256)
assert outs[3].shape == (1, 512)
# test pre_logits output
out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits')
assert out.shape == (1, 512)
# test transformer style feature extraction
model_cfg = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer', arch='b', out_indices=[-3, -2, -1]),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
in_channels=768,
hidden_dim=1024,
loss=dict(type='CrossEntropyLoss'),
))
model = CLASSIFIERS.build(model_cfg)
# test backbone output
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='backbone')
for out in outs:
patch_token, cls_token = out
assert patch_token.shape == (1, 768, 14, 14)
assert cls_token.shape == (1, 768)
# test neck output (the same with backbone)
outs = model.extract_feat(torch.rand(1, 3, 224, 224), stage='neck')
for out in outs:
patch_token, cls_token = out
assert patch_token.shape == (1, 768, 14, 14)
assert cls_token.shape == (1, 768)
# test pre_logits output
out = model.extract_feat(torch.rand(1, 3, 224, 224), stage='pre_logits')
assert out.shape == (1, 1024)
# test extract_feats
multi_imgs = [torch.rand(1, 3, 224, 224) for _ in range(3)]
outs = model.extract_feats(multi_imgs)
for outs_per_img in outs:
for out in outs_per_img:
patch_token, cls_token = out
assert patch_token.shape == (1, 768, 14, 14)
assert cls_token.shape == (1, 768)
outs = model.extract_feats(multi_imgs, stage='pre_logits')
for out_per_img in outs:
assert out_per_img.shape == (1, 1024)

View File

@ -4,36 +4,53 @@ from unittest.mock import patch
import pytest
import torch
from mmcls.models.heads import (ClsHead, DeiTClsHead, LinearClsHead,
MultiLabelClsHead, MultiLabelLinearClsHead,
StackedLinearClsHead, VisionTransformerClsHead)
from mmcls.models.heads import (ClsHead, ConformerHead, DeiTClsHead,
LinearClsHead, MultiLabelClsHead,
MultiLabelLinearClsHead, StackedLinearClsHead,
VisionTransformerClsHead)
@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )])
def test_cls_head(feat):
fake_gt_label = torch.randint(0, 10, (4, ))
# test ClsHead with cal_acc=False
head = ClsHead()
fake_gt_label = torch.randint(0, 2, (4, ))
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
# test ClsHead with cal_acc=True
# test forward_train with cal_acc=True
head = ClsHead(cal_acc=True)
feat = torch.rand(4, 3)
fake_gt_label = torch.randint(0, 2, (4, ))
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
assert 'accuracy' in losses
# test forward_train with cal_acc=False
head = ClsHead()
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
# test ClsHead with weight
# test forward_train with weight
weight = torch.tensor([0.5, 0.5, 0.5, 0.5])
losses_ = head.forward_train(feat, fake_gt_label)
losses = head.forward_train(feat, fake_gt_label, weight=weight)
assert losses['loss'].item() == losses_['loss'].item() * 0.5
# test simple_test with post_process
pred = head.simple_test(feat)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred.shape == (4, 10)
# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, softmax=False, post_process=False)
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
# test pre_logits
features = head.pre_logits(feat)
if isinstance(feat, tuple):
torch.testing.assert_allclose(features, feat[0])
else:
torch.testing.assert_allclose(features, feat)
@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
def test_linear_head(feat):
@ -50,35 +67,85 @@ def test_linear_head(feat):
head.init_weights()
assert abs(head.fc.weight).sum() > 0
# test simple_test
head = LinearClsHead(10, 3)
# test simple_test with post_process
pred = head.simple_test(feat)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
head = LinearClsHead(10, 3)
pred = head.simple_test(feat)
assert pred.shape == (4, 10)
# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, softmax=False, post_process=False)
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
# test pre_logits
features = head.pre_logits(feat)
if isinstance(feat, tuple):
torch.testing.assert_allclose(features, feat[0])
else:
torch.testing.assert_allclose(features, feat)
@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )])
def test_multilabel_head(feat):
head = MultiLabelClsHead()
fake_gt_label = torch.randint(0, 2, (4, 3))
fake_gt_label = torch.randint(0, 2, (4, 10))
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
# test simple_test with post_process
pred = head.simple_test(feat)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred.shape == (4, 10)
# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, sigmoid=False, post_process=False)
torch.testing.assert_allclose(pred, torch.sigmoid(logits))
# test pre_logits
features = head.pre_logits(feat)
if isinstance(feat, tuple):
torch.testing.assert_allclose(features, feat[0])
else:
torch.testing.assert_allclose(features, feat)
@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )])
def test_multilabel_linear_head(feat):
head = MultiLabelLinearClsHead(3, 5)
fake_gt_label = torch.randint(0, 2, (4, 3))
head = MultiLabelLinearClsHead(10, 5)
fake_gt_label = torch.randint(0, 2, (4, 10))
head.init_weights()
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
# test simple_test with post_process
pred = head.simple_test(feat)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred.shape == (4, 10)
# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, sigmoid=False, post_process=False)
torch.testing.assert_allclose(pred, torch.sigmoid(logits))
# test pre_logits
features = head.pre_logits(feat)
if isinstance(feat, tuple):
torch.testing.assert_allclose(features, feat[0])
else:
torch.testing.assert_allclose(features, feat)
@pytest.mark.parametrize('feat', [torch.rand(4, 5), (torch.rand(4, 5), )])
def test_stacked_linear_cls_head(feat):
@ -93,20 +160,28 @@ def test_stacked_linear_cls_head(feat):
# test forward with default setting
head = StackedLinearClsHead(
num_classes=3, in_channels=5, mid_channels=[10])
num_classes=10, in_channels=5, mid_channels=[20])
head.init_weights()
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
# test simple test
# test simple_test with post_process
pred = head.simple_test(feat)
assert len(pred) == 4
# test simple test in tracing
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(feat)
assert pred.shape == torch.Size((4, 3))
assert pred.shape == (4, 10)
# test simple_test without post_process
pred = head.simple_test(feat, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(feat, softmax=False, post_process=False)
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
# test pre_logits
features = head.pre_logits(feat)
assert features.shape == (4, 20)
# test forward with full function
head = StackedLinearClsHead(
@ -144,21 +219,56 @@ def test_vit_head():
head.init_weights()
assert abs(head.layers.pre_logits.weight).sum() > 0
# test simple_test
head = VisionTransformerClsHead(10, 100, hidden_dim=20)
# test simple_test with post_process
pred = head.simple_test(fake_features)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
head = VisionTransformerClsHead(10, 100, hidden_dim=20)
pred = head.simple_test(fake_features)
assert pred.shape == (4, 10)
# test simple_test without post_process
pred = head.simple_test(fake_features, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(fake_features, softmax=False, post_process=False)
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
# test pre_logits
features = head.pre_logits(fake_features)
assert features.shape == (4, 20)
# test assertion
with pytest.raises(ValueError):
VisionTransformerClsHead(-1, 100)
def test_conformer_head():
fake_features = ([torch.rand(4, 64), torch.rand(4, 96)], )
fake_gt_label = torch.randint(0, 10, (4, ))
# test conformer head forward
head = ConformerHead(num_classes=10, in_channels=[64, 96])
losses = head.forward_train(fake_features, fake_gt_label)
assert losses['loss'].item() > 0
# test simple_test with post_process
pred = head.simple_test(fake_features)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(fake_features)
assert pred.shape == (4, 10)
# test simple_test without post_process
pred = head.simple_test(fake_features, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(fake_features, softmax=False, post_process=False)
torch.testing.assert_allclose(pred, torch.softmax(sum(logits), dim=1))
# test pre_logits
features = head.pre_logits(fake_features)
assert features is fake_features[0]
def test_deit_head():
fake_features = ([
torch.rand(4, 7, 7, 16),
@ -185,16 +295,25 @@ def test_deit_head():
head.init_weights()
assert abs(head.layers.pre_logits.weight).sum() > 0
# test simple_test
head = DeiTClsHead(10, 100, hidden_dim=20)
# test simple_test with post_process
pred = head.simple_test(fake_features)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
head = DeiTClsHead(10, 100, hidden_dim=20)
pred = head.simple_test(fake_features)
assert pred.shape == (4, 10)
# test simple_test without post_process
pred = head.simple_test(fake_features, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(fake_features, softmax=False, post_process=False)
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
# test pre_logits
cls_token, dist_token = head.pre_logits(fake_features)
assert cls_token.shape == (4, 20)
assert dist_token.shape == (4, 20)
# test assertion
with pytest.raises(ValueError):
DeiTClsHead(-1, 100)