[Refactor] refactor ViTHead, DeiTHead, ConformerHead, StackedHead

This commit is contained in:
yingfhu 2022-06-21 09:01:19 +00:00 committed by mzr1996
parent a82de04b67
commit 62b046521e
7 changed files with 357 additions and 372 deletions

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@ -27,10 +27,10 @@ class ClsHead(BaseHead):
"""
def __init__(self,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, ),
cal_acc=False,
init_cfg=None):
loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0),
topk: Union[int, Tuple[int]] = (1, ),
cal_acc: bool = False,
init_cfg: Optional[dict] = None):
super(ClsHead, self).__init__(init_cfg=init_cfg)
self.topk = topk

View File

@ -1,8 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.utils.weight_init import trunc_normal_
from typing import List, Sequence, Tuple
import torch
import torch.nn as nn
from mmcls.core import ClsDataSample
from mmcls.metrics import Accuracy
from mmcls.registry import MODELS
from .cls_head import ClsHead
@ -14,19 +17,19 @@ class ConformerHead(ClsHead):
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
in_channels (Sequence[int]): Number of channels in the input
feature map.
init_cfg (dict | optional): The extra init config of layers.
Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``.
"""
def __init__(
self,
num_classes,
in_channels, # [conv_dim, trans_dim]
init_cfg=dict(type='Normal', layer='Linear', std=0.01),
*args,
num_classes: int,
in_channels: Sequence[int], # [conv_dim, trans_dim]
init_cfg: dict = dict(type='TruncNormal', layer='Linear', std=.02),
**kwargs):
super(ConformerHead, self).__init__(init_cfg=None, *args, **kwargs)
super(ConformerHead, self).__init__(init_cfg=init_cfg, **kwargs)
self.in_channels = in_channels
self.num_classes = num_classes
@ -39,94 +42,82 @@ class ConformerHead(ClsHead):
self.conv_cls_head = nn.Linear(self.in_channels[0], num_classes)
self.trans_cls_head = nn.Linear(self.in_channels[1], num_classes)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def pre_logits(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor:
"""The process before the final classification head.
def init_weights(self):
super(ConformerHead, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
else:
self.apply(self._init_weights)
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
return x
def simple_test(self, x, softmax=True, post_process=True):
"""Inference without augmentation.
Args:
x (tuple[tuple[tensor, tensor]]): The input features.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. Every item should be a tuple which
includes convluation features and transformer features. The
shape of them should be ``(num_samples, in_channels[0])`` and
``(num_samples, in_channels[1])``.
softmax (bool): Whether to softmax the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
The input ``feats`` is a tuple of tensor, and each tensor is the
feature of a backbone stage. In ``ConformerHead``, we just obtain the
feature of the last stage.
"""
x = self.pre_logits(x)
# The ConformerHead doesn't have other module,
# just return after unpacking.
return feats[-1]
def forward(self, feats: Tuple[List[torch.Tensor]]) -> Tuple[torch.Tensor]:
"""The forward process."""
x = self.pre_logits(feats)
# There are two outputs in the Conformer model
assert len(x) == 2
conv_cls_score = self.conv_cls_head(x[0])
tran_cls_score = self.trans_cls_head(x[1])
if softmax:
cls_score = conv_cls_score + tran_cls_score
pred = (
F.softmax(cls_score, dim=1) if cls_score is not None else None)
if post_process:
pred = self.post_process(pred)
return conv_cls_score, tran_cls_score
def predict(
self,
feats: Tuple[List[torch.Tensor]],
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
"""Inference without augmentation.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample], optional): The annotation
data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None.
Returns:
List[ClsDataSample]: A list of data samples which contains the
predicted results.
"""
# The part can be traced by torch.fx
conv_cls_score, tran_cls_score = self(feats)
cls_score = conv_cls_score + tran_cls_score
# The part can not be traced by torch.fx
predictions = self._get_predictions(cls_score, data_samples)
return predictions
def _get_loss(self, cls_score: Tuple[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> dict:
"""Unpack data samples and compute loss."""
# Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label:
# Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_label.score for i in data_samples])
else:
pred = [conv_cls_score, tran_cls_score]
if post_process:
pred = list(map(self.post_process, pred))
return pred
target = torch.hstack([i.gt_label.label for i in data_samples])
def forward_train(self, x, gt_label):
x = self.pre_logits(x)
assert isinstance(x, list) and len(x) == 2, \
'There should be two outputs in the Conformer model'
conv_cls_score = self.conv_cls_head(x[0])
tran_cls_score = self.trans_cls_head(x[1])
losses = self.loss([conv_cls_score, tran_cls_score], gt_label)
return losses
def loss(self, cls_score, gt_label):
num_samples = len(cls_score[0])
losses = dict()
# compute loss
losses = dict()
loss = sum([
self.compute_loss(score, gt_label, avg_factor=num_samples) /
len(cls_score) for score in cls_score
self.loss_module(
score, target, avg_factor=score.size(0), **kwargs)
for score in cls_score
])
if self.cal_acc:
# compute accuracy
acc = self.compute_accuracy(cls_score[0] + cls_score[1], gt_label)
assert len(acc) == len(self.topk)
losses['accuracy'] = {
f'top-{k}': a
for k, a in zip(self.topk, acc)
}
losses['loss'] = loss
# compute accuracy
if self.cal_acc:
assert target.ndim == 1, 'If you enable batch augmentation ' \
'like mixup during training, `cal_acc` is pointless.'
acc = Accuracy.calculate(
cls_score[0] + cls_score[1], target, topk=self.topk)
losses.update(
{f'accuracy_top-{k}': a
for k, a in zip(self.topk, acc)})
return losses

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcls.registry import MODELS
from mmcls.utils import get_root_logger
@ -20,7 +22,7 @@ class DeiTClsHead(VisionTransformerClsHead):
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
hidden_dim (int): Number of the dimensions for hidden layer.
hidden_dim (int, optional): Number of the dimensions for hidden layer.
Defaults to None, which means no extra hidden layer.
act_cfg (dict): The activation config. Only available during
pre-training. Defaults to ``dict(type='Tanh')``.
@ -28,19 +30,24 @@ class DeiTClsHead(VisionTransformerClsHead):
``dict(type='Constant', layer='Linear', val=0)``.
"""
def __init__(self, *args, **kwargs):
super(DeiTClsHead, self).__init__(*args, **kwargs)
def _init_layers(self):
""""Init extra hidden linear layer to handle dist token if exists."""
super(DeiTClsHead, self)._init_layers()
if self.hidden_dim is None:
head_dist = nn.Linear(self.in_channels, self.num_classes)
else:
head_dist = nn.Linear(self.hidden_dim, self.num_classes)
self.layers.add_module('head_dist', head_dist)
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
_, cls_token, dist_token = x
def pre_logits(self,
feats: Tuple[List[torch.Tensor]]) -> Tuple[torch.Tensor]:
"""The process before the final classification head.
The input ``feats`` is a tuple of list of tensor, and each tensor is
the feature of a backbone stage. In ``DeiTClsHead``, we obtain the
feature of the last stage and forward in hidden layer if exists.
"""
_, cls_token, dist_token = feats[-1]
if self.hidden_dim is None:
return cls_token, dist_token
else:
@ -48,49 +55,13 @@ class DeiTClsHead(VisionTransformerClsHead):
dist_token = self.layers.act(self.layers.pre_logits(dist_token))
return cls_token, dist_token
def simple_test(self, x, softmax=True, post_process=True):
"""Inference without augmentation.
Args:
x (tuple[tuple[tensor, tensor, tensor]]): The input features.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. Every item should be a tuple which
includes patch token, cls token and dist token. The cls token
and dist token will be used to classify and the shape of them
should be ``(num_samples, in_channels)``.
softmax (bool): Whether to softmax the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
"""
cls_token, dist_token = self.pre_logits(x)
cls_score = (self.layers.head(cls_token) +
self.layers.head_dist(dist_token)) / 2
if softmax:
pred = F.softmax(
cls_score, dim=1) if cls_score is not None else None
else:
pred = cls_score
if post_process:
return self.post_process(pred)
else:
return pred
def forward_train(self, x, gt_label):
def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor:
"""The forward process."""
logger = get_root_logger()
logger.warning("MMClassification doesn't support to train the "
'distilled version DeiT.')
cls_token, dist_token = self.pre_logits(x)
cls_token, dist_token = self.pre_logits(feats)
# The final classification head.
cls_score = (self.layers.head(cls_token) +
self.layers.head_dist(dist_token)) / 2
losses = self.loss(cls_score, gt_label)
return losses
return cls_score

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyrigforward_trainht (c) OpenMMLab. All rights reserved.
from typing import Tuple
from typing import Optional, Tuple
import torch
import torch.nn as nn
@ -29,9 +29,10 @@ class LinearClsHead(ClsHead):
"""
def __init__(self,
num_classes,
in_channels,
init_cfg=dict(type='Normal', layer='Linear', std=0.01),
num_classes: int,
in_channels: int,
init_cfg: Optional[dict] = dict(
type='Normal', layer='Linear', std=0.01),
**kwargs):
super(LinearClsHead, self).__init__(init_cfg=init_cfg, **kwargs)

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Sequence
from typing import Dict, Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmcv.runner import BaseModule, ModuleList
@ -11,6 +11,7 @@ from .cls_head import ClsHead
class LinearBlock(BaseModule):
"""Linear block for StackedLinearClsHead."""
def __init__(self,
in_channels,
@ -34,6 +35,7 @@ class LinearBlock(BaseModule):
self.dropout = nn.Dropout(p=dropout_rate)
def forward(self, x):
"""The forward process."""
x = self.fc(x)
if self.norm is not None:
x = self.norm(x)
@ -51,7 +53,8 @@ class StackedLinearClsHead(ClsHead):
Args:
num_classes (int): Number of categories.
in_channels (int): Number of channels in the input feature map.
mid_channels (Sequence): Number of channels in the hidden fc layers.
mid_channels (Sequence[int]): Number of channels in the hidden fc
layers.
dropout_rate (float): Dropout rate after each hidden fc layer,
except the last layer. Defaults to 0.
norm_cfg (dict, optional): Config dict of normalization layer after
@ -63,18 +66,17 @@ class StackedLinearClsHead(ClsHead):
def __init__(self,
num_classes: int,
in_channels: int,
mid_channels: Sequence,
mid_channels: Sequence[int],
dropout_rate: float = 0.,
norm_cfg: Dict = None,
act_cfg: Dict = dict(type='ReLU'),
norm_cfg: Optional[Dict] = None,
act_cfg: Optional[Dict] = dict(type='ReLU'),
**kwargs):
super(StackedLinearClsHead, self).__init__(**kwargs)
assert num_classes > 0, \
f'`num_classes` of StackedLinearClsHead must be a positive ' \
f'integer, got {num_classes} instead.'
self.num_classes = num_classes
self.in_channels = in_channels
if self.num_classes <= 0:
raise ValueError(
f'num_classes={num_classes} must be a positive integer')
assert isinstance(mid_channels, Sequence), \
f'`mid_channels` of StackedLinearClsHead should be a sequence, ' \
@ -88,6 +90,7 @@ class StackedLinearClsHead(ClsHead):
self._init_layers()
def _init_layers(self):
""""Init layers."""
self.layers = ModuleList()
in_channels = self.in_channels
for hidden_channels in self.mid_channels:
@ -108,56 +111,25 @@ class StackedLinearClsHead(ClsHead):
norm_cfg=None,
act_cfg=None))
def init_weights(self):
self.layers.init_weights()
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The process before the final classification head.
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
The input ``feats`` is a tuple of tensor, and each tensor is the
feature of a backbone stage.
"""
x = feats[-1]
for layer in self.layers[:-1]:
x = layer(x)
return x
@property
def fc(self):
"""Full connected layer."""
return self.layers[-1]
def simple_test(self, x, softmax=True, post_process=True):
"""Inference without augmentation.
Args:
x (tuple[Tensor]): The input features.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. The shape of every item should be
``(num_samples, in_channels)``.
softmax (bool): Whether to softmax the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
"""
x = self.pre_logits(x)
cls_score = self.fc(x)
if softmax:
pred = (
F.softmax(cls_score, dim=1) if cls_score is not None else None)
else:
pred = cls_score
if post_process:
return self.post_process(pred)
else:
return pred
def forward_train(self, x, gt_label, **kwargs):
x = self.pre_logits(x)
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
# The final classification head.
cls_score = self.fc(pre_logits)
return cls_score

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from collections import OrderedDict
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_activation_layer
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner import Sequential
@ -20,7 +21,7 @@ class VisionTransformerClsHead(ClsHead):
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
hidden_dim (int): Number of the dimensions for hidden layer.
hidden_dim (int, optional): Number of the dimensions for hidden layer.
Defaults to None, which means no extra hidden layer.
act_cfg (dict): The activation config. Only available during
pre-training. Defaults to ``dict(type='Tanh')``.
@ -29,15 +30,14 @@ class VisionTransformerClsHead(ClsHead):
"""
def __init__(self,
num_classes,
in_channels,
hidden_dim=None,
act_cfg=dict(type='Tanh'),
init_cfg=dict(type='Constant', layer='Linear', val=0),
*args,
num_classes: int,
in_channels: int,
hidden_dim: Optional[int] = None,
act_cfg: dict = dict(type='Tanh'),
init_cfg: dict = dict(type='Constant', layer='Linear', val=0),
**kwargs):
super(VisionTransformerClsHead, self).__init__(
init_cfg=init_cfg, *args, **kwargs)
init_cfg=init_cfg, **kwargs)
self.in_channels = in_channels
self.num_classes = num_classes
self.hidden_dim = hidden_dim
@ -50,6 +50,7 @@ class VisionTransformerClsHead(ClsHead):
self._init_layers()
def _init_layers(self):
""""Init hidden layer if exists."""
if self.hidden_dim is None:
layers = [('head', nn.Linear(self.in_channels, self.num_classes))]
else:
@ -61,6 +62,7 @@ class VisionTransformerClsHead(ClsHead):
self.layers = Sequential(OrderedDict(layers))
def init_weights(self):
""""Init weights of hidden layer if exists."""
super(VisionTransformerClsHead, self).init_weights()
# Modified from ClassyVision
if hasattr(self.layers, 'pre_logits'):
@ -70,54 +72,24 @@ class VisionTransformerClsHead(ClsHead):
std=math.sqrt(1 / self.layers.pre_logits.in_features))
nn.init.zeros_(self.layers.pre_logits.bias)
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
_, cls_token = x
def pre_logits(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor:
"""The process before the final classification head.
The input ``feats`` is a tuple of list of tensor, and each tensor is
the feature of a backbone stage. In ``VisionTransformerClsHead``, we
obtain the feature of the last stage and forward in hidden layer if
exists.
"""
_, cls_token = feats[-1]
if self.hidden_dim is None:
return cls_token
else:
x = self.layers.pre_logits(cls_token)
return self.layers.act(x)
def simple_test(self, x, softmax=True, post_process=True):
"""Inference without augmentation.
Args:
x (tuple[tuple[tensor, tensor]]): The input features.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. Every item should be a tuple which
includes patch token and cls token. The cls token will be used
to classify and the shape of it should be
``(num_samples, in_channels)``.
softmax (bool): Whether to softmax the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
"""
x = self.pre_logits(x)
cls_score = self.layers.head(x)
if softmax:
pred = (
F.softmax(cls_score, dim=1) if cls_score is not None else None)
else:
pred = cls_score
if post_process:
return self.post_process(pred)
else:
return pred
def forward_train(self, x, gt_label, **kwargs):
x = self.pre_logits(x)
cls_score = self.layers.head(x)
losses = self.loss(cls_score, gt_label, **kwargs)
return losses
def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
# The final classification head.
cls_score = self.layers.head(pre_logits)
return cls_score

View File

@ -106,6 +106,205 @@ class TestLinearClsHead(TestCase):
self.assertEqual(outs.shape, (4, 5))
class TestVisionTransformerClsHead(TestCase):
DEFAULT_ARGS = dict(
type='VisionTransformerClsHead', in_channels=10, num_classes=5)
fake_feats = ([torch.rand(4, 7, 7, 16), torch.rand(4, 10)], )
def test_initialize(self):
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
# test vit head default
head = MODELS.build(self.DEFAULT_ARGS)
assert not hasattr(head.layers, 'pre_logits')
assert not hasattr(head.layers, 'act')
# test vit head hidden_dim
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
assert hasattr(head.layers, 'pre_logits')
assert hasattr(head.layers, 'act')
# test vit head init_weights
head = MODELS.build(self.DEFAULT_ARGS)
head.init_weights()
# test vit head init_weights with hidden_dim
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
head.init_weights()
assert abs(head.layers.pre_logits.weight).sum() > 0
def test_pre_logits(self):
# test default
head = MODELS.build(self.DEFAULT_ARGS)
pre_logits = head.pre_logits(self.fake_feats)
self.assertIs(pre_logits, self.fake_feats[-1][1])
# test hidden_dim
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
pre_logits = head.pre_logits(self.fake_feats)
self.assertEqual(pre_logits.shape, (4, 30))
def test_forward(self):
# test default
head = MODELS.build(self.DEFAULT_ARGS)
outs = head(self.fake_feats)
self.assertEqual(outs.shape, (4, 5))
# test hidden_dim
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
outs = head(self.fake_feats)
self.assertEqual(outs.shape, (4, 5))
class TestDeiTClsHead(TestVisionTransformerClsHead):
DEFAULT_ARGS = dict(type='DeiTClsHead', in_channels=10, num_classes=5)
fake_feats = ([
torch.rand(4, 7, 7, 16),
torch.rand(4, 10),
torch.rand(4, 10)
], )
def test_pre_logits(self):
# test default
head = MODELS.build(self.DEFAULT_ARGS)
cls_token, dist_token = head.pre_logits(self.fake_feats)
self.assertIs(cls_token, self.fake_feats[-1][1])
self.assertIs(dist_token, self.fake_feats[-1][2])
# test hidden_dim
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
cls_token, dist_token = head.pre_logits(self.fake_feats)
self.assertEqual(cls_token.shape, (4, 30))
self.assertEqual(dist_token.shape, (4, 30))
class TestConformerHead(TestCase):
DEFAULT_ARGS = dict(
type='ConformerHead', in_channels=[64, 96], num_classes=5)
fake_feats = ([torch.rand(4, 64), torch.rand(4, 96)], )
def test_initialize(self):
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
# test default
head = MODELS.build(self.DEFAULT_ARGS)
assert hasattr(head, 'conv_cls_head')
assert hasattr(head, 'trans_cls_head')
# test init_weights
head = MODELS.build(self.DEFAULT_ARGS)
head.init_weights()
assert abs(head.conv_cls_head.weight).sum() > 0
assert abs(head.trans_cls_head.weight).sum() > 0
def test_pre_logits(self):
# test default
head = MODELS.build(self.DEFAULT_ARGS)
pre_logits = head.pre_logits(self.fake_feats)
self.assertIs(pre_logits, self.fake_feats[-1])
def test_forward(self):
head = MODELS.build(self.DEFAULT_ARGS)
outs = head(self.fake_feats)
self.assertEqual(outs[0].shape, (4, 5))
self.assertEqual(outs[1].shape, (4, 5))
def test_loss(self):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
# with cal_acc = False
head = MODELS.build(self.DEFAULT_ARGS)
losses = head.loss(self.fake_feats, data_samples)
self.assertEqual(losses.keys(), {'loss'})
self.assertGreater(losses['loss'].item(), 0)
# with cal_acc = True
cfg = {**self.DEFAULT_ARGS, 'topk': (1, 2), 'cal_acc': True}
head = MODELS.build(cfg)
losses = head.loss(self.fake_feats, data_samples)
self.assertEqual(losses.keys(),
{'loss', 'accuracy_top-1', 'accuracy_top-2'})
self.assertGreater(losses['loss'].item(), 0)
# test assertion when cal_acc but data is batch agumented.
data_samples = [
sample.set_gt_score(torch.rand(5)) for sample in data_samples
]
cfg = {
**self.DEFAULT_ARGS, 'cal_acc': True,
'loss': dict(type='CrossEntropyLoss', use_soft=True)
}
head = MODELS.build(cfg)
with self.assertRaisesRegex(AssertionError, 'batch augmentation'):
head.loss(self.fake_feats, data_samples)
def test_predict(self):
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples
predictions = head.predict(self.fake_feats)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
for pred in predictions:
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
# with with data_samples
predictions = head.predict(self.fake_feats, data_samples)
self.assertTrue(is_seq_of(predictions, ClsDataSample))
for sample, pred in zip(data_samples, predictions):
self.assertIs(sample, pred)
self.assertIn('label', pred.pred_label)
self.assertIn('score', pred.pred_label)
class TestStackedLinearClsHead(TestCase):
DEFAULT_ARGS = dict(
type='StackedLinearClsHead', in_channels=10, num_classes=5)
fake_feats = (torch.rand(4, 10), )
def test_initialize(self):
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
MODELS.build({
**self.DEFAULT_ARGS, 'num_classes': -5,
'mid_channels': 10
})
# test mid_channels
with self.assertRaisesRegex(AssertionError, 'should be a sequence'):
MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': 10})
# test default
head = MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': [20]})
assert len(head.layers) == 2
head.init_weights()
def test_pre_logits(self):
# test default
head = MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': [20, 30]})
pre_logits = head.pre_logits(self.fake_feats)
self.assertEqual(pre_logits.shape, (4, 30))
def test_forward(self):
# test default
head = MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': [20, 30]})
outs = head(self.fake_feats)
self.assertEqual(outs.shape, (4, 5))
head = MODELS.build({
**self.DEFAULT_ARGS, 'mid_channels': [8, 10],
'dropout_rate': 0.2,
'norm_cfg': dict(type='BN1d'),
'act_cfg': dict(type='HSwish')
})
outs = head(self.fake_feats)
self.assertEqual(outs.shape, (4, 5))
"""Temporarily disabled.
@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )])
def test_multilabel_head(feat):
@ -215,125 +414,4 @@ def test_stacked_linear_cls_head(feat):
losses = head.forward_train(feat, fake_gt_label)
assert losses['loss'].item() > 0
def test_vit_head():
fake_features = ([torch.rand(4, 7, 7, 16), torch.rand(4, 100)], )
fake_gt_label = torch.randint(0, 10, (4, ))
# test vit head forward
head = VisionTransformerClsHead(10, 100)
losses = head.forward_train(fake_features, fake_gt_label)
assert not hasattr(head.layers, 'pre_logits')
assert not hasattr(head.layers, 'act')
assert losses['loss'].item() > 0
# test vit head forward with hidden layer
head = VisionTransformerClsHead(10, 100, hidden_dim=20)
losses = head.forward_train(fake_features, fake_gt_label)
assert hasattr(head.layers, 'pre_logits') and hasattr(head.layers, 'act')
assert losses['loss'].item() > 0
# test vit head init_weights
head = VisionTransformerClsHead(10, 100, hidden_dim=20)
head.init_weights()
assert abs(head.layers.pre_logits.weight).sum() > 0
head = VisionTransformerClsHead(10, 100, hidden_dim=20)
# test simple_test with post_process
pred = head.simple_test(fake_features)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(fake_features)
assert pred.shape == (4, 10)
# test simple_test without post_process
pred = head.simple_test(fake_features, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(fake_features, softmax=False, post_process=False)
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
# test pre_logits
features = head.pre_logits(fake_features)
assert features.shape == (4, 20)
# test assertion
with pytest.raises(ValueError):
VisionTransformerClsHead(-1, 100)
def test_conformer_head():
fake_features = ([torch.rand(4, 64), torch.rand(4, 96)], )
fake_gt_label = torch.randint(0, 10, (4, ))
# test conformer head forward
head = ConformerHead(num_classes=10, in_channels=[64, 96])
losses = head.forward_train(fake_features, fake_gt_label)
assert losses['loss'].item() > 0
# test simple_test with post_process
pred = head.simple_test(fake_features)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(fake_features)
assert pred.shape == (4, 10)
# test simple_test without post_process
pred = head.simple_test(fake_features, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(fake_features, softmax=False, post_process=False)
torch.testing.assert_allclose(pred, torch.softmax(sum(logits), dim=1))
# test pre_logits
features = head.pre_logits(fake_features)
assert features is fake_features[0]
def test_deit_head():
fake_features = ([
torch.rand(4, 7, 7, 16),
torch.rand(4, 100),
torch.rand(4, 100)
], )
fake_gt_label = torch.randint(0, 10, (4, ))
# test deit head forward
head = DeiTClsHead(num_classes=10, in_channels=100)
losses = head.forward_train(fake_features, fake_gt_label)
assert not hasattr(head.layers, 'pre_logits')
assert not hasattr(head.layers, 'act')
assert losses['loss'].item() > 0
# test deit head forward with hidden layer
head = DeiTClsHead(num_classes=10, in_channels=100, hidden_dim=20)
losses = head.forward_train(fake_features, fake_gt_label)
assert hasattr(head.layers, 'pre_logits') and hasattr(head.layers, 'act')
assert losses['loss'].item() > 0
# test deit head init_weights
head = DeiTClsHead(10, 100, hidden_dim=20)
head.init_weights()
assert abs(head.layers.pre_logits.weight).sum() > 0
head = DeiTClsHead(10, 100, hidden_dim=20)
# test simple_test with post_process
pred = head.simple_test(fake_features)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
pred = head.simple_test(fake_features)
assert pred.shape == (4, 10)
# test simple_test without post_process
pred = head.simple_test(fake_features, post_process=False)
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
logits = head.simple_test(fake_features, softmax=False, post_process=False)
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
# test pre_logits
cls_token, dist_token = head.pre_logits(fake_features)
assert cls_token.shape == (4, 20)
assert dist_token.shape == (4, 20)
# test assertion
with pytest.raises(ValueError):
DeiTClsHead(-1, 100)
"""