mirror of
https://github.com/open-mmlab/mmpretrain.git
synced 2025-06-03 14:59:18 +08:00
[Refactor] refactor ViTHead, DeiTHead, ConformerHead, StackedHead
This commit is contained in:
parent
a82de04b67
commit
62b046521e
@ -1,5 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -27,10 +27,10 @@ class ClsHead(BaseHead):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, ),
|
||||
cal_acc=False,
|
||||
init_cfg=None):
|
||||
loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk: Union[int, Tuple[int]] = (1, ),
|
||||
cal_acc: bool = False,
|
||||
init_cfg: Optional[dict] = None):
|
||||
super(ClsHead, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.topk = topk
|
||||
|
@ -1,8 +1,11 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcls.core import ClsDataSample
|
||||
from mmcls.metrics import Accuracy
|
||||
from mmcls.registry import MODELS
|
||||
from .cls_head import ClsHead
|
||||
|
||||
@ -14,19 +17,19 @@ class ConformerHead(ClsHead):
|
||||
Args:
|
||||
num_classes (int): Number of categories excluding the background
|
||||
category.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
in_channels (Sequence[int]): Number of channels in the input
|
||||
feature map.
|
||||
init_cfg (dict | optional): The extra init config of layers.
|
||||
Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
in_channels, # [conv_dim, trans_dim]
|
||||
init_cfg=dict(type='Normal', layer='Linear', std=0.01),
|
||||
*args,
|
||||
num_classes: int,
|
||||
in_channels: Sequence[int], # [conv_dim, trans_dim]
|
||||
init_cfg: dict = dict(type='TruncNormal', layer='Linear', std=.02),
|
||||
**kwargs):
|
||||
super(ConformerHead, self).__init__(init_cfg=None, *args, **kwargs)
|
||||
super(ConformerHead, self).__init__(init_cfg=init_cfg, **kwargs)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
@ -39,94 +42,82 @@ class ConformerHead(ClsHead):
|
||||
self.conv_cls_head = nn.Linear(self.in_channels[0], num_classes)
|
||||
self.trans_cls_head = nn.Linear(self.in_channels[1], num_classes)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
def pre_logits(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor:
|
||||
"""The process before the final classification head.
|
||||
|
||||
def init_weights(self):
|
||||
super(ConformerHead, self).init_weights()
|
||||
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
# Suppress default init if use pretrained model.
|
||||
return
|
||||
else:
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def pre_logits(self, x):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
return x
|
||||
|
||||
def simple_test(self, x, softmax=True, post_process=True):
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
x (tuple[tuple[tensor, tensor]]): The input features.
|
||||
Multi-stage inputs are acceptable but only the last stage will
|
||||
be used to classify. Every item should be a tuple which
|
||||
includes convluation features and transformer features. The
|
||||
shape of them should be ``(num_samples, in_channels[0])`` and
|
||||
``(num_samples, in_channels[1])``.
|
||||
softmax (bool): Whether to softmax the classification score.
|
||||
post_process (bool): Whether to do post processing the
|
||||
inference results. It will convert the output to a list.
|
||||
|
||||
Returns:
|
||||
Tensor | list: The inference results.
|
||||
|
||||
- If no post processing, the output is a tensor with shape
|
||||
``(num_samples, num_classes)``.
|
||||
- If post processing, the output is a multi-dimentional list of
|
||||
float and the dimensions are ``(num_samples, num_classes)``.
|
||||
The input ``feats`` is a tuple of tensor, and each tensor is the
|
||||
feature of a backbone stage. In ``ConformerHead``, we just obtain the
|
||||
feature of the last stage.
|
||||
"""
|
||||
x = self.pre_logits(x)
|
||||
# The ConformerHead doesn't have other module,
|
||||
# just return after unpacking.
|
||||
return feats[-1]
|
||||
|
||||
def forward(self, feats: Tuple[List[torch.Tensor]]) -> Tuple[torch.Tensor]:
|
||||
"""The forward process."""
|
||||
x = self.pre_logits(feats)
|
||||
# There are two outputs in the Conformer model
|
||||
assert len(x) == 2
|
||||
|
||||
conv_cls_score = self.conv_cls_head(x[0])
|
||||
tran_cls_score = self.trans_cls_head(x[1])
|
||||
|
||||
if softmax:
|
||||
cls_score = conv_cls_score + tran_cls_score
|
||||
pred = (
|
||||
F.softmax(cls_score, dim=1) if cls_score is not None else None)
|
||||
if post_process:
|
||||
pred = self.post_process(pred)
|
||||
return conv_cls_score, tran_cls_score
|
||||
|
||||
def predict(
|
||||
self,
|
||||
feats: Tuple[List[torch.Tensor]],
|
||||
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
feats (tuple[Tensor]): The features extracted from the backbone.
|
||||
Multiple stage inputs are acceptable but only the last stage
|
||||
will be used to classify. The shape of every item should be
|
||||
``(num_samples, num_classes)``.
|
||||
data_samples (List[ClsDataSample], optional): The annotation
|
||||
data of every samples. If not None, set ``pred_label`` of
|
||||
the input data samples. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[ClsDataSample]: A list of data samples which contains the
|
||||
predicted results.
|
||||
"""
|
||||
# The part can be traced by torch.fx
|
||||
conv_cls_score, tran_cls_score = self(feats)
|
||||
cls_score = conv_cls_score + tran_cls_score
|
||||
|
||||
# The part can not be traced by torch.fx
|
||||
predictions = self._get_predictions(cls_score, data_samples)
|
||||
return predictions
|
||||
|
||||
def _get_loss(self, cls_score: Tuple[torch.Tensor],
|
||||
data_samples: List[ClsDataSample], **kwargs) -> dict:
|
||||
"""Unpack data samples and compute loss."""
|
||||
# Unpack data samples and pack targets
|
||||
if 'score' in data_samples[0].gt_label:
|
||||
# Batch augmentation may convert labels to one-hot format scores.
|
||||
target = torch.stack([i.gt_label.score for i in data_samples])
|
||||
else:
|
||||
pred = [conv_cls_score, tran_cls_score]
|
||||
if post_process:
|
||||
pred = list(map(self.post_process, pred))
|
||||
return pred
|
||||
target = torch.hstack([i.gt_label.label for i in data_samples])
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
x = self.pre_logits(x)
|
||||
assert isinstance(x, list) and len(x) == 2, \
|
||||
'There should be two outputs in the Conformer model'
|
||||
|
||||
conv_cls_score = self.conv_cls_head(x[0])
|
||||
tran_cls_score = self.trans_cls_head(x[1])
|
||||
|
||||
losses = self.loss([conv_cls_score, tran_cls_score], gt_label)
|
||||
return losses
|
||||
|
||||
def loss(self, cls_score, gt_label):
|
||||
num_samples = len(cls_score[0])
|
||||
losses = dict()
|
||||
# compute loss
|
||||
losses = dict()
|
||||
loss = sum([
|
||||
self.compute_loss(score, gt_label, avg_factor=num_samples) /
|
||||
len(cls_score) for score in cls_score
|
||||
self.loss_module(
|
||||
score, target, avg_factor=score.size(0), **kwargs)
|
||||
for score in cls_score
|
||||
])
|
||||
if self.cal_acc:
|
||||
# compute accuracy
|
||||
acc = self.compute_accuracy(cls_score[0] + cls_score[1], gt_label)
|
||||
assert len(acc) == len(self.topk)
|
||||
losses['accuracy'] = {
|
||||
f'top-{k}': a
|
||||
for k, a in zip(self.topk, acc)
|
||||
}
|
||||
losses['loss'] = loss
|
||||
|
||||
# compute accuracy
|
||||
if self.cal_acc:
|
||||
assert target.ndim == 1, 'If you enable batch augmentation ' \
|
||||
'like mixup during training, `cal_acc` is pointless.'
|
||||
acc = Accuracy.calculate(
|
||||
cls_score[0] + cls_score[1], target, topk=self.topk)
|
||||
losses.update(
|
||||
{f'accuracy_top-{k}': a
|
||||
for k, a in zip(self.topk, acc)})
|
||||
|
||||
return losses
|
||||
|
@ -1,6 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from mmcls.utils import get_root_logger
|
||||
@ -20,7 +22,7 @@ class DeiTClsHead(VisionTransformerClsHead):
|
||||
num_classes (int): Number of categories excluding the background
|
||||
category.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
hidden_dim (int): Number of the dimensions for hidden layer.
|
||||
hidden_dim (int, optional): Number of the dimensions for hidden layer.
|
||||
Defaults to None, which means no extra hidden layer.
|
||||
act_cfg (dict): The activation config. Only available during
|
||||
pre-training. Defaults to ``dict(type='Tanh')``.
|
||||
@ -28,19 +30,24 @@ class DeiTClsHead(VisionTransformerClsHead):
|
||||
``dict(type='Constant', layer='Linear', val=0)``.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DeiTClsHead, self).__init__(*args, **kwargs)
|
||||
def _init_layers(self):
|
||||
""""Init extra hidden linear layer to handle dist token if exists."""
|
||||
super(DeiTClsHead, self)._init_layers()
|
||||
if self.hidden_dim is None:
|
||||
head_dist = nn.Linear(self.in_channels, self.num_classes)
|
||||
else:
|
||||
head_dist = nn.Linear(self.hidden_dim, self.num_classes)
|
||||
self.layers.add_module('head_dist', head_dist)
|
||||
|
||||
def pre_logits(self, x):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
_, cls_token, dist_token = x
|
||||
def pre_logits(self,
|
||||
feats: Tuple[List[torch.Tensor]]) -> Tuple[torch.Tensor]:
|
||||
"""The process before the final classification head.
|
||||
|
||||
The input ``feats`` is a tuple of list of tensor, and each tensor is
|
||||
the feature of a backbone stage. In ``DeiTClsHead``, we obtain the
|
||||
feature of the last stage and forward in hidden layer if exists.
|
||||
"""
|
||||
_, cls_token, dist_token = feats[-1]
|
||||
if self.hidden_dim is None:
|
||||
return cls_token, dist_token
|
||||
else:
|
||||
@ -48,49 +55,13 @@ class DeiTClsHead(VisionTransformerClsHead):
|
||||
dist_token = self.layers.act(self.layers.pre_logits(dist_token))
|
||||
return cls_token, dist_token
|
||||
|
||||
def simple_test(self, x, softmax=True, post_process=True):
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
x (tuple[tuple[tensor, tensor, tensor]]): The input features.
|
||||
Multi-stage inputs are acceptable but only the last stage will
|
||||
be used to classify. Every item should be a tuple which
|
||||
includes patch token, cls token and dist token. The cls token
|
||||
and dist token will be used to classify and the shape of them
|
||||
should be ``(num_samples, in_channels)``.
|
||||
softmax (bool): Whether to softmax the classification score.
|
||||
post_process (bool): Whether to do post processing the
|
||||
inference results. It will convert the output to a list.
|
||||
|
||||
Returns:
|
||||
Tensor | list: The inference results.
|
||||
|
||||
- If no post processing, the output is a tensor with shape
|
||||
``(num_samples, num_classes)``.
|
||||
- If post processing, the output is a multi-dimentional list of
|
||||
float and the dimensions are ``(num_samples, num_classes)``.
|
||||
"""
|
||||
cls_token, dist_token = self.pre_logits(x)
|
||||
cls_score = (self.layers.head(cls_token) +
|
||||
self.layers.head_dist(dist_token)) / 2
|
||||
|
||||
if softmax:
|
||||
pred = F.softmax(
|
||||
cls_score, dim=1) if cls_score is not None else None
|
||||
else:
|
||||
pred = cls_score
|
||||
|
||||
if post_process:
|
||||
return self.post_process(pred)
|
||||
else:
|
||||
return pred
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor:
|
||||
"""The forward process."""
|
||||
logger = get_root_logger()
|
||||
logger.warning("MMClassification doesn't support to train the "
|
||||
'distilled version DeiT.')
|
||||
cls_token, dist_token = self.pre_logits(x)
|
||||
cls_token, dist_token = self.pre_logits(feats)
|
||||
# The final classification head.
|
||||
cls_score = (self.layers.head(cls_token) +
|
||||
self.layers.head_dist(dist_token)) / 2
|
||||
losses = self.loss(cls_score, gt_label)
|
||||
return losses
|
||||
return cls_score
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Copyrigforward_trainht (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -29,9 +29,10 @@ class LinearClsHead(ClsHead):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
in_channels,
|
||||
init_cfg=dict(type='Normal', layer='Linear', std=0.01),
|
||||
num_classes: int,
|
||||
in_channels: int,
|
||||
init_cfg: Optional[dict] = dict(
|
||||
type='Normal', layer='Linear', std=0.01),
|
||||
**kwargs):
|
||||
super(LinearClsHead, self).__init__(init_cfg=init_cfg, **kwargs)
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Sequence
|
||||
from typing import Dict, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_activation_layer, build_norm_layer
|
||||
from mmcv.runner import BaseModule, ModuleList
|
||||
|
||||
@ -11,6 +11,7 @@ from .cls_head import ClsHead
|
||||
|
||||
|
||||
class LinearBlock(BaseModule):
|
||||
"""Linear block for StackedLinearClsHead."""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
@ -34,6 +35,7 @@ class LinearBlock(BaseModule):
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
"""The forward process."""
|
||||
x = self.fc(x)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
@ -51,7 +53,8 @@ class StackedLinearClsHead(ClsHead):
|
||||
Args:
|
||||
num_classes (int): Number of categories.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
mid_channels (Sequence): Number of channels in the hidden fc layers.
|
||||
mid_channels (Sequence[int]): Number of channels in the hidden fc
|
||||
layers.
|
||||
dropout_rate (float): Dropout rate after each hidden fc layer,
|
||||
except the last layer. Defaults to 0.
|
||||
norm_cfg (dict, optional): Config dict of normalization layer after
|
||||
@ -63,18 +66,17 @@ class StackedLinearClsHead(ClsHead):
|
||||
def __init__(self,
|
||||
num_classes: int,
|
||||
in_channels: int,
|
||||
mid_channels: Sequence,
|
||||
mid_channels: Sequence[int],
|
||||
dropout_rate: float = 0.,
|
||||
norm_cfg: Dict = None,
|
||||
act_cfg: Dict = dict(type='ReLU'),
|
||||
norm_cfg: Optional[Dict] = None,
|
||||
act_cfg: Optional[Dict] = dict(type='ReLU'),
|
||||
**kwargs):
|
||||
super(StackedLinearClsHead, self).__init__(**kwargs)
|
||||
assert num_classes > 0, \
|
||||
f'`num_classes` of StackedLinearClsHead must be a positive ' \
|
||||
f'integer, got {num_classes} instead.'
|
||||
self.num_classes = num_classes
|
||||
|
||||
self.in_channels = in_channels
|
||||
if self.num_classes <= 0:
|
||||
raise ValueError(
|
||||
f'num_classes={num_classes} must be a positive integer')
|
||||
|
||||
assert isinstance(mid_channels, Sequence), \
|
||||
f'`mid_channels` of StackedLinearClsHead should be a sequence, ' \
|
||||
@ -88,6 +90,7 @@ class StackedLinearClsHead(ClsHead):
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
""""Init layers."""
|
||||
self.layers = ModuleList()
|
||||
in_channels = self.in_channels
|
||||
for hidden_channels in self.mid_channels:
|
||||
@ -108,56 +111,25 @@ class StackedLinearClsHead(ClsHead):
|
||||
norm_cfg=None,
|
||||
act_cfg=None))
|
||||
|
||||
def init_weights(self):
|
||||
self.layers.init_weights()
|
||||
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""The process before the final classification head.
|
||||
|
||||
def pre_logits(self, x):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
The input ``feats`` is a tuple of tensor, and each tensor is the
|
||||
feature of a backbone stage.
|
||||
"""
|
||||
x = feats[-1]
|
||||
for layer in self.layers[:-1]:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def fc(self):
|
||||
"""Full connected layer."""
|
||||
return self.layers[-1]
|
||||
|
||||
def simple_test(self, x, softmax=True, post_process=True):
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): The input features.
|
||||
Multi-stage inputs are acceptable but only the last stage will
|
||||
be used to classify. The shape of every item should be
|
||||
``(num_samples, in_channels)``.
|
||||
softmax (bool): Whether to softmax the classification score.
|
||||
post_process (bool): Whether to do post processing the
|
||||
inference results. It will convert the output to a list.
|
||||
|
||||
Returns:
|
||||
Tensor | list: The inference results.
|
||||
|
||||
- If no post processing, the output is a tensor with shape
|
||||
``(num_samples, num_classes)``.
|
||||
- If post processing, the output is a multi-dimentional list of
|
||||
float and the dimensions are ``(num_samples, num_classes)``.
|
||||
"""
|
||||
x = self.pre_logits(x)
|
||||
cls_score = self.fc(x)
|
||||
|
||||
if softmax:
|
||||
pred = (
|
||||
F.softmax(cls_score, dim=1) if cls_score is not None else None)
|
||||
else:
|
||||
pred = cls_score
|
||||
|
||||
if post_process:
|
||||
return self.post_process(pred)
|
||||
else:
|
||||
return pred
|
||||
|
||||
def forward_train(self, x, gt_label, **kwargs):
|
||||
x = self.pre_logits(x)
|
||||
cls_score = self.fc(x)
|
||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""The forward process."""
|
||||
pre_logits = self.pre_logits(feats)
|
||||
# The final classification head.
|
||||
cls_score = self.fc(pre_logits)
|
||||
return cls_score
|
||||
|
@ -1,9 +1,10 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_activation_layer
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from mmcv.runner import Sequential
|
||||
@ -20,7 +21,7 @@ class VisionTransformerClsHead(ClsHead):
|
||||
num_classes (int): Number of categories excluding the background
|
||||
category.
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
hidden_dim (int): Number of the dimensions for hidden layer.
|
||||
hidden_dim (int, optional): Number of the dimensions for hidden layer.
|
||||
Defaults to None, which means no extra hidden layer.
|
||||
act_cfg (dict): The activation config. Only available during
|
||||
pre-training. Defaults to ``dict(type='Tanh')``.
|
||||
@ -29,15 +30,14 @@ class VisionTransformerClsHead(ClsHead):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
in_channels,
|
||||
hidden_dim=None,
|
||||
act_cfg=dict(type='Tanh'),
|
||||
init_cfg=dict(type='Constant', layer='Linear', val=0),
|
||||
*args,
|
||||
num_classes: int,
|
||||
in_channels: int,
|
||||
hidden_dim: Optional[int] = None,
|
||||
act_cfg: dict = dict(type='Tanh'),
|
||||
init_cfg: dict = dict(type='Constant', layer='Linear', val=0),
|
||||
**kwargs):
|
||||
super(VisionTransformerClsHead, self).__init__(
|
||||
init_cfg=init_cfg, *args, **kwargs)
|
||||
init_cfg=init_cfg, **kwargs)
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
self.hidden_dim = hidden_dim
|
||||
@ -50,6 +50,7 @@ class VisionTransformerClsHead(ClsHead):
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
""""Init hidden layer if exists."""
|
||||
if self.hidden_dim is None:
|
||||
layers = [('head', nn.Linear(self.in_channels, self.num_classes))]
|
||||
else:
|
||||
@ -61,6 +62,7 @@ class VisionTransformerClsHead(ClsHead):
|
||||
self.layers = Sequential(OrderedDict(layers))
|
||||
|
||||
def init_weights(self):
|
||||
""""Init weights of hidden layer if exists."""
|
||||
super(VisionTransformerClsHead, self).init_weights()
|
||||
# Modified from ClassyVision
|
||||
if hasattr(self.layers, 'pre_logits'):
|
||||
@ -70,54 +72,24 @@ class VisionTransformerClsHead(ClsHead):
|
||||
std=math.sqrt(1 / self.layers.pre_logits.in_features))
|
||||
nn.init.zeros_(self.layers.pre_logits.bias)
|
||||
|
||||
def pre_logits(self, x):
|
||||
if isinstance(x, tuple):
|
||||
x = x[-1]
|
||||
_, cls_token = x
|
||||
def pre_logits(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor:
|
||||
"""The process before the final classification head.
|
||||
|
||||
The input ``feats`` is a tuple of list of tensor, and each tensor is
|
||||
the feature of a backbone stage. In ``VisionTransformerClsHead``, we
|
||||
obtain the feature of the last stage and forward in hidden layer if
|
||||
exists.
|
||||
"""
|
||||
_, cls_token = feats[-1]
|
||||
if self.hidden_dim is None:
|
||||
return cls_token
|
||||
else:
|
||||
x = self.layers.pre_logits(cls_token)
|
||||
return self.layers.act(x)
|
||||
|
||||
def simple_test(self, x, softmax=True, post_process=True):
|
||||
"""Inference without augmentation.
|
||||
|
||||
Args:
|
||||
x (tuple[tuple[tensor, tensor]]): The input features.
|
||||
Multi-stage inputs are acceptable but only the last stage will
|
||||
be used to classify. Every item should be a tuple which
|
||||
includes patch token and cls token. The cls token will be used
|
||||
to classify and the shape of it should be
|
||||
``(num_samples, in_channels)``.
|
||||
softmax (bool): Whether to softmax the classification score.
|
||||
post_process (bool): Whether to do post processing the
|
||||
inference results. It will convert the output to a list.
|
||||
|
||||
Returns:
|
||||
Tensor | list: The inference results.
|
||||
|
||||
- If no post processing, the output is a tensor with shape
|
||||
``(num_samples, num_classes)``.
|
||||
- If post processing, the output is a multi-dimentional list of
|
||||
float and the dimensions are ``(num_samples, num_classes)``.
|
||||
"""
|
||||
x = self.pre_logits(x)
|
||||
cls_score = self.layers.head(x)
|
||||
|
||||
if softmax:
|
||||
pred = (
|
||||
F.softmax(cls_score, dim=1) if cls_score is not None else None)
|
||||
else:
|
||||
pred = cls_score
|
||||
|
||||
if post_process:
|
||||
return self.post_process(pred)
|
||||
else:
|
||||
return pred
|
||||
|
||||
def forward_train(self, x, gt_label, **kwargs):
|
||||
x = self.pre_logits(x)
|
||||
cls_score = self.layers.head(x)
|
||||
losses = self.loss(cls_score, gt_label, **kwargs)
|
||||
return losses
|
||||
def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor:
|
||||
"""The forward process."""
|
||||
pre_logits = self.pre_logits(feats)
|
||||
# The final classification head.
|
||||
cls_score = self.layers.head(pre_logits)
|
||||
return cls_score
|
||||
|
@ -106,6 +106,205 @@ class TestLinearClsHead(TestCase):
|
||||
self.assertEqual(outs.shape, (4, 5))
|
||||
|
||||
|
||||
class TestVisionTransformerClsHead(TestCase):
|
||||
DEFAULT_ARGS = dict(
|
||||
type='VisionTransformerClsHead', in_channels=10, num_classes=5)
|
||||
fake_feats = ([torch.rand(4, 7, 7, 16), torch.rand(4, 10)], )
|
||||
|
||||
def test_initialize(self):
|
||||
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
|
||||
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
|
||||
|
||||
# test vit head default
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
assert not hasattr(head.layers, 'pre_logits')
|
||||
assert not hasattr(head.layers, 'act')
|
||||
|
||||
# test vit head hidden_dim
|
||||
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
|
||||
assert hasattr(head.layers, 'pre_logits')
|
||||
assert hasattr(head.layers, 'act')
|
||||
|
||||
# test vit head init_weights
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
head.init_weights()
|
||||
|
||||
# test vit head init_weights with hidden_dim
|
||||
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
|
||||
head.init_weights()
|
||||
assert abs(head.layers.pre_logits.weight).sum() > 0
|
||||
|
||||
def test_pre_logits(self):
|
||||
# test default
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
pre_logits = head.pre_logits(self.fake_feats)
|
||||
self.assertIs(pre_logits, self.fake_feats[-1][1])
|
||||
|
||||
# test hidden_dim
|
||||
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
|
||||
pre_logits = head.pre_logits(self.fake_feats)
|
||||
self.assertEqual(pre_logits.shape, (4, 30))
|
||||
|
||||
def test_forward(self):
|
||||
# test default
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
outs = head(self.fake_feats)
|
||||
self.assertEqual(outs.shape, (4, 5))
|
||||
|
||||
# test hidden_dim
|
||||
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
|
||||
outs = head(self.fake_feats)
|
||||
self.assertEqual(outs.shape, (4, 5))
|
||||
|
||||
|
||||
class TestDeiTClsHead(TestVisionTransformerClsHead):
|
||||
DEFAULT_ARGS = dict(type='DeiTClsHead', in_channels=10, num_classes=5)
|
||||
fake_feats = ([
|
||||
torch.rand(4, 7, 7, 16),
|
||||
torch.rand(4, 10),
|
||||
torch.rand(4, 10)
|
||||
], )
|
||||
|
||||
def test_pre_logits(self):
|
||||
# test default
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
cls_token, dist_token = head.pre_logits(self.fake_feats)
|
||||
self.assertIs(cls_token, self.fake_feats[-1][1])
|
||||
self.assertIs(dist_token, self.fake_feats[-1][2])
|
||||
|
||||
# test hidden_dim
|
||||
head = MODELS.build({**self.DEFAULT_ARGS, 'hidden_dim': 30})
|
||||
cls_token, dist_token = head.pre_logits(self.fake_feats)
|
||||
self.assertEqual(cls_token.shape, (4, 30))
|
||||
self.assertEqual(dist_token.shape, (4, 30))
|
||||
|
||||
|
||||
class TestConformerHead(TestCase):
|
||||
DEFAULT_ARGS = dict(
|
||||
type='ConformerHead', in_channels=[64, 96], num_classes=5)
|
||||
fake_feats = ([torch.rand(4, 64), torch.rand(4, 96)], )
|
||||
|
||||
def test_initialize(self):
|
||||
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
|
||||
MODELS.build({**self.DEFAULT_ARGS, 'num_classes': -5})
|
||||
|
||||
# test default
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
assert hasattr(head, 'conv_cls_head')
|
||||
assert hasattr(head, 'trans_cls_head')
|
||||
|
||||
# test init_weights
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
head.init_weights()
|
||||
assert abs(head.conv_cls_head.weight).sum() > 0
|
||||
assert abs(head.trans_cls_head.weight).sum() > 0
|
||||
|
||||
def test_pre_logits(self):
|
||||
# test default
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
pre_logits = head.pre_logits(self.fake_feats)
|
||||
self.assertIs(pre_logits, self.fake_feats[-1])
|
||||
|
||||
def test_forward(self):
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
outs = head(self.fake_feats)
|
||||
self.assertEqual(outs[0].shape, (4, 5))
|
||||
self.assertEqual(outs[1].shape, (4, 5))
|
||||
|
||||
def test_loss(self):
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
|
||||
|
||||
# with cal_acc = False
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
||||
losses = head.loss(self.fake_feats, data_samples)
|
||||
self.assertEqual(losses.keys(), {'loss'})
|
||||
self.assertGreater(losses['loss'].item(), 0)
|
||||
|
||||
# with cal_acc = True
|
||||
cfg = {**self.DEFAULT_ARGS, 'topk': (1, 2), 'cal_acc': True}
|
||||
head = MODELS.build(cfg)
|
||||
|
||||
losses = head.loss(self.fake_feats, data_samples)
|
||||
self.assertEqual(losses.keys(),
|
||||
{'loss', 'accuracy_top-1', 'accuracy_top-2'})
|
||||
self.assertGreater(losses['loss'].item(), 0)
|
||||
|
||||
# test assertion when cal_acc but data is batch agumented.
|
||||
data_samples = [
|
||||
sample.set_gt_score(torch.rand(5)) for sample in data_samples
|
||||
]
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS, 'cal_acc': True,
|
||||
'loss': dict(type='CrossEntropyLoss', use_soft=True)
|
||||
}
|
||||
head = MODELS.build(cfg)
|
||||
with self.assertRaisesRegex(AssertionError, 'batch augmentation'):
|
||||
head.loss(self.fake_feats, data_samples)
|
||||
|
||||
def test_predict(self):
|
||||
data_samples = [ClsDataSample().set_gt_label(1) for _ in range(4)]
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
|
||||
# with without data_samples
|
||||
predictions = head.predict(self.fake_feats)
|
||||
self.assertTrue(is_seq_of(predictions, ClsDataSample))
|
||||
for pred in predictions:
|
||||
self.assertIn('label', pred.pred_label)
|
||||
self.assertIn('score', pred.pred_label)
|
||||
|
||||
# with with data_samples
|
||||
predictions = head.predict(self.fake_feats, data_samples)
|
||||
self.assertTrue(is_seq_of(predictions, ClsDataSample))
|
||||
for sample, pred in zip(data_samples, predictions):
|
||||
self.assertIs(sample, pred)
|
||||
self.assertIn('label', pred.pred_label)
|
||||
self.assertIn('score', pred.pred_label)
|
||||
|
||||
|
||||
class TestStackedLinearClsHead(TestCase):
|
||||
DEFAULT_ARGS = dict(
|
||||
type='StackedLinearClsHead', in_channels=10, num_classes=5)
|
||||
fake_feats = (torch.rand(4, 10), )
|
||||
|
||||
def test_initialize(self):
|
||||
with self.assertRaisesRegex(ValueError, 'num_classes=-5 must be'):
|
||||
MODELS.build({
|
||||
**self.DEFAULT_ARGS, 'num_classes': -5,
|
||||
'mid_channels': 10
|
||||
})
|
||||
|
||||
# test mid_channels
|
||||
with self.assertRaisesRegex(AssertionError, 'should be a sequence'):
|
||||
MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': 10})
|
||||
|
||||
# test default
|
||||
head = MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': [20]})
|
||||
assert len(head.layers) == 2
|
||||
head.init_weights()
|
||||
|
||||
def test_pre_logits(self):
|
||||
# test default
|
||||
head = MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': [20, 30]})
|
||||
pre_logits = head.pre_logits(self.fake_feats)
|
||||
self.assertEqual(pre_logits.shape, (4, 30))
|
||||
|
||||
def test_forward(self):
|
||||
# test default
|
||||
head = MODELS.build({**self.DEFAULT_ARGS, 'mid_channels': [20, 30]})
|
||||
outs = head(self.fake_feats)
|
||||
self.assertEqual(outs.shape, (4, 5))
|
||||
|
||||
head = MODELS.build({
|
||||
**self.DEFAULT_ARGS, 'mid_channels': [8, 10],
|
||||
'dropout_rate': 0.2,
|
||||
'norm_cfg': dict(type='BN1d'),
|
||||
'act_cfg': dict(type='HSwish')
|
||||
})
|
||||
outs = head(self.fake_feats)
|
||||
self.assertEqual(outs.shape, (4, 5))
|
||||
|
||||
|
||||
"""Temporarily disabled.
|
||||
@pytest.mark.parametrize('feat', [torch.rand(4, 10), (torch.rand(4, 10), )])
|
||||
def test_multilabel_head(feat):
|
||||
@ -215,125 +414,4 @@ def test_stacked_linear_cls_head(feat):
|
||||
losses = head.forward_train(feat, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
|
||||
def test_vit_head():
|
||||
fake_features = ([torch.rand(4, 7, 7, 16), torch.rand(4, 100)], )
|
||||
fake_gt_label = torch.randint(0, 10, (4, ))
|
||||
|
||||
# test vit head forward
|
||||
head = VisionTransformerClsHead(10, 100)
|
||||
losses = head.forward_train(fake_features, fake_gt_label)
|
||||
assert not hasattr(head.layers, 'pre_logits')
|
||||
assert not hasattr(head.layers, 'act')
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test vit head forward with hidden layer
|
||||
head = VisionTransformerClsHead(10, 100, hidden_dim=20)
|
||||
losses = head.forward_train(fake_features, fake_gt_label)
|
||||
assert hasattr(head.layers, 'pre_logits') and hasattr(head.layers, 'act')
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test vit head init_weights
|
||||
head = VisionTransformerClsHead(10, 100, hidden_dim=20)
|
||||
head.init_weights()
|
||||
assert abs(head.layers.pre_logits.weight).sum() > 0
|
||||
|
||||
head = VisionTransformerClsHead(10, 100, hidden_dim=20)
|
||||
# test simple_test with post_process
|
||||
pred = head.simple_test(fake_features)
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
pred = head.simple_test(fake_features)
|
||||
assert pred.shape == (4, 10)
|
||||
|
||||
# test simple_test without post_process
|
||||
pred = head.simple_test(fake_features, post_process=False)
|
||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
||||
logits = head.simple_test(fake_features, softmax=False, post_process=False)
|
||||
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
|
||||
|
||||
# test pre_logits
|
||||
features = head.pre_logits(fake_features)
|
||||
assert features.shape == (4, 20)
|
||||
|
||||
# test assertion
|
||||
with pytest.raises(ValueError):
|
||||
VisionTransformerClsHead(-1, 100)
|
||||
|
||||
|
||||
def test_conformer_head():
|
||||
fake_features = ([torch.rand(4, 64), torch.rand(4, 96)], )
|
||||
fake_gt_label = torch.randint(0, 10, (4, ))
|
||||
|
||||
# test conformer head forward
|
||||
head = ConformerHead(num_classes=10, in_channels=[64, 96])
|
||||
losses = head.forward_train(fake_features, fake_gt_label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test simple_test with post_process
|
||||
pred = head.simple_test(fake_features)
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
pred = head.simple_test(fake_features)
|
||||
assert pred.shape == (4, 10)
|
||||
|
||||
# test simple_test without post_process
|
||||
pred = head.simple_test(fake_features, post_process=False)
|
||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
||||
logits = head.simple_test(fake_features, softmax=False, post_process=False)
|
||||
torch.testing.assert_allclose(pred, torch.softmax(sum(logits), dim=1))
|
||||
|
||||
# test pre_logits
|
||||
features = head.pre_logits(fake_features)
|
||||
assert features is fake_features[0]
|
||||
|
||||
|
||||
def test_deit_head():
|
||||
fake_features = ([
|
||||
torch.rand(4, 7, 7, 16),
|
||||
torch.rand(4, 100),
|
||||
torch.rand(4, 100)
|
||||
], )
|
||||
fake_gt_label = torch.randint(0, 10, (4, ))
|
||||
|
||||
# test deit head forward
|
||||
head = DeiTClsHead(num_classes=10, in_channels=100)
|
||||
losses = head.forward_train(fake_features, fake_gt_label)
|
||||
assert not hasattr(head.layers, 'pre_logits')
|
||||
assert not hasattr(head.layers, 'act')
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test deit head forward with hidden layer
|
||||
head = DeiTClsHead(num_classes=10, in_channels=100, hidden_dim=20)
|
||||
losses = head.forward_train(fake_features, fake_gt_label)
|
||||
assert hasattr(head.layers, 'pre_logits') and hasattr(head.layers, 'act')
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test deit head init_weights
|
||||
head = DeiTClsHead(10, 100, hidden_dim=20)
|
||||
head.init_weights()
|
||||
assert abs(head.layers.pre_logits.weight).sum() > 0
|
||||
|
||||
head = DeiTClsHead(10, 100, hidden_dim=20)
|
||||
# test simple_test with post_process
|
||||
pred = head.simple_test(fake_features)
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
pred = head.simple_test(fake_features)
|
||||
assert pred.shape == (4, 10)
|
||||
|
||||
# test simple_test without post_process
|
||||
pred = head.simple_test(fake_features, post_process=False)
|
||||
assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
|
||||
logits = head.simple_test(fake_features, softmax=False, post_process=False)
|
||||
torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
|
||||
|
||||
# test pre_logits
|
||||
cls_token, dist_token = head.pre_logits(fake_features)
|
||||
assert cls_token.shape == (4, 20)
|
||||
assert dist_token.shape == (4, 20)
|
||||
|
||||
# test assertion
|
||||
with pytest.raises(ValueError):
|
||||
DeiTClsHead(-1, 100)
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user