284 lines
10 KiB
Python
284 lines
10 KiB
Python
""" Classifier head and layer factory
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
"""
|
|
from collections import OrderedDict
|
|
from functools import partial
|
|
from typing import Optional, Union, Callable
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
|
from .create_act import get_act_layer
|
|
from .create_norm import get_norm_layer
|
|
|
|
|
|
def _create_pool(
|
|
num_features: int,
|
|
num_classes: int,
|
|
pool_type: str = 'avg',
|
|
use_conv: bool = False,
|
|
input_fmt: Optional[str] = None,
|
|
):
|
|
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
|
|
if not pool_type:
|
|
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
|
|
global_pool = SelectAdaptivePool2d(
|
|
pool_type=pool_type,
|
|
flatten=flatten_in_pool,
|
|
input_fmt=input_fmt,
|
|
)
|
|
num_pooled_features = num_features * global_pool.feat_mult()
|
|
return global_pool, num_pooled_features
|
|
|
|
|
|
def _create_fc(num_features, num_classes, use_conv=False):
|
|
if num_classes <= 0:
|
|
fc = nn.Identity() # pass-through (no classifier)
|
|
elif use_conv:
|
|
fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
|
|
else:
|
|
fc = nn.Linear(num_features, num_classes, bias=True)
|
|
return fc
|
|
|
|
|
|
def create_classifier(
|
|
num_features: int,
|
|
num_classes: int,
|
|
pool_type: str = 'avg',
|
|
use_conv: bool = False,
|
|
input_fmt: str = 'NCHW',
|
|
drop_rate: Optional[float] = None,
|
|
):
|
|
global_pool, num_pooled_features = _create_pool(
|
|
num_features,
|
|
num_classes,
|
|
pool_type,
|
|
use_conv=use_conv,
|
|
input_fmt=input_fmt,
|
|
)
|
|
fc = _create_fc(
|
|
num_pooled_features,
|
|
num_classes,
|
|
use_conv=use_conv,
|
|
)
|
|
if drop_rate is not None:
|
|
dropout = nn.Dropout(drop_rate)
|
|
return global_pool, dropout, fc
|
|
return global_pool, fc
|
|
|
|
|
|
class ClassifierHead(nn.Module):
|
|
"""Classifier head w/ configurable global pooling and dropout."""
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
num_classes: int,
|
|
pool_type: str = 'avg',
|
|
drop_rate: float = 0.,
|
|
use_conv: bool = False,
|
|
input_fmt: str = 'NCHW',
|
|
):
|
|
"""
|
|
Args:
|
|
in_features: The number of input features.
|
|
num_classes: The number of classes for the final classifier layer (output).
|
|
pool_type: Global pooling type, pooling disabled if empty string ('').
|
|
drop_rate: Pre-classifier dropout rate.
|
|
"""
|
|
super(ClassifierHead, self).__init__()
|
|
self.in_features = in_features
|
|
self.use_conv = use_conv
|
|
self.input_fmt = input_fmt
|
|
|
|
global_pool, fc = create_classifier(
|
|
in_features,
|
|
num_classes,
|
|
pool_type,
|
|
use_conv=use_conv,
|
|
input_fmt=input_fmt,
|
|
)
|
|
self.global_pool = global_pool
|
|
self.drop = nn.Dropout(drop_rate)
|
|
self.fc = fc
|
|
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
|
|
|
|
def reset(self, num_classes: int, pool_type: Optional[str] = None):
|
|
if pool_type is not None and pool_type != self.global_pool.pool_type:
|
|
self.global_pool, self.fc = create_classifier(
|
|
self.in_features,
|
|
num_classes,
|
|
pool_type=pool_type,
|
|
use_conv=self.use_conv,
|
|
input_fmt=self.input_fmt,
|
|
)
|
|
self.flatten = nn.Flatten(1) if self.use_conv and pool_type else nn.Identity()
|
|
else:
|
|
num_pooled_features = self.in_features * self.global_pool.feat_mult()
|
|
self.fc = _create_fc(
|
|
num_pooled_features,
|
|
num_classes,
|
|
use_conv=self.use_conv,
|
|
)
|
|
|
|
def forward(self, x, pre_logits: bool = False):
|
|
x = self.global_pool(x)
|
|
x = self.drop(x)
|
|
if pre_logits:
|
|
return self.flatten(x)
|
|
x = self.fc(x)
|
|
return self.flatten(x)
|
|
|
|
|
|
class NormMlpClassifierHead(nn.Module):
|
|
""" A Pool -> Norm -> Mlp Classifier Head for '2D' NCHW tensors
|
|
"""
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
num_classes: int,
|
|
hidden_size: Optional[int] = None,
|
|
pool_type: str = 'avg',
|
|
drop_rate: float = 0.,
|
|
norm_layer: Union[str, Callable] = 'layernorm2d',
|
|
act_layer: Union[str, Callable] = 'tanh',
|
|
):
|
|
"""
|
|
Args:
|
|
in_features: The number of input features.
|
|
num_classes: The number of classes for the final classifier layer (output).
|
|
hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
|
|
pool_type: Global pooling type, pooling disabled if empty string ('').
|
|
drop_rate: Pre-classifier dropout rate.
|
|
norm_layer: Normalization layer type.
|
|
act_layer: MLP activation layer type (only used if hidden_size is not None).
|
|
"""
|
|
super().__init__()
|
|
self.in_features = in_features
|
|
self.hidden_size = hidden_size
|
|
self.num_features = in_features
|
|
self.use_conv = not pool_type
|
|
norm_layer = get_norm_layer(norm_layer)
|
|
act_layer = get_act_layer(act_layer)
|
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
|
|
self.norm = norm_layer(in_features)
|
|
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
|
|
if hidden_size:
|
|
self.pre_logits = nn.Sequential(OrderedDict([
|
|
('fc', linear_layer(in_features, hidden_size)),
|
|
('act', act_layer()),
|
|
]))
|
|
self.num_features = hidden_size
|
|
else:
|
|
self.pre_logits = nn.Identity()
|
|
self.drop = nn.Dropout(drop_rate)
|
|
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
def reset(self, num_classes: int, pool_type: Optional[str] = None):
|
|
if pool_type is not None:
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
|
|
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
|
|
self.use_conv = self.global_pool.is_identity()
|
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
|
|
if self.hidden_size:
|
|
if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
|
|
(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
|
|
with torch.no_grad():
|
|
new_fc = linear_layer(self.in_features, self.hidden_size)
|
|
new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
|
|
new_fc.bias.copy_(self.pre_logits.fc.bias)
|
|
self.pre_logits.fc = new_fc
|
|
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
def forward(self, x, pre_logits: bool = False):
|
|
x = self.global_pool(x)
|
|
x = self.norm(x)
|
|
x = self.flatten(x)
|
|
x = self.pre_logits(x)
|
|
x = self.drop(x)
|
|
if pre_logits:
|
|
return x
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
|
|
class ClNormMlpClassifierHead(nn.Module):
|
|
""" A Pool -> Norm -> Mlp Classifier Head for n-D NxxC tensors
|
|
"""
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
num_classes: int,
|
|
hidden_size: Optional[int] = None,
|
|
pool_type: str = 'avg',
|
|
drop_rate: float = 0.,
|
|
norm_layer: Union[str, Callable] = 'layernorm',
|
|
act_layer: Union[str, Callable] = 'gelu',
|
|
input_fmt: str = 'NHWC',
|
|
):
|
|
"""
|
|
Args:
|
|
in_features: The number of input features.
|
|
num_classes: The number of classes for the final classifier layer (output).
|
|
hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
|
|
pool_type: Global pooling type, pooling disabled if empty string ('').
|
|
drop_rate: Pre-classifier dropout rate.
|
|
norm_layer: Normalization layer type.
|
|
act_layer: MLP activation layer type (only used if hidden_size is not None).
|
|
"""
|
|
super().__init__()
|
|
self.in_features = in_features
|
|
self.hidden_size = hidden_size
|
|
self.num_features = in_features
|
|
assert pool_type in ('', 'avg', 'max', 'avgmax')
|
|
self.pool_type = pool_type
|
|
assert input_fmt in ('NHWC', 'NLC')
|
|
self.pool_dim = 1 if input_fmt == 'NLC' else (1, 2)
|
|
norm_layer = get_norm_layer(norm_layer)
|
|
act_layer = get_act_layer(act_layer)
|
|
|
|
self.norm = norm_layer(in_features)
|
|
if hidden_size:
|
|
self.pre_logits = nn.Sequential(OrderedDict([
|
|
('fc', nn.Linear(in_features, hidden_size)),
|
|
('act', act_layer()),
|
|
]))
|
|
self.num_features = hidden_size
|
|
else:
|
|
self.pre_logits = nn.Identity()
|
|
self.drop = nn.Dropout(drop_rate)
|
|
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
|
|
if pool_type is not None:
|
|
self.pool_type = pool_type
|
|
if reset_other:
|
|
self.pre_logits = nn.Identity()
|
|
self.norm = nn.Identity()
|
|
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
def _global_pool(self, x):
|
|
if self.pool_type:
|
|
if self.pool_type == 'avg':
|
|
x = x.mean(dim=self.pool_dim)
|
|
elif self.pool_type == 'max':
|
|
x = x.amax(dim=self.pool_dim)
|
|
elif self.pool_type == 'avgmax':
|
|
x = 0.5 * (x.amax(dim=self.pool_dim) + x.mean(dim=self.pool_dim))
|
|
return x
|
|
|
|
def forward(self, x, pre_logits: bool = False):
|
|
x = self._global_pool(x)
|
|
x = self.norm(x)
|
|
x = self.pre_logits(x)
|
|
x = self.drop(x)
|
|
if pre_logits:
|
|
return x
|
|
x = self.fc(x)
|
|
return x
|