79 lines
3.0 KiB
Python
79 lines
3.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from mmcls.utils import get_root_logger
|
|
from ..builder import HEADS
|
|
from .vision_transformer_head import VisionTransformerClsHead
|
|
|
|
|
|
@HEADS.register_module()
|
|
class DeiTClsHead(VisionTransformerClsHead):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(DeiTClsHead, self).__init__(*args, **kwargs)
|
|
if self.hidden_dim is None:
|
|
head_dist = nn.Linear(self.in_channels, self.num_classes)
|
|
else:
|
|
head_dist = nn.Linear(self.hidden_dim, self.num_classes)
|
|
self.layers.add_module('head_dist', head_dist)
|
|
|
|
def pre_logits(self, x):
|
|
if isinstance(x, tuple):
|
|
x = x[-1]
|
|
_, cls_token, dist_token = x
|
|
|
|
if self.hidden_dim is None:
|
|
return cls_token, dist_token
|
|
else:
|
|
cls_token = self.layers.act(self.layers.pre_logits(cls_token))
|
|
dist_token = self.layers.act(self.layers.pre_logits(dist_token))
|
|
return cls_token, dist_token
|
|
|
|
def simple_test(self, x, softmax=True, post_process=True):
|
|
"""Inference without augmentation.
|
|
|
|
Args:
|
|
x (tuple[tuple[tensor, tensor, tensor]]): The input features.
|
|
Multi-stage inputs are acceptable but only the last stage will
|
|
be used to classify. Every item should be a tuple which
|
|
includes patch token, cls token and dist token. The cls token
|
|
and dist token will be used to classify and the shape of them
|
|
should be ``(num_samples, in_channels)``.
|
|
softmax (bool): Whether to softmax the classification score.
|
|
post_process (bool): Whether to do post processing the
|
|
inference results. It will convert the output to a list.
|
|
|
|
Returns:
|
|
Tensor | list: The inference results.
|
|
|
|
- If no post processing, the output is a tensor with shape
|
|
``(num_samples, num_classes)``.
|
|
- If post processing, the output is a multi-dimentional list of
|
|
float and the dimensions are ``(num_samples, num_classes)``.
|
|
"""
|
|
cls_token, dist_token = self.pre_logits(x)
|
|
cls_score = (self.layers.head(cls_token) +
|
|
self.layers.head_dist(dist_token)) / 2
|
|
|
|
if softmax:
|
|
pred = F.softmax(
|
|
cls_score, dim=1) if cls_score is not None else None
|
|
else:
|
|
pred = cls_score
|
|
|
|
if post_process:
|
|
return self.post_process(pred)
|
|
else:
|
|
return pred
|
|
|
|
def forward_train(self, x, gt_label):
|
|
logger = get_root_logger()
|
|
logger.warning("MMClassification doesn't support to train the "
|
|
'distilled version DeiT.')
|
|
cls_token, dist_token = self.pre_logits(x)
|
|
cls_score = (self.layers.head(cls_token) +
|
|
self.layers.head_dist(dist_token)) / 2
|
|
losses = self.loss(cls_score, gt_label)
|
|
return losses
|