mmclassification/mmcls/models/heads/deit_head.py

79 lines
3.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from mmcls.utils import get_root_logger
from ..builder import HEADS
from .vision_transformer_head import VisionTransformerClsHead
@HEADS.register_module()
class DeiTClsHead(VisionTransformerClsHead):
def __init__(self, *args, **kwargs):
super(DeiTClsHead, self).__init__(*args, **kwargs)
if self.hidden_dim is None:
head_dist = nn.Linear(self.in_channels, self.num_classes)
else:
head_dist = nn.Linear(self.hidden_dim, self.num_classes)
self.layers.add_module('head_dist', head_dist)
def pre_logits(self, x):
if isinstance(x, tuple):
x = x[-1]
_, cls_token, dist_token = x
if self.hidden_dim is None:
return cls_token, dist_token
else:
cls_token = self.layers.act(self.layers.pre_logits(cls_token))
dist_token = self.layers.act(self.layers.pre_logits(dist_token))
return cls_token, dist_token
def simple_test(self, x, softmax=True, post_process=True):
"""Inference without augmentation.
Args:
x (tuple[tuple[tensor, tensor, tensor]]): The input features.
Multi-stage inputs are acceptable but only the last stage will
be used to classify. Every item should be a tuple which
includes patch token, cls token and dist token. The cls token
and dist token will be used to classify and the shape of them
should be ``(num_samples, in_channels)``.
softmax (bool): Whether to softmax the classification score.
post_process (bool): Whether to do post processing the
inference results. It will convert the output to a list.
Returns:
Tensor | list: The inference results.
- If no post processing, the output is a tensor with shape
``(num_samples, num_classes)``.
- If post processing, the output is a multi-dimentional list of
float and the dimensions are ``(num_samples, num_classes)``.
"""
cls_token, dist_token = self.pre_logits(x)
cls_score = (self.layers.head(cls_token) +
self.layers.head_dist(dist_token)) / 2
if softmax:
pred = F.softmax(
cls_score, dim=1) if cls_score is not None else None
else:
pred = cls_score
if post_process:
return self.post_process(pred)
else:
return pred
def forward_train(self, x, gt_label):
logger = get_root_logger()
logger.warning("MMClassification doesn't support to train the "
'distilled version DeiT.')
cls_token, dist_token = self.pre_logits(x)
cls_score = (self.layers.head(cls_token) +
self.layers.head_dist(dist_token)) / 2
losses = self.loss(cls_score, gt_label)
return losses