Merge remote-tracking branch 'origin/main' into levit_efficientformer_redux
commit
9d03c6f526
|
@ -3,7 +3,7 @@ from .adaptive_avgmax_pool import \
|
|||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
||||
from .blur_pool import BlurPool2d
|
||||
from .classifier import ClassifierHead, create_classifier
|
||||
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
|
||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
|
||||
set_layer_config
|
||||
|
|
|
@ -2,10 +2,17 @@
|
|||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from torch import nn as nn
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Optional, Union, Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from .create_act import get_act_layer
|
||||
from .create_norm import get_norm_layer
|
||||
|
||||
|
||||
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
|
@ -38,7 +45,21 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
|
|||
class ClassifierHead(nn.Module):
|
||||
"""Classifier head w/ configurable global pooling and dropout."""
|
||||
|
||||
def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
pool_type: str = 'avg',
|
||||
drop_rate: float = 0.,
|
||||
use_conv: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
in_features: The number of input features.
|
||||
num_classes: The number of classes for the final classifier layer (output).
|
||||
pool_type: Global pooling type, pooling disabled if empty string ('').
|
||||
drop_rate: Pre-classifier dropout rate.
|
||||
"""
|
||||
super(ClassifierHead, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.in_features = in_features
|
||||
|
@ -65,3 +86,76 @@ class ClassifierHead(nn.Module):
|
|||
else:
|
||||
x = self.fc(x)
|
||||
return self.flatten(x)
|
||||
|
||||
|
||||
class NormMlpClassifierHead(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
hidden_size: Optional[int] = None,
|
||||
pool_type: str = 'avg',
|
||||
drop_rate: float = 0.,
|
||||
norm_layer: Union[str, Callable] = 'layernorm2d',
|
||||
act_layer: Union[str, Callable] = 'tanh',
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
in_features: The number of input features.
|
||||
num_classes: The number of classes for the final classifier layer (output).
|
||||
hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
|
||||
pool_type: Global pooling type, pooling disabled if empty string ('').
|
||||
drop_rate: Pre-classifier dropout rate.
|
||||
norm_layer: Normalization layer type.
|
||||
act_layer: MLP activation layer type (only used if hidden_size is not None).
|
||||
"""
|
||||
super().__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.in_features = in_features
|
||||
self.hidden_size = hidden_size
|
||||
self.num_features = in_features
|
||||
self.use_conv = not pool_type
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
act_layer = get_act_layer(act_layer)
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
|
||||
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
|
||||
self.norm = norm_layer(in_features)
|
||||
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
|
||||
if hidden_size:
|
||||
self.pre_logits = nn.Sequential(OrderedDict([
|
||||
('fc', linear_layer(in_features, hidden_size)),
|
||||
('act', act_layer()),
|
||||
]))
|
||||
self.num_features = hidden_size
|
||||
else:
|
||||
self.pre_logits = nn.Identity()
|
||||
self.drop = nn.Dropout(self.drop_rate)
|
||||
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def reset(self, num_classes, global_pool=None):
|
||||
if global_pool is not None:
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
self.use_conv = self.global_pool.is_identity()
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
|
||||
if self.hidden_size:
|
||||
if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
|
||||
(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
|
||||
with torch.no_grad():
|
||||
new_fc = linear_layer(self.in_features, self.hidden_size)
|
||||
new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
|
||||
new_fc.bias.copy_(self.pre_logits.fc.bias)
|
||||
self.pre_logits.fc = new_fc
|
||||
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
x = self.norm(x)
|
||||
x = self.flatten(x)
|
||||
x = self.pre_logits(x)
|
||||
if pre_logits:
|
||||
return x
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
|
|
@ -39,6 +39,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
|
|||
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -46,6 +47,7 @@ import torch.nn as nn
|
|||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
|
||||
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
|
||||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
|
@ -188,48 +190,50 @@ class ConvNeXt(nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='avg',
|
||||
output_stride=32,
|
||||
depths=(3, 3, 9, 3),
|
||||
dims=(96, 192, 384, 768),
|
||||
kernel_sizes=7,
|
||||
ls_init_value=1e-6,
|
||||
stem_type='patch',
|
||||
patch_size=4,
|
||||
head_init_scale=1.,
|
||||
head_norm_first=False,
|
||||
conv_mlp=False,
|
||||
conv_bias=True,
|
||||
use_grn=False,
|
||||
act_layer='gelu',
|
||||
norm_layer=None,
|
||||
norm_eps=None,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
global_pool: str = 'avg',
|
||||
output_stride: int = 32,
|
||||
depths: Tuple[int, ...] = (3, 3, 9, 3),
|
||||
dims: Tuple[int, ...] = (96, 192, 384, 768),
|
||||
kernel_sizes: Union[int, Tuple[int, ...]] = 7,
|
||||
ls_init_value: Optional[float] = 1e-6,
|
||||
stem_type: str = 'patch',
|
||||
patch_size: int = 4,
|
||||
head_init_scale: float = 1.,
|
||||
head_norm_first: bool = False,
|
||||
head_hidden_size: Optional[int] = None,
|
||||
conv_mlp: bool = False,
|
||||
conv_bias: bool = True,
|
||||
use_grn: bool = False,
|
||||
act_layer: Union[str, Callable] = 'gelu',
|
||||
norm_layer: Optional[Union[str, Callable]] = None,
|
||||
norm_eps: Optional[float] = None,
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
in_chans (int): Number of input image channels (default: 3)
|
||||
num_classes (int): Number of classes for classification head (default: 1000)
|
||||
global_pool (str): Global pooling type (default: 'avg')
|
||||
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
|
||||
depths (tuple(int)): Number of blocks at each stage. (default: [3, 3, 9, 3])
|
||||
dims (tuple(int)): Feature dimension at each stage. (default: [96, 192, 384, 768])
|
||||
kernel_sizes (Union[int, List[int]]: Depthwise convolution kernel-sizes for each stage (default: 7)
|
||||
ls_init_value (float): Init value for Layer Scale (default: 1e-6)
|
||||
stem_type (str): Type of stem (default: 'patch')
|
||||
patch_size (int): Stem patch size for patch stem (default: 4)
|
||||
head_init_scale (float): Init scaling value for classifier weights and biases (default: 1)
|
||||
head_norm_first (bool): Apply normalization before global pool + head (default: False)
|
||||
conv_mlp (bool): Use 1x1 conv in MLP, improves speed for small networks w/ chan last (default: False)
|
||||
conv_bias (bool): Use bias layers w/ all convolutions (default: True)
|
||||
use_grn (bool): Use Global Response Norm (ConvNeXt-V2) in MLP (default: False)
|
||||
act_layer (Union[str, nn.Module]): Activation Layer
|
||||
norm_layer (Union[str, nn.Module]): Normalization Layer
|
||||
drop_rate (float): Head dropout rate (default: 0.)
|
||||
drop_path_rate (float): Stochastic depth rate (default: 0.)
|
||||
in_chans: Number of input image channels.
|
||||
num_classes: Number of classes for classification head.
|
||||
global_pool: Global pooling type.
|
||||
output_stride: Output stride of network, one of (8, 16, 32).
|
||||
depths: Number of blocks at each stage.
|
||||
dims: Feature dimension at each stage.
|
||||
kernel_sizes: Depthwise convolution kernel-sizes for each stage.
|
||||
ls_init_value: Init value for Layer Scale, disabled if None.
|
||||
stem_type: Type of stem.
|
||||
patch_size: Stem patch size for patch stem.
|
||||
head_init_scale: Init scaling value for classifier weights and biases.
|
||||
head_norm_first: Apply normalization before global pool + head.
|
||||
head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
|
||||
conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
|
||||
conv_bias: Use bias layers w/ all convolutions.
|
||||
use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
|
||||
act_layer: Activation layer type.
|
||||
norm_layer: Normalization layer type.
|
||||
drop_rate: Head pre-classifier dropout rate.
|
||||
drop_path_rate: Stochastic depth drop rate.
|
||||
"""
|
||||
super().__init__()
|
||||
assert output_stride in (8, 16, 32)
|
||||
|
@ -307,14 +311,26 @@ class ConvNeXt(nn.Module):
|
|||
|
||||
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
|
||||
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
|
||||
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
|
||||
self.head = nn.Sequential(OrderedDict([
|
||||
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
|
||||
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
|
||||
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
|
||||
('drop', nn.Dropout(self.drop_rate)),
|
||||
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
|
||||
|
||||
if head_norm_first:
|
||||
assert not head_hidden_size
|
||||
self.norm_pre = norm_layer(self.num_features)
|
||||
self.head = ClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=self.drop_rate,
|
||||
)
|
||||
else:
|
||||
self.norm_pre = nn.Identity()
|
||||
self.head = NormMlpClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
hidden_size=head_hidden_size,
|
||||
pool_type=global_pool,
|
||||
drop_rate=self.drop_rate,
|
||||
norm_layer=norm_layer,
|
||||
act_layer='gelu',
|
||||
)
|
||||
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
|
||||
|
||||
@torch.jit.ignore
|
||||
|
@ -338,10 +354,7 @@ class ConvNeXt(nn.Module):
|
|||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes=0, global_pool=None):
|
||||
if global_pool is not None:
|
||||
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head.reset(num_classes, global_pool=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
|
@ -350,12 +363,7 @@ class ConvNeXt(nn.Module):
|
|||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
|
||||
x = self.head.global_pool(x)
|
||||
x = self.head.norm(x)
|
||||
x = self.head.flatten(x)
|
||||
x = self.head.drop(x)
|
||||
return x if pre_logits else self.head.fc(x)
|
||||
return self.head(x, pre_logits=pre_logits)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
@ -389,6 +397,11 @@ def checkpoint_filter_fn(state_dict, model):
|
|||
if 'visual.head.proj.weight' in state_dict:
|
||||
out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
|
||||
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
|
||||
elif 'visual.head.mlp.fc1.weight' in state_dict:
|
||||
out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
|
||||
out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
|
||||
out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
|
||||
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
|
||||
return out_dict
|
||||
|
||||
import re
|
||||
|
@ -708,6 +721,22 @@ default_cfgs = generate_default_cfgs({
|
|||
|
||||
'convnextv2_small.untrained': _cfg(),
|
||||
|
||||
# CLIP weights, fine-tuned on in1k or in12k + in1k
|
||||
'convnext_base.clip_laion2b_augreg_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
|
||||
'convnext_base.clip_laiona_augreg_ft_in1k_384': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
'convnext_large_mlp.clip_laion2b_augreg_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0
|
||||
),
|
||||
|
||||
|
||||
# CLIP based weights, original image tower weights and fine-tunes
|
||||
'convnext_base.clip_laion2b': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
|
||||
|
@ -734,6 +763,11 @@ default_cfgs = generate_default_cfgs({
|
|||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
|
||||
'convnext_large_mlp.clip_laion2b_augreg': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
|
||||
})
|
||||
|
||||
|
||||
|
@ -846,6 +880,13 @@ def convnext_large(pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_large_mlp(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536, **kwargs)
|
||||
model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_xlarge(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
|
||||
|
|
|
@ -23,6 +23,7 @@ from torch import Tensor
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer
|
||||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq
|
||||
|
@ -519,14 +520,23 @@ class DaViT(nn.Module):
|
|||
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
|
||||
# otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt
|
||||
# FIXME generalize this structure to ClassifierHead
|
||||
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
|
||||
self.head = nn.Sequential(OrderedDict([
|
||||
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
|
||||
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
|
||||
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
|
||||
('drop', nn.Dropout(self.drop_rate)),
|
||||
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
|
||||
|
||||
if head_norm_first:
|
||||
self.norm_pre = norm_layer(self.num_features)
|
||||
self.head = ClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=self.drop_rate,
|
||||
)
|
||||
else:
|
||||
self.norm_pre = nn.Identity()
|
||||
self.head = NormMlpClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=self.drop_rate,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
|
@ -546,10 +556,7 @@ class DaViT(nn.Module):
|
|||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
if global_pool is not None:
|
||||
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head.reset(num_classes, global_pool=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
|
|
|
@ -36,7 +36,7 @@ Hacked together by / Copyright 2022, Ross Wightman
|
|||
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, replace
|
||||
from dataclasses import dataclass, replace, field
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Union, Tuple, List
|
||||
|
||||
|
@ -44,7 +44,7 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, LayerNorm, SelectAdaptivePool2d
|
||||
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
|
||||
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d
|
||||
from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert
|
||||
from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf
|
||||
|
@ -133,8 +133,8 @@ class MaxxVitCfg:
|
|||
block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T')
|
||||
stem_width: Union[int, Tuple[int, int]] = 64
|
||||
stem_bias: bool = False
|
||||
conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg()
|
||||
transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg()
|
||||
conv_cfg: MaxxVitConvCfg = field(default_factory=MaxxVitConvCfg)
|
||||
transformer_cfg: MaxxVitTransformerCfg = field(default_factory=MaxxVitTransformerCfg)
|
||||
head_hidden_size: int = None
|
||||
weight_init: str = 'vit_eff'
|
||||
|
||||
|
@ -1072,69 +1072,6 @@ def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
|
|||
return cfg
|
||||
|
||||
|
||||
class NormMlpHead(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
num_classes,
|
||||
hidden_size=None,
|
||||
pool_type='avg',
|
||||
drop_rate=0.,
|
||||
norm_layer='layernorm2d',
|
||||
act_layer='tanh',
|
||||
):
|
||||
super().__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.in_features = in_features
|
||||
self.hidden_size = hidden_size
|
||||
self.num_features = in_features
|
||||
self.use_conv = not pool_type
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
act_layer = get_act_layer(act_layer)
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
|
||||
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
|
||||
self.norm = norm_layer(in_features)
|
||||
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
|
||||
if hidden_size:
|
||||
self.pre_logits = nn.Sequential(OrderedDict([
|
||||
('fc', linear_layer(in_features, hidden_size)),
|
||||
('act', act_layer()),
|
||||
]))
|
||||
self.num_features = hidden_size
|
||||
else:
|
||||
self.pre_logits = nn.Identity()
|
||||
self.drop = nn.Dropout(self.drop_rate)
|
||||
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def reset(self, num_classes, global_pool=None):
|
||||
if global_pool is not None:
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
self.use_conv = self.global_pool.is_identity()
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
|
||||
if self.hidden_size:
|
||||
if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
|
||||
(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
|
||||
with torch.no_grad():
|
||||
new_fc = linear_layer(self.in_features, self.hidden_size)
|
||||
new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
|
||||
new_fc.bias.copy_(self.pre_logits.fc.bias)
|
||||
self.pre_logits.fc = new_fc
|
||||
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
x = self.norm(x)
|
||||
x = self.flatten(x)
|
||||
x = self.pre_logits(x)
|
||||
if pre_logits:
|
||||
return x
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs):
|
||||
transformer_kwargs = {}
|
||||
conv_kwargs = {}
|
||||
|
@ -1225,7 +1162,7 @@ class MaxxVit(nn.Module):
|
|||
self.head_hidden_size = cfg.head_hidden_size
|
||||
if self.head_hidden_size:
|
||||
self.norm = nn.Identity()
|
||||
self.head = NormMlpHead(
|
||||
self.head = NormMlpClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
hidden_size=self.head_hidden_size,
|
||||
|
@ -2342,4 +2279,4 @@ def maxvit_xlarge_tf_384(pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def maxvit_xlarge_tf_512(pretrained=False, **kwargs):
|
||||
return _create_maxxvit('maxvit_xlarge_tf_512', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)
|
||||
return _create_maxxvit('maxvit_xlarge_tf_512', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue