Merge pull request #2196 from huggingface/mega_merge

Mega merge
This commit is contained in:
Ross Wightman 2024-06-07 13:12:36 -07:00 committed by GitHub
commit 52659842cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 2711 additions and 314 deletions

View File

@ -52,7 +52,7 @@ FEAT_INTER_FILTERS = [
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit',
]
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
@ -60,7 +60,7 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*'
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*'
]
NUM_NON_STD = len(NON_STD_FILTERS)

View File

@ -1,9 +1,10 @@
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
from .attention_pool import AttentionPoolLatent
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
from .blur_pool import BlurPool2d
from .blur_pool import BlurPool2d, create_aa
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \

337
timm/layers/attention2d.py Normal file
View File

@ -0,0 +1,337 @@
from typing import List, Optional, Union
import torch
from torch import nn as nn
from torch.nn import functional as F
from .config import use_fused_attn
from .create_conv2d import create_conv2d
from .helpers import to_2tuple
from .pool2d_same import create_pool2d
class MultiQueryAttentionV2(nn.Module):
"""Multi Query Attention.
Fast Transformer Decoding: One Write-Head is All You Need
https://arxiv.org/pdf/1911.02150.pdf
This is an acceletor optimized version - removing multiple unneccessary
tensor transpose by re-arranging indices according to the following rules: 1)
contracted indices are at the end, 2) other indices have the same order in the
input and output tensores.
Compared to V1, this gives 3x speed up.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
num_heads: int = 8,
key_dim: int = 64,
value_dim: int = 64,
attn_drop: float = 0.,
proj_drop: float = 0.,
):
"""Initializer."""
super().__init__()
dim_out = dim_out or dim
self.num_heads = num_heads
self.key_dim = key_dim
self.value_dim = value_dim
self.scale = key_dim ** -0.5
self.query_proj = nn.Parameter(torch.randn([self.num_heads, self.key_dim, dim]))
self.key_proj = nn.Parameter(torch.randn([dim, self.key_dim]))
self.value_proj = nn.Parameter(torch.randn([dim, self.value_dim]))
self.attn_drop = nn.Dropout(attn_drop)
self.out_proj = nn.Parameter(torch.randn([dim_out, self.num_heads, self.value_dim]))
self.proj_drop = nn.Dropout(proj_drop)
def _reshape_input(self, t):
"""Reshapes a tensor to three dimensions, keeping the first and last."""
s = t.shape
# Propagate the shape statically where possible.
#num = t.shape[1:-1].numel()
#return t.reshape(s[0], num, s[-1])
return t.reshape(s[0], s[1], -1).transpose(1, 2)
def forward(self, x, m: Optional[torch.Tensor] = None):
"""Run layer computation."""
s = x.shape
m = m or x
reshaped_x = self._reshape_input(x)
reshaped_m = self._reshape_input(m)
q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj)
k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj)
attn = torch.einsum('bnhk,bmk->bnhm', q, k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj)
o = torch.einsum('bnhm,bmv->bnhv', attn, v)
result = torch.einsum('bnhv,dhv->bnd', o, self.out_proj)
result = self.proj_drop(result)
return result.reshape(s)
class MultiQueryAttention2d(nn.Module):
"""Multi Query Attention with spatial downsampling.
3 parameters are introduced for the spatial downsampling:
1. kv_stride: downsampling factor on Key and Values only.
2. query_strides: horizontal & vertical strides on Query only.
This is an optimized version.
1. Projections in Attention is explict written out as 1x1 Conv2D.
2. Additional reshapes are introduced to bring a up to 3x speed up.
"""
fused_attn: torch.jit.Final[bool]
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
num_heads: int = 8,
key_dim: Optional[int] = None,
value_dim: Optional[int] = None,
query_strides: int = 1,
kv_stride: int = 1,
dw_kernel_size: int = 3,
dilation: int = 1,
padding: Union[str, int, List[int]] = '',
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = nn.BatchNorm2d,
use_bias: bool = False,
):
"""Initializer.
Args:
num_heads: Number of attention heads.
key_dim: Size of the attention key dimension.
value_dim: Size of the attention value dimension.
query_strides: Vertical stride size for query only.
kv_stride: Key and value stride size.
dw_kernel_size: Spatial dimension of the depthwise kernel.
"""
super().__init__()
dim_out = dim_out or dim
self.num_heads = num_heads
self.key_dim = key_dim or dim // num_heads
self.value_dim = value_dim or dim // num_heads
self.query_strides = to_2tuple(query_strides)
self.kv_stride = kv_stride
self.has_query_strides = any([s > 1 for s in self.query_strides])
self.scale = self.key_dim ** -0.5
self.fused_attn = use_fused_attn()
self.drop = attn_drop
self.query = nn.Sequential()
if self.has_query_strides:
# FIXME dilation
self.query.add_module('down_pool', create_pool2d(
'avg',
kernel_size=self.query_strides,
padding=padding,
))
self.query.add_module('norm', norm_layer(dim))
self.query.add_module('proj', create_conv2d(
dim,
self.num_heads * self.key_dim,
kernel_size=1,
bias=use_bias,
))
self.key = nn.Sequential()
if kv_stride > 1:
self.key.add_module('down_conv', create_conv2d(
dim,
dim,
kernel_size=dw_kernel_size,
stride=kv_stride,
dilation=dilation,
padding=padding,
depthwise=True,
))
self.key.add_module('norm', norm_layer(dim))
self.key.add_module('proj', create_conv2d(
dim,
self.key_dim,
kernel_size=1,
padding=padding,
bias=use_bias,
))
self.value = nn.Sequential()
if kv_stride > 1:
self.value.add_module('down_conv', create_conv2d(
dim,
dim,
kernel_size=dw_kernel_size,
stride=kv_stride,
dilation=dilation,
padding=padding,
depthwise=True,
))
self.value.add_module('norm', norm_layer(dim))
self.value.add_module('proj', create_conv2d(
dim,
self.value_dim,
kernel_size=1,
bias=use_bias,
))
self.attn_drop = nn.Dropout(attn_drop)
self.output = nn.Sequential()
if self.has_query_strides:
self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False))
self.output.add_module('proj', create_conv2d(
self.value_dim * self.num_heads,
dim_out,
kernel_size=1,
bias=use_bias,
))
self.output.add_module('drop', nn.Dropout(proj_drop))
self.einsum = False
def _reshape_input(self, t: torch.Tensor):
"""Reshapes a tensor to three dimensions, keeping the batch and channels."""
s = t.shape
t = t.reshape(s[0], s[1], -1).transpose(1, 2)
if self.einsum:
return t
else:
return t.unsqueeze(1).contiguous()
def _reshape_projected_query(self, t: torch.Tensor, num_heads: int, key_dim: int):
"""Reshapes projected query: [b, n, n, h x k] -> [b, n x n, h, k]."""
s = t.shape
t = t.reshape(s[0], num_heads, key_dim, -1)
if self.einsum:
return t.permute(0, 3, 1, 2).contiguous()
else:
return t.transpose(-1, -2).contiguous()
def _reshape_output(self, t: torch.Tensor, num_heads: int, h_px: int, w_px: int):
"""Reshape output:[b, n x n x h, k] -> [b, n, n, hk]."""
s = t.shape
feat_dim = s[-1] * num_heads
if not self.einsum:
t = t.transpose(1, 2)
return t.reshape(s[0], h_px, w_px, feat_dim).permute(0, 3, 1, 2).contiguous()
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
"""Run layer computation."""
B, C, H, W = s = x.shape
q = self.query(x)
# desired q shape: [b, h, k, n x n] - [b, l, h, k]
q = self._reshape_projected_query(q, self.num_heads, self.key_dim)
k = self.key(x)
# output shape of k: [b, k, p], p = m x m
k = self._reshape_input(k)
v = self.value(x)
# output shape of v: [ b, p, k], p = m x m
v = self._reshape_input(v)
# desired q shape: [b, n x n, h, k]
# desired k shape: [b, m x m, k]
# desired logits shape: [b, n x n, h, m x m]
if self.einsum:
attn = torch.einsum('blhk,bpk->blhp', q, k) * self.scale
if attn_mask is not None:
# NOTE: assumes mask is float and in correct shape
attn = attn + attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
o = torch.einsum('blhp,bpk->blhk', attn, v)
else:
if self.fused_attn:
o = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.
)
else:
q = q * self.scale
attn = q @ k.transpose(-1, -2)
if attn_mask is not None:
# NOTE: assumes mask is float and in correct shape
attn = attn + attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
o = attn @ v
# reshape o into [b, hk, n, n,]
o = self._reshape_output(o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1])
x = self.output(o)
return x
class Attention2d(nn.Module):
fused_attn: torch.jit.Final[bool]
""" multi-head attention for 2D NCHW tensors"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
num_heads: int = 32,
bias: bool = True,
expand_first: bool = False,
head_first: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.
):
super().__init__()
dim_out = dim_out or dim
dim_attn = dim_out if expand_first else dim
self.num_heads = num_heads
self.dim_head = dim_attn // num_heads
self.head_first = head_first
self.scale = num_heads ** -0.5
self.fused_attn = use_fused_attn()
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
B, C, H, W = x.shape
if self.head_first:
q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2)
else:
q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
if self.fused_attn:
x = torch.nn.functional.scaled_dot_product_attention(
q.transpose(-1, -2).contiguous(),
k.transpose(-1, -2).contiguous(),
v.transpose(-1, -2).contiguous(),
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.,
).transpose(-1, -2).reshape(B, -1, H, W)
else:
q = q * self.scale
attn = q.transpose(-2, -1) @ k
if attn_mask is not None:
# NOTE: assumes mask is float and in correct shape
attn = attn + attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
x = self.proj(x)
x = self.proj_drop(x)
return x

View File

@ -5,12 +5,16 @@ BlurPool layer inspired by
Hacked together by Chris Ha and Ross Wightman
"""
from functools import partial
from typing import Optional, Type
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .padding import get_padding
from .typing import LayerType
class BlurPool2d(nn.Module):
@ -26,17 +30,62 @@ class BlurPool2d(nn.Module):
Returns:
torch.Tensor: the transformed tensor.
"""
def __init__(self, channels, filt_size=3, stride=2) -> None:
def __init__(
self,
channels: Optional[int] = None,
filt_size: int = 3,
stride: int = 2,
pad_mode: str = 'reflect',
) -> None:
super(BlurPool2d, self).__init__()
assert filt_size > 1
self.channels = channels
self.filt_size = filt_size
self.stride = stride
self.pad_mode = pad_mode
self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1)
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :]
if channels is not None:
blur_filter = blur_filter.repeat(self.channels, 1, 1, 1)
self.register_buffer('filt', blur_filter, persistent=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, self.padding, 'reflect')
return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels)
x = F.pad(x, self.padding, mode=self.pad_mode)
if self.channels is None:
channels = x.shape[1]
weight = self.filt.expand(channels, 1, self.filt_size, self.filt_size)
else:
channels = self.channels
weight = self.filt
return F.conv2d(x, weight, stride=self.stride, groups=channels)
def create_aa(
aa_layer: LayerType,
channels: Optional[int] = None,
stride: int = 2,
enable: bool = True,
noop: Optional[Type[nn.Module]] = nn.Identity
) -> nn.Module:
""" Anti-aliasing """
if not aa_layer or not enable:
return noop() if noop is not None else None
if isinstance(aa_layer, str):
aa_layer = aa_layer.lower().replace('_', '').replace('-', '')
if aa_layer == 'avg' or aa_layer == 'avgpool':
aa_layer = nn.AvgPool2d
elif aa_layer == 'blur' or aa_layer == 'blurpool':
aa_layer = BlurPool2d
elif aa_layer == 'blurpc':
aa_layer = partial(BlurPool2d, pad_mode='constant')
else:
assert False, f"Unknown anti-aliasing layer ({aa_layer})."
try:
return aa_layer(channels=channels, stride=stride)
except TypeError as e:
return aa_layer(stride)

View File

@ -2,9 +2,12 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import functools
from typing import Any, Dict, Optional, Type
from torch import nn as nn
from .typing import LayerType, PadType
from .blur_pool import create_aa
from .create_conv2d import create_conv2d
from .create_norm_act import get_norm_act_layer
@ -12,41 +15,58 @@ from .create_norm_act import get_norm_act_layer
class ConvNormAct(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding='',
dilation=1,
groups=1,
bias=False,
apply_act=True,
norm_layer=nn.BatchNorm2d,
norm_kwargs=None,
act_layer=nn.ReLU,
act_kwargs=None,
drop_layer=None,
in_channels: int,
out_channels: int,
kernel_size: int = 1,
stride: int = 1,
padding: PadType = '',
dilation: int = 1,
groups: int = 1,
bias: bool = False,
apply_norm: bool = True,
apply_act: bool = True,
norm_layer: LayerType = nn.BatchNorm2d,
act_layer: LayerType = nn.ReLU,
drop_layer: Optional[Type[nn.Module]] = None,
conv_kwargs: Optional[Dict[str, Any]] = None,
norm_kwargs: Optional[Dict[str, Any]] = None,
act_kwargs: Optional[Dict[str, Any]] = None,
):
super(ConvNormAct, self).__init__()
conv_kwargs = conv_kwargs or {}
norm_kwargs = norm_kwargs or {}
act_kwargs = act_kwargs or {}
self.conv = create_conv2d(
in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
if drop_layer:
norm_kwargs['drop_layer'] = drop_layer
self.bn = norm_act_layer(
in_channels,
out_channels,
apply_act=apply_act,
act_kwargs=act_kwargs,
**norm_kwargs,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
**conv_kwargs,
)
if apply_norm:
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
if drop_layer:
norm_kwargs['drop_layer'] = drop_layer
self.bn = norm_act_layer(
out_channels,
apply_act=apply_act,
act_kwargs=act_kwargs,
**norm_kwargs,
)
else:
self.bn = nn.Sequential()
if drop_layer:
norm_kwargs['drop_layer'] = drop_layer
self.bn.add_module('drop', drop_layer())
@property
def in_channels(self):
return self.conv.in_channels
@ -64,54 +84,61 @@ class ConvNormAct(nn.Module):
ConvBnAct = ConvNormAct
def create_aa(aa_layer, channels, stride=2, enable=True):
if not aa_layer or not enable:
return nn.Identity()
if isinstance(aa_layer, functools.partial):
if issubclass(aa_layer.func, nn.AvgPool2d):
return aa_layer()
else:
return aa_layer(channels)
elif issubclass(aa_layer, nn.AvgPool2d):
return aa_layer(stride)
else:
return aa_layer(channels=channels, stride=stride)
class ConvNormActAa(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding='',
dilation=1,
groups=1,
bias=False,
apply_act=True,
norm_layer=nn.BatchNorm2d,
norm_kwargs=None,
act_layer=nn.ReLU,
act_kwargs=None,
aa_layer=None,
drop_layer=None,
in_channels: int,
out_channels: int,
kernel_size: int = 1,
stride: int = 1,
padding: PadType = '',
dilation: int = 1,
groups: int = 1,
bias: bool = False,
apply_norm: bool = True,
apply_act: bool = True,
norm_layer: LayerType = nn.BatchNorm2d,
act_layer: LayerType = nn.ReLU,
aa_layer: Optional[LayerType] = None,
drop_layer: Optional[Type[nn.Module]] = None,
conv_kwargs: Optional[Dict[str, Any]] = None,
norm_kwargs: Optional[Dict[str, Any]] = None,
act_kwargs: Optional[Dict[str, Any]] = None,
):
super(ConvNormActAa, self).__init__()
use_aa = aa_layer is not None and stride == 2
conv_kwargs = conv_kwargs or {}
norm_kwargs = norm_kwargs or {}
act_kwargs = act_kwargs or {}
self.conv = create_conv2d(
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
in_channels, out_channels, kernel_size,
stride=1 if use_aa else stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
**conv_kwargs,
)
if apply_norm:
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
if drop_layer:
norm_kwargs['drop_layer'] = drop_layer
self.bn = norm_act_layer(
out_channels,
apply_act=apply_act,
act_kwargs=act_kwargs,
**norm_kwargs,
)
else:
self.bn = nn.Sequential()
if drop_layer:
norm_kwargs['drop_layer'] = drop_layer
self.bn.add_module('drop', drop_layer())
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
if drop_layer:
norm_kwargs['drop_layer'] = drop_layer
self.bn = norm_act_layer(out_channels, apply_act=apply_act, act_kwargs=act_kwargs, **norm_kwargs)
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
@property

View File

@ -19,21 +19,18 @@ from torch import nn as nn
from torch.nn import functional as F
from torchvision.ops.misc import FrozenBatchNorm2d
from .create_act import get_act_layer
from .create_act import create_act_layer
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
from .trace_utils import _assert
def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
act_layer = get_act_layer(act_layer) # string -> nn.Module
act_kwargs = act_kwargs or {}
if act_layer is not None and apply_act:
if inplace:
act_kwargs['inplace'] = inplace
act = act_layer(**act_kwargs)
else:
act = nn.Identity()
return act
act_kwargs.setdefault('inplace', inplace)
act = None
if apply_act:
act = create_act_layer(act_layer, **act_kwargs)
return nn.Identity() if act is None else act
class BatchNormAct2d(nn.BatchNorm2d):
@ -421,7 +418,6 @@ class LayerNormAct(nn.LayerNorm):
):
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
self._fast_norm = is_fast_norm()

View File

@ -71,6 +71,7 @@ from .vision_transformer import *
from .vision_transformer_hybrid import *
from .vision_transformer_relpos import *
from .vision_transformer_sam import *
from .vitamin import *
from .volo import *
from .vovnet import *
from .xception import *

View File

@ -2,18 +2,24 @@
Hacked together by / Copyright 2019, Ross Wightman
"""
from typing import Callable, Dict, Optional, Type
import torch
import torch.nn as nn
from torch.nn import functional as F
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, create_aa, to_2tuple, LayerType,\
ConvNormAct, ConvNormActAa, get_norm_act_layer, MultiQueryAttention2d, Attention2d
__all__ = [
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual']
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual',
'UniversalInvertedResidual', 'MobileAttention'
]
ModuleType = Type[nn.Module]
def num_groups(group_size, channels):
def num_groups(group_size: Optional[int], channels: int):
if not group_size: # 0 or None
return 1 # normal conv with 1 group
else:
@ -35,8 +41,15 @@ class SqueezeExcite(nn.Module):
"""
def __init__(
self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU,
gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None):
self,
in_chs: int,
rd_ratio: float = 0.25,
rd_channels: Optional[int] = None,
act_layer: LayerType = nn.ReLU,
gate_layer: LayerType = nn.Sigmoid,
force_act_layer: Optional[LayerType] = None,
rd_round_fn: Optional[Callable] = None,
):
super(SqueezeExcite, self).__init__()
if rd_channels is None:
rd_round_fn = rd_round_fn or round
@ -59,16 +72,32 @@ class ConvBnAct(nn.Module):
""" Conv + Norm Layer + Activation w/ optional skip connection
"""
def __init__(
self, in_chs, out_chs, kernel_size, stride=1, dilation=1, group_size=0, pad_type='',
skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.):
self,
in_chs: int,
out_chs: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
group_size: int = 0,
pad_type: str = '',
skip: bool = False,
act_layer: LayerType = nn.ReLU,
norm_layer: LayerType = nn.BatchNorm2d,
aa_layer: Optional[LayerType] = None,
drop_path_rate: float = 0.,
):
super(ConvBnAct, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
groups = num_groups(group_size, in_chs)
self.has_skip = skip and stride == 1 and in_chs == out_chs
use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
self.conv = create_conv2d(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type)
in_chs, out_chs, kernel_size,
stride=1 if use_aa else stride,
dilation=dilation, groups=groups, padding=pad_type)
self.bn1 = norm_act_layer(out_chs, inplace=True)
self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
def feature_info(self, location):
@ -81,29 +110,64 @@ class ConvBnAct(nn.Module):
shortcut = x
x = self.conv(x)
x = self.bn1(x)
x = self.aa(x)
if self.has_skip:
x = self.drop_path(x) + shortcut
return x
class DepthwiseSeparableConv(nn.Module):
""" DepthwiseSeparable block
""" Depthwise-separable block
Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
"""
def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='',
noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
se_layer=None, drop_path_rate=0.):
self,
in_chs: int,
out_chs: int,
dw_kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
group_size: int = 1,
pad_type: str = '',
noskip: bool = False,
pw_kernel_size: int = 1,
pw_act: bool = False,
s2d: int = 0,
act_layer: LayerType = nn.ReLU,
norm_layer: LayerType = nn.BatchNorm2d,
aa_layer: Optional[LayerType] = None,
se_layer: Optional[ModuleType] = None,
drop_path_rate: float = 0.,
):
super(DepthwiseSeparableConv, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
groups = num_groups(group_size, in_chs)
self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv
use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
# Space to depth
if s2d == 1:
sd_chs = int(in_chs * 4)
self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same')
self.bn_s2d = norm_act_layer(sd_chs, sd_chs)
dw_kernel_size = (dw_kernel_size + 1) // 2
dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type
in_chs = sd_chs
use_aa = False # disable AA
else:
self.conv_s2d = None
self.bn_s2d = None
dw_pad_type = pad_type
groups = num_groups(group_size, in_chs)
self.conv_dw = create_conv2d(
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, groups=groups)
in_chs, in_chs, dw_kernel_size,
stride=1 if use_aa else stride,
dilation=dilation, padding=dw_pad_type, groups=groups)
self.bn1 = norm_act_layer(in_chs, inplace=True)
self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa)
# Squeeze-and-excitation
self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity()
@ -120,8 +184,12 @@ class DepthwiseSeparableConv(nn.Module):
def forward(self, x):
shortcut = x
if self.conv_s2d is not None:
x = self.conv_s2d(x)
x = self.bn_s2d(x)
x = self.conv_dw(x)
x = self.bn1(x)
x = self.aa(x)
x = self.se(x)
x = self.conv_pw(x)
x = self.bn2(x)
@ -141,15 +209,48 @@ class InvertedResidual(nn.Module):
"""
def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='',
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
self,
in_chs: int,
out_chs: int,
dw_kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
group_size: int = 1,
pad_type: str = '',
noskip: bool = False,
exp_ratio: float = 1.0,
exp_kernel_size: int = 1,
pw_kernel_size: int = 1,
s2d: int = 0,
act_layer: LayerType = nn.ReLU,
norm_layer: LayerType = nn.BatchNorm2d,
aa_layer: Optional[LayerType] = None,
se_layer: Optional[ModuleType] = None,
conv_kwargs: Optional[Dict] = None,
drop_path_rate: float = 0.,
):
super(InvertedResidual, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
conv_kwargs = conv_kwargs or {}
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
# Space to depth
if s2d == 1:
sd_chs = int(in_chs * 4)
self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same')
self.bn_s2d = norm_act_layer(sd_chs, sd_chs)
dw_kernel_size = (dw_kernel_size + 1) // 2
dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type
in_chs = sd_chs
use_aa = False # disable AA
else:
self.conv_s2d = None
self.bn_s2d = None
dw_pad_type = pad_type
mid_chs = make_divisible(in_chs * exp_ratio)
groups = num_groups(group_size, mid_chs)
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
# Point-wise expansion
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
@ -157,9 +258,11 @@ class InvertedResidual(nn.Module):
# Depth-wise convolution
self.conv_dw = create_conv2d(
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
groups=groups, padding=pad_type, **conv_kwargs)
mid_chs, mid_chs, dw_kernel_size,
stride=1 if use_aa else stride,
dilation=dilation, groups=groups, padding=dw_pad_type, **conv_kwargs)
self.bn2 = norm_act_layer(mid_chs, inplace=True)
self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa)
# Squeeze-and-excitation
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
@ -177,10 +280,14 @@ class InvertedResidual(nn.Module):
def forward(self, x):
shortcut = x
if self.conv_s2d is not None:
x = self.conv_s2d(x)
x = self.bn_s2d(x)
x = self.conv_pw(x)
x = self.bn1(x)
x = self.conv_dw(x)
x = self.bn2(x)
x = self.aa(x)
x = self.se(x)
x = self.conv_pwl(x)
x = self.bn3(x)
@ -189,23 +296,317 @@ class InvertedResidual(nn.Module):
return x
class LayerScale2d(nn.Module):
def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
gamma = self.gamma.view(1, -1, 1, 1)
return x.mul_(gamma) if self.inplace else x * gamma
class UniversalInvertedResidual(nn.Module):
""" Universal Inverted Residual Block (aka Universal Inverted Bottleneck, UIB)
For MobileNetV4 - https://arxiv.org/abs/, referenced from
https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L778
"""
def __init__(
self,
in_chs: int,
out_chs: int,
dw_kernel_size_start: int = 0,
dw_kernel_size_mid: int = 3,
dw_kernel_size_end: int = 0,
stride: int = 1,
dilation: int = 1,
group_size: int = 1,
pad_type: str = '',
noskip: bool = False,
exp_ratio: float = 1.0,
act_layer: LayerType = nn.ReLU,
norm_layer: LayerType = nn.BatchNorm2d,
aa_layer: Optional[LayerType] = None,
se_layer: Optional[ModuleType] = None,
conv_kwargs: Optional[Dict] = None,
drop_path_rate: float = 0.,
layer_scale_init_value: Optional[float] = 1e-5,
):
super(UniversalInvertedResidual, self).__init__()
conv_kwargs = conv_kwargs or {}
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
if stride > 1:
assert dw_kernel_size_start or dw_kernel_size_mid or dw_kernel_size_end
# FIXME dilation isn't right w/ extra ks > 1 convs
if dw_kernel_size_start:
dw_start_stride = stride if not dw_kernel_size_mid else 1
dw_start_groups = num_groups(group_size, in_chs)
self.dw_start = ConvNormActAa(
in_chs, in_chs, dw_kernel_size_start,
stride=dw_start_stride,
dilation=dilation, # FIXME
groups=dw_start_groups,
padding=pad_type,
apply_act=False,
act_layer=act_layer,
norm_layer=norm_layer,
aa_layer=aa_layer,
**conv_kwargs,
)
else:
self.dw_start = nn.Identity()
# Point-wise expansion
mid_chs = make_divisible(in_chs * exp_ratio)
self.pw_exp = ConvNormAct(
in_chs, mid_chs, 1,
padding=pad_type,
act_layer=act_layer,
norm_layer=norm_layer,
**conv_kwargs,
)
# Middle depth-wise convolution
if dw_kernel_size_mid:
groups = num_groups(group_size, mid_chs)
self.dw_mid = ConvNormActAa(
mid_chs, mid_chs, dw_kernel_size_mid,
stride=stride,
dilation=dilation, # FIXME
groups=groups,
padding=pad_type,
act_layer=act_layer,
norm_layer=norm_layer,
aa_layer=aa_layer,
**conv_kwargs,
)
else:
# keeping mid as identity so it can be hooked more easily for features
self.dw_mid = nn.Identity()
# Squeeze-and-excitation
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
# Point-wise linear projection
self.pw_proj = ConvNormAct(
mid_chs, out_chs, 1,
padding=pad_type,
apply_act=False,
act_layer=act_layer,
norm_layer=norm_layer,
**conv_kwargs,
)
if dw_kernel_size_end:
dw_end_stride = stride if not dw_kernel_size_start and not dw_kernel_size_mid else 1
dw_end_groups = num_groups(group_size, out_chs)
if dw_end_stride > 1:
assert not aa_layer
self.dw_end = ConvNormAct(
out_chs, out_chs, dw_kernel_size_end,
stride=dw_end_stride,
dilation=dilation,
groups=dw_end_groups,
padding=pad_type,
apply_act=False,
act_layer=act_layer,
norm_layer=norm_layer,
**conv_kwargs,
)
else:
self.dw_end = nn.Identity()
if layer_scale_init_value is not None:
self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
else:
self.layer_scale = nn.Identity()
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
def feature_info(self, location):
if location == 'expansion': # after SE, input to PWL
return dict(module='pw_proj.conv', hook_type='forward_pre', num_chs=self.pw_proj.conv.in_channels)
else: # location == 'bottleneck', block output
return dict(module='', num_chs=self.pw_proj.conv.out_channels)
def forward(self, x):
shortcut = x
x = self.dw_start(x)
x = self.pw_exp(x)
x = self.dw_mid(x)
x = self.se(x)
x = self.pw_proj(x)
x = self.dw_end(x)
x = self.layer_scale(x)
if self.has_skip:
x = self.drop_path(x) + shortcut
return x
class MobileAttention(nn.Module):
""" Mobile Attention Block
For MobileNetV4 - https://arxiv.org/abs/, referenced from
https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L1504
"""
def __init__(
self,
in_chs: int,
out_chs: int,
stride: int = 1,
dw_kernel_size: int = 3,
dilation: int = 1,
group_size: int = 1,
pad_type: str = '',
num_heads: int = 8,
key_dim: int = 64,
value_dim: int = 64,
use_multi_query: bool = False,
query_strides: int = (1, 1),
kv_stride: int = 1,
cpe_dw_kernel_size: int = 3,
noskip: bool = False,
act_layer: LayerType = nn.ReLU,
norm_layer: LayerType = nn.BatchNorm2d,
aa_layer: Optional[LayerType] = None,
drop_path_rate: float = 0.,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
layer_scale_init_value: Optional[float] = 1e-5,
use_bias: bool = False,
use_cpe: bool = False,
):
super(MobileAttention, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
self.query_strides = to_2tuple(query_strides)
self.kv_stride = kv_stride
self.has_query_stride = any([s > 1 for s in self.query_strides])
# This CPE is different than the one suggested in the original paper.
# https://arxiv.org/abs/2102.10882
# 1. Rather than adding one CPE before the attention blocks, we add a CPE
# into every attention block.
# 2. We replace the expensive Conv2D by a Seperable DW Conv.
if use_cpe:
self.conv_cpe_dw = create_conv2d(
in_chs, in_chs,
kernel_size=cpe_dw_kernel_size,
dilation=dilation,
depthwise=True,
bias=True,
)
else:
self.conv_cpe_dw = None
self.norm = norm_act_layer(in_chs, apply_act=False)
if num_heads is None:
assert in_chs % key_dim == 0
num_heads = in_chs // key_dim
if use_multi_query:
self.attn = MultiQueryAttention2d(
in_chs,
dim_out=out_chs,
num_heads=num_heads,
key_dim=key_dim,
value_dim=value_dim,
query_strides=query_strides,
kv_stride=kv_stride,
dilation=dilation,
padding=pad_type,
dw_kernel_size=dw_kernel_size,
attn_drop=attn_drop,
proj_drop=proj_drop,
#bias=use_bias, # why not here if used w/ mhsa?
)
else:
self.attn = Attention2d(
in_chs,
dim_out=out_chs,
num_heads=num_heads,
attn_drop=attn_drop,
proj_drop=proj_drop,
bias=use_bias,
)
if layer_scale_init_value is not None:
self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
else:
self.layer_scale = nn.Identity()
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
def feature_info(self, location):
if location == 'expansion': # after SE, input to PW
return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
else: # location == 'bottleneck', block output
return dict(module='', num_chs=self.conv_pw.out_channels)
def forward(self, x):
if self.conv_cpe_dw is not None:
x_cpe = self.conv_cpe_dw(x)
x = x + x_cpe
shortcut = x
x = self.norm(x)
x = self.attn(x)
x = self.layer_scale(x)
if self.has_skip:
x = self.drop_path(x) + shortcut
return x
class CondConvResidual(InvertedResidual):
""" Inverted residual block w/ CondConv routing"""
def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='',
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
self,
in_chs: int,
out_chs: int,
dw_kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
group_size: int = 1,
pad_type: str = '',
noskip: bool = False,
exp_ratio: float = 1.0,
exp_kernel_size: int = 1,
pw_kernel_size: int = 1,
act_layer: LayerType = nn.ReLU,
norm_layer: LayerType = nn.BatchNorm2d,
aa_layer: Optional[LayerType] = None,
se_layer: Optional[ModuleType] = None,
num_experts: int = 0,
drop_path_rate: float = 0.,
):
self.num_experts = num_experts
conv_kwargs = dict(num_experts=self.num_experts)
super(CondConvResidual, self).__init__(
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, group_size=group_size,
pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs,
drop_path_rate=drop_path_rate)
in_chs,
out_chs,
dw_kernel_size=dw_kernel_size,
stride=stride,
dilation=dilation,
group_size=group_size,
pad_type=pad_type,
noskip=noskip,
exp_ratio=exp_ratio,
exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size,
act_layer=act_layer,
norm_layer=norm_layer,
aa_layer=aa_layer,
se_layer=se_layer,
conv_kwargs=conv_kwargs,
drop_path_rate=drop_path_rate,
)
self.routing_fn = nn.Linear(in_chs, self.num_experts)
def forward(self, x):
@ -237,9 +638,24 @@ class EdgeResidual(nn.Module):
"""
def __init__(
self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, group_size=0, pad_type='',
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
self,
in_chs: int,
out_chs: int,
exp_kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
group_size: int = 0,
pad_type: str = '',
force_in_chs: int = 0,
noskip: bool = False,
exp_ratio: float = 1.0,
pw_kernel_size: int = 1,
act_layer: LayerType = nn.ReLU,
norm_layer: LayerType = nn.BatchNorm2d,
aa_layer: Optional[LayerType] = None,
se_layer: Optional[ModuleType] = None,
drop_path_rate: float = 0.,
):
super(EdgeResidual, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
if force_in_chs > 0:
@ -248,12 +664,17 @@ class EdgeResidual(nn.Module):
mid_chs = make_divisible(in_chs * exp_ratio)
groups = num_groups(group_size, in_chs)
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
# Expansion convolution
self.conv_exp = create_conv2d(
in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type)
in_chs, mid_chs, exp_kernel_size,
stride=1 if use_aa else stride,
dilation=dilation, groups=groups, padding=pad_type)
self.bn1 = norm_act_layer(mid_chs, inplace=True)
self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa)
# Squeeze-and-excitation
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
@ -272,6 +693,7 @@ class EdgeResidual(nn.Module):
shortcut = x
x = self.conv_exp(x)
x = self.bn1(x)
x = self.aa(x)
x = self.se(x)
x = self.conv_pwl(x)
x = self.bn2(x)

View File

@ -5,6 +5,7 @@ Handles stride, dilation calculations, and selects feature extraction points.
Hacked together by / Copyright 2019, Ross Wightman
"""
from typing import Callable, Optional
import logging
import math
@ -16,7 +17,7 @@ from typing import Any, Dict, List
import torch.nn as nn
from ._efficientnet_blocks import *
from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible, LayerType
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
@ -139,8 +140,8 @@ def _decode_block_str(block_str):
# if act_layer is None, the model default (passed to model init) will be used
act_layer = options['n'] if 'n' in options else None
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
num_repeat = int(options['r'])
@ -154,29 +155,31 @@ def _decode_block_str(block_str):
if block_type == 'ir':
block_args.update(dict(
dw_kernel_size=_parse_ksize(options['k']),
exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size,
exp_kernel_size=start_kernel_size,
pw_kernel_size=end_kernel_size,
exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else 0.,
se_ratio=float(options.get('se', 0.)),
noskip=skip is False,
s2d=int(options.get('d', 0)) > 0,
))
if 'cc' in options:
block_args['num_experts'] = int(options['cc'])
elif block_type == 'ds' or block_type == 'dsa':
block_args.update(dict(
dw_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
se_ratio=float(options['se']) if 'se' in options else 0.,
pw_kernel_size=end_kernel_size,
se_ratio=float(options.get('se', 0.)),
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or skip is False,
s2d=int(options.get('d', 0)) > 0,
))
elif block_type == 'er':
block_args.update(dict(
exp_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
pw_kernel_size=end_kernel_size,
exp_ratio=float(options['e']),
force_in_chs=force_in_chs,
se_ratio=float(options['se']) if 'se' in options else 0.,
se_ratio=float(options.get('se', 0.)),
noskip=skip is False,
))
elif block_type == 'cn':
@ -184,6 +187,38 @@ def _decode_block_str(block_str):
kernel_size=int(options['k']),
skip=skip is True,
))
elif block_type == 'uir':
# override exp / proj kernels for start/end in uir block
start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 0
end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 0
block_args.update(dict(
dw_kernel_size_start=start_kernel_size, # overload exp ks arg for dw start
dw_kernel_size_mid=_parse_ksize(options['k']),
dw_kernel_size_end=end_kernel_size, # overload pw ks arg for dw end
exp_ratio=float(options['e']),
se_ratio=float(options.get('se', 0.)),
noskip=skip is False,
))
elif block_type == 'mha':
kv_dim = int(options['d'])
block_args.update(dict(
dw_kernel_size=_parse_ksize(options['k']),
num_heads=int(options['h']),
key_dim=kv_dim,
value_dim=kv_dim,
kv_stride=int(options.get('v', 1)),
noskip=skip is False,
))
elif block_type == 'mqa':
kv_dim = int(options['d'])
block_args.update(dict(
dw_kernel_size=_parse_ksize(options['k']),
num_heads=int(options['h']),
key_dim=kv_dim,
value_dim=kv_dim,
kv_stride=int(options.get('v', 1)),
noskip=skip is False,
))
else:
assert False, 'Unknown block type (%s)' % block_type
if 'gs' in options:
@ -285,14 +320,27 @@ class EfficientNetBuilder:
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
"""
def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=False,
act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''):
def __init__(
self,
output_stride: int = 32,
pad_type: str = '',
round_chs_fn: Callable = round_channels,
se_from_exp: bool = False,
act_layer: Optional[LayerType] = None,
norm_layer: Optional[LayerType] = None,
aa_layer: Optional[LayerType] = None,
se_layer: Optional[LayerType] = None,
drop_path_rate: float = 0.,
layer_scale_init_value: Optional[float] = None,
feature_location: str = '',
):
self.output_stride = output_stride
self.pad_type = pad_type
self.round_chs_fn = round_chs_fn
self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs
self.act_layer = act_layer
self.norm_layer = norm_layer
self.aa_layer = aa_layer
self.se_layer = get_attn(se_layer)
try:
self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg
@ -300,6 +348,7 @@ class EfficientNetBuilder:
except TypeError:
self.se_has_ratio = False
self.drop_path_rate = drop_path_rate
self.layer_scale_init_value = layer_scale_init_value
if feature_location == 'depthwise':
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
_logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
@ -317,6 +366,10 @@ class EfficientNetBuilder:
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self.round_chs_fn(ba['out_chs'])
s2d = ba.get('s2d', 0)
if s2d > 0:
# adjust while space2depth active
ba['out_chs'] *= 4
if 'force_in_chs' in ba and ba['force_in_chs']:
# NOTE this is a hack to work around mismatch in TF EdgeEffNet impl
ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs'])
@ -326,16 +379,22 @@ class EfficientNetBuilder:
assert ba['act_layer'] is not None
ba['norm_layer'] = self.norm_layer
ba['drop_path_rate'] = drop_path_rate
if bt != 'cn':
se_ratio = ba.pop('se_ratio')
if se_ratio and self.se_layer is not None:
if not self.se_from_exp:
# adjust se_ratio by expansion ratio if calculating se channels from block input
se_ratio /= ba.get('exp_ratio', 1.0)
if self.se_has_ratio:
ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
else:
ba['se_layer'] = self.se_layer
if self.aa_layer is not None:
ba['aa_layer'] = self.aa_layer
se_ratio = ba.pop('se_ratio', None)
if se_ratio and self.se_layer is not None:
if not self.se_from_exp:
# adjust se_ratio by expansion ratio if calculating se channels from block input
se_ratio /= ba.get('exp_ratio', 1.0)
if s2d == 1:
# adjust for start of space2depth
se_ratio /= 4
if self.se_has_ratio:
ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
else:
ba['se_layer'] = self.se_layer
if bt == 'ir':
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
@ -349,8 +408,17 @@ class EfficientNetBuilder:
elif bt == 'cn':
_log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = ConvBnAct(**ba)
elif bt == 'uir':
_log_info_if(' UniversalInvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = UniversalInvertedResidual(**ba, layer_scale_init_value=self.layer_scale_init_value)
elif bt == 'mqa':
_log_info_if(' MobileMultiQueryAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = MobileAttention(**ba, use_multi_query=True, layer_scale_init_value=self.layer_scale_init_value)
elif bt == 'mha':
_log_info_if(' MobileMultiHeadAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = MobileAttention(**ba, layer_scale_init_value=self.layer_scale_init_value)
else:
assert False, 'Uknkown block type (%s) while building model.' % bt
assert False, 'Unknown block type (%s) while building model.' % bt
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
return block
@ -377,6 +445,7 @@ class EfficientNetBuilder:
self.features.append(feature_info)
# outer list of block_args defines the stacks
space2depth = 0
for stack_idx, stack_args in enumerate(model_block_args):
last_stack = stack_idx + 1 == len(model_block_args)
_log_info_if('Stack: {}'.format(stack_idx), self.verbose)
@ -392,6 +461,20 @@ class EfficientNetBuilder:
if block_idx >= 1: # only the first block in any stack can have a stride > 1
block_args['stride'] = 1
if not space2depth and block_args.pop('s2d', False):
assert block_args['stride'] == 1
space2depth = 1
if space2depth > 0:
# FIXME s2d is a WIP
if space2depth == 2 and block_args['stride'] == 2:
block_args['stride'] = 1
# to end s2d region, need to correct expansion and se ratio relative to input
block_args['exp_ratio'] /= 4
space2depth = 0
else:
block_args['s2d'] = space2depth
extract_features = False
if last_block:
next_stack_idx = stack_idx + 1
@ -416,6 +499,9 @@ class EfficientNetBuilder:
block = self._make_block(block_args, total_block_idx, total_block_count)
blocks.append(block)
if space2depth == 1:
space2depth = 2
# stash feature module name and channel info for model feature extraction
if extract_features:
feature_info = dict(

View File

@ -591,7 +591,7 @@ default_cfgs = generate_default_cfgs({
})
def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
def checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('module', state_dict)
# beit v2 didn't strip module
@ -637,7 +637,7 @@ def _create_beit(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', 3)
model = build_model_with_cfg(
Beit, variant, pretrained,
pretrained_filter_fn=_beit_checkpoint_filter_fn,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)

View File

@ -556,7 +556,7 @@ class EfficientFormer(nn.Module):
return x
def _checkpoint_filter_fn(state_dict, model):
def checkpoint_filter_fn(state_dict, model):
""" Remap original checkpoints -> timm """
if 'stem.0.weight' in state_dict:
return state_dict # non-original checkpoint, no remapping needed
@ -611,7 +611,7 @@ def _create_efficientformer(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', 4)
model = build_model_with_cfg(
EfficientFormer, variant, pretrained,
pretrained_filter_fn=_checkpoint_filter_fn,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)

View File

@ -36,7 +36,7 @@ the models and weights open source!
Hacked together by / Copyright 2019, Ross Wightman
"""
from functools import partial
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -44,10 +44,10 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct, LayerType
from ._builder import build_model_with_cfg, pretrained_cfg_for_features
from ._efficientnet_blocks import SqueezeExcite
from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
from ._manipulate import checkpoint_seq
@ -74,21 +74,22 @@ class EfficientNet(nn.Module):
def __init__(
self,
block_args,
num_classes=1000,
num_features=1280,
in_chans=3,
stem_size=32,
fix_stem=False,
output_stride=32,
pad_type='',
round_chs_fn=round_channels,
act_layer=None,
norm_layer=None,
se_layer=None,
drop_rate=0.,
drop_path_rate=0.,
global_pool='avg'
block_args: BlockArgs,
num_classes: int = 1000,
num_features: int = 1280,
in_chans: int = 3,
stem_size: int = 32,
fix_stem: bool = False,
output_stride: int = 32,
pad_type: str = '',
act_layer: Optional[LayerType] = None,
norm_layer: Optional[LayerType] = None,
aa_layer: Optional[LayerType] = None,
se_layer: Optional[LayerType] = None,
round_chs_fn: Callable = round_channels,
drop_rate: float = 0.,
drop_path_rate: float = 0.,
global_pool: str = 'avg'
):
super(EfficientNet, self).__init__()
act_layer = act_layer or nn.ReLU
@ -113,6 +114,7 @@ class EfficientNet(nn.Module):
round_chs_fn=round_chs_fn,
act_layer=act_layer,
norm_layer=norm_layer,
aa_layer=aa_layer,
se_layer=se_layer,
drop_path_rate=drop_path_rate,
)
@ -270,20 +272,21 @@ class EfficientNetFeatures(nn.Module):
def __init__(
self,
block_args,
out_indices=(0, 1, 2, 3, 4),
feature_location='bottleneck',
in_chans=3,
stem_size=32,
fix_stem=False,
output_stride=32,
pad_type='',
round_chs_fn=round_channels,
act_layer=None,
norm_layer=None,
se_layer=None,
drop_rate=0.,
drop_path_rate=0.
block_args: BlockArgs,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
feature_location: str = 'bottleneck',
in_chans: int = 3,
stem_size: int = 32,
fix_stem: bool = False,
output_stride: int = 32,
pad_type: str = '',
act_layer: Optional[LayerType] = None,
norm_layer: Optional[LayerType] = None,
aa_layer: Optional[LayerType] = None,
se_layer: Optional[LayerType] = None,
round_chs_fn: Callable = round_channels,
drop_rate: float = 0.,
drop_path_rate: float = 0.,
):
super(EfficientNetFeatures, self).__init__()
act_layer = act_layer or nn.ReLU
@ -306,6 +309,7 @@ class EfficientNetFeatures(nn.Module):
round_chs_fn=round_chs_fn,
act_layer=act_layer,
norm_layer=norm_layer,
aa_layer=aa_layer,
se_layer=se_layer,
drop_path_rate=drop_path_rate,
feature_location=feature_location,
@ -879,6 +883,88 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
return model
def _gen_efficientnet_x(
variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8,
group_size=None, version=1, pretrained=False, **kwargs):
"""Creates an EfficientNet model.
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
Paper: https://arxiv.org/abs/1905.11946
EfficientNet params
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
'efficientnet-x-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-x-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-x-b2': (1.1, 1.2, 260, 0.3),
'efficientnet-x-b3': (1.2, 1.4, 300, 0.3),
'efficientnet-x-b4': (1.4, 1.8, 380, 0.4),
'efficientnet-x-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-x-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-x-b7': (2.0, 3.1, 600, 0.5),
'efficientnet-x-b8': (2.2, 3.6, 672, 0.5),
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
Args:
channel_multiplier: multiplier to number of channels per layer
depth_multiplier: multiplier to number of repeats per stage
"""
"""
if version == 1:
blocks_args = [
'r1_k3_s11_e1_i32_o16_se0.25_d1_a0',
'r2_k3_s22_e6_i16_o24_se0.25_f1_d2_a1',
'r2_k5_s22_e6_i24_o40_se0.25_f1_a1',
'r3_k3_s22_e6_i40_o80_se0.25_a0',
'r3_k5_s11_e6_i80_o112_se0.25_a0',
'r4_k5_s22_e6_i112_o192_se0.25_a0',
'r1_k3_s11_e6_i192_o320_se0.25_a0',
]
elif version == 2:
blocks_args = [
'r1_k3_s11_e1_i32_o16_se0.25_d1_a0',
'r2_k3_s22_e4_i16_o24_se0.25_f1_d2_a1',
'r2_k5_s22_e4_i24_o40_se0.25_f1_a1',
'r3_k3_s22_e4_i40_o80_se0.25_a0',
'r3_k5_s11_e6_i80_o112_se0.25_a0',
'r4_k5_s22_e6_i112_o192_se0.25_a0',
'r1_k3_s11_e6_i192_o320_se0.25_a0',
]
"""
if version == 1:
arch_def = [
['ds_r1_k3_s1_e1_c16_se0.25_d1'],
['er_r2_k3_s2_e6_c24_se0.25_nre'],
['er_r2_k5_s2_e6_c40_se0.25_nre'],
['ir_r3_k3_s2_e6_c80_se0.25'],
['ir_r3_k5_s1_e6_c112_se0.25'],
['ir_r4_k5_s2_e6_c192_se0.25'],
['ir_r1_k3_s1_e6_c320_se0.25'],
]
else:
arch_def = [
['ds_r1_k3_s1_e1_c16_se0.25_d1'],
['er_r2_k3_s2_e4_c24_se0.25_nre'],
['er_r2_k5_s2_e4_c40_se0.25_nre'],
['ir_r3_k3_s2_e4_c80_se0.25'],
['ir_r3_k5_s1_e6_c112_se0.25'],
['ir_r4_k5_s2_e6_c192_se0.25'],
['ir_r1_k3_s1_e6_c320_se0.25'],
]
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, divisor=channel_divisor)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=round_chs_fn(1280),
stem_size=32,
round_chs_fn=round_chs_fn,
act_layer=resolve_act_layer(kwargs, 'silu'),
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs,
)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MixNet Small model.
@ -1072,6 +1158,7 @@ default_cfgs = generate_default_cfgs({
input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0),
'efficientnet_b3_g8_gn.untrained': _cfg(
input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0),
'efficientnet_blur_b0.untrained': _cfg(),
'efficientnet_es.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth',
@ -1768,6 +1855,17 @@ def efficientnet_b3_g8_gn(pretrained=False, **kwargs) -> EfficientNet:
return model
@register_model
def efficientnet_blur_b0(pretrained=False, **kwargs) -> EfficientNet:
""" EfficientNet-B0 w/ BlurPool """
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_blur_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained,
aa_layer='blurpc', **kwargs
)
return model
@register_model
def efficientnet_es(pretrained=False, **kwargs) -> EfficientNet:
""" EfficientNet-Edge Small. """
@ -2277,6 +2375,31 @@ def tf_efficientnetv2_b3(pretrained=False, **kwargs) -> EfficientNet:
return model
@register_model
def efficientnet_x_b3(pretrained=False, **kwargs) -> EfficientNet:
""" EfficientNet-B3 """
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
model = _gen_efficientnet_x(
'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_x_b5(pretrained=False, **kwargs) -> EfficientNet:
""" EfficientNet-B5 """
model = _gen_efficientnet_x(
'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_h_b5(pretrained=False, **kwargs) -> EfficientNet:
""" EfficientNet-B5 """
model = _gen_efficientnet_x(
'efficientnet_b5', channel_multiplier=1.92, depth_multiplier=2.2, version=2, pretrained=pretrained, **kwargs)
return model
@register_model
def mixnet_s(pretrained=False, **kwargs) -> EfficientNet:
"""Creates a MixNet Small model.

View File

@ -7,7 +7,7 @@
#
import os
from functools import partial
from typing import Tuple, Optional, Union
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \
ClassifierHead
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -40,19 +41,19 @@ class MobileOneBlock(nn.Module):
"""
def __init__(
self,
in_chs: int,
out_chs: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
group_size: int = 0,
inference_mode: bool = False,
use_se: bool = False,
use_act: bool = True,
use_scale_branch: bool = True,
num_conv_branches: int = 1,
act_layer: nn.Module = nn.GELU,
self,
in_chs: int,
out_chs: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
group_size: int = 0,
inference_mode: bool = False,
use_se: bool = False,
use_act: bool = True,
use_scale_branch: bool = True,
num_conv_branches: int = 1,
act_layer: nn.Module = nn.GELU,
) -> None:
"""Construct a MobileOneBlock module.
@ -280,15 +281,16 @@ class ReparamLargeKernelConv(nn.Module):
"""
def __init__(
self,
in_chs: int,
out_chs: int,
kernel_size: int,
stride: int,
group_size: int,
small_kernel: Optional[int] = None,
inference_mode: bool = False,
act_layer: Optional[nn.Module] = None,
self,
in_chs: int,
out_chs: int,
kernel_size: int,
stride: int,
group_size: int,
small_kernel: Optional[int] = None,
use_se: bool = False,
act_layer: Optional[nn.Module] = None,
inference_mode: bool = False,
) -> None:
"""Construct a ReparamLargeKernelConv module.
@ -299,8 +301,8 @@ class ReparamLargeKernelConv(nn.Module):
stride: Stride size. Default: 1
group_size: Group size. Default: 1
small_kernel: Kernel size of small kernel conv branch.
inference_mode: If True, instantiates model in inference mode. Default: ``False``
act_layer: Activation module. Default: ``nn.GELU``
inference_mode: If True, instantiates model in inference mode. Default: ``False``
"""
super(ReparamLargeKernelConv, self).__init__()
self.stride = stride
@ -342,6 +344,7 @@ class ReparamLargeKernelConv(nn.Module):
groups=self.groups,
apply_act=False,
)
self.se = SqueezeExcite(out_chs, rd_ratio=0.25) if use_se else nn.Identity()
# FIXME output of this act was not used in original impl, likely due to bug
self.act = act_layer() if act_layer is not None else nn.Identity()
@ -352,6 +355,7 @@ class ReparamLargeKernelConv(nn.Module):
out = self.large_conv(x)
if self.small_conv is not None:
out = out + self.small_conv(x)
out = self.se(out)
out = self.act(out)
return out
@ -472,12 +476,12 @@ class Attention(nn.Module):
fused_attn: torch.jit.Final[bool]
def __init__(
self,
dim: int,
head_dim: int = 32,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
self,
dim: int,
head_dim: int = 32,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
"""Build MHSA module that can handle 3D or 4D input tensors.
@ -535,14 +539,15 @@ class PatchEmbed(nn.Module):
"""Convolutional patch embedding layer."""
def __init__(
self,
patch_size: int,
stride: int,
in_chs: int,
embed_dim: int,
act_layer: nn.Module = nn.GELU,
lkc_use_act: bool = False,
inference_mode: bool = False,
self,
patch_size: int,
stride: int,
in_chs: int,
embed_dim: int,
act_layer: nn.Module = nn.GELU,
lkc_use_act: bool = False,
use_se: bool = False,
inference_mode: bool = False,
) -> None:
"""Build patch embedding layer.
@ -562,14 +567,16 @@ class PatchEmbed(nn.Module):
stride=stride,
group_size=1,
small_kernel=3,
inference_mode=inference_mode,
use_se=use_se,
act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act
inference_mode=inference_mode,
),
MobileOneBlock(
in_chs=embed_dim,
out_chs=embed_dim,
kernel_size=1,
stride=1,
use_se=False,
act_layer=act_layer,
inference_mode=inference_mode,
)
@ -598,11 +605,11 @@ class RepMixer(nn.Module):
"""
def __init__(
self,
dim,
kernel_size=3,
layer_scale_init_value=1e-5,
inference_mode: bool = False,
self,
dim,
kernel_size=3,
layer_scale_init_value=1e-5,
inference_mode: bool = False,
):
"""Build RepMixer Module.
@ -648,7 +655,7 @@ class RepMixer(nn.Module):
if layer_scale_init_value is not None:
self.layer_scale = LayerScale2d(dim, layer_scale_init_value)
else:
self.layer_scale = nn.Identity
self.layer_scale = nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.reparam_conv is not None:
@ -706,12 +713,12 @@ class ConvMlp(nn.Module):
"""Convolutional FFN Module."""
def __init__(
self,
in_chs: int,
hidden_channels: Optional[int] = None,
out_chs: Optional[int] = None,
act_layer: nn.Module = nn.GELU,
drop: float = 0.0,
self,
in_chs: int,
hidden_channels: Optional[int] = None,
out_chs: Optional[int] = None,
act_layer: nn.Module = nn.GELU,
drop: float = 0.0,
) -> None:
"""Build convolutional FFN module.
@ -764,11 +771,11 @@ class RepConditionalPosEnc(nn.Module):
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
inference_mode=False,
self,
dim: int,
dim_out: Optional[int] = None,
spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
inference_mode=False,
) -> None:
"""Build reparameterizable conditional positional encoding
@ -878,15 +885,15 @@ class RepMixerBlock(nn.Module):
"""
def __init__(
self,
dim: int,
kernel_size: int = 3,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
proj_drop: float = 0.0,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-5,
inference_mode: bool = False,
self,
dim: int,
kernel_size: int = 3,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
proj_drop: float = 0.0,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-5,
inference_mode: bool = False,
):
"""Build RepMixer Block.
@ -936,14 +943,14 @@ class AttentionBlock(nn.Module):
"""
def __init__(
self,
dim: int,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d,
proj_drop: float = 0.0,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-5,
self,
dim: int,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d,
proj_drop: float = 0.0,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-5,
):
"""Build Attention Block.
@ -993,6 +1000,7 @@ class FastVitStage(nn.Module):
depth: int,
token_mixer_type: str,
downsample: bool = True,
se_downsample: bool = False,
down_patch_size: int = 7,
down_stride: int = 2,
pos_emb_layer: Optional[nn.Module] = None,
@ -1030,6 +1038,7 @@ class FastVitStage(nn.Module):
stride=down_stride,
in_chs=dim,
embed_dim=dim_out,
use_se=se_downsample,
act_layer=act_layer,
lkc_use_act=lkc_use_act,
inference_mode=inference_mode,
@ -1090,29 +1099,30 @@ class FastVit(nn.Module):
"""
def __init__(
self,
in_chans: int = 3,
layers: Tuple[int, ...] = (2, 2, 6, 2),
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
mlp_ratios: Tuple[float, ...] = (4,) * 4,
downsamples: Tuple[bool, ...] = (False, True, True, True),
repmixer_kernel_size: int = 3,
num_classes: int = 1000,
pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4,
down_patch_size: int = 7,
down_stride: int = 2,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
layer_scale_init_value: float = 1e-5,
fork_feat: bool = False,
cls_ratio: float = 2.0,
global_pool: str = 'avg',
norm_layer: nn.Module = nn.BatchNorm2d,
act_layer: nn.Module = nn.GELU,
lkc_use_act: bool = False,
inference_mode: bool = False,
self,
in_chans: int = 3,
layers: Tuple[int, ...] = (2, 2, 6, 2),
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
mlp_ratios: Tuple[float, ...] = (4,) * 4,
downsamples: Tuple[bool, ...] = (False, True, True, True),
se_downsamples: Tuple[bool, ...] = (False, False, False, False),
repmixer_kernel_size: int = 3,
num_classes: int = 1000,
pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4,
down_patch_size: int = 7,
down_stride: int = 2,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
layer_scale_init_value: float = 1e-5,
lkc_use_act: bool = False,
fork_feat: bool = False,
cls_ratio: float = 2.0,
global_pool: str = 'avg',
norm_layer: nn.Module = nn.BatchNorm2d,
act_layer: nn.Module = nn.GELU,
inference_mode: bool = False,
) -> None:
super().__init__()
self.num_classes = 0 if fork_feat else num_classes
@ -1140,6 +1150,7 @@ class FastVit(nn.Module):
dim_out=embed_dims[i],
depth=layers[i],
downsample=downsample,
se_downsample=se_downsamples[i],
down_patch_size=down_patch_size,
down_stride=down_stride,
pos_emb_layer=pos_embs[i],
@ -1160,6 +1171,7 @@ class FastVit(nn.Module):
scale *= 2
self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.num_stages = len(self.stages)
self.num_features = prev_dim
# For segmentation and detection, extract intermediate output
@ -1236,6 +1248,66 @@ class FastVit(nn.Module):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, all if None, select matching indices if sequence
norm: Apply norm layer to compatible intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
Returns:
"""
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
intermediates = []
take_indices, max_index = feature_take_indices(len(self.stages), indices)
# forward pass
x = self.stem(x)
last_idx = self.num_stages - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
feat_idx = 0
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.final_conv(x)
return x, intermediates
def prune_intermediate_layers(
self,
indices: Union[int, List[int], Tuple[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_head:
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
# input embedding
x = self.stem(x)
@ -1297,8 +1369,7 @@ default_cfgs = generate_default_cfgs({
"fastvit_ma36.apple_in1k": _cfg(
hf_hub_id='timm/',
crop_pct=0.95
),
crop_pct=0.95),
"fastvit_t8.apple_dist_in1k": _cfg(
hf_hub_id='timm/'),
@ -1318,15 +1389,111 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
crop_pct=0.95
),
"fastvit_mci0.apple_mclip": _cfg(
#hf_hub_id='timm/',
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt',
crop_pct=0.95,
num_classes=512, # CLIP proj dim
mean=(0., 0., 0.), std=(1., 1., 1.)
),
"fastvit_mci1.apple_mclip": _cfg(
# hf_hub_id='timm/',
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt',
crop_pct=0.95,
num_classes=512, # CLIP proj dim
mean=(0., 0., 0.), std=(1., 1., 1.)
),
"fastvit_mci2.apple_mclip": _cfg(
# hf_hub_id='timm/',
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt',
crop_pct=0.95,
num_classes=512, # CLIP proj dim
mean=(0., 0., 0.), std=(1., 1., 1.)
),
})
def checkpoint_filter_fn(state_dict, model):
""" Remap original checkpoints -> timm """
if 'stem.0.conv_kxk.0.conv.weight' in state_dict:
return state_dict # non-original checkpoint, no remapping needed
state_dict = state_dict.get('state_dict', state_dict)
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
# remap MobileCLIP checkpoints
prefix = 'image_encoder.model.'
else:
prefix = ''
import re
import bisect
# find stage ends by locating downsample layers
stage_ends = []
for k, v in state_dict.items():
match = re.match(r'^(.*?)network\.(\d+)\.proj.*', k)
if match:
stage_ends.append(int(match.group(2)))
stage_ends = list(sorted(set(stage_ends)))
out_dict = {}
for k, v in state_dict.items():
if prefix:
if prefix not in k:
continue
k = k.replace(prefix, '')
# remap renamed layers
k = k.replace('patch_embed', 'stem')
k = k.replace('rbr_conv', 'conv_kxk')
k = k.replace('rbr_scale', 'conv_scale')
k = k.replace('rbr_skip', 'identity')
k = k.replace('conv_exp', 'final_conv') # to match byobnet, regnet, nfnet
k = k.replace('lkb_origin', 'large_conv')
k = k.replace('convffn', 'mlp')
k = k.replace('se.reduce', 'se.fc1')
k = k.replace('se.expand', 'se.fc2')
k = re.sub(r'layer_scale_([0-9])', r'layer_scale_\1.gamma', k)
if k.endswith('layer_scale'):
k = k.replace('layer_scale', 'layer_scale.gamma')
k = k.replace('dist_head', 'head_dist')
if k.startswith('head.'):
if k == 'head.proj' and hasattr(model.head, 'fc') and isinstance(model.head.fc, nn.Linear):
# if CLIP projection, map to head.fc w/ bias = zeros
k = k.replace('head.proj', 'head.fc.weight')
v = v.T
out_dict['head.fc.bias'] = torch.zeros(v.shape[0])
else:
k = k.replace('head.', 'head.fc.')
# remap flat sequential network to stages
match = re.match(r'^network\.(\d+)', k)
stage_idx, net_idx = None, None
if match:
net_idx = int(match.group(1))
stage_idx = bisect.bisect_right(stage_ends, net_idx)
if stage_idx is not None:
net_prefix = f'network.{net_idx}'
stage_prefix = f'stages.{stage_idx}'
if net_prefix + '.proj' in k:
k = k.replace(net_prefix + '.proj', stage_prefix + '.downsample.proj')
elif net_prefix + '.pe' in k:
k = k.replace(net_prefix + '.pe', stage_prefix + '.pos_emb.pos_enc')
else:
k = k.replace(net_prefix, stage_prefix + '.blocks')
out_dict[k] = v
return out_dict
def _create_fastvit(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
model = build_model_with_cfg(
FastVit,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs
)
@ -1419,3 +1586,48 @@ def fastvit_ma36(pretrained=False, **kwargs):
token_mixers=("repmixer", "repmixer", "repmixer", "attention")
)
return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_mci0(pretrained=False, **kwargs):
"""Instantiate MCi0 model variant."""
model_args = dict(
layers=(2, 6, 10, 2),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(3, 3, 3, 3),
se_downsamples=(False, False, True, True),
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
lkc_use_act=True,
)
return _create_fastvit('fastvit_mci0', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_mci1(pretrained=False, **kwargs):
"""Instantiate MCi1 model variant."""
model_args = dict(
layers=(4, 12, 20, 4),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(3, 3, 3, 3),
se_downsamples=(False, False, True, True),
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
lkc_use_act=True,
)
return _create_fastvit('fastvit_mci1', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def fastvit_mci2(pretrained=False, **kwargs):
"""Instantiate MCi2 model variant."""
model_args = dict(
layers=(4, 12, 24, 4),
embed_dims=(80, 160, 320, 640),
mlp_ratios=(3, 3, 3, 3),
se_downsamples=(False, False, True, True),
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
lkc_use_act=True,
)
return _create_fastvit('fastvit_mci2', pretrained=pretrained, **dict(model_args, **kwargs))

View File

@ -40,6 +40,7 @@ class MobileNetV3(nn.Module):
* HardCoRe-NAS - https://arxiv.org/abs/2102.11646 (defn in hardcorenas.py uses this class)
* FBNet-V3 - https://arxiv.org/abs/2006.02049
* LCNet - https://arxiv.org/abs/2109.15099
* MobileNet-V4 - https://arxiv.org/abs/2404.10518
"""
def __init__(
@ -51,14 +52,17 @@ class MobileNetV3(nn.Module):
fix_stem: bool = False,
num_features: int = 1280,
head_bias: bool = True,
pad_type: PadType = '',
head_norm: bool = False,
pad_type: str = '',
act_layer: Optional[LayerType] = None,
norm_layer: Optional[LayerType] = None,
aa_layer: Optional[LayerType] = None,
se_layer: Optional[LayerType] = None,
se_from_exp: bool = True,
round_chs_fn: Callable = round_channels,
drop_rate: float = 0.,
drop_path_rate: float = 0.,
layer_scale_init_value: Optional[float] = None,
global_pool: str = 'avg',
):
"""
@ -73,11 +77,13 @@ class MobileNetV3(nn.Module):
pad_type: Type of padding to use for convolution layers.
act_layer: Type of activation layer.
norm_layer: Type of normalization layer.
aa_layer: Type of anti-aliasing layer.
se_layer: Type of Squeeze-and-Excite layer.
se_from_exp: If True, calculate SE channel reduction from expanded mid channels.
round_chs_fn: Callable to round number of filters based on depth multiplier.
drop_rate: Dropout rate.
drop_path_rate: Stochastic depth rate.
layer_scale_init_value: Enable layer scale on compatible blocks if not None.
global_pool: Type of pooling to use for global pooling features of the FC head.
"""
super(MobileNetV3, self).__init__()
@ -104,8 +110,10 @@ class MobileNetV3(nn.Module):
se_from_exp=se_from_exp,
act_layer=act_layer,
norm_layer=norm_layer,
aa_layer=aa_layer,
se_layer=se_layer,
drop_path_rate=drop_path_rate,
layer_scale_init_value=layer_scale_init_value,
)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features
@ -115,8 +123,16 @@ class MobileNetV3(nn.Module):
# Head + Pooling
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
num_pooled_chs = head_chs * self.global_pool.feat_mult()
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.act2 = act_layer(inplace=True)
if head_norm:
# mobilenet-v4 post-pooling PW conv is followed by a norm+act layer
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type) # never bias
self.norm_head = norm_act_layer(self.num_features)
self.act2 = nn.Identity()
else:
# mobilenet-v3 and others only have an activation after final PW conv
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.norm_head = nn.Identity()
self.act2 = act_layer(inplace=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
@ -125,7 +141,7 @@ class MobileNetV3(nn.Module):
def as_sequential(self):
layers = [self.conv_stem, self.bn1]
layers.extend(self.blocks)
layers.extend([self.global_pool, self.conv_head, self.act2])
layers.extend([self.global_pool, self.conv_head, self.norm_head, self.act2])
layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
return nn.Sequential(*layers)
@ -224,8 +240,10 @@ class MobileNetV3(nn.Module):
self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0
if max_index < len(self.blocks):
self.conv_head = nn.Identity()
self.norm_head = nn.Identity()
if prune_head:
self.conv_head = nn.Identity()
self.norm_head = nn.Identity()
self.reset_classifier(0, '')
return take_indices
@ -241,6 +259,7 @@ class MobileNetV3(nn.Module):
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
x = self.global_pool(x)
x = self.conv_head(x)
x = self.norm_head(x)
x = self.act2(x)
x = self.flatten(x)
if pre_logits:
@ -276,9 +295,11 @@ class MobileNetV3Features(nn.Module):
se_from_exp: bool = True,
act_layer: Optional[LayerType] = None,
norm_layer: Optional[LayerType] = None,
aa_layer: Optional[LayerType] = None,
se_layer: Optional[LayerType] = None,
drop_rate: float = 0.,
drop_path_rate: float = 0.,
layer_scale_init_value: Optional[float] = None,
):
"""
Args:
@ -297,6 +318,7 @@ class MobileNetV3Features(nn.Module):
se_layer: Type of Squeeze-and-Excite layer.
drop_rate: Dropout rate.
drop_path_rate: Stochastic depth rate.
layer_scale_init_value: Enable layer scale on compatible blocks if not None.
"""
super(MobileNetV3Features, self).__init__()
act_layer = act_layer or nn.ReLU
@ -320,8 +342,10 @@ class MobileNetV3Features(nn.Module):
se_from_exp=se_from_exp,
act_layer=act_layer,
norm_layer=norm_layer,
aa_layer=aa_layer,
se_layer=se_layer,
drop_path_rate=drop_path_rate,
layer_scale_init_value=layer_scale_init_value,
feature_location=feature_location,
)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
@ -370,7 +394,7 @@ def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV
if 'feature_cfg' in kwargs or 'feature_cls' in kwargs:
features_mode = 'cfg'
else:
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'head_norm', 'global_pool')
model_cls = MobileNetV3Features
features_mode = 'cls'
@ -622,6 +646,252 @@ def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool =
return model
def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
"""Creates a MobileNet-V4 model.
Ref impl: ?
Paper: https://arxiv.org/abs/1905.02244
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
num_features = 1280
if 'hybrid' in variant:
layer_scale_init_value = 1e-5
if 'medium' in variant:
stem_size = 32
act_layer = resolve_act_layer(kwargs, 'relu')
arch_def = [
# stage 0, 112x112 in
[
'er_r1_k3_s2_e4_c48' # FusedIB (EdgeResidual)
],
# stage 1, 56x56 in
[
'uir_r1_a3_k5_s2_e4_c80', # ExtraDW
'uir_r1_a3_k3_s1_e2_c80', # ExtraDW
],
# stage 2, 28x28 in
[
'uir_r1_a3_k5_s2_e6_c160', # ExtraDW
'uir_r1_a0_k0_s1_e2_c160', # FFN
'uir_r1_a3_k3_s1_e4_c160', # ExtraDW
'uir_r1_a3_k5_s1_e4_c160', # ExtraDW
'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample
'uir_r1_a3_k3_s1_e4_c160', # ExtraDW
'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample
'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt
'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample
'uir_r1_a3_k3_s1_e4_c160', # ExtraDW
'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample
'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt
],
# stage 3, 14x14in
[
'uir_r1_a5_k5_s2_e6_c256', # ExtraDW
'uir_r1_a5_k5_s1_e4_c256', # ExtraDW
'uir_r2_a3_k5_s1_e4_c256', # ExtraDW
'uir_r1_a0_k0_s1_e2_c256', # FFN
'uir_r1_a3_k5_s1_e2_c256', # ExtraDW
'uir_r1_a0_k0_s1_e2_c256', # FFN
'uir_r1_a0_k0_s1_e4_c256', # FFN
'mqa_r1_k3_h4_s1_d64_c256', # MQA
'uir_r1_a3_k0_s1_e4_c256', # ConvNeXt
'mqa_r1_k3_h4_s1_d64_c256', # MQA
'uir_r1_a5_k5_s1_e4_c256', # ExtraDW
'mqa_r1_k3_h4_s1_d64_c256', # MQA
'uir_r1_a5_k0_s1_e4_c256', # ConvNeXt
'mqa_r1_k3_h4_s1_d64_c256', # MQA
'uir_r1_a5_k0_s1_e4_c256', # ConvNeXt
],
# stage 4, 7x7 in
[
'cn_r1_k1_s1_c960' # Conv
],
]
elif 'large' in variant:
stem_size = 24
act_layer = resolve_act_layer(kwargs, 'gelu')
arch_def = [
# stage 0, 112x112 in
[
'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual)
],
# stage 1, 56x56 in
[
'uir_r1_a3_k5_s2_e4_c96', # ExtraDW
'uir_r1_a3_k3_s1_e4_c96', # ExtraDW
],
# stage 2, 28x28 in
[
'uir_r1_a3_k5_s2_e4_c192', # ExtraDW
'uir_r3_a3_k3_s1_e4_c192', # ExtraDW
'uir_r1_a3_k5_s1_e4_c192', # ExtraDW
'uir_r2_a5_k3_s1_e4_c192', # ExtraDW
'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample
'uir_r1_a5_k3_s1_e4_c192', # ExtraDW
'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample
'uir_r1_a5_k3_s1_e4_c192', # ExtraDW
'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample
'uir_r1_a5_k3_s1_e4_c192', # ExtraDW
'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample
'uir_r1_a3_k0_s1_e4_c192', # ConvNeXt
],
# stage 3, 14x14in
[
'uir_r4_a5_k5_s2_e4_c512', # ExtraDW
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
'uir_r1_a5_k3_s1_e4_c512', # ExtraDW
'uir_r2_a5_k0_s1_e4_c512', # ConvNeXt
'uir_r1_a5_k3_s1_e4_c512', # ExtraDW
'uir_r1_a5_k5_s1_e4_c512', # ExtraDW
'mqa_r1_k3_h8_s1_d64_c512', # MQA
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
'mqa_r1_k3_h8_s1_d64_c512', # MQA
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
'mqa_r1_k3_h8_s1_d64_c512', # MQA
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
'mqa_r1_k3_h8_s1_d64_c512', # MQA
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
],
# stage 4, 7x7 in
[
'cn_r1_k1_s1_c960', # Conv
],
]
else:
assert False, f'Unknown variant {variant}.'
else:
layer_scale_init_value = None
if 'small' in variant:
stem_size = 32
act_layer = resolve_act_layer(kwargs, 'relu')
arch_def = [
# stage 0, 112x112 in
[
'cn_r1_k3_s2_e1_c32', # Conv
'cn_r1_k1_s1_e1_c32', # Conv
],
# stage 1, 56x56 in
[
'cn_r1_k3_s2_e1_c96', # Conv
'cn_r1_k1_s1_e1_c64', # Conv
],
# stage 2, 28x28 in
[
'uir_r1_a5_k5_s2_e3_c96', # ExtraDW
'uir_r4_a0_k3_s1_e2_c96', # IR
'uir_r1_a3_k0_s1_e4_c96', # ConvNeXt
],
# stage 3, 14x14 in
[
'uir_r1_a3_k3_s2_e6_c128', # ExtraDW
'uir_r1_a5_k5_s1_e4_c128', # ExtraDW
'uir_r1_a0_k5_s1_e4_c128', # IR
'uir_r1_a0_k5_s1_e3_c128', # IR
'uir_r2_a0_k3_s1_e4_c128', # IR
],
# stage 4, 7x7 in
[
'cn_r1_k1_s1_c960', # Conv
],
]
elif 'medium' in variant:
stem_size = 32
act_layer = resolve_act_layer(kwargs, 'relu')
arch_def = [
# stage 0, 112x112 in
[
'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual)
],
# stage 1, 56x56 in
[
'uir_r1_a3_k5_s2_e4_c80', # ExtraDW
'uir_r1_a3_k3_s1_e2_c80', # ExtraDW
],
# stage 2, 28x28 in
[
'uir_r1_a3_k5_s2_e6_c160', # ExtraDW
'uir_r2_a3_k3_s1_e4_c160', # ExtraDW
'uir_r1_a3_k5_s1_e4_c160', # ExtraDW
'uir_r1_a3_k3_s1_e4_c160', # ExtraDW
'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt
'uir_r1_a0_k0_s1_e2_c160', # ExtraDW
'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt
],
# stage 3, 14x14in
[
'uir_r1_a5_k5_s2_e6_c256', # ExtraDW
'uir_r1_a5_k5_s1_e4_c256', # ExtraDW
'uir_r2_a3_k5_s1_e4_c256', # ExtraDW
'uir_r1_a0_k0_s1_e4_c256', # FFN
'uir_r1_a3_k0_s1_e4_c256', # ConvNeXt
'uir_r1_a3_k5_s1_e2_c256', # ExtraDW
'uir_r1_a5_k5_s1_e4_c256', # ExtraDW
'uir_r2_a0_k0_s1_e4_c256', # FFN
'uir_r1_a5_k0_s1_e2_c256', # ConvNeXt
],
# stage 4, 7x7 in
[
'cn_r1_k1_s1_c960', # Conv
],
]
elif 'large' in variant:
stem_size = 24
act_layer = resolve_act_layer(kwargs, 'relu')
arch_def = [
# stage 0, 112x112 in
[
'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual)
],
# stage 1, 56x56 in
[
'uir_r1_a3_k5_s2_e4_c96', # ExtraDW
'uir_r1_a3_k3_s1_e4_c96', # ExtraDW
],
# stage 2, 28x28 in
[
'uir_r1_a3_k5_s2_e4_c192', # ExtraDW
'uir_r3_a3_k3_s1_e4_c192', # ExtraDW
'uir_r1_a3_k5_s1_e4_c192', # ExtraDW
'uir_r5_a5_k3_s1_e4_c192', # ExtraDW
'uir_r1_a3_k0_s1_e4_c192', # ConvNeXt
],
# stage 3, 14x14in
[
'uir_r4_a5_k5_s2_e4_c512', # ExtraDW
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
'uir_r1_a5_k3_s1_e4_c512', # ExtraDW
'uir_r2_a5_k0_s1_e4_c512', # ConvNeXt
'uir_r1_a5_k3_s1_e4_c512', # ExtraDW
'uir_r1_a5_k5_s1_e4_c512', # ExtraDW
'uir_r3_a5_k0_s1_e4_c512', # ConvNeXt
],
# stage 4, 7x7 in
[
'cn_r1_k1_s1_c960', # Conv
],
]
else:
assert False, f'Unknown variant {variant}.'
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
head_bias=False,
head_norm=True,
num_features=num_features,
stem_size=stem_size,
fix_stem=channel_multiplier < 1.0,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=act_layer,
layer_scale_init_value=layer_scale_init_value,
**kwargs,
)
model = _create_mnv3(variant, pretrained, **model_kwargs)
return model
def _cfg(url: str = '', **kwargs):
return {
@ -725,6 +995,52 @@ default_cfgs = generate_default_cfgs({
interpolation='bicubic',
),
"lcnet_150.untrained": _cfg(),
'mobilenetv4_conv_small': _cfg(
# hf_hub_id='timm/',
interpolation='bicubic'),
'mobilenetv4_conv_medium.r224': _cfg(
# hf_hub_id='timm/',
crop_pct=0.95, interpolation='bicubic'),
'mobilenetv4_conv_medium.r256': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
'mobilenetv4_conv_large.r256': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
'mobilenetv4_conv_large.r384': _cfg(
# hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=0.95, interpolation='bicubic'),
'mobilenetv4_hybrid_small': _cfg(
# hf_hub_id='timm/',
interpolation='bicubic'),
'mobilenetv4_hybrid_medium.r224': _cfg(
# hf_hub_id='timm/',
crop_pct=0.95, interpolation='bicubic'),
'mobilenetv4_hybrid_medium.r256': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
'mobilenetv4_hybrid_large.r256': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
'mobilenetv4_hybrid_large.r384': _cfg(
# hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=0.95, interpolation='bicubic'),
# experimental
'mobilenetv4_conv_aa_medium.r256': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
'mobilenetv4_conv_blur_medium.r256': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
'mobilenetv4_hybrid_medium_075': _cfg(
# hf_hub_id='timm/',
crop_pct=0.95, interpolation='bicubic'),
'mobilenetv4_hybrid_large_075.r256': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
})
@ -881,6 +1197,69 @@ def lcnet_150(pretrained: bool = False, **kwargs) -> MobileNetV3:
return model
@register_model
def mobilenetv4_conv_small(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V4 """
model = _gen_mobilenet_v4('mobilenetv4_conv_small', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv4_conv_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V4 """
model = _gen_mobilenet_v4('mobilenetv4_conv_medium', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv4_conv_large(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V4 """
model = _gen_mobilenet_v4('mobilenetv4_conv_large', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv4_hybrid_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V4 Hybrid """
model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv4_hybrid_large(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V4 Hybrid"""
model = _gen_mobilenet_v4('mobilenetv4_hybrid_large', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv4_conv_aa_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V4 w/ AvgPool AA """
model = _gen_mobilenet_v4('mobilenetv4_conv_aa_medium', 1.0, pretrained=pretrained, aa_layer='avg', **kwargs)
return model
@register_model
def mobilenetv4_conv_blur_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V4 Conv w/ Blur AA """
model = _gen_mobilenet_v4('mobilenetv4_conv_blur_medium', 1.0, pretrained=pretrained, aa_layer='blurpc', **kwargs)
return model
@register_model
def mobilenetv4_hybrid_medium_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V4 Hybrid """
model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium_075', 0.75, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv4_hybrid_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V4 Hybrid"""
model = _gen_mobilenet_v4('mobilenetv4_hybrid_large', 0.75, pretrained=pretrained, **kwargs)
return model
register_model_deprecations(__name__, {
'mobilenetv3_large_100_miil': 'mobilenetv3_large_100.miil_in21k_ft_in1k',
'mobilenetv3_large_100_miil_in21k': 'mobilenetv3_large_100.miil_in21k',

View File

@ -403,7 +403,7 @@ class PyramidVisionTransformerV2(nn.Module):
return x
def _checkpoint_filter_fn(state_dict, model):
def checkpoint_filter_fn(state_dict, model):
""" Remap original checkpoints -> timm """
if 'patch_embed.proj.weight' in state_dict:
return state_dict # non-original checkpoint, no remapping needed
@ -430,7 +430,7 @@ def _create_pvt2(variant, pretrained=False, **kwargs):
PyramidVisionTransformerV2,
variant,
pretrained,
pretrained_filter_fn=_checkpoint_filter_fn,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs,
)

View File

@ -17,7 +17,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \
get_attn, get_act_layer, get_norm_layer, create_classifier
get_attn, get_act_layer, get_norm_layer, create_classifier, create_aa
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint_seq
@ -31,15 +31,6 @@ def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int:
return padding
def create_aa(aa_layer: Type[nn.Module], channels: int, stride: int = 2, enable: bool = True) -> nn.Module:
if not aa_layer or not enable:
return nn.Identity()
if issubclass(aa_layer, nn.AvgPool2d):
return aa_layer(stride)
else:
return aa_layer(channels=channels, stride=stride)
class BasicBlock(nn.Module):
expansion = 1

View File

@ -409,6 +409,7 @@ class VisionTransformer(nn.Module):
qk_norm: bool = False,
init_values: Optional[float] = None,
class_token: bool = True,
pos_embed: str = 'learn',
no_embed_class: bool = False,
reg_tokens: int = 0,
pre_norm: bool = False,
@ -460,6 +461,7 @@ class VisionTransformer(nn.Module):
super().__init__()
assert global_pool in ('', 'avg', 'token', 'map')
assert class_token or global_pool != 'token'
assert pos_embed in ('', 'none', 'learn')
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
act_layer = get_act_layer(act_layer) or nn.GELU
@ -494,7 +496,10 @@ class VisionTransformer(nn.Module):
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
if not pos_embed or pos_embed == 'none':
self.pos_embed = None
else:
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
self.patch_drop = PatchDropout(
@ -556,7 +561,8 @@ class VisionTransformer(nn.Module):
def init_weights(self, mode: str = '') -> None:
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02)
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(get_init_weights_vit(mode, head_bias), self)
@ -583,6 +589,8 @@ class VisionTransformer(nn.Module):
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True) -> None:
self.grad_checkpointing = enable
if hasattr(self.patch_embed, 'set_grad_checkpointing'):
self.patch_embed.set_grad_checkpointing(enable)
@torch.jit.ignore
def get_classifier(self) -> nn.Module:
@ -600,6 +608,9 @@ class VisionTransformer(nn.Module):
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
if self.pos_embed is None:
return x.view(x.shape[0], -1, x.shape[-1])
if self.dynamic_img_size:
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(
@ -1066,10 +1077,13 @@ def checkpoint_filter_fn(
# IJEPA, vit in an 'encoder' submodule
state_dict = state_dict['encoder']
prefix = 'module.'
elif 'visual.trunk.pos_embed' in state_dict:
elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict:
# OpenCLIP model with timm vision encoder
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
prefix = 'visual.trunk.'
if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear):
# remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
out_dict['head.weight'] = state_dict['visual.head.proj.weight']
out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
if prefix:
# filter on & remove prefix string from keys

View File

@ -15,18 +15,20 @@ Hacked together by / Copyright 2020, Ross Wightman
"""
import math
from functools import partial
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to
from ._builder import build_model_with_cfg
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, create_resnetv2_stem
from .vision_transformer import _create_vision_transformer, VisionTransformer
from .vision_transformer import VisionTransformer
class HybridEmbed(nn.Module):
@ -38,14 +40,15 @@ class HybridEmbed(nn.Module):
def __init__(
self,
backbone,
img_size=224,
patch_size=1,
feature_size=None,
feature_ratio=None,
in_chans=3,
embed_dim=768,
bias=True,
backbone: nn.Module,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 1,
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
in_chans: int = 3,
embed_dim: int = 768,
bias: bool = True,
proj: bool = True,
flatten: bool = True,
output_fmt: Optional[str] = None,
strict_img_size: bool = True,
@ -95,7 +98,18 @@ class HybridEmbed(nn.Module):
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
if proj:
self.proj = nn.Conv2d(
feature_dim,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
)
else:
assert feature_dim == embed_dim,\
f'The feature dim ({feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
self.proj = nn.Identity()
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
total_reduction = (
@ -116,6 +130,13 @@ class HybridEmbed(nn.Module):
else:
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
if hasattr(self.backbone, 'set_grad_checkpointing'):
self.backbone.set_grad_checkpointing(enable=enable)
elif hasattr(self.backbone, 'grad_checkpointing'):
self.backbone.grad_checkpointing = enable
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
@ -139,24 +160,35 @@ class HybridEmbedWithSize(nn.Module):
"""
def __init__(
self,
backbone,
img_size=224,
patch_size=1,
feature_size=None,
in_chans=3,
embed_dim=768,
backbone: nn.Module,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 1,
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
in_chans: int = 3,
embed_dim: int = 768,
bias=True,
proj=True,
):
super().__init__(
backbone=backbone,
img_size=img_size,
patch_size=patch_size,
feature_size=feature_size,
feature_ratio=feature_ratio,
in_chans=in_chans,
embed_dim=embed_dim,
bias=bias,
proj=proj,
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
if hasattr(self.backbone, 'set_grad_checkpointing'):
self.backbone.set_grad_checkpointing(enable=enable)
elif hasattr(self.backbone, 'grad_checkpointing'):
self.backbone.grad_checkpointing = enable
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
x = self.backbone(x)
if isinstance(x, (list, tuple)):
@ -165,10 +197,43 @@ class HybridEmbedWithSize(nn.Module):
return x.flatten(2).transpose(1, 2), x.shape[-2:]
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
embed_layer = partial(HybridEmbed, backbone=backbone)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
class ConvStem(nn.Sequential):
def __init__(
self,
in_chans: int = 3,
depth: int = 3,
channels: Union[int, Tuple[int, ...]] = 64,
kernel_size: Union[int, Tuple[int, ...]] = 3,
stride: Union[int, Tuple[int, ...]] = (2, 2, 2),
padding: Union[str, int, Tuple[int, ...]] = "",
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
act_layer: Type[nn.Module] = nn.ReLU,
):
super().__init__()
if isinstance(channels, int):
# a default tiered channel strategy
channels = tuple([channels // 2**i for i in range(depth)][::-1])
kernel_size = to_ntuple(depth)(kernel_size)
padding = to_ntuple(depth)(padding)
assert depth == len(stride) == len(kernel_size) == len(channels)
in_chs = in_chans
for i in range(len(channels)):
last_conv = i == len(channels) - 1
self.add_module(f'{i}', ConvNormAct(
in_chs,
channels[i],
kernel_size=kernel_size[i],
stride=stride[i],
padding=padding[i],
bias=last_conv,
apply_norm=not last_conv,
apply_act=not last_conv,
norm_layer=norm_layer,
act_layer=act_layer,
))
in_chs = channels[i]
def _resnetv2(layers=(3, 4, 9), **kwargs):
@ -186,6 +251,66 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
return backbone
def _convert_mobileclip(state_dict, model, prefix='image_encoder.model.'):
out = {}
for k, v in state_dict.items():
if not k.startswith(prefix):
continue
k = k.replace(prefix, '')
k = k.replace('patch_emb.', 'patch_embed.backbone.')
k = k.replace('block.conv', 'conv')
k = k.replace('block.norm', 'bn')
k = k.replace('post_transformer_norm.', 'norm.')
k = k.replace('pre_norm_mha.0', 'norm1')
k = k.replace('pre_norm_mha.1', 'attn')
k = k.replace('pre_norm_ffn.0', 'norm2')
k = k.replace('pre_norm_ffn.1', 'mlp.fc1')
k = k.replace('pre_norm_ffn.4', 'mlp.fc2')
k = k.replace('qkv_proj.', 'qkv.')
k = k.replace('out_proj.', 'proj.')
k = k.replace('transformer.', 'blocks.')
if k == 'pos_embed.pos_embed.pos_embed':
k = 'pos_embed'
v = v.squeeze(0)
if 'classifier.proj' in k:
bias_k = k.replace('classifier.proj', 'head.bias')
k = k.replace('classifier.proj', 'head.weight')
v = v.T
out[bias_k] = torch.zeros(v.shape[0])
out[k] = v
return out
def checkpoint_filter_fn(
state_dict: Dict[str, torch.Tensor],
model: VisionTransformer,
interpolation: str = 'bicubic',
antialias: bool = True,
) -> Dict[str, torch.Tensor]:
from .vision_transformer import checkpoint_filter_fn as _filter_fn
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
state_dict = _convert_mobileclip(state_dict, model)
return _filter_fn(state_dict, model, interpolation=interpolation, antialias=antialias)
def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', 3)
embed_args = embed_args or {}
embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args)
kwargs.setdefault('embed_layer', embed_layer)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
return build_model_with_cfg(
VisionTransformer,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
def _cfg(url='', **kwargs):
return {
'url': url,
@ -260,6 +385,17 @@ default_cfgs = generate_default_cfgs({
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet50d_224.untrained': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_mci_224.apple_mclip': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt',
num_classes=512,
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv',
),
'vit_base_mci_224.apple_mclip_lt': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt',
num_classes=512,
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv',
),
})
@ -407,6 +543,26 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs) -> VisionTransformer:
return model
@register_model
def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer:
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
"""
backbone = ConvStem(
channels=(768//4, 768//4, 768),
stride=(4, 2, 2),
kernel_size=(4, 2, 2),
padding=0,
in_chans=kwargs.get('in_chans', 3),
act_layer=nn.GELU,
)
model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
model = _create_vision_transformer_hybrid(
'vit_base_mci_224', backbone=backbone, embed_args=dict(proj=False),
pretrained=pretrained, **dict(model_args, **kwargs)
)
return model
register_model_deprecations(__name__, {
'vit_tiny_r_s16_p8_224_in21k': 'vit_tiny_r_s16_p8_224.augreg_in21k',
'vit_small_r26_s32_224_in21k': 'vit_small_r26_s32_224.augreg_in21k',

603
timm/models/vitamin.py Normal file
View File

@ -0,0 +1,603 @@
""" ViTamin
Paper: Designing Scalable Vison Models in the Vision-Language Era
A family of model weights on Huggingface: https://huggingface.co/collections/jienengchen/vitamin-family-661048126b72debdaca060bf
@inproceedings{chen2024vitamin,
title={ViTamin: Designing Scalable Vision Models in the Vision-language Era},
author={Chen, Jieneng and Yu, Qihang and Shen, Xiaohui and Yuille, Alan and Chen, Liang-Chieh},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2024}
}
Based on Apache 2.0 licensed code at https://github.com/ViTamin/ViTamin
Modifications and timm support by Jieneng Chen 2024
Reference:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py
"""
import math
from dataclasses import dataclass, field
from functools import partial
from typing import Optional, Union, Tuple
import torch
import torch.nn as nn
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import create_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, \
make_divisible, DropPath
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq
from ._registry import register_model, generate_default_cfgs
from .vision_transformer import VisionTransformer, checkpoint_filter_fn
from .vision_transformer_hybrid import HybridEmbed
@dataclass
class VitConvCfg:
expand_ratio: float = 4.0
expand_output: bool = True # calculate expansion channels from output (vs input chs)
kernel_size: int = 3
group_size: int = 1 # 1 == depthwise
pre_norm_act: bool = False # activation after pre-norm
stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
pool_type: str = 'avg2'
downsample_pool_type: str = 'avg2'
act_layer: str = 'gelu' # stem & stage 1234
norm_layer: str = ''
norm_eps: float = 1e-5
down_shortcut: Optional[bool] = True
mlp: str = 'mlp'
@dataclass
class VitCfg:
embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768)
depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2)
stem_width: int = 64
conv_cfg: VitConvCfg = field(default_factory=VitConvCfg)
head_type: str = ""
def _init_conv(module, name, scheme=''):
if isinstance(module, nn.Conv2d):
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
fan_out //= module.groups
nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
if module.bias is not None:
nn.init.zeros_(module.bias)
class Stem(nn.Module):
def __init__(
self,
in_chs: int,
out_chs: int,
act_layer: str = 'gelu',
norm_layer: str = 'layernorm2d',
norm_eps: float = 1e-6,
bias: bool = True,
):
super().__init__()
norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
self.out_chs = out_chs
self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias)
self.norm1 = norm_act_layer(out_chs)
self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias)
named_apply(_init_conv, self)
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.conv2(x)
return x
class Downsample2d(nn.Module):
def __init__(
self,
dim: int,
dim_out: int,
pool_type: str = 'avg2',
bias: bool = True,
):
super().__init__()
self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
if dim != dim_out:
self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) # 1x1 conv
else:
self.expand = nn.Identity()
def forward(self, x):
x = self.pool(x) # spatial downsample
x = self.expand(x) # expand chs
return x
class StridedConv(nn.Module):
""" downsample 2d as well
"""
def __init__(
self,
kernel_size=3,
stride=2,
padding=1,
in_chans=3,
embed_dim=768
):
super().__init__()
norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
self.norm = norm_layer(in_chans) # affine over C
def forward(self, x):
x = self.norm(x)
x = self.proj(x)
return x
class MbConvLNBlock(nn.Module):
""" Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)
"""
def __init__(
self,
in_chs: int,
out_chs: int,
stride: int = 1,
drop_path: float = 0.,
kernel_size: int = 3,
norm_layer: str = 'layernorm2d',
norm_eps: float = 1e-6,
act_layer: str = 'gelu',
expand_ratio: float = 4.0,
):
super(MbConvLNBlock, self).__init__()
self.stride, self.in_chs, self.out_chs = stride, in_chs, out_chs
mid_chs = make_divisible(out_chs * expand_ratio)
prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
if stride == 2:
self.shortcut = Downsample2d(in_chs, out_chs, pool_type='avg', bias=True)
elif in_chs != out_chs:
self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True)
else:
self.shortcut = nn.Identity()
self.pre_norm = prenorm_act_layer(in_chs, apply_act=False)
self.down = nn.Identity()
self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True)
self.act1 = create_act_layer(act_layer, inplace=True)
self.conv2_kxk = create_conv2d(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True)
self.act2 = create_act_layer(act_layer, inplace=True)
self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def init_weights(self, scheme=''):
named_apply(partial(_init_conv, scheme=scheme), self)
def forward(self, x):
shortcut = self.shortcut(x)
x = self.pre_norm(x)
x = self.down(x) # nn.Identity()
# 1x1 expansion conv & act
x = self.conv1_1x1(x)
x = self.act1(x)
# (strided) depthwise 3x3 conv & act
x = self.conv2_kxk(x)
x = self.act2(x)
# 1x1 linear projection to output width
x = self.conv3_1x1(x)
x = self.drop_path(x) + shortcut
return x
class MbConvStages(nn.Module):
""" MobileConv for stage 1 and stage 2 of ViTamin
"""
def __init__(
self,
cfg: VitCfg,
img_size: Union[int, Tuple[int, int]] = 224, # place holder
in_chans: int = 3,
):
super().__init__()
self.grad_checkpointing = False
self.stem = Stem(
in_chs=in_chans,
out_chs=cfg.stem_width,
)
stages = []
self.num_stages = len(cfg.embed_dim)
for s, dim in enumerate(cfg.embed_dim[:2]): # stage
stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width
blocks = [
MbConvLNBlock(
in_chs = stage_in_chs if d==0 else dim,
out_chs = dim,
stride = 2 if d == 0 else 1,
)
for d in range(cfg.depths[s])
]
stages += [nn.Sequential(*blocks)]
self.stages = nn.Sequential(*stages)
self.pool = StridedConv(
stride=2,
in_chans=cfg.embed_dim[1],
embed_dim=cfg.embed_dim[2]
)
def forward(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
x = self.pool(x)
return x
class GeGluMlp(nn.Module):
def __init__(
self,
in_features,
hidden_features,
act_layer = 'gelu',
drop = 0.0,
):
super().__init__()
norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6)
self.norm = norm_layer(in_features)
self.w0 = nn.Linear(in_features, hidden_features)
self.act = create_act_layer(act_layer)
self.w1 = nn.Linear(in_features, hidden_features)
self.w2 = nn.Linear(hidden_features, in_features)
def forward(self, x):
x = self.norm(x)
x = self.act(self.w0(x)) * self.w1(x)
x = self.w2(x)
return x
def _create_vitamin(variant, pretrained=False, embed_cfg=None, **kwargs):
out_indices = kwargs.pop('out_indices', 3)
assert embed_cfg is not None
backbone = MbConvStages(cfg=embed_cfg)
kwargs['embed_layer'] = partial(HybridEmbed, backbone=backbone, proj=False)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
return build_model_with_cfg(
VisionTransformer,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
'first_conv': 'patch_embed.backbone.stem.conv1',
'classifier': 'head',
**kwargs
}
default_cfgs = generate_default_cfgs({
'vitamin_small_224.datacomp1b_clip_ltt': _cfg(
hf_hub_id='jienengchen/ViTamin-S-LTT', num_classes=384),
'vitamin_small_224.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-S', num_classes=384),
'vitamin_base_224.datacomp1b_clip_ltt': _cfg(
hf_hub_id='jienengchen/ViTamin-B-LTT', num_classes=768),
'vitamin_base_224.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-B', num_classes=768),
'vitamin_large_224.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-L-224px', num_classes=768),
'vitamin_large_256.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-L-256px', num_classes=768,
input_size=(3, 256, 256), crop_pct=1.0),
'vitamin_large_336.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-L-336px', num_classes=768,
input_size=(3, 336, 336), crop_pct=1.0),
'vitamin_large_384.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-L-384px', num_classes=768,
input_size=(3, 384, 384), crop_pct=1.0),
'vitamin_large2_224.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-L2-224px', num_classes=1024),
'vitamin_large2_256.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-L2-256px', num_classes=1024,
input_size=(3, 256, 256), crop_pct=1.0),
'vitamin_large2_336.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-L2-336px', num_classes=1024,
input_size=(3, 336, 336), crop_pct=1.0),
'vitamin_large2_384.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-L2-384px', num_classes=1024,
input_size=(3, 384, 384), crop_pct=1.0),
'vitamin_xlarge_256.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-XL-256px', num_classes=1152,
input_size=(3, 256, 256), crop_pct=1.0),
'vitamin_xlarge_336.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-XL-336px', num_classes=1152,
input_size=(3, 336, 336), crop_pct=1.0),
'vitamin_xlarge_384.datacomp1b_clip': _cfg(
hf_hub_id='jienengchen/ViTamin-XL-384px', num_classes=1152,
input_size=(3, 384, 384), crop_pct=1.0),
})
@register_model
def vitamin_small_224(pretrained=False, **kwargs) -> VisionTransformer:
embed_cfg = VitCfg(
embed_dim=(64, 128, 384),
depths=(2, 4, 1),
stem_width=64,
conv_cfg=VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
)
model_args = dict(
embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg
)
model = _create_vitamin('vitamin_small_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_base_224(pretrained=False, **kwargs) -> VisionTransformer:
embed_cfg = VitCfg(
embed_dim=(128, 256, 768),
depths=(2, 4, 1),
stem_width=128,
conv_cfg=VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
)
model_args = dict(
embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
model = _create_vitamin('vitamin_base_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_large_224(pretrained=False, **kwargs) -> VisionTransformer:
embed_cfg = VitCfg(
embed_dim=(160, 320, 1024),
depths=(2, 4, 1),
stem_width=160,
conv_cfg=VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
)
model_args = dict(
embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg,
)
model = _create_vitamin('vitamin_large_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
embed_cfg = VitCfg(
embed_dim=(160, 320, 1024),
depths=(2, 4, 1),
stem_width=160,
conv_cfg=VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
)
model_args = dict(
img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
model = _create_vitamin('vitamin_large_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
embed_cfg = VitCfg(
embed_dim=(160, 320, 1024),
depths=(2, 4, 1),
stem_width=160,
conv_cfg=VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
)
model_args = dict(
img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg
)
model = _create_vitamin('vitamin_large_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
embed_cfg = VitCfg(
embed_dim=(160, 320, 1024),
depths=(2, 4, 1),
stem_width=160,
conv_cfg=VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
)
model_args = dict(
img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
model = _create_vitamin('vitamin_large_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_large2_224(pretrained=False, **kwargs) -> VisionTransformer:
embed_cfg = VitCfg(
embed_dim=(160, 320, 1024),
depths=(2, 4, 1),
stem_width=160,
conv_cfg=VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
)
model_args = dict(
embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg,
)
model = _create_vitamin('vitamin_large2_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_large2_256(pretrained=False, **kwargs) -> VisionTransformer:
embed_cfg = VitCfg(
embed_dim=(160, 320, 1024),
depths=(2, 4, 1),
stem_width=160,
conv_cfg=VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
)
model_args = dict(
img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
model = _create_vitamin('vitamin_large2_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_large2_336(pretrained=False, **kwargs) -> VisionTransformer:
embed_cfg = VitCfg(
embed_dim=(160, 320, 1024),
depths=(2, 4, 1),
stem_width=160,
conv_cfg=VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
)
model_args = dict(
img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg
)
model = _create_vitamin('vitamin_large2_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_large2_384(pretrained=False, **kwargs) -> VisionTransformer:
embed_cfg = VitCfg(
embed_dim=(160, 320, 1024),
depths=(2, 4, 1),
stem_width=160,
conv_cfg=VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
)
model_args = dict(
img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
model = _create_vitamin('vitamin_large2_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
embed_cfg=VitCfg(
embed_dim=(192, 384, 1152),
depths=(2, 4, 1),
stem_width=192,
conv_cfg=VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
)
model_args = dict(
img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
model = _create_vitamin(
'vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
embed_cfg = VitCfg(
embed_dim=(192, 384, 1152),
depths=(2, 4, 1),
stem_width=192,
conv_cfg=VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
)
model_args = dict(
img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
model = _create_vitamin('vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
embed_cfg = VitCfg(
embed_dim=(192, 384, 1152),
depths=(2, 4, 1),
stem_width=192,
conv_cfg=VitConvCfg(
norm_layer='layernorm2d',
norm_eps=1e-6,
),
head_type='1d',
)
model_args = dict(
img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
model = _create_vitamin('vitamin_xlarge_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model