mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
commit
52659842cc
@ -52,7 +52,7 @@ FEAT_INTER_FILTERS = [
|
||||
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
|
||||
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
|
||||
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
|
||||
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera',
|
||||
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit',
|
||||
]
|
||||
|
||||
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
|
||||
@ -60,7 +60,7 @@ NON_STD_FILTERS = [
|
||||
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*'
|
||||
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*'
|
||||
]
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
from .activations import *
|
||||
from .adaptive_avgmax_pool import \
|
||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
|
||||
from .attention_pool import AttentionPoolLatent
|
||||
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
||||
from .blur_pool import BlurPool2d
|
||||
from .blur_pool import BlurPool2d, create_aa
|
||||
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
|
||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
|
||||
|
337
timm/layers/attention2d.py
Normal file
337
timm/layers/attention2d.py
Normal file
@ -0,0 +1,337 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .config import use_fused_attn
|
||||
from .create_conv2d import create_conv2d
|
||||
from .helpers import to_2tuple
|
||||
from .pool2d_same import create_pool2d
|
||||
|
||||
|
||||
class MultiQueryAttentionV2(nn.Module):
|
||||
"""Multi Query Attention.
|
||||
|
||||
Fast Transformer Decoding: One Write-Head is All You Need
|
||||
https://arxiv.org/pdf/1911.02150.pdf
|
||||
|
||||
This is an acceletor optimized version - removing multiple unneccessary
|
||||
tensor transpose by re-arranging indices according to the following rules: 1)
|
||||
contracted indices are at the end, 2) other indices have the same order in the
|
||||
input and output tensores.
|
||||
|
||||
Compared to V1, this gives 3x speed up.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
num_heads: int = 8,
|
||||
key_dim: int = 64,
|
||||
value_dim: int = 64,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
):
|
||||
"""Initializer."""
|
||||
super().__init__()
|
||||
dim_out = dim_out or dim
|
||||
self.num_heads = num_heads
|
||||
self.key_dim = key_dim
|
||||
self.value_dim = value_dim
|
||||
self.scale = key_dim ** -0.5
|
||||
|
||||
self.query_proj = nn.Parameter(torch.randn([self.num_heads, self.key_dim, dim]))
|
||||
self.key_proj = nn.Parameter(torch.randn([dim, self.key_dim]))
|
||||
self.value_proj = nn.Parameter(torch.randn([dim, self.value_dim]))
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.out_proj = nn.Parameter(torch.randn([dim_out, self.num_heads, self.value_dim]))
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def _reshape_input(self, t):
|
||||
"""Reshapes a tensor to three dimensions, keeping the first and last."""
|
||||
s = t.shape
|
||||
# Propagate the shape statically where possible.
|
||||
#num = t.shape[1:-1].numel()
|
||||
#return t.reshape(s[0], num, s[-1])
|
||||
return t.reshape(s[0], s[1], -1).transpose(1, 2)
|
||||
|
||||
def forward(self, x, m: Optional[torch.Tensor] = None):
|
||||
"""Run layer computation."""
|
||||
s = x.shape
|
||||
m = m or x
|
||||
|
||||
reshaped_x = self._reshape_input(x)
|
||||
reshaped_m = self._reshape_input(m)
|
||||
|
||||
q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj)
|
||||
k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj)
|
||||
|
||||
attn = torch.einsum('bnhk,bmk->bnhm', q, k)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj)
|
||||
o = torch.einsum('bnhm,bmv->bnhv', attn, v)
|
||||
result = torch.einsum('bnhv,dhv->bnd', o, self.out_proj)
|
||||
result = self.proj_drop(result)
|
||||
return result.reshape(s)
|
||||
|
||||
|
||||
class MultiQueryAttention2d(nn.Module):
|
||||
"""Multi Query Attention with spatial downsampling.
|
||||
|
||||
3 parameters are introduced for the spatial downsampling:
|
||||
1. kv_stride: downsampling factor on Key and Values only.
|
||||
2. query_strides: horizontal & vertical strides on Query only.
|
||||
|
||||
This is an optimized version.
|
||||
1. Projections in Attention is explict written out as 1x1 Conv2D.
|
||||
2. Additional reshapes are introduced to bring a up to 3x speed up.
|
||||
"""
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
num_heads: int = 8,
|
||||
key_dim: Optional[int] = None,
|
||||
value_dim: Optional[int] = None,
|
||||
query_strides: int = 1,
|
||||
kv_stride: int = 1,
|
||||
dw_kernel_size: int = 3,
|
||||
dilation: int = 1,
|
||||
padding: Union[str, int, List[int]] = '',
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
use_bias: bool = False,
|
||||
):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
num_heads: Number of attention heads.
|
||||
key_dim: Size of the attention key dimension.
|
||||
value_dim: Size of the attention value dimension.
|
||||
query_strides: Vertical stride size for query only.
|
||||
kv_stride: Key and value stride size.
|
||||
dw_kernel_size: Spatial dimension of the depthwise kernel.
|
||||
"""
|
||||
super().__init__()
|
||||
dim_out = dim_out or dim
|
||||
self.num_heads = num_heads
|
||||
self.key_dim = key_dim or dim // num_heads
|
||||
self.value_dim = value_dim or dim // num_heads
|
||||
self.query_strides = to_2tuple(query_strides)
|
||||
self.kv_stride = kv_stride
|
||||
self.has_query_strides = any([s > 1 for s in self.query_strides])
|
||||
self.scale = self.key_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
self.drop = attn_drop
|
||||
|
||||
self.query = nn.Sequential()
|
||||
if self.has_query_strides:
|
||||
# FIXME dilation
|
||||
self.query.add_module('down_pool', create_pool2d(
|
||||
'avg',
|
||||
kernel_size=self.query_strides,
|
||||
padding=padding,
|
||||
))
|
||||
self.query.add_module('norm', norm_layer(dim))
|
||||
self.query.add_module('proj', create_conv2d(
|
||||
dim,
|
||||
self.num_heads * self.key_dim,
|
||||
kernel_size=1,
|
||||
bias=use_bias,
|
||||
))
|
||||
|
||||
self.key = nn.Sequential()
|
||||
if kv_stride > 1:
|
||||
self.key.add_module('down_conv', create_conv2d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=dw_kernel_size,
|
||||
stride=kv_stride,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
depthwise=True,
|
||||
))
|
||||
self.key.add_module('norm', norm_layer(dim))
|
||||
self.key.add_module('proj', create_conv2d(
|
||||
dim,
|
||||
self.key_dim,
|
||||
kernel_size=1,
|
||||
padding=padding,
|
||||
bias=use_bias,
|
||||
))
|
||||
|
||||
self.value = nn.Sequential()
|
||||
if kv_stride > 1:
|
||||
self.value.add_module('down_conv', create_conv2d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=dw_kernel_size,
|
||||
stride=kv_stride,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
depthwise=True,
|
||||
))
|
||||
self.value.add_module('norm', norm_layer(dim))
|
||||
self.value.add_module('proj', create_conv2d(
|
||||
dim,
|
||||
self.value_dim,
|
||||
kernel_size=1,
|
||||
bias=use_bias,
|
||||
))
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
|
||||
self.output = nn.Sequential()
|
||||
if self.has_query_strides:
|
||||
self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False))
|
||||
self.output.add_module('proj', create_conv2d(
|
||||
self.value_dim * self.num_heads,
|
||||
dim_out,
|
||||
kernel_size=1,
|
||||
bias=use_bias,
|
||||
))
|
||||
self.output.add_module('drop', nn.Dropout(proj_drop))
|
||||
|
||||
self.einsum = False
|
||||
|
||||
def _reshape_input(self, t: torch.Tensor):
|
||||
"""Reshapes a tensor to three dimensions, keeping the batch and channels."""
|
||||
s = t.shape
|
||||
t = t.reshape(s[0], s[1], -1).transpose(1, 2)
|
||||
if self.einsum:
|
||||
return t
|
||||
else:
|
||||
return t.unsqueeze(1).contiguous()
|
||||
|
||||
def _reshape_projected_query(self, t: torch.Tensor, num_heads: int, key_dim: int):
|
||||
"""Reshapes projected query: [b, n, n, h x k] -> [b, n x n, h, k]."""
|
||||
s = t.shape
|
||||
t = t.reshape(s[0], num_heads, key_dim, -1)
|
||||
if self.einsum:
|
||||
return t.permute(0, 3, 1, 2).contiguous()
|
||||
else:
|
||||
return t.transpose(-1, -2).contiguous()
|
||||
|
||||
def _reshape_output(self, t: torch.Tensor, num_heads: int, h_px: int, w_px: int):
|
||||
"""Reshape output:[b, n x n x h, k] -> [b, n, n, hk]."""
|
||||
s = t.shape
|
||||
feat_dim = s[-1] * num_heads
|
||||
if not self.einsum:
|
||||
t = t.transpose(1, 2)
|
||||
return t.reshape(s[0], h_px, w_px, feat_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
||||
"""Run layer computation."""
|
||||
B, C, H, W = s = x.shape
|
||||
|
||||
q = self.query(x)
|
||||
# desired q shape: [b, h, k, n x n] - [b, l, h, k]
|
||||
q = self._reshape_projected_query(q, self.num_heads, self.key_dim)
|
||||
|
||||
k = self.key(x)
|
||||
# output shape of k: [b, k, p], p = m x m
|
||||
k = self._reshape_input(k)
|
||||
|
||||
v = self.value(x)
|
||||
# output shape of v: [ b, p, k], p = m x m
|
||||
v = self._reshape_input(v)
|
||||
|
||||
# desired q shape: [b, n x n, h, k]
|
||||
# desired k shape: [b, m x m, k]
|
||||
# desired logits shape: [b, n x n, h, m x m]
|
||||
if self.einsum:
|
||||
attn = torch.einsum('blhk,bpk->blhp', q, k) * self.scale
|
||||
if attn_mask is not None:
|
||||
# NOTE: assumes mask is float and in correct shape
|
||||
attn = attn + attn_mask
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
o = torch.einsum('blhp,bpk->blhk', attn, v)
|
||||
else:
|
||||
if self.fused_attn:
|
||||
o = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-1, -2)
|
||||
if attn_mask is not None:
|
||||
# NOTE: assumes mask is float and in correct shape
|
||||
attn = attn + attn_mask
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
o = attn @ v
|
||||
|
||||
# reshape o into [b, hk, n, n,]
|
||||
o = self._reshape_output(o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1])
|
||||
x = self.output(o)
|
||||
return x
|
||||
|
||||
|
||||
class Attention2d(nn.Module):
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
""" multi-head attention for 2D NCHW tensors"""
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
num_heads: int = 32,
|
||||
bias: bool = True,
|
||||
expand_first: bool = False,
|
||||
head_first: bool = False,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.
|
||||
):
|
||||
super().__init__()
|
||||
dim_out = dim_out or dim
|
||||
dim_attn = dim_out if expand_first else dim
|
||||
self.num_heads = num_heads
|
||||
self.dim_head = dim_attn // num_heads
|
||||
self.head_first = head_first
|
||||
self.scale = num_heads ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
if self.head_first:
|
||||
q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2)
|
||||
else:
|
||||
q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
|
||||
|
||||
if self.fused_attn:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q.transpose(-1, -2).contiguous(),
|
||||
k.transpose(-1, -2).contiguous(),
|
||||
v.transpose(-1, -2).contiguous(),
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
).transpose(-1, -2).reshape(B, -1, H, W)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q.transpose(-2, -1) @ k
|
||||
if attn_mask is not None:
|
||||
# NOTE: assumes mask is float and in correct shape
|
||||
attn = attn + attn_mask
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
@ -5,12 +5,16 @@ BlurPool layer inspired by
|
||||
|
||||
Hacked together by Chris Ha and Ross Wightman
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import Optional, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from .padding import get_padding
|
||||
from .typing import LayerType
|
||||
|
||||
|
||||
class BlurPool2d(nn.Module):
|
||||
@ -26,17 +30,62 @@ class BlurPool2d(nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: the transformed tensor.
|
||||
"""
|
||||
def __init__(self, channels, filt_size=3, stride=2) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
channels: Optional[int] = None,
|
||||
filt_size: int = 3,
|
||||
stride: int = 2,
|
||||
pad_mode: str = 'reflect',
|
||||
) -> None:
|
||||
super(BlurPool2d, self).__init__()
|
||||
assert filt_size > 1
|
||||
self.channels = channels
|
||||
self.filt_size = filt_size
|
||||
self.stride = stride
|
||||
self.pad_mode = pad_mode
|
||||
self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
|
||||
|
||||
coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
|
||||
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1)
|
||||
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :]
|
||||
if channels is not None:
|
||||
blur_filter = blur_filter.repeat(self.channels, 1, 1, 1)
|
||||
self.register_buffer('filt', blur_filter, persistent=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.pad(x, self.padding, 'reflect')
|
||||
return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels)
|
||||
x = F.pad(x, self.padding, mode=self.pad_mode)
|
||||
if self.channels is None:
|
||||
channels = x.shape[1]
|
||||
weight = self.filt.expand(channels, 1, self.filt_size, self.filt_size)
|
||||
else:
|
||||
channels = self.channels
|
||||
weight = self.filt
|
||||
return F.conv2d(x, weight, stride=self.stride, groups=channels)
|
||||
|
||||
|
||||
def create_aa(
|
||||
aa_layer: LayerType,
|
||||
channels: Optional[int] = None,
|
||||
stride: int = 2,
|
||||
enable: bool = True,
|
||||
noop: Optional[Type[nn.Module]] = nn.Identity
|
||||
) -> nn.Module:
|
||||
""" Anti-aliasing """
|
||||
if not aa_layer or not enable:
|
||||
return noop() if noop is not None else None
|
||||
|
||||
if isinstance(aa_layer, str):
|
||||
aa_layer = aa_layer.lower().replace('_', '').replace('-', '')
|
||||
if aa_layer == 'avg' or aa_layer == 'avgpool':
|
||||
aa_layer = nn.AvgPool2d
|
||||
elif aa_layer == 'blur' or aa_layer == 'blurpool':
|
||||
aa_layer = BlurPool2d
|
||||
elif aa_layer == 'blurpc':
|
||||
aa_layer = partial(BlurPool2d, pad_mode='constant')
|
||||
|
||||
else:
|
||||
assert False, f"Unknown anti-aliasing layer ({aa_layer})."
|
||||
|
||||
try:
|
||||
return aa_layer(channels=channels, stride=stride)
|
||||
except TypeError as e:
|
||||
return aa_layer(stride)
|
||||
|
@ -2,9 +2,12 @@
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import functools
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from torch import nn as nn
|
||||
|
||||
from .typing import LayerType, PadType
|
||||
from .blur_pool import create_aa
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_norm_act import get_norm_act_layer
|
||||
|
||||
@ -12,41 +15,58 @@ from .create_norm_act import get_norm_act_layer
|
||||
class ConvNormAct(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding='',
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
apply_act=True,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
norm_kwargs=None,
|
||||
act_layer=nn.ReLU,
|
||||
act_kwargs=None,
|
||||
drop_layer=None,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 1,
|
||||
stride: int = 1,
|
||||
padding: PadType = '',
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = False,
|
||||
apply_norm: bool = True,
|
||||
apply_act: bool = True,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
drop_layer: Optional[Type[nn.Module]] = None,
|
||||
conv_kwargs: Optional[Dict[str, Any]] = None,
|
||||
norm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
act_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(ConvNormAct, self).__init__()
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
act_kwargs = act_kwargs or {}
|
||||
|
||||
self.conv = create_conv2d(
|
||||
in_channels, out_channels, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
|
||||
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn = norm_act_layer(
|
||||
in_channels,
|
||||
out_channels,
|
||||
apply_act=apply_act,
|
||||
act_kwargs=act_kwargs,
|
||||
**norm_kwargs,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
**conv_kwargs,
|
||||
)
|
||||
|
||||
if apply_norm:
|
||||
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn = norm_act_layer(
|
||||
out_channels,
|
||||
apply_act=apply_act,
|
||||
act_kwargs=act_kwargs,
|
||||
**norm_kwargs,
|
||||
)
|
||||
else:
|
||||
self.bn = nn.Sequential()
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn.add_module('drop', drop_layer())
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
return self.conv.in_channels
|
||||
@ -64,54 +84,61 @@ class ConvNormAct(nn.Module):
|
||||
ConvBnAct = ConvNormAct
|
||||
|
||||
|
||||
def create_aa(aa_layer, channels, stride=2, enable=True):
|
||||
if not aa_layer or not enable:
|
||||
return nn.Identity()
|
||||
if isinstance(aa_layer, functools.partial):
|
||||
if issubclass(aa_layer.func, nn.AvgPool2d):
|
||||
return aa_layer()
|
||||
else:
|
||||
return aa_layer(channels)
|
||||
elif issubclass(aa_layer, nn.AvgPool2d):
|
||||
return aa_layer(stride)
|
||||
else:
|
||||
return aa_layer(channels=channels, stride=stride)
|
||||
|
||||
|
||||
class ConvNormActAa(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding='',
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
apply_act=True,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
norm_kwargs=None,
|
||||
act_layer=nn.ReLU,
|
||||
act_kwargs=None,
|
||||
aa_layer=None,
|
||||
drop_layer=None,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 1,
|
||||
stride: int = 1,
|
||||
padding: PadType = '',
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = False,
|
||||
apply_norm: bool = True,
|
||||
apply_act: bool = True,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
drop_layer: Optional[Type[nn.Module]] = None,
|
||||
conv_kwargs: Optional[Dict[str, Any]] = None,
|
||||
norm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
act_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(ConvNormActAa, self).__init__()
|
||||
use_aa = aa_layer is not None and stride == 2
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
act_kwargs = act_kwargs or {}
|
||||
|
||||
self.conv = create_conv2d(
|
||||
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
|
||||
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
in_channels, out_channels, kernel_size,
|
||||
stride=1 if use_aa else stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
**conv_kwargs,
|
||||
)
|
||||
|
||||
if apply_norm:
|
||||
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn = norm_act_layer(
|
||||
out_channels,
|
||||
apply_act=apply_act,
|
||||
act_kwargs=act_kwargs,
|
||||
**norm_kwargs,
|
||||
)
|
||||
else:
|
||||
self.bn = nn.Sequential()
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn.add_module('drop', drop_layer())
|
||||
|
||||
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn = norm_act_layer(out_channels, apply_act=apply_act, act_kwargs=act_kwargs, **norm_kwargs)
|
||||
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
|
||||
|
||||
@property
|
||||
|
@ -19,21 +19,18 @@ from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from .create_act import get_act_layer
|
||||
from .create_act import create_act_layer
|
||||
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
|
||||
from .trace_utils import _assert
|
||||
|
||||
|
||||
def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
|
||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||
act_kwargs = act_kwargs or {}
|
||||
if act_layer is not None and apply_act:
|
||||
if inplace:
|
||||
act_kwargs['inplace'] = inplace
|
||||
act = act_layer(**act_kwargs)
|
||||
else:
|
||||
act = nn.Identity()
|
||||
return act
|
||||
act_kwargs.setdefault('inplace', inplace)
|
||||
act = None
|
||||
if apply_act:
|
||||
act = create_act_layer(act_layer, **act_kwargs)
|
||||
return nn.Identity() if act is None else act
|
||||
|
||||
|
||||
class BatchNormAct2d(nn.BatchNorm2d):
|
||||
@ -421,7 +418,6 @@ class LayerNormAct(nn.LayerNorm):
|
||||
):
|
||||
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
|
||||
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
|
||||
|
||||
self._fast_norm = is_fast_norm()
|
||||
|
@ -71,6 +71,7 @@ from .vision_transformer import *
|
||||
from .vision_transformer_hybrid import *
|
||||
from .vision_transformer_relpos import *
|
||||
from .vision_transformer_sam import *
|
||||
from .vitamin import *
|
||||
from .volo import *
|
||||
from .vovnet import *
|
||||
from .xception import *
|
||||
|
@ -2,18 +2,24 @@
|
||||
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
from typing import Callable, Dict, Optional, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer
|
||||
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, create_aa, to_2tuple, LayerType,\
|
||||
ConvNormAct, ConvNormActAa, get_norm_act_layer, MultiQueryAttention2d, Attention2d
|
||||
|
||||
__all__ = [
|
||||
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual']
|
||||
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual',
|
||||
'UniversalInvertedResidual', 'MobileAttention'
|
||||
]
|
||||
|
||||
ModuleType = Type[nn.Module]
|
||||
|
||||
|
||||
def num_groups(group_size, channels):
|
||||
def num_groups(group_size: Optional[int], channels: int):
|
||||
if not group_size: # 0 or None
|
||||
return 1 # normal conv with 1 group
|
||||
else:
|
||||
@ -35,8 +41,15 @@ class SqueezeExcite(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU,
|
||||
gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None):
|
||||
self,
|
||||
in_chs: int,
|
||||
rd_ratio: float = 0.25,
|
||||
rd_channels: Optional[int] = None,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
gate_layer: LayerType = nn.Sigmoid,
|
||||
force_act_layer: Optional[LayerType] = None,
|
||||
rd_round_fn: Optional[Callable] = None,
|
||||
):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
if rd_channels is None:
|
||||
rd_round_fn = rd_round_fn or round
|
||||
@ -59,16 +72,32 @@ class ConvBnAct(nn.Module):
|
||||
""" Conv + Norm Layer + Activation w/ optional skip connection
|
||||
"""
|
||||
def __init__(
|
||||
self, in_chs, out_chs, kernel_size, stride=1, dilation=1, group_size=0, pad_type='',
|
||||
skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.):
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 0,
|
||||
pad_type: str = '',
|
||||
skip: bool = False,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
drop_path_rate: float = 0.,
|
||||
):
|
||||
super(ConvBnAct, self).__init__()
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
groups = num_groups(group_size, in_chs)
|
||||
self.has_skip = skip and stride == 1 and in_chs == out_chs
|
||||
use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
|
||||
|
||||
self.conv = create_conv2d(
|
||||
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type)
|
||||
in_chs, out_chs, kernel_size,
|
||||
stride=1 if use_aa else stride,
|
||||
dilation=dilation, groups=groups, padding=pad_type)
|
||||
self.bn1 = norm_act_layer(out_chs, inplace=True)
|
||||
self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
|
||||
|
||||
def feature_info(self, location):
|
||||
@ -81,29 +110,64 @@ class ConvBnAct(nn.Module):
|
||||
shortcut = x
|
||||
x = self.conv(x)
|
||||
x = self.bn1(x)
|
||||
x = self.aa(x)
|
||||
if self.has_skip:
|
||||
x = self.drop_path(x) + shortcut
|
||||
return x
|
||||
|
||||
|
||||
class DepthwiseSeparableConv(nn.Module):
|
||||
""" DepthwiseSeparable block
|
||||
""" Depthwise-separable block
|
||||
Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
|
||||
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
|
||||
"""
|
||||
def __init__(
|
||||
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='',
|
||||
noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
||||
se_layer=None, drop_path_rate=0.):
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
dw_kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 1,
|
||||
pad_type: str = '',
|
||||
noskip: bool = False,
|
||||
pw_kernel_size: int = 1,
|
||||
pw_act: bool = False,
|
||||
s2d: int = 0,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[ModuleType] = None,
|
||||
drop_path_rate: float = 0.,
|
||||
):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
groups = num_groups(group_size, in_chs)
|
||||
self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.has_pw_act = pw_act # activation after point-wise conv
|
||||
use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
|
||||
|
||||
# Space to depth
|
||||
if s2d == 1:
|
||||
sd_chs = int(in_chs * 4)
|
||||
self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same')
|
||||
self.bn_s2d = norm_act_layer(sd_chs, sd_chs)
|
||||
dw_kernel_size = (dw_kernel_size + 1) // 2
|
||||
dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type
|
||||
in_chs = sd_chs
|
||||
use_aa = False # disable AA
|
||||
else:
|
||||
self.conv_s2d = None
|
||||
self.bn_s2d = None
|
||||
dw_pad_type = pad_type
|
||||
|
||||
groups = num_groups(group_size, in_chs)
|
||||
|
||||
self.conv_dw = create_conv2d(
|
||||
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, groups=groups)
|
||||
in_chs, in_chs, dw_kernel_size,
|
||||
stride=1 if use_aa else stride,
|
||||
dilation=dilation, padding=dw_pad_type, groups=groups)
|
||||
self.bn1 = norm_act_layer(in_chs, inplace=True)
|
||||
self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity()
|
||||
@ -120,8 +184,12 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
if self.conv_s2d is not None:
|
||||
x = self.conv_s2d(x)
|
||||
x = self.bn_s2d(x)
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn1(x)
|
||||
x = self.aa(x)
|
||||
x = self.se(x)
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn2(x)
|
||||
@ -141,15 +209,48 @@ class InvertedResidual(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='',
|
||||
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
dw_kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 1,
|
||||
pad_type: str = '',
|
||||
noskip: bool = False,
|
||||
exp_ratio: float = 1.0,
|
||||
exp_kernel_size: int = 1,
|
||||
pw_kernel_size: int = 1,
|
||||
s2d: int = 0,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[ModuleType] = None,
|
||||
conv_kwargs: Optional[Dict] = None,
|
||||
drop_path_rate: float = 0.,
|
||||
):
|
||||
super(InvertedResidual, self).__init__()
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
|
||||
use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
|
||||
|
||||
# Space to depth
|
||||
if s2d == 1:
|
||||
sd_chs = int(in_chs * 4)
|
||||
self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same')
|
||||
self.bn_s2d = norm_act_layer(sd_chs, sd_chs)
|
||||
dw_kernel_size = (dw_kernel_size + 1) // 2
|
||||
dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type
|
||||
in_chs = sd_chs
|
||||
use_aa = False # disable AA
|
||||
else:
|
||||
self.conv_s2d = None
|
||||
self.bn_s2d = None
|
||||
dw_pad_type = pad_type
|
||||
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
groups = num_groups(group_size, mid_chs)
|
||||
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
|
||||
|
||||
# Point-wise expansion
|
||||
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
|
||||
@ -157,9 +258,11 @@ class InvertedResidual(nn.Module):
|
||||
|
||||
# Depth-wise convolution
|
||||
self.conv_dw = create_conv2d(
|
||||
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
|
||||
groups=groups, padding=pad_type, **conv_kwargs)
|
||||
mid_chs, mid_chs, dw_kernel_size,
|
||||
stride=1 if use_aa else stride,
|
||||
dilation=dilation, groups=groups, padding=dw_pad_type, **conv_kwargs)
|
||||
self.bn2 = norm_act_layer(mid_chs, inplace=True)
|
||||
self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
|
||||
@ -177,10 +280,14 @@ class InvertedResidual(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
if self.conv_s2d is not None:
|
||||
x = self.conv_s2d(x)
|
||||
x = self.bn_s2d(x)
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn1(x)
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn2(x)
|
||||
x = self.aa(x)
|
||||
x = self.se(x)
|
||||
x = self.conv_pwl(x)
|
||||
x = self.bn3(x)
|
||||
@ -189,23 +296,317 @@ class InvertedResidual(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class LayerScale2d(nn.Module):
|
||||
def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False):
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
gamma = self.gamma.view(1, -1, 1, 1)
|
||||
return x.mul_(gamma) if self.inplace else x * gamma
|
||||
|
||||
|
||||
class UniversalInvertedResidual(nn.Module):
|
||||
""" Universal Inverted Residual Block (aka Universal Inverted Bottleneck, UIB)
|
||||
|
||||
For MobileNetV4 - https://arxiv.org/abs/, referenced from
|
||||
https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L778
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
dw_kernel_size_start: int = 0,
|
||||
dw_kernel_size_mid: int = 3,
|
||||
dw_kernel_size_end: int = 0,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 1,
|
||||
pad_type: str = '',
|
||||
noskip: bool = False,
|
||||
exp_ratio: float = 1.0,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[ModuleType] = None,
|
||||
conv_kwargs: Optional[Dict] = None,
|
||||
drop_path_rate: float = 0.,
|
||||
layer_scale_init_value: Optional[float] = 1e-5,
|
||||
):
|
||||
super(UniversalInvertedResidual, self).__init__()
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
|
||||
if stride > 1:
|
||||
assert dw_kernel_size_start or dw_kernel_size_mid or dw_kernel_size_end
|
||||
|
||||
# FIXME dilation isn't right w/ extra ks > 1 convs
|
||||
if dw_kernel_size_start:
|
||||
dw_start_stride = stride if not dw_kernel_size_mid else 1
|
||||
dw_start_groups = num_groups(group_size, in_chs)
|
||||
self.dw_start = ConvNormActAa(
|
||||
in_chs, in_chs, dw_kernel_size_start,
|
||||
stride=dw_start_stride,
|
||||
dilation=dilation, # FIXME
|
||||
groups=dw_start_groups,
|
||||
padding=pad_type,
|
||||
apply_act=False,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
aa_layer=aa_layer,
|
||||
**conv_kwargs,
|
||||
)
|
||||
else:
|
||||
self.dw_start = nn.Identity()
|
||||
|
||||
# Point-wise expansion
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
self.pw_exp = ConvNormAct(
|
||||
in_chs, mid_chs, 1,
|
||||
padding=pad_type,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
**conv_kwargs,
|
||||
)
|
||||
|
||||
# Middle depth-wise convolution
|
||||
if dw_kernel_size_mid:
|
||||
groups = num_groups(group_size, mid_chs)
|
||||
self.dw_mid = ConvNormActAa(
|
||||
mid_chs, mid_chs, dw_kernel_size_mid,
|
||||
stride=stride,
|
||||
dilation=dilation, # FIXME
|
||||
groups=groups,
|
||||
padding=pad_type,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
aa_layer=aa_layer,
|
||||
**conv_kwargs,
|
||||
)
|
||||
else:
|
||||
# keeping mid as identity so it can be hooked more easily for features
|
||||
self.dw_mid = nn.Identity()
|
||||
|
||||
# Squeeze-and-excitation
|
||||
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
|
||||
|
||||
# Point-wise linear projection
|
||||
self.pw_proj = ConvNormAct(
|
||||
mid_chs, out_chs, 1,
|
||||
padding=pad_type,
|
||||
apply_act=False,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
**conv_kwargs,
|
||||
)
|
||||
|
||||
if dw_kernel_size_end:
|
||||
dw_end_stride = stride if not dw_kernel_size_start and not dw_kernel_size_mid else 1
|
||||
dw_end_groups = num_groups(group_size, out_chs)
|
||||
if dw_end_stride > 1:
|
||||
assert not aa_layer
|
||||
self.dw_end = ConvNormAct(
|
||||
out_chs, out_chs, dw_kernel_size_end,
|
||||
stride=dw_end_stride,
|
||||
dilation=dilation,
|
||||
groups=dw_end_groups,
|
||||
padding=pad_type,
|
||||
apply_act=False,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
**conv_kwargs,
|
||||
)
|
||||
else:
|
||||
self.dw_end = nn.Identity()
|
||||
|
||||
if layer_scale_init_value is not None:
|
||||
self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
|
||||
else:
|
||||
self.layer_scale = nn.Identity()
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
|
||||
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion': # after SE, input to PWL
|
||||
return dict(module='pw_proj.conv', hook_type='forward_pre', num_chs=self.pw_proj.conv.in_channels)
|
||||
else: # location == 'bottleneck', block output
|
||||
return dict(module='', num_chs=self.pw_proj.conv.out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
x = self.dw_start(x)
|
||||
x = self.pw_exp(x)
|
||||
x = self.dw_mid(x)
|
||||
x = self.se(x)
|
||||
x = self.pw_proj(x)
|
||||
x = self.dw_end(x)
|
||||
x = self.layer_scale(x)
|
||||
if self.has_skip:
|
||||
x = self.drop_path(x) + shortcut
|
||||
return x
|
||||
|
||||
|
||||
class MobileAttention(nn.Module):
|
||||
""" Mobile Attention Block
|
||||
|
||||
For MobileNetV4 - https://arxiv.org/abs/, referenced from
|
||||
https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L1504
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
stride: int = 1,
|
||||
dw_kernel_size: int = 3,
|
||||
dilation: int = 1,
|
||||
group_size: int = 1,
|
||||
pad_type: str = '',
|
||||
num_heads: int = 8,
|
||||
key_dim: int = 64,
|
||||
value_dim: int = 64,
|
||||
use_multi_query: bool = False,
|
||||
query_strides: int = (1, 1),
|
||||
kv_stride: int = 1,
|
||||
cpe_dw_kernel_size: int = 3,
|
||||
noskip: bool = False,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
drop_path_rate: float = 0.,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
layer_scale_init_value: Optional[float] = 1e-5,
|
||||
use_bias: bool = False,
|
||||
use_cpe: bool = False,
|
||||
):
|
||||
super(MobileAttention, self).__init__()
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.query_strides = to_2tuple(query_strides)
|
||||
self.kv_stride = kv_stride
|
||||
self.has_query_stride = any([s > 1 for s in self.query_strides])
|
||||
|
||||
# This CPE is different than the one suggested in the original paper.
|
||||
# https://arxiv.org/abs/2102.10882
|
||||
# 1. Rather than adding one CPE before the attention blocks, we add a CPE
|
||||
# into every attention block.
|
||||
# 2. We replace the expensive Conv2D by a Seperable DW Conv.
|
||||
if use_cpe:
|
||||
self.conv_cpe_dw = create_conv2d(
|
||||
in_chs, in_chs,
|
||||
kernel_size=cpe_dw_kernel_size,
|
||||
dilation=dilation,
|
||||
depthwise=True,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.conv_cpe_dw = None
|
||||
|
||||
self.norm = norm_act_layer(in_chs, apply_act=False)
|
||||
|
||||
if num_heads is None:
|
||||
assert in_chs % key_dim == 0
|
||||
num_heads = in_chs // key_dim
|
||||
|
||||
if use_multi_query:
|
||||
self.attn = MultiQueryAttention2d(
|
||||
in_chs,
|
||||
dim_out=out_chs,
|
||||
num_heads=num_heads,
|
||||
key_dim=key_dim,
|
||||
value_dim=value_dim,
|
||||
query_strides=query_strides,
|
||||
kv_stride=kv_stride,
|
||||
dilation=dilation,
|
||||
padding=pad_type,
|
||||
dw_kernel_size=dw_kernel_size,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
#bias=use_bias, # why not here if used w/ mhsa?
|
||||
)
|
||||
else:
|
||||
self.attn = Attention2d(
|
||||
in_chs,
|
||||
dim_out=out_chs,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
bias=use_bias,
|
||||
)
|
||||
|
||||
if layer_scale_init_value is not None:
|
||||
self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
|
||||
else:
|
||||
self.layer_scale = nn.Identity()
|
||||
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
|
||||
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion': # after SE, input to PW
|
||||
return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
|
||||
else: # location == 'bottleneck', block output
|
||||
return dict(module='', num_chs=self.conv_pw.out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
if self.conv_cpe_dw is not None:
|
||||
x_cpe = self.conv_cpe_dw(x)
|
||||
x = x + x_cpe
|
||||
|
||||
shortcut = x
|
||||
x = self.norm(x)
|
||||
x = self.attn(x)
|
||||
x = self.layer_scale(x)
|
||||
if self.has_skip:
|
||||
x = self.drop_path(x) + shortcut
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CondConvResidual(InvertedResidual):
|
||||
""" Inverted residual block w/ CondConv routing"""
|
||||
|
||||
def __init__(
|
||||
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='',
|
||||
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
dw_kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 1,
|
||||
pad_type: str = '',
|
||||
noskip: bool = False,
|
||||
exp_ratio: float = 1.0,
|
||||
exp_kernel_size: int = 1,
|
||||
pw_kernel_size: int = 1,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[ModuleType] = None,
|
||||
num_experts: int = 0,
|
||||
drop_path_rate: float = 0.,
|
||||
):
|
||||
|
||||
self.num_experts = num_experts
|
||||
conv_kwargs = dict(num_experts=self.num_experts)
|
||||
|
||||
super(CondConvResidual, self).__init__(
|
||||
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, group_size=group_size,
|
||||
pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs,
|
||||
drop_path_rate=drop_path_rate)
|
||||
|
||||
in_chs,
|
||||
out_chs,
|
||||
dw_kernel_size=dw_kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
group_size=group_size,
|
||||
pad_type=pad_type,
|
||||
noskip=noskip,
|
||||
exp_ratio=exp_ratio,
|
||||
exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
aa_layer=aa_layer,
|
||||
se_layer=se_layer,
|
||||
conv_kwargs=conv_kwargs,
|
||||
drop_path_rate=drop_path_rate,
|
||||
)
|
||||
self.routing_fn = nn.Linear(in_chs, self.num_experts)
|
||||
|
||||
def forward(self, x):
|
||||
@ -237,9 +638,24 @@ class EdgeResidual(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, group_size=0, pad_type='',
|
||||
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
exp_kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 0,
|
||||
pad_type: str = '',
|
||||
force_in_chs: int = 0,
|
||||
noskip: bool = False,
|
||||
exp_ratio: float = 1.0,
|
||||
pw_kernel_size: int = 1,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[ModuleType] = None,
|
||||
drop_path_rate: float = 0.,
|
||||
):
|
||||
super(EdgeResidual, self).__init__()
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
if force_in_chs > 0:
|
||||
@ -248,12 +664,17 @@ class EdgeResidual(nn.Module):
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
groups = num_groups(group_size, in_chs)
|
||||
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
|
||||
use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
|
||||
|
||||
# Expansion convolution
|
||||
self.conv_exp = create_conv2d(
|
||||
in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type)
|
||||
in_chs, mid_chs, exp_kernel_size,
|
||||
stride=1 if use_aa else stride,
|
||||
dilation=dilation, groups=groups, padding=pad_type)
|
||||
self.bn1 = norm_act_layer(mid_chs, inplace=True)
|
||||
|
||||
self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
|
||||
|
||||
@ -272,6 +693,7 @@ class EdgeResidual(nn.Module):
|
||||
shortcut = x
|
||||
x = self.conv_exp(x)
|
||||
x = self.bn1(x)
|
||||
x = self.aa(x)
|
||||
x = self.se(x)
|
||||
x = self.conv_pwl(x)
|
||||
x = self.bn2(x)
|
||||
|
@ -5,6 +5,7 @@ Handles stride, dilation calculations, and selects feature extraction points.
|
||||
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
from typing import Callable, Optional
|
||||
|
||||
import logging
|
||||
import math
|
||||
@ -16,7 +17,7 @@ from typing import Any, Dict, List
|
||||
import torch.nn as nn
|
||||
|
||||
from ._efficientnet_blocks import *
|
||||
from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
|
||||
from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible, LayerType
|
||||
|
||||
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
|
||||
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
|
||||
@ -139,8 +140,8 @@ def _decode_block_str(block_str):
|
||||
|
||||
# if act_layer is None, the model default (passed to model init) will be used
|
||||
act_layer = options['n'] if 'n' in options else None
|
||||
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
|
||||
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
|
||||
start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
|
||||
end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
|
||||
force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
|
||||
num_repeat = int(options['r'])
|
||||
|
||||
@ -154,29 +155,31 @@ def _decode_block_str(block_str):
|
||||
if block_type == 'ir':
|
||||
block_args.update(dict(
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
exp_kernel_size=start_kernel_size,
|
||||
pw_kernel_size=end_kernel_size,
|
||||
exp_ratio=float(options['e']),
|
||||
se_ratio=float(options['se']) if 'se' in options else 0.,
|
||||
se_ratio=float(options.get('se', 0.)),
|
||||
noskip=skip is False,
|
||||
s2d=int(options.get('d', 0)) > 0,
|
||||
))
|
||||
if 'cc' in options:
|
||||
block_args['num_experts'] = int(options['cc'])
|
||||
elif block_type == 'ds' or block_type == 'dsa':
|
||||
block_args.update(dict(
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
se_ratio=float(options['se']) if 'se' in options else 0.,
|
||||
pw_kernel_size=end_kernel_size,
|
||||
se_ratio=float(options.get('se', 0.)),
|
||||
pw_act=block_type == 'dsa',
|
||||
noskip=block_type == 'dsa' or skip is False,
|
||||
s2d=int(options.get('d', 0)) > 0,
|
||||
))
|
||||
elif block_type == 'er':
|
||||
block_args.update(dict(
|
||||
exp_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
pw_kernel_size=end_kernel_size,
|
||||
exp_ratio=float(options['e']),
|
||||
force_in_chs=force_in_chs,
|
||||
se_ratio=float(options['se']) if 'se' in options else 0.,
|
||||
se_ratio=float(options.get('se', 0.)),
|
||||
noskip=skip is False,
|
||||
))
|
||||
elif block_type == 'cn':
|
||||
@ -184,6 +187,38 @@ def _decode_block_str(block_str):
|
||||
kernel_size=int(options['k']),
|
||||
skip=skip is True,
|
||||
))
|
||||
elif block_type == 'uir':
|
||||
# override exp / proj kernels for start/end in uir block
|
||||
start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 0
|
||||
end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 0
|
||||
block_args.update(dict(
|
||||
dw_kernel_size_start=start_kernel_size, # overload exp ks arg for dw start
|
||||
dw_kernel_size_mid=_parse_ksize(options['k']),
|
||||
dw_kernel_size_end=end_kernel_size, # overload pw ks arg for dw end
|
||||
exp_ratio=float(options['e']),
|
||||
se_ratio=float(options.get('se', 0.)),
|
||||
noskip=skip is False,
|
||||
))
|
||||
elif block_type == 'mha':
|
||||
kv_dim = int(options['d'])
|
||||
block_args.update(dict(
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
num_heads=int(options['h']),
|
||||
key_dim=kv_dim,
|
||||
value_dim=kv_dim,
|
||||
kv_stride=int(options.get('v', 1)),
|
||||
noskip=skip is False,
|
||||
))
|
||||
elif block_type == 'mqa':
|
||||
kv_dim = int(options['d'])
|
||||
block_args.update(dict(
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
num_heads=int(options['h']),
|
||||
key_dim=kv_dim,
|
||||
value_dim=kv_dim,
|
||||
kv_stride=int(options.get('v', 1)),
|
||||
noskip=skip is False,
|
||||
))
|
||||
else:
|
||||
assert False, 'Unknown block type (%s)' % block_type
|
||||
if 'gs' in options:
|
||||
@ -285,14 +320,27 @@ class EfficientNetBuilder:
|
||||
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
|
||||
|
||||
"""
|
||||
def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=False,
|
||||
act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''):
|
||||
def __init__(
|
||||
self,
|
||||
output_stride: int = 32,
|
||||
pad_type: str = '',
|
||||
round_chs_fn: Callable = round_channels,
|
||||
se_from_exp: bool = False,
|
||||
act_layer: Optional[LayerType] = None,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[LayerType] = None,
|
||||
drop_path_rate: float = 0.,
|
||||
layer_scale_init_value: Optional[float] = None,
|
||||
feature_location: str = '',
|
||||
):
|
||||
self.output_stride = output_stride
|
||||
self.pad_type = pad_type
|
||||
self.round_chs_fn = round_chs_fn
|
||||
self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs
|
||||
self.act_layer = act_layer
|
||||
self.norm_layer = norm_layer
|
||||
self.aa_layer = aa_layer
|
||||
self.se_layer = get_attn(se_layer)
|
||||
try:
|
||||
self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg
|
||||
@ -300,6 +348,7 @@ class EfficientNetBuilder:
|
||||
except TypeError:
|
||||
self.se_has_ratio = False
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.layer_scale_init_value = layer_scale_init_value
|
||||
if feature_location == 'depthwise':
|
||||
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
|
||||
_logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
|
||||
@ -317,6 +366,10 @@ class EfficientNetBuilder:
|
||||
bt = ba.pop('block_type')
|
||||
ba['in_chs'] = self.in_chs
|
||||
ba['out_chs'] = self.round_chs_fn(ba['out_chs'])
|
||||
s2d = ba.get('s2d', 0)
|
||||
if s2d > 0:
|
||||
# adjust while space2depth active
|
||||
ba['out_chs'] *= 4
|
||||
if 'force_in_chs' in ba and ba['force_in_chs']:
|
||||
# NOTE this is a hack to work around mismatch in TF EdgeEffNet impl
|
||||
ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs'])
|
||||
@ -326,16 +379,22 @@ class EfficientNetBuilder:
|
||||
assert ba['act_layer'] is not None
|
||||
ba['norm_layer'] = self.norm_layer
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
if bt != 'cn':
|
||||
se_ratio = ba.pop('se_ratio')
|
||||
if se_ratio and self.se_layer is not None:
|
||||
if not self.se_from_exp:
|
||||
# adjust se_ratio by expansion ratio if calculating se channels from block input
|
||||
se_ratio /= ba.get('exp_ratio', 1.0)
|
||||
if self.se_has_ratio:
|
||||
ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
|
||||
else:
|
||||
ba['se_layer'] = self.se_layer
|
||||
|
||||
if self.aa_layer is not None:
|
||||
ba['aa_layer'] = self.aa_layer
|
||||
|
||||
se_ratio = ba.pop('se_ratio', None)
|
||||
if se_ratio and self.se_layer is not None:
|
||||
if not self.se_from_exp:
|
||||
# adjust se_ratio by expansion ratio if calculating se channels from block input
|
||||
se_ratio /= ba.get('exp_ratio', 1.0)
|
||||
if s2d == 1:
|
||||
# adjust for start of space2depth
|
||||
se_ratio /= 4
|
||||
if self.se_has_ratio:
|
||||
ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
|
||||
else:
|
||||
ba['se_layer'] = self.se_layer
|
||||
|
||||
if bt == 'ir':
|
||||
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
@ -349,8 +408,17 @@ class EfficientNetBuilder:
|
||||
elif bt == 'cn':
|
||||
_log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
block = ConvBnAct(**ba)
|
||||
elif bt == 'uir':
|
||||
_log_info_if(' UniversalInvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
block = UniversalInvertedResidual(**ba, layer_scale_init_value=self.layer_scale_init_value)
|
||||
elif bt == 'mqa':
|
||||
_log_info_if(' MobileMultiQueryAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
block = MobileAttention(**ba, use_multi_query=True, layer_scale_init_value=self.layer_scale_init_value)
|
||||
elif bt == 'mha':
|
||||
_log_info_if(' MobileMultiHeadAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
block = MobileAttention(**ba, layer_scale_init_value=self.layer_scale_init_value)
|
||||
else:
|
||||
assert False, 'Uknkown block type (%s) while building model.' % bt
|
||||
assert False, 'Unknown block type (%s) while building model.' % bt
|
||||
|
||||
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
|
||||
return block
|
||||
@ -377,6 +445,7 @@ class EfficientNetBuilder:
|
||||
self.features.append(feature_info)
|
||||
|
||||
# outer list of block_args defines the stacks
|
||||
space2depth = 0
|
||||
for stack_idx, stack_args in enumerate(model_block_args):
|
||||
last_stack = stack_idx + 1 == len(model_block_args)
|
||||
_log_info_if('Stack: {}'.format(stack_idx), self.verbose)
|
||||
@ -392,6 +461,20 @@ class EfficientNetBuilder:
|
||||
if block_idx >= 1: # only the first block in any stack can have a stride > 1
|
||||
block_args['stride'] = 1
|
||||
|
||||
if not space2depth and block_args.pop('s2d', False):
|
||||
assert block_args['stride'] == 1
|
||||
space2depth = 1
|
||||
|
||||
if space2depth > 0:
|
||||
# FIXME s2d is a WIP
|
||||
if space2depth == 2 and block_args['stride'] == 2:
|
||||
block_args['stride'] = 1
|
||||
# to end s2d region, need to correct expansion and se ratio relative to input
|
||||
block_args['exp_ratio'] /= 4
|
||||
space2depth = 0
|
||||
else:
|
||||
block_args['s2d'] = space2depth
|
||||
|
||||
extract_features = False
|
||||
if last_block:
|
||||
next_stack_idx = stack_idx + 1
|
||||
@ -416,6 +499,9 @@ class EfficientNetBuilder:
|
||||
block = self._make_block(block_args, total_block_idx, total_block_count)
|
||||
blocks.append(block)
|
||||
|
||||
if space2depth == 1:
|
||||
space2depth = 2
|
||||
|
||||
# stash feature module name and channel info for model feature extraction
|
||||
if extract_features:
|
||||
feature_info = dict(
|
||||
|
@ -591,7 +591,7 @@ default_cfgs = generate_default_cfgs({
|
||||
})
|
||||
|
||||
|
||||
def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
|
||||
def checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('module', state_dict)
|
||||
# beit v2 didn't strip module
|
||||
@ -637,7 +637,7 @@ def _create_beit(variant, pretrained=False, **kwargs):
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
model = build_model_with_cfg(
|
||||
Beit, variant, pretrained,
|
||||
pretrained_filter_fn=_beit_checkpoint_filter_fn,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -556,7 +556,7 @@ class EfficientFormer(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _checkpoint_filter_fn(state_dict, model):
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" Remap original checkpoints -> timm """
|
||||
if 'stem.0.weight' in state_dict:
|
||||
return state_dict # non-original checkpoint, no remapping needed
|
||||
@ -611,7 +611,7 @@ def _create_efficientformer(variant, pretrained=False, **kwargs):
|
||||
out_indices = kwargs.pop('out_indices', 4)
|
||||
model = build_model_with_cfg(
|
||||
EfficientFormer, variant, pretrained,
|
||||
pretrained_filter_fn=_checkpoint_filter_fn,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -36,7 +36,7 @@ the models and weights open source!
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -44,10 +44,10 @@ import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct
|
||||
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct, LayerType
|
||||
from ._builder import build_model_with_cfg, pretrained_cfg_for_features
|
||||
from ._efficientnet_blocks import SqueezeExcite
|
||||
from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
|
||||
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
|
||||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
@ -74,21 +74,22 @@ class EfficientNet(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_args,
|
||||
num_classes=1000,
|
||||
num_features=1280,
|
||||
in_chans=3,
|
||||
stem_size=32,
|
||||
fix_stem=False,
|
||||
output_stride=32,
|
||||
pad_type='',
|
||||
round_chs_fn=round_channels,
|
||||
act_layer=None,
|
||||
norm_layer=None,
|
||||
se_layer=None,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
global_pool='avg'
|
||||
block_args: BlockArgs,
|
||||
num_classes: int = 1000,
|
||||
num_features: int = 1280,
|
||||
in_chans: int = 3,
|
||||
stem_size: int = 32,
|
||||
fix_stem: bool = False,
|
||||
output_stride: int = 32,
|
||||
pad_type: str = '',
|
||||
act_layer: Optional[LayerType] = None,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[LayerType] = None,
|
||||
round_chs_fn: Callable = round_channels,
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
global_pool: str = 'avg'
|
||||
):
|
||||
super(EfficientNet, self).__init__()
|
||||
act_layer = act_layer or nn.ReLU
|
||||
@ -113,6 +114,7 @@ class EfficientNet(nn.Module):
|
||||
round_chs_fn=round_chs_fn,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
aa_layer=aa_layer,
|
||||
se_layer=se_layer,
|
||||
drop_path_rate=drop_path_rate,
|
||||
)
|
||||
@ -270,20 +272,21 @@ class EfficientNetFeatures(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_args,
|
||||
out_indices=(0, 1, 2, 3, 4),
|
||||
feature_location='bottleneck',
|
||||
in_chans=3,
|
||||
stem_size=32,
|
||||
fix_stem=False,
|
||||
output_stride=32,
|
||||
pad_type='',
|
||||
round_chs_fn=round_channels,
|
||||
act_layer=None,
|
||||
norm_layer=None,
|
||||
se_layer=None,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.
|
||||
block_args: BlockArgs,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
feature_location: str = 'bottleneck',
|
||||
in_chans: int = 3,
|
||||
stem_size: int = 32,
|
||||
fix_stem: bool = False,
|
||||
output_stride: int = 32,
|
||||
pad_type: str = '',
|
||||
act_layer: Optional[LayerType] = None,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[LayerType] = None,
|
||||
round_chs_fn: Callable = round_channels,
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
):
|
||||
super(EfficientNetFeatures, self).__init__()
|
||||
act_layer = act_layer or nn.ReLU
|
||||
@ -306,6 +309,7 @@ class EfficientNetFeatures(nn.Module):
|
||||
round_chs_fn=round_chs_fn,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
aa_layer=aa_layer,
|
||||
se_layer=se_layer,
|
||||
drop_path_rate=drop_path_rate,
|
||||
feature_location=feature_location,
|
||||
@ -879,6 +883,88 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
|
||||
return model
|
||||
|
||||
|
||||
def _gen_efficientnet_x(
|
||||
variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8,
|
||||
group_size=None, version=1, pretrained=False, **kwargs):
|
||||
"""Creates an EfficientNet model.
|
||||
|
||||
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
||||
Paper: https://arxiv.org/abs/1905.11946
|
||||
|
||||
EfficientNet params
|
||||
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
|
||||
'efficientnet-x-b0': (1.0, 1.0, 224, 0.2),
|
||||
'efficientnet-x-b1': (1.0, 1.1, 240, 0.2),
|
||||
'efficientnet-x-b2': (1.1, 1.2, 260, 0.3),
|
||||
'efficientnet-x-b3': (1.2, 1.4, 300, 0.3),
|
||||
'efficientnet-x-b4': (1.4, 1.8, 380, 0.4),
|
||||
'efficientnet-x-b5': (1.6, 2.2, 456, 0.4),
|
||||
'efficientnet-x-b6': (1.8, 2.6, 528, 0.5),
|
||||
'efficientnet-x-b7': (2.0, 3.1, 600, 0.5),
|
||||
'efficientnet-x-b8': (2.2, 3.6, 672, 0.5),
|
||||
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
|
||||
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer
|
||||
depth_multiplier: multiplier to number of repeats per stage
|
||||
|
||||
"""
|
||||
"""
|
||||
if version == 1:
|
||||
blocks_args = [
|
||||
'r1_k3_s11_e1_i32_o16_se0.25_d1_a0',
|
||||
'r2_k3_s22_e6_i16_o24_se0.25_f1_d2_a1',
|
||||
'r2_k5_s22_e6_i24_o40_se0.25_f1_a1',
|
||||
'r3_k3_s22_e6_i40_o80_se0.25_a0',
|
||||
'r3_k5_s11_e6_i80_o112_se0.25_a0',
|
||||
'r4_k5_s22_e6_i112_o192_se0.25_a0',
|
||||
'r1_k3_s11_e6_i192_o320_se0.25_a0',
|
||||
]
|
||||
elif version == 2:
|
||||
blocks_args = [
|
||||
'r1_k3_s11_e1_i32_o16_se0.25_d1_a0',
|
||||
'r2_k3_s22_e4_i16_o24_se0.25_f1_d2_a1',
|
||||
'r2_k5_s22_e4_i24_o40_se0.25_f1_a1',
|
||||
'r3_k3_s22_e4_i40_o80_se0.25_a0',
|
||||
'r3_k5_s11_e6_i80_o112_se0.25_a0',
|
||||
'r4_k5_s22_e6_i112_o192_se0.25_a0',
|
||||
'r1_k3_s11_e6_i192_o320_se0.25_a0',
|
||||
]
|
||||
"""
|
||||
if version == 1:
|
||||
arch_def = [
|
||||
['ds_r1_k3_s1_e1_c16_se0.25_d1'],
|
||||
['er_r2_k3_s2_e6_c24_se0.25_nre'],
|
||||
['er_r2_k5_s2_e6_c40_se0.25_nre'],
|
||||
['ir_r3_k3_s2_e6_c80_se0.25'],
|
||||
['ir_r3_k5_s1_e6_c112_se0.25'],
|
||||
['ir_r4_k5_s2_e6_c192_se0.25'],
|
||||
['ir_r1_k3_s1_e6_c320_se0.25'],
|
||||
]
|
||||
else:
|
||||
arch_def = [
|
||||
['ds_r1_k3_s1_e1_c16_se0.25_d1'],
|
||||
['er_r2_k3_s2_e4_c24_se0.25_nre'],
|
||||
['er_r2_k5_s2_e4_c40_se0.25_nre'],
|
||||
['ir_r3_k3_s2_e4_c80_se0.25'],
|
||||
['ir_r3_k5_s1_e6_c112_se0.25'],
|
||||
['ir_r4_k5_s2_e6_c192_se0.25'],
|
||||
['ir_r1_k3_s1_e6_c320_se0.25'],
|
||||
]
|
||||
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, divisor=channel_divisor)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
|
||||
num_features=round_chs_fn(1280),
|
||||
stem_size=32,
|
||||
round_chs_fn=round_chs_fn,
|
||||
act_layer=resolve_act_layer(kwargs, 'silu'),
|
||||
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
"""Creates a MixNet Small model.
|
||||
|
||||
@ -1072,6 +1158,7 @@ default_cfgs = generate_default_cfgs({
|
||||
input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0),
|
||||
'efficientnet_b3_g8_gn.untrained': _cfg(
|
||||
input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0),
|
||||
'efficientnet_blur_b0.untrained': _cfg(),
|
||||
|
||||
'efficientnet_es.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth',
|
||||
@ -1768,6 +1855,17 @@ def efficientnet_b3_g8_gn(pretrained=False, **kwargs) -> EfficientNet:
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_blur_b0(pretrained=False, **kwargs) -> EfficientNet:
|
||||
""" EfficientNet-B0 w/ BlurPool """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_blur_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained,
|
||||
aa_layer='blurpc', **kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_es(pretrained=False, **kwargs) -> EfficientNet:
|
||||
""" EfficientNet-Edge Small. """
|
||||
@ -2277,6 +2375,31 @@ def tf_efficientnetv2_b3(pretrained=False, **kwargs) -> EfficientNet:
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_x_b3(pretrained=False, **kwargs) -> EfficientNet:
|
||||
""" EfficientNet-B3 """
|
||||
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet_x(
|
||||
'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_x_b5(pretrained=False, **kwargs) -> EfficientNet:
|
||||
""" EfficientNet-B5 """
|
||||
model = _gen_efficientnet_x(
|
||||
'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_h_b5(pretrained=False, **kwargs) -> EfficientNet:
|
||||
""" EfficientNet-B5 """
|
||||
model = _gen_efficientnet_x(
|
||||
'efficientnet_b5', channel_multiplier=1.92, depth_multiplier=2.2, version=2, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixnet_s(pretrained=False, **kwargs) -> EfficientNet:
|
||||
"""Creates a MixNet Small model.
|
||||
|
@ -7,7 +7,7 @@
|
||||
#
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Tuple, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \
|
||||
ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
@ -40,19 +41,19 @@ class MobileOneBlock(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 0,
|
||||
inference_mode: bool = False,
|
||||
use_se: bool = False,
|
||||
use_act: bool = True,
|
||||
use_scale_branch: bool = True,
|
||||
num_conv_branches: int = 1,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 0,
|
||||
inference_mode: bool = False,
|
||||
use_se: bool = False,
|
||||
use_act: bool = True,
|
||||
use_scale_branch: bool = True,
|
||||
num_conv_branches: int = 1,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
) -> None:
|
||||
"""Construct a MobileOneBlock module.
|
||||
|
||||
@ -280,15 +281,16 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
kernel_size: int,
|
||||
stride: int,
|
||||
group_size: int,
|
||||
small_kernel: Optional[int] = None,
|
||||
inference_mode: bool = False,
|
||||
act_layer: Optional[nn.Module] = None,
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
kernel_size: int,
|
||||
stride: int,
|
||||
group_size: int,
|
||||
small_kernel: Optional[int] = None,
|
||||
use_se: bool = False,
|
||||
act_layer: Optional[nn.Module] = None,
|
||||
inference_mode: bool = False,
|
||||
) -> None:
|
||||
"""Construct a ReparamLargeKernelConv module.
|
||||
|
||||
@ -299,8 +301,8 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
stride: Stride size. Default: 1
|
||||
group_size: Group size. Default: 1
|
||||
small_kernel: Kernel size of small kernel conv branch.
|
||||
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||||
act_layer: Activation module. Default: ``nn.GELU``
|
||||
inference_mode: If True, instantiates model in inference mode. Default: ``False``
|
||||
"""
|
||||
super(ReparamLargeKernelConv, self).__init__()
|
||||
self.stride = stride
|
||||
@ -342,6 +344,7 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
groups=self.groups,
|
||||
apply_act=False,
|
||||
)
|
||||
self.se = SqueezeExcite(out_chs, rd_ratio=0.25) if use_se else nn.Identity()
|
||||
# FIXME output of this act was not used in original impl, likely due to bug
|
||||
self.act = act_layer() if act_layer is not None else nn.Identity()
|
||||
|
||||
@ -352,6 +355,7 @@ class ReparamLargeKernelConv(nn.Module):
|
||||
out = self.large_conv(x)
|
||||
if self.small_conv is not None:
|
||||
out = out + self.small_conv(x)
|
||||
out = self.se(out)
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
@ -472,12 +476,12 @@ class Attention(nn.Module):
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
head_dim: int = 32,
|
||||
qkv_bias: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
self,
|
||||
dim: int,
|
||||
head_dim: int = 32,
|
||||
qkv_bias: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
) -> None:
|
||||
"""Build MHSA module that can handle 3D or 4D input tensors.
|
||||
|
||||
@ -535,14 +539,15 @@ class PatchEmbed(nn.Module):
|
||||
"""Convolutional patch embedding layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
stride: int,
|
||||
in_chs: int,
|
||||
embed_dim: int,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
lkc_use_act: bool = False,
|
||||
inference_mode: bool = False,
|
||||
self,
|
||||
patch_size: int,
|
||||
stride: int,
|
||||
in_chs: int,
|
||||
embed_dim: int,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
lkc_use_act: bool = False,
|
||||
use_se: bool = False,
|
||||
inference_mode: bool = False,
|
||||
) -> None:
|
||||
"""Build patch embedding layer.
|
||||
|
||||
@ -562,14 +567,16 @@ class PatchEmbed(nn.Module):
|
||||
stride=stride,
|
||||
group_size=1,
|
||||
small_kernel=3,
|
||||
inference_mode=inference_mode,
|
||||
use_se=use_se,
|
||||
act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act
|
||||
inference_mode=inference_mode,
|
||||
),
|
||||
MobileOneBlock(
|
||||
in_chs=embed_dim,
|
||||
out_chs=embed_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
use_se=False,
|
||||
act_layer=act_layer,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
@ -598,11 +605,11 @@ class RepMixer(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
kernel_size=3,
|
||||
layer_scale_init_value=1e-5,
|
||||
inference_mode: bool = False,
|
||||
self,
|
||||
dim,
|
||||
kernel_size=3,
|
||||
layer_scale_init_value=1e-5,
|
||||
inference_mode: bool = False,
|
||||
):
|
||||
"""Build RepMixer Module.
|
||||
|
||||
@ -648,7 +655,7 @@ class RepMixer(nn.Module):
|
||||
if layer_scale_init_value is not None:
|
||||
self.layer_scale = LayerScale2d(dim, layer_scale_init_value)
|
||||
else:
|
||||
self.layer_scale = nn.Identity
|
||||
self.layer_scale = nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.reparam_conv is not None:
|
||||
@ -706,12 +713,12 @@ class ConvMlp(nn.Module):
|
||||
"""Convolutional FFN Module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
hidden_channels: Optional[int] = None,
|
||||
out_chs: Optional[int] = None,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
self,
|
||||
in_chs: int,
|
||||
hidden_channels: Optional[int] = None,
|
||||
out_chs: Optional[int] = None,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
"""Build convolutional FFN module.
|
||||
|
||||
@ -764,11 +771,11 @@ class RepConditionalPosEnc(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
|
||||
inference_mode=False,
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
|
||||
inference_mode=False,
|
||||
) -> None:
|
||||
"""Build reparameterizable conditional positional encoding
|
||||
|
||||
@ -878,15 +885,15 @@ class RepMixerBlock(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
kernel_size: int = 3,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
proj_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
inference_mode: bool = False,
|
||||
self,
|
||||
dim: int,
|
||||
kernel_size: int = 3,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
proj_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
inference_mode: bool = False,
|
||||
):
|
||||
"""Build RepMixer Block.
|
||||
|
||||
@ -936,14 +943,14 @@ class AttentionBlock(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
proj_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
self,
|
||||
dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
proj_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
):
|
||||
"""Build Attention Block.
|
||||
|
||||
@ -993,6 +1000,7 @@ class FastVitStage(nn.Module):
|
||||
depth: int,
|
||||
token_mixer_type: str,
|
||||
downsample: bool = True,
|
||||
se_downsample: bool = False,
|
||||
down_patch_size: int = 7,
|
||||
down_stride: int = 2,
|
||||
pos_emb_layer: Optional[nn.Module] = None,
|
||||
@ -1030,6 +1038,7 @@ class FastVitStage(nn.Module):
|
||||
stride=down_stride,
|
||||
in_chs=dim,
|
||||
embed_dim=dim_out,
|
||||
use_se=se_downsample,
|
||||
act_layer=act_layer,
|
||||
lkc_use_act=lkc_use_act,
|
||||
inference_mode=inference_mode,
|
||||
@ -1090,29 +1099,30 @@ class FastVit(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
layers: Tuple[int, ...] = (2, 2, 6, 2),
|
||||
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
|
||||
embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
|
||||
mlp_ratios: Tuple[float, ...] = (4,) * 4,
|
||||
downsamples: Tuple[bool, ...] = (False, True, True, True),
|
||||
repmixer_kernel_size: int = 3,
|
||||
num_classes: int = 1000,
|
||||
pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4,
|
||||
down_patch_size: int = 7,
|
||||
down_stride: int = 2,
|
||||
drop_rate: float = 0.0,
|
||||
proj_drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
fork_feat: bool = False,
|
||||
cls_ratio: float = 2.0,
|
||||
global_pool: str = 'avg',
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
lkc_use_act: bool = False,
|
||||
inference_mode: bool = False,
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
layers: Tuple[int, ...] = (2, 2, 6, 2),
|
||||
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
|
||||
embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
|
||||
mlp_ratios: Tuple[float, ...] = (4,) * 4,
|
||||
downsamples: Tuple[bool, ...] = (False, True, True, True),
|
||||
se_downsamples: Tuple[bool, ...] = (False, False, False, False),
|
||||
repmixer_kernel_size: int = 3,
|
||||
num_classes: int = 1000,
|
||||
pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4,
|
||||
down_patch_size: int = 7,
|
||||
down_stride: int = 2,
|
||||
drop_rate: float = 0.0,
|
||||
proj_drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
lkc_use_act: bool = False,
|
||||
fork_feat: bool = False,
|
||||
cls_ratio: float = 2.0,
|
||||
global_pool: str = 'avg',
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
inference_mode: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_classes = 0 if fork_feat else num_classes
|
||||
@ -1140,6 +1150,7 @@ class FastVit(nn.Module):
|
||||
dim_out=embed_dims[i],
|
||||
depth=layers[i],
|
||||
downsample=downsample,
|
||||
se_downsample=se_downsamples[i],
|
||||
down_patch_size=down_patch_size,
|
||||
down_stride=down_stride,
|
||||
pos_emb_layer=pos_embs[i],
|
||||
@ -1160,6 +1171,7 @@ class FastVit(nn.Module):
|
||||
scale *= 2
|
||||
self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
self.num_stages = len(self.stages)
|
||||
self.num_features = prev_dim
|
||||
|
||||
# For segmentation and detection, extract intermediate output
|
||||
@ -1236,6 +1248,66 @@ class FastVit(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
last_idx = self.num_stages - 1
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
feat_idx = 0
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
if feat_idx == last_idx:
|
||||
x = self.final_conv(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# input embedding
|
||||
x = self.stem(x)
|
||||
@ -1297,8 +1369,7 @@ default_cfgs = generate_default_cfgs({
|
||||
|
||||
"fastvit_ma36.apple_in1k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95
|
||||
),
|
||||
crop_pct=0.95),
|
||||
|
||||
"fastvit_t8.apple_dist_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
@ -1318,15 +1389,111 @@ default_cfgs = generate_default_cfgs({
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95
|
||||
),
|
||||
|
||||
"fastvit_mci0.apple_mclip": _cfg(
|
||||
#hf_hub_id='timm/',
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt',
|
||||
crop_pct=0.95,
|
||||
num_classes=512, # CLIP proj dim
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.)
|
||||
),
|
||||
"fastvit_mci1.apple_mclip": _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt',
|
||||
crop_pct=0.95,
|
||||
num_classes=512, # CLIP proj dim
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.)
|
||||
),
|
||||
"fastvit_mci2.apple_mclip": _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt',
|
||||
crop_pct=0.95,
|
||||
num_classes=512, # CLIP proj dim
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.)
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" Remap original checkpoints -> timm """
|
||||
if 'stem.0.conv_kxk.0.conv.weight' in state_dict:
|
||||
return state_dict # non-original checkpoint, no remapping needed
|
||||
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
|
||||
# remap MobileCLIP checkpoints
|
||||
prefix = 'image_encoder.model.'
|
||||
else:
|
||||
prefix = ''
|
||||
|
||||
import re
|
||||
import bisect
|
||||
|
||||
# find stage ends by locating downsample layers
|
||||
stage_ends = []
|
||||
for k, v in state_dict.items():
|
||||
match = re.match(r'^(.*?)network\.(\d+)\.proj.*', k)
|
||||
if match:
|
||||
stage_ends.append(int(match.group(2)))
|
||||
stage_ends = list(sorted(set(stage_ends)))
|
||||
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if prefix:
|
||||
if prefix not in k:
|
||||
continue
|
||||
k = k.replace(prefix, '')
|
||||
|
||||
# remap renamed layers
|
||||
k = k.replace('patch_embed', 'stem')
|
||||
k = k.replace('rbr_conv', 'conv_kxk')
|
||||
k = k.replace('rbr_scale', 'conv_scale')
|
||||
k = k.replace('rbr_skip', 'identity')
|
||||
k = k.replace('conv_exp', 'final_conv') # to match byobnet, regnet, nfnet
|
||||
k = k.replace('lkb_origin', 'large_conv')
|
||||
k = k.replace('convffn', 'mlp')
|
||||
k = k.replace('se.reduce', 'se.fc1')
|
||||
k = k.replace('se.expand', 'se.fc2')
|
||||
k = re.sub(r'layer_scale_([0-9])', r'layer_scale_\1.gamma', k)
|
||||
if k.endswith('layer_scale'):
|
||||
k = k.replace('layer_scale', 'layer_scale.gamma')
|
||||
k = k.replace('dist_head', 'head_dist')
|
||||
if k.startswith('head.'):
|
||||
if k == 'head.proj' and hasattr(model.head, 'fc') and isinstance(model.head.fc, nn.Linear):
|
||||
# if CLIP projection, map to head.fc w/ bias = zeros
|
||||
k = k.replace('head.proj', 'head.fc.weight')
|
||||
v = v.T
|
||||
out_dict['head.fc.bias'] = torch.zeros(v.shape[0])
|
||||
else:
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
|
||||
# remap flat sequential network to stages
|
||||
match = re.match(r'^network\.(\d+)', k)
|
||||
stage_idx, net_idx = None, None
|
||||
if match:
|
||||
net_idx = int(match.group(1))
|
||||
stage_idx = bisect.bisect_right(stage_ends, net_idx)
|
||||
if stage_idx is not None:
|
||||
net_prefix = f'network.{net_idx}'
|
||||
stage_prefix = f'stages.{stage_idx}'
|
||||
if net_prefix + '.proj' in k:
|
||||
k = k.replace(net_prefix + '.proj', stage_prefix + '.downsample.proj')
|
||||
elif net_prefix + '.pe' in k:
|
||||
k = k.replace(net_prefix + '.pe', stage_prefix + '.pos_emb.pos_enc')
|
||||
else:
|
||||
k = k.replace(net_prefix, stage_prefix + '.blocks')
|
||||
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_fastvit(variant, pretrained=False, **kwargs):
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
|
||||
model = build_model_with_cfg(
|
||||
FastVit,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs
|
||||
)
|
||||
@ -1419,3 +1586,48 @@ def fastvit_ma36(pretrained=False, **kwargs):
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention")
|
||||
)
|
||||
return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def fastvit_mci0(pretrained=False, **kwargs):
|
||||
"""Instantiate MCi0 model variant."""
|
||||
model_args = dict(
|
||||
layers=(2, 6, 10, 2),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(3, 3, 3, 3),
|
||||
se_downsamples=(False, False, True, True),
|
||||
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
|
||||
lkc_use_act=True,
|
||||
)
|
||||
return _create_fastvit('fastvit_mci0', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def fastvit_mci1(pretrained=False, **kwargs):
|
||||
"""Instantiate MCi1 model variant."""
|
||||
model_args = dict(
|
||||
layers=(4, 12, 20, 4),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(3, 3, 3, 3),
|
||||
se_downsamples=(False, False, True, True),
|
||||
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
|
||||
lkc_use_act=True,
|
||||
)
|
||||
return _create_fastvit('fastvit_mci1', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def fastvit_mci2(pretrained=False, **kwargs):
|
||||
"""Instantiate MCi2 model variant."""
|
||||
model_args = dict(
|
||||
layers=(4, 12, 24, 4),
|
||||
embed_dims=(80, 160, 320, 640),
|
||||
mlp_ratios=(3, 3, 3, 3),
|
||||
se_downsamples=(False, False, True, True),
|
||||
pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
|
||||
lkc_use_act=True,
|
||||
)
|
||||
return _create_fastvit('fastvit_mci2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
@ -40,6 +40,7 @@ class MobileNetV3(nn.Module):
|
||||
* HardCoRe-NAS - https://arxiv.org/abs/2102.11646 (defn in hardcorenas.py uses this class)
|
||||
* FBNet-V3 - https://arxiv.org/abs/2006.02049
|
||||
* LCNet - https://arxiv.org/abs/2109.15099
|
||||
* MobileNet-V4 - https://arxiv.org/abs/2404.10518
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -51,14 +52,17 @@ class MobileNetV3(nn.Module):
|
||||
fix_stem: bool = False,
|
||||
num_features: int = 1280,
|
||||
head_bias: bool = True,
|
||||
pad_type: PadType = '',
|
||||
head_norm: bool = False,
|
||||
pad_type: str = '',
|
||||
act_layer: Optional[LayerType] = None,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[LayerType] = None,
|
||||
se_from_exp: bool = True,
|
||||
round_chs_fn: Callable = round_channels,
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
layer_scale_init_value: Optional[float] = None,
|
||||
global_pool: str = 'avg',
|
||||
):
|
||||
"""
|
||||
@ -73,11 +77,13 @@ class MobileNetV3(nn.Module):
|
||||
pad_type: Type of padding to use for convolution layers.
|
||||
act_layer: Type of activation layer.
|
||||
norm_layer: Type of normalization layer.
|
||||
aa_layer: Type of anti-aliasing layer.
|
||||
se_layer: Type of Squeeze-and-Excite layer.
|
||||
se_from_exp: If True, calculate SE channel reduction from expanded mid channels.
|
||||
round_chs_fn: Callable to round number of filters based on depth multiplier.
|
||||
drop_rate: Dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
layer_scale_init_value: Enable layer scale on compatible blocks if not None.
|
||||
global_pool: Type of pooling to use for global pooling features of the FC head.
|
||||
"""
|
||||
super(MobileNetV3, self).__init__()
|
||||
@ -104,8 +110,10 @@ class MobileNetV3(nn.Module):
|
||||
se_from_exp=se_from_exp,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
aa_layer=aa_layer,
|
||||
se_layer=se_layer,
|
||||
drop_path_rate=drop_path_rate,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
)
|
||||
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||
self.feature_info = builder.features
|
||||
@ -115,8 +123,16 @@ class MobileNetV3(nn.Module):
|
||||
# Head + Pooling
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
num_pooled_chs = head_chs * self.global_pool.feat_mult()
|
||||
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
if head_norm:
|
||||
# mobilenet-v4 post-pooling PW conv is followed by a norm+act layer
|
||||
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type) # never bias
|
||||
self.norm_head = norm_act_layer(self.num_features)
|
||||
self.act2 = nn.Identity()
|
||||
else:
|
||||
# mobilenet-v3 and others only have an activation after final PW conv
|
||||
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
|
||||
self.norm_head = nn.Identity()
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
||||
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
@ -125,7 +141,7 @@ class MobileNetV3(nn.Module):
|
||||
def as_sequential(self):
|
||||
layers = [self.conv_stem, self.bn1]
|
||||
layers.extend(self.blocks)
|
||||
layers.extend([self.global_pool, self.conv_head, self.act2])
|
||||
layers.extend([self.global_pool, self.conv_head, self.norm_head, self.act2])
|
||||
layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
@ -224,8 +240,10 @@ class MobileNetV3(nn.Module):
|
||||
self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0
|
||||
if max_index < len(self.blocks):
|
||||
self.conv_head = nn.Identity()
|
||||
self.norm_head = nn.Identity()
|
||||
if prune_head:
|
||||
self.conv_head = nn.Identity()
|
||||
self.norm_head = nn.Identity()
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
@ -241,6 +259,7 @@ class MobileNetV3(nn.Module):
|
||||
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
||||
x = self.global_pool(x)
|
||||
x = self.conv_head(x)
|
||||
x = self.norm_head(x)
|
||||
x = self.act2(x)
|
||||
x = self.flatten(x)
|
||||
if pre_logits:
|
||||
@ -276,9 +295,11 @@ class MobileNetV3Features(nn.Module):
|
||||
se_from_exp: bool = True,
|
||||
act_layer: Optional[LayerType] = None,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[LayerType] = None,
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
layer_scale_init_value: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -297,6 +318,7 @@ class MobileNetV3Features(nn.Module):
|
||||
se_layer: Type of Squeeze-and-Excite layer.
|
||||
drop_rate: Dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
layer_scale_init_value: Enable layer scale on compatible blocks if not None.
|
||||
"""
|
||||
super(MobileNetV3Features, self).__init__()
|
||||
act_layer = act_layer or nn.ReLU
|
||||
@ -320,8 +342,10 @@ class MobileNetV3Features(nn.Module):
|
||||
se_from_exp=se_from_exp,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
aa_layer=aa_layer,
|
||||
se_layer=se_layer,
|
||||
drop_path_rate=drop_path_rate,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
feature_location=feature_location,
|
||||
)
|
||||
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||
@ -370,7 +394,7 @@ def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV
|
||||
if 'feature_cfg' in kwargs or 'feature_cls' in kwargs:
|
||||
features_mode = 'cfg'
|
||||
else:
|
||||
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
|
||||
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'head_norm', 'global_pool')
|
||||
model_cls = MobileNetV3Features
|
||||
features_mode = 'cls'
|
||||
|
||||
@ -622,6 +646,252 @@ def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool =
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
"""Creates a MobileNet-V4 model.
|
||||
|
||||
Ref impl: ?
|
||||
Paper: https://arxiv.org/abs/1905.02244
|
||||
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
num_features = 1280
|
||||
if 'hybrid' in variant:
|
||||
layer_scale_init_value = 1e-5
|
||||
if 'medium' in variant:
|
||||
stem_size = 32
|
||||
act_layer = resolve_act_layer(kwargs, 'relu')
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
[
|
||||
'er_r1_k3_s2_e4_c48' # FusedIB (EdgeResidual)
|
||||
],
|
||||
# stage 1, 56x56 in
|
||||
[
|
||||
'uir_r1_a3_k5_s2_e4_c80', # ExtraDW
|
||||
'uir_r1_a3_k3_s1_e2_c80', # ExtraDW
|
||||
],
|
||||
# stage 2, 28x28 in
|
||||
[
|
||||
'uir_r1_a3_k5_s2_e6_c160', # ExtraDW
|
||||
'uir_r1_a0_k0_s1_e2_c160', # FFN
|
||||
'uir_r1_a3_k3_s1_e4_c160', # ExtraDW
|
||||
'uir_r1_a3_k5_s1_e4_c160', # ExtraDW
|
||||
'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample
|
||||
'uir_r1_a3_k3_s1_e4_c160', # ExtraDW
|
||||
'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample
|
||||
'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt
|
||||
'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample
|
||||
'uir_r1_a3_k3_s1_e4_c160', # ExtraDW
|
||||
'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample
|
||||
'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt
|
||||
],
|
||||
# stage 3, 14x14in
|
||||
[
|
||||
'uir_r1_a5_k5_s2_e6_c256', # ExtraDW
|
||||
'uir_r1_a5_k5_s1_e4_c256', # ExtraDW
|
||||
'uir_r2_a3_k5_s1_e4_c256', # ExtraDW
|
||||
'uir_r1_a0_k0_s1_e2_c256', # FFN
|
||||
'uir_r1_a3_k5_s1_e2_c256', # ExtraDW
|
||||
'uir_r1_a0_k0_s1_e2_c256', # FFN
|
||||
'uir_r1_a0_k0_s1_e4_c256', # FFN
|
||||
'mqa_r1_k3_h4_s1_d64_c256', # MQA
|
||||
'uir_r1_a3_k0_s1_e4_c256', # ConvNeXt
|
||||
'mqa_r1_k3_h4_s1_d64_c256', # MQA
|
||||
'uir_r1_a5_k5_s1_e4_c256', # ExtraDW
|
||||
'mqa_r1_k3_h4_s1_d64_c256', # MQA
|
||||
'uir_r1_a5_k0_s1_e4_c256', # ConvNeXt
|
||||
'mqa_r1_k3_h4_s1_d64_c256', # MQA
|
||||
'uir_r1_a5_k0_s1_e4_c256', # ConvNeXt
|
||||
],
|
||||
# stage 4, 7x7 in
|
||||
[
|
||||
'cn_r1_k1_s1_c960' # Conv
|
||||
],
|
||||
]
|
||||
elif 'large' in variant:
|
||||
stem_size = 24
|
||||
act_layer = resolve_act_layer(kwargs, 'gelu')
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
[
|
||||
'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual)
|
||||
],
|
||||
# stage 1, 56x56 in
|
||||
[
|
||||
'uir_r1_a3_k5_s2_e4_c96', # ExtraDW
|
||||
'uir_r1_a3_k3_s1_e4_c96', # ExtraDW
|
||||
],
|
||||
# stage 2, 28x28 in
|
||||
[
|
||||
'uir_r1_a3_k5_s2_e4_c192', # ExtraDW
|
||||
'uir_r3_a3_k3_s1_e4_c192', # ExtraDW
|
||||
'uir_r1_a3_k5_s1_e4_c192', # ExtraDW
|
||||
'uir_r2_a5_k3_s1_e4_c192', # ExtraDW
|
||||
'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample
|
||||
'uir_r1_a5_k3_s1_e4_c192', # ExtraDW
|
||||
'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample
|
||||
'uir_r1_a5_k3_s1_e4_c192', # ExtraDW
|
||||
'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample
|
||||
'uir_r1_a5_k3_s1_e4_c192', # ExtraDW
|
||||
'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample
|
||||
'uir_r1_a3_k0_s1_e4_c192', # ConvNeXt
|
||||
],
|
||||
# stage 3, 14x14in
|
||||
[
|
||||
'uir_r4_a5_k5_s2_e4_c512', # ExtraDW
|
||||
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
|
||||
'uir_r1_a5_k3_s1_e4_c512', # ExtraDW
|
||||
'uir_r2_a5_k0_s1_e4_c512', # ConvNeXt
|
||||
'uir_r1_a5_k3_s1_e4_c512', # ExtraDW
|
||||
'uir_r1_a5_k5_s1_e4_c512', # ExtraDW
|
||||
'mqa_r1_k3_h8_s1_d64_c512', # MQA
|
||||
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
|
||||
'mqa_r1_k3_h8_s1_d64_c512', # MQA
|
||||
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
|
||||
'mqa_r1_k3_h8_s1_d64_c512', # MQA
|
||||
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
|
||||
'mqa_r1_k3_h8_s1_d64_c512', # MQA
|
||||
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
|
||||
],
|
||||
# stage 4, 7x7 in
|
||||
[
|
||||
'cn_r1_k1_s1_c960', # Conv
|
||||
],
|
||||
]
|
||||
else:
|
||||
assert False, f'Unknown variant {variant}.'
|
||||
else:
|
||||
layer_scale_init_value = None
|
||||
if 'small' in variant:
|
||||
stem_size = 32
|
||||
act_layer = resolve_act_layer(kwargs, 'relu')
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
[
|
||||
'cn_r1_k3_s2_e1_c32', # Conv
|
||||
'cn_r1_k1_s1_e1_c32', # Conv
|
||||
],
|
||||
# stage 1, 56x56 in
|
||||
[
|
||||
'cn_r1_k3_s2_e1_c96', # Conv
|
||||
'cn_r1_k1_s1_e1_c64', # Conv
|
||||
],
|
||||
# stage 2, 28x28 in
|
||||
[
|
||||
'uir_r1_a5_k5_s2_e3_c96', # ExtraDW
|
||||
'uir_r4_a0_k3_s1_e2_c96', # IR
|
||||
'uir_r1_a3_k0_s1_e4_c96', # ConvNeXt
|
||||
],
|
||||
# stage 3, 14x14 in
|
||||
[
|
||||
'uir_r1_a3_k3_s2_e6_c128', # ExtraDW
|
||||
'uir_r1_a5_k5_s1_e4_c128', # ExtraDW
|
||||
'uir_r1_a0_k5_s1_e4_c128', # IR
|
||||
'uir_r1_a0_k5_s1_e3_c128', # IR
|
||||
'uir_r2_a0_k3_s1_e4_c128', # IR
|
||||
],
|
||||
# stage 4, 7x7 in
|
||||
[
|
||||
'cn_r1_k1_s1_c960', # Conv
|
||||
],
|
||||
]
|
||||
elif 'medium' in variant:
|
||||
stem_size = 32
|
||||
act_layer = resolve_act_layer(kwargs, 'relu')
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
[
|
||||
'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual)
|
||||
],
|
||||
# stage 1, 56x56 in
|
||||
[
|
||||
'uir_r1_a3_k5_s2_e4_c80', # ExtraDW
|
||||
'uir_r1_a3_k3_s1_e2_c80', # ExtraDW
|
||||
],
|
||||
# stage 2, 28x28 in
|
||||
[
|
||||
'uir_r1_a3_k5_s2_e6_c160', # ExtraDW
|
||||
'uir_r2_a3_k3_s1_e4_c160', # ExtraDW
|
||||
'uir_r1_a3_k5_s1_e4_c160', # ExtraDW
|
||||
'uir_r1_a3_k3_s1_e4_c160', # ExtraDW
|
||||
'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt
|
||||
'uir_r1_a0_k0_s1_e2_c160', # ExtraDW
|
||||
'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt
|
||||
],
|
||||
# stage 3, 14x14in
|
||||
[
|
||||
'uir_r1_a5_k5_s2_e6_c256', # ExtraDW
|
||||
'uir_r1_a5_k5_s1_e4_c256', # ExtraDW
|
||||
'uir_r2_a3_k5_s1_e4_c256', # ExtraDW
|
||||
'uir_r1_a0_k0_s1_e4_c256', # FFN
|
||||
'uir_r1_a3_k0_s1_e4_c256', # ConvNeXt
|
||||
'uir_r1_a3_k5_s1_e2_c256', # ExtraDW
|
||||
'uir_r1_a5_k5_s1_e4_c256', # ExtraDW
|
||||
'uir_r2_a0_k0_s1_e4_c256', # FFN
|
||||
'uir_r1_a5_k0_s1_e2_c256', # ConvNeXt
|
||||
],
|
||||
# stage 4, 7x7 in
|
||||
[
|
||||
'cn_r1_k1_s1_c960', # Conv
|
||||
],
|
||||
]
|
||||
elif 'large' in variant:
|
||||
stem_size = 24
|
||||
act_layer = resolve_act_layer(kwargs, 'relu')
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
[
|
||||
'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual)
|
||||
],
|
||||
# stage 1, 56x56 in
|
||||
[
|
||||
'uir_r1_a3_k5_s2_e4_c96', # ExtraDW
|
||||
'uir_r1_a3_k3_s1_e4_c96', # ExtraDW
|
||||
],
|
||||
# stage 2, 28x28 in
|
||||
[
|
||||
'uir_r1_a3_k5_s2_e4_c192', # ExtraDW
|
||||
'uir_r3_a3_k3_s1_e4_c192', # ExtraDW
|
||||
'uir_r1_a3_k5_s1_e4_c192', # ExtraDW
|
||||
'uir_r5_a5_k3_s1_e4_c192', # ExtraDW
|
||||
'uir_r1_a3_k0_s1_e4_c192', # ConvNeXt
|
||||
],
|
||||
# stage 3, 14x14in
|
||||
[
|
||||
'uir_r4_a5_k5_s2_e4_c512', # ExtraDW
|
||||
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
|
||||
'uir_r1_a5_k3_s1_e4_c512', # ExtraDW
|
||||
'uir_r2_a5_k0_s1_e4_c512', # ConvNeXt
|
||||
'uir_r1_a5_k3_s1_e4_c512', # ExtraDW
|
||||
'uir_r1_a5_k5_s1_e4_c512', # ExtraDW
|
||||
'uir_r3_a5_k0_s1_e4_c512', # ConvNeXt
|
||||
|
||||
],
|
||||
# stage 4, 7x7 in
|
||||
[
|
||||
'cn_r1_k1_s1_c960', # Conv
|
||||
],
|
||||
]
|
||||
else:
|
||||
assert False, f'Unknown variant {variant}.'
|
||||
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
head_bias=False,
|
||||
head_norm=True,
|
||||
num_features=num_features,
|
||||
stem_size=stem_size,
|
||||
fix_stem=channel_multiplier < 1.0,
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=act_layer,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_mnv3(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url: str = '', **kwargs):
|
||||
return {
|
||||
@ -725,6 +995,52 @@ default_cfgs = generate_default_cfgs({
|
||||
interpolation='bicubic',
|
||||
),
|
||||
"lcnet_150.untrained": _cfg(),
|
||||
|
||||
'mobilenetv4_conv_small': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'mobilenetv4_conv_medium.r224': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_conv_medium.r256': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_conv_large.r256': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_conv_large.r384': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=0.95, interpolation='bicubic'),
|
||||
|
||||
'mobilenetv4_hybrid_small': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_medium.r224': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_medium.r256': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_large.r256': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_large.r384': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=0.95, interpolation='bicubic'),
|
||||
|
||||
# experimental
|
||||
'mobilenetv4_conv_aa_medium.r256': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_conv_blur_medium.r256': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_medium_075': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_large_075.r256': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
|
||||
})
|
||||
|
||||
|
||||
@ -881,6 +1197,69 @@ def lcnet_150(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_conv_small(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_conv_small', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_conv_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_conv_medium', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_conv_large(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_conv_large', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_hybrid_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Hybrid """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_hybrid_large(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Hybrid"""
|
||||
model = _gen_mobilenet_v4('mobilenetv4_hybrid_large', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_conv_aa_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 w/ AvgPool AA """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_conv_aa_medium', 1.0, pretrained=pretrained, aa_layer='avg', **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_conv_blur_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Conv w/ Blur AA """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_conv_blur_medium', 1.0, pretrained=pretrained, aa_layer='blurpc', **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_hybrid_medium_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Hybrid """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_hybrid_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Hybrid"""
|
||||
model = _gen_mobilenet_v4('mobilenetv4_hybrid_large', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'mobilenetv3_large_100_miil': 'mobilenetv3_large_100.miil_in21k_ft_in1k',
|
||||
'mobilenetv3_large_100_miil_in21k': 'mobilenetv3_large_100.miil_in21k',
|
||||
|
@ -403,7 +403,7 @@ class PyramidVisionTransformerV2(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _checkpoint_filter_fn(state_dict, model):
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" Remap original checkpoints -> timm """
|
||||
if 'patch_embed.proj.weight' in state_dict:
|
||||
return state_dict # non-original checkpoint, no remapping needed
|
||||
@ -430,7 +430,7 @@ def _create_pvt2(variant, pretrained=False, **kwargs):
|
||||
PyramidVisionTransformerV2,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=_checkpoint_filter_fn,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -17,7 +17,7 @@ import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \
|
||||
get_attn, get_act_layer, get_norm_layer, create_classifier
|
||||
get_attn, get_act_layer, get_norm_layer, create_classifier, create_aa
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
@ -31,15 +31,6 @@ def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int:
|
||||
return padding
|
||||
|
||||
|
||||
def create_aa(aa_layer: Type[nn.Module], channels: int, stride: int = 2, enable: bool = True) -> nn.Module:
|
||||
if not aa_layer or not enable:
|
||||
return nn.Identity()
|
||||
if issubclass(aa_layer, nn.AvgPool2d):
|
||||
return aa_layer(stride)
|
||||
else:
|
||||
return aa_layer(channels=channels, stride=stride)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
|
@ -409,6 +409,7 @@ class VisionTransformer(nn.Module):
|
||||
qk_norm: bool = False,
|
||||
init_values: Optional[float] = None,
|
||||
class_token: bool = True,
|
||||
pos_embed: str = 'learn',
|
||||
no_embed_class: bool = False,
|
||||
reg_tokens: int = 0,
|
||||
pre_norm: bool = False,
|
||||
@ -460,6 +461,7 @@ class VisionTransformer(nn.Module):
|
||||
super().__init__()
|
||||
assert global_pool in ('', 'avg', 'token', 'map')
|
||||
assert class_token or global_pool != 'token'
|
||||
assert pos_embed in ('', 'none', 'learn')
|
||||
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
||||
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
||||
act_layer = get_act_layer(act_layer) or nn.GELU
|
||||
@ -494,7 +496,10 @@ class VisionTransformer(nn.Module):
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
||||
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||
if not pos_embed or pos_embed == 'none':
|
||||
self.pos_embed = None
|
||||
else:
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||
if patch_drop_rate > 0:
|
||||
self.patch_drop = PatchDropout(
|
||||
@ -556,7 +561,8 @@ class VisionTransformer(nn.Module):
|
||||
def init_weights(self, mode: str = '') -> None:
|
||||
assert mode in ('jax', 'jax_nlhb', 'moco', '')
|
||||
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
if self.cls_token is not None:
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
named_apply(get_init_weights_vit(mode, head_bias), self)
|
||||
@ -583,6 +589,8 @@ class VisionTransformer(nn.Module):
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
||||
self.grad_checkpointing = enable
|
||||
if hasattr(self.patch_embed, 'set_grad_checkpointing'):
|
||||
self.patch_embed.set_grad_checkpointing(enable)
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self) -> nn.Module:
|
||||
@ -600,6 +608,9 @@ class VisionTransformer(nn.Module):
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.pos_embed is None:
|
||||
return x.view(x.shape[0], -1, x.shape[-1])
|
||||
|
||||
if self.dynamic_img_size:
|
||||
B, H, W, C = x.shape
|
||||
pos_embed = resample_abs_pos_embed(
|
||||
@ -1066,10 +1077,13 @@ def checkpoint_filter_fn(
|
||||
# IJEPA, vit in an 'encoder' submodule
|
||||
state_dict = state_dict['encoder']
|
||||
prefix = 'module.'
|
||||
elif 'visual.trunk.pos_embed' in state_dict:
|
||||
elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict:
|
||||
# OpenCLIP model with timm vision encoder
|
||||
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
||||
prefix = 'visual.trunk.'
|
||||
if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear):
|
||||
# remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
||||
out_dict['head.weight'] = state_dict['visual.head.proj.weight']
|
||||
out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
|
||||
|
||||
if prefix:
|
||||
# filter on & remove prefix string from keys
|
||||
|
@ -15,18 +15,20 @@ Hacked together by / Copyright 2020, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
|
||||
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
from .resnet import resnet26d, resnet50d
|
||||
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
||||
from .vision_transformer import _create_vision_transformer, VisionTransformer
|
||||
from .vision_transformer import VisionTransformer
|
||||
|
||||
|
||||
class HybridEmbed(nn.Module):
|
||||
@ -38,14 +40,15 @@ class HybridEmbed(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone,
|
||||
img_size=224,
|
||||
patch_size=1,
|
||||
feature_size=None,
|
||||
feature_ratio=None,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
bias=True,
|
||||
backbone: nn.Module,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 1,
|
||||
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
bias: bool = True,
|
||||
proj: bool = True,
|
||||
flatten: bool = True,
|
||||
output_fmt: Optional[str] = None,
|
||||
strict_img_size: bool = True,
|
||||
@ -95,7 +98,18 @@ class HybridEmbed(nn.Module):
|
||||
self.strict_img_size = strict_img_size
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
if proj:
|
||||
self.proj = nn.Conv2d(
|
||||
feature_dim,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
assert feature_dim == embed_dim,\
|
||||
f'The feature dim ({feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
|
||||
self.proj = nn.Identity()
|
||||
|
||||
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
||||
total_reduction = (
|
||||
@ -116,6 +130,13 @@ class HybridEmbed(nn.Module):
|
||||
else:
|
||||
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
if hasattr(self.backbone, 'set_grad_checkpointing'):
|
||||
self.backbone.set_grad_checkpointing(enable=enable)
|
||||
elif hasattr(self.backbone, 'grad_checkpointing'):
|
||||
self.backbone.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
@ -139,24 +160,35 @@ class HybridEmbedWithSize(nn.Module):
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
backbone,
|
||||
img_size=224,
|
||||
patch_size=1,
|
||||
feature_size=None,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
backbone: nn.Module,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 1,
|
||||
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
bias=True,
|
||||
proj=True,
|
||||
):
|
||||
super().__init__(
|
||||
backbone=backbone,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
feature_size=feature_size,
|
||||
feature_ratio=feature_ratio,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
bias=bias,
|
||||
proj=proj,
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
if hasattr(self.backbone, 'set_grad_checkpointing'):
|
||||
self.backbone.set_grad_checkpointing(enable=enable)
|
||||
elif hasattr(self.backbone, 'grad_checkpointing'):
|
||||
self.backbone.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
@ -165,10 +197,43 @@ class HybridEmbedWithSize(nn.Module):
|
||||
return x.flatten(2).transpose(1, 2), x.shape[-2:]
|
||||
|
||||
|
||||
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
|
||||
embed_layer = partial(HybridEmbed, backbone=backbone)
|
||||
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
||||
return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
|
||||
class ConvStem(nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
depth: int = 3,
|
||||
channels: Union[int, Tuple[int, ...]] = 64,
|
||||
kernel_size: Union[int, Tuple[int, ...]] = 3,
|
||||
stride: Union[int, Tuple[int, ...]] = (2, 2, 2),
|
||||
padding: Union[str, int, Tuple[int, ...]] = "",
|
||||
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||
act_layer: Type[nn.Module] = nn.ReLU,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(channels, int):
|
||||
# a default tiered channel strategy
|
||||
channels = tuple([channels // 2**i for i in range(depth)][::-1])
|
||||
|
||||
kernel_size = to_ntuple(depth)(kernel_size)
|
||||
padding = to_ntuple(depth)(padding)
|
||||
assert depth == len(stride) == len(kernel_size) == len(channels)
|
||||
|
||||
in_chs = in_chans
|
||||
for i in range(len(channels)):
|
||||
last_conv = i == len(channels) - 1
|
||||
self.add_module(f'{i}', ConvNormAct(
|
||||
in_chs,
|
||||
channels[i],
|
||||
kernel_size=kernel_size[i],
|
||||
stride=stride[i],
|
||||
padding=padding[i],
|
||||
bias=last_conv,
|
||||
apply_norm=not last_conv,
|
||||
apply_act=not last_conv,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
))
|
||||
in_chs = channels[i]
|
||||
|
||||
|
||||
def _resnetv2(layers=(3, 4, 9), **kwargs):
|
||||
@ -186,6 +251,66 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
|
||||
return backbone
|
||||
|
||||
|
||||
def _convert_mobileclip(state_dict, model, prefix='image_encoder.model.'):
|
||||
out = {}
|
||||
for k, v in state_dict.items():
|
||||
if not k.startswith(prefix):
|
||||
continue
|
||||
k = k.replace(prefix, '')
|
||||
k = k.replace('patch_emb.', 'patch_embed.backbone.')
|
||||
k = k.replace('block.conv', 'conv')
|
||||
k = k.replace('block.norm', 'bn')
|
||||
k = k.replace('post_transformer_norm.', 'norm.')
|
||||
k = k.replace('pre_norm_mha.0', 'norm1')
|
||||
k = k.replace('pre_norm_mha.1', 'attn')
|
||||
k = k.replace('pre_norm_ffn.0', 'norm2')
|
||||
k = k.replace('pre_norm_ffn.1', 'mlp.fc1')
|
||||
k = k.replace('pre_norm_ffn.4', 'mlp.fc2')
|
||||
k = k.replace('qkv_proj.', 'qkv.')
|
||||
k = k.replace('out_proj.', 'proj.')
|
||||
k = k.replace('transformer.', 'blocks.')
|
||||
if k == 'pos_embed.pos_embed.pos_embed':
|
||||
k = 'pos_embed'
|
||||
v = v.squeeze(0)
|
||||
if 'classifier.proj' in k:
|
||||
bias_k = k.replace('classifier.proj', 'head.bias')
|
||||
k = k.replace('classifier.proj', 'head.weight')
|
||||
v = v.T
|
||||
out[bias_k] = torch.zeros(v.shape[0])
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
def checkpoint_filter_fn(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
model: VisionTransformer,
|
||||
interpolation: str = 'bicubic',
|
||||
antialias: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
from .vision_transformer import checkpoint_filter_fn as _filter_fn
|
||||
|
||||
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
|
||||
state_dict = _convert_mobileclip(state_dict, model)
|
||||
|
||||
return _filter_fn(state_dict, model, interpolation=interpolation, antialias=antialias)
|
||||
|
||||
|
||||
def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs):
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
embed_args = embed_args or {}
|
||||
embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args)
|
||||
kwargs.setdefault('embed_layer', embed_layer)
|
||||
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
||||
return build_model_with_cfg(
|
||||
VisionTransformer,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
@ -260,6 +385,17 @@ default_cfgs = generate_default_cfgs({
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||
'vit_base_resnet50d_224.untrained': _cfg(
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||
|
||||
'vit_base_mci_224.apple_mclip': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt',
|
||||
num_classes=512,
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv',
|
||||
),
|
||||
'vit_base_mci_224.apple_mclip_lt': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt',
|
||||
num_classes=512,
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv',
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
@ -407,6 +543,26 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
|
||||
"""
|
||||
backbone = ConvStem(
|
||||
channels=(768//4, 768//4, 768),
|
||||
stride=(4, 2, 2),
|
||||
kernel_size=(4, 2, 2),
|
||||
padding=0,
|
||||
in_chans=kwargs.get('in_chans', 3),
|
||||
act_layer=nn.GELU,
|
||||
)
|
||||
model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_base_mci_224', backbone=backbone, embed_args=dict(proj=False),
|
||||
pretrained=pretrained, **dict(model_args, **kwargs)
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'vit_tiny_r_s16_p8_224_in21k': 'vit_tiny_r_s16_p8_224.augreg_in21k',
|
||||
'vit_small_r26_s32_224_in21k': 'vit_small_r26_s32_224.augreg_in21k',
|
||||
|
603
timm/models/vitamin.py
Normal file
603
timm/models/vitamin.py
Normal file
@ -0,0 +1,603 @@
|
||||
""" ViTamin
|
||||
|
||||
Paper: Designing Scalable Vison Models in the Vision-Language Era
|
||||
A family of model weights on Huggingface: https://huggingface.co/collections/jienengchen/vitamin-family-661048126b72debdaca060bf
|
||||
|
||||
@inproceedings{chen2024vitamin,
|
||||
title={ViTamin: Designing Scalable Vision Models in the Vision-language Era},
|
||||
author={Chen, Jieneng and Yu, Qihang and Shen, Xiaohui and Yuille, Alan and Chen, Liang-Chieh},
|
||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
||||
year={2024}
|
||||
}
|
||||
|
||||
Based on Apache 2.0 licensed code at https://github.com/ViTamin/ViTamin
|
||||
|
||||
Modifications and timm support by Jieneng Chen 2024
|
||||
|
||||
Reference:
|
||||
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
|
||||
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Optional, Union, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import create_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, \
|
||||
make_divisible, DropPath
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .vision_transformer import VisionTransformer, checkpoint_filter_fn
|
||||
from .vision_transformer_hybrid import HybridEmbed
|
||||
|
||||
|
||||
@dataclass
|
||||
class VitConvCfg:
|
||||
expand_ratio: float = 4.0
|
||||
expand_output: bool = True # calculate expansion channels from output (vs input chs)
|
||||
kernel_size: int = 3
|
||||
group_size: int = 1 # 1 == depthwise
|
||||
pre_norm_act: bool = False # activation after pre-norm
|
||||
stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
|
||||
pool_type: str = 'avg2'
|
||||
downsample_pool_type: str = 'avg2'
|
||||
act_layer: str = 'gelu' # stem & stage 1234
|
||||
norm_layer: str = ''
|
||||
norm_eps: float = 1e-5
|
||||
down_shortcut: Optional[bool] = True
|
||||
mlp: str = 'mlp'
|
||||
|
||||
|
||||
@dataclass
|
||||
class VitCfg:
|
||||
embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768)
|
||||
depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2)
|
||||
stem_width: int = 64
|
||||
conv_cfg: VitConvCfg = field(default_factory=VitConvCfg)
|
||||
head_type: str = ""
|
||||
|
||||
|
||||
def _init_conv(module, name, scheme=''):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
|
||||
fan_out //= module.groups
|
||||
nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
|
||||
class Stem(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
act_layer: str = 'gelu',
|
||||
norm_layer: str = 'layernorm2d',
|
||||
norm_eps: float = 1e-6,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
|
||||
self.out_chs = out_chs
|
||||
|
||||
self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias)
|
||||
self.norm1 = norm_act_layer(out_chs)
|
||||
self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias)
|
||||
|
||||
named_apply(_init_conv, self)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: int,
|
||||
pool_type: str = 'avg2',
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
|
||||
|
||||
if dim != dim_out:
|
||||
self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) # 1x1 conv
|
||||
else:
|
||||
self.expand = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pool(x) # spatial downsample
|
||||
x = self.expand(x) # expand chs
|
||||
return x
|
||||
|
||||
|
||||
class StridedConv(nn.Module):
|
||||
""" downsample 2d as well
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
in_chans=3,
|
||||
embed_dim=768
|
||||
):
|
||||
super().__init__()
|
||||
norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.norm = norm_layer(in_chans) # affine over C
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class MbConvLNBlock(nn.Module):
|
||||
""" Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
stride: int = 1,
|
||||
drop_path: float = 0.,
|
||||
kernel_size: int = 3,
|
||||
norm_layer: str = 'layernorm2d',
|
||||
norm_eps: float = 1e-6,
|
||||
act_layer: str = 'gelu',
|
||||
expand_ratio: float = 4.0,
|
||||
):
|
||||
super(MbConvLNBlock, self).__init__()
|
||||
self.stride, self.in_chs, self.out_chs = stride, in_chs, out_chs
|
||||
mid_chs = make_divisible(out_chs * expand_ratio)
|
||||
prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
|
||||
|
||||
if stride == 2:
|
||||
self.shortcut = Downsample2d(in_chs, out_chs, pool_type='avg', bias=True)
|
||||
elif in_chs != out_chs:
|
||||
self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True)
|
||||
else:
|
||||
self.shortcut = nn.Identity()
|
||||
|
||||
self.pre_norm = prenorm_act_layer(in_chs, apply_act=False)
|
||||
self.down = nn.Identity()
|
||||
self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True)
|
||||
self.act1 = create_act_layer(act_layer, inplace=True)
|
||||
self.conv2_kxk = create_conv2d(
|
||||
mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True)
|
||||
self.act2 = create_act_layer(act_layer, inplace=True)
|
||||
self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
|
||||
def init_weights(self, scheme=''):
|
||||
named_apply(partial(_init_conv, scheme=scheme), self)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut(x)
|
||||
|
||||
x = self.pre_norm(x)
|
||||
x = self.down(x) # nn.Identity()
|
||||
|
||||
# 1x1 expansion conv & act
|
||||
x = self.conv1_1x1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
# (strided) depthwise 3x3 conv & act
|
||||
x = self.conv2_kxk(x)
|
||||
x = self.act2(x)
|
||||
|
||||
# 1x1 linear projection to output width
|
||||
x = self.conv3_1x1(x)
|
||||
x = self.drop_path(x) + shortcut
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MbConvStages(nn.Module):
|
||||
""" MobileConv for stage 1 and stage 2 of ViTamin
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
cfg: VitCfg,
|
||||
img_size: Union[int, Tuple[int, int]] = 224, # place holder
|
||||
in_chans: int = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.stem = Stem(
|
||||
in_chs=in_chans,
|
||||
out_chs=cfg.stem_width,
|
||||
)
|
||||
|
||||
stages = []
|
||||
self.num_stages = len(cfg.embed_dim)
|
||||
for s, dim in enumerate(cfg.embed_dim[:2]): # stage
|
||||
stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width
|
||||
blocks = [
|
||||
MbConvLNBlock(
|
||||
in_chs = stage_in_chs if d==0 else dim,
|
||||
out_chs = dim,
|
||||
stride = 2 if d == 0 else 1,
|
||||
)
|
||||
for d in range(cfg.depths[s])
|
||||
]
|
||||
stages += [nn.Sequential(*blocks)]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
self.pool = StridedConv(
|
||||
stride=2,
|
||||
in_chans=cfg.embed_dim[1],
|
||||
embed_dim=cfg.embed_dim[2]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.stages, x)
|
||||
else:
|
||||
x = self.stages(x)
|
||||
x = self.pool(x)
|
||||
return x
|
||||
|
||||
|
||||
class GeGluMlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features,
|
||||
act_layer = 'gelu',
|
||||
drop = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6)
|
||||
|
||||
self.norm = norm_layer(in_features)
|
||||
self.w0 = nn.Linear(in_features, hidden_features)
|
||||
self.act = create_act_layer(act_layer)
|
||||
self.w1 = nn.Linear(in_features, hidden_features)
|
||||
self.w2 = nn.Linear(hidden_features, in_features)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
x = self.act(self.w0(x)) * self.w1(x)
|
||||
x = self.w2(x)
|
||||
return x
|
||||
|
||||
|
||||
def _create_vitamin(variant, pretrained=False, embed_cfg=None, **kwargs):
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
assert embed_cfg is not None
|
||||
backbone = MbConvStages(cfg=embed_cfg)
|
||||
kwargs['embed_layer'] = partial(HybridEmbed, backbone=backbone, proj=False)
|
||||
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
||||
|
||||
return build_model_with_cfg(
|
||||
VisionTransformer,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
|
||||
'first_conv': 'patch_embed.backbone.stem.conv1',
|
||||
'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'vitamin_small_224.datacomp1b_clip_ltt': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-S-LTT', num_classes=384),
|
||||
'vitamin_small_224.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-S', num_classes=384),
|
||||
'vitamin_base_224.datacomp1b_clip_ltt': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-B-LTT', num_classes=768),
|
||||
'vitamin_base_224.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-B', num_classes=768),
|
||||
'vitamin_large_224.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-L-224px', num_classes=768),
|
||||
'vitamin_large_256.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-L-256px', num_classes=768,
|
||||
input_size=(3, 256, 256), crop_pct=1.0),
|
||||
'vitamin_large_336.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-L-336px', num_classes=768,
|
||||
input_size=(3, 336, 336), crop_pct=1.0),
|
||||
'vitamin_large_384.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-L-384px', num_classes=768,
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vitamin_large2_224.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-L2-224px', num_classes=1024),
|
||||
'vitamin_large2_256.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-L2-256px', num_classes=1024,
|
||||
input_size=(3, 256, 256), crop_pct=1.0),
|
||||
'vitamin_large2_336.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-L2-336px', num_classes=1024,
|
||||
input_size=(3, 336, 336), crop_pct=1.0),
|
||||
'vitamin_large2_384.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-L2-384px', num_classes=1024,
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vitamin_xlarge_256.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-XL-256px', num_classes=1152,
|
||||
input_size=(3, 256, 256), crop_pct=1.0),
|
||||
'vitamin_xlarge_336.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-XL-336px', num_classes=1152,
|
||||
input_size=(3, 336, 336), crop_pct=1.0),
|
||||
'vitamin_xlarge_384.datacomp1b_clip': _cfg(
|
||||
hf_hub_id='jienengchen/ViTamin-XL-384px', num_classes=1152,
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_small_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(64, 128, 384),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=64,
|
||||
conv_cfg=VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
)
|
||||
model_args = dict(
|
||||
embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg
|
||||
)
|
||||
model = _create_vitamin('vitamin_small_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_base_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(128, 256, 768),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=128,
|
||||
conv_cfg=VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
)
|
||||
model_args = dict(
|
||||
embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
|
||||
model = _create_vitamin('vitamin_base_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_large_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(160, 320, 1024),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=160,
|
||||
conv_cfg=VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
)
|
||||
model_args = dict(
|
||||
embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg,
|
||||
)
|
||||
model = _create_vitamin('vitamin_large_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(160, 320, 1024),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=160,
|
||||
conv_cfg=VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
)
|
||||
model_args = dict(
|
||||
img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
|
||||
model = _create_vitamin('vitamin_large_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(160, 320, 1024),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=160,
|
||||
conv_cfg=VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
)
|
||||
model_args = dict(
|
||||
img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg
|
||||
)
|
||||
model = _create_vitamin('vitamin_large_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(160, 320, 1024),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=160,
|
||||
conv_cfg=VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
)
|
||||
model_args = dict(
|
||||
img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
|
||||
model = _create_vitamin('vitamin_large_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_large2_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(160, 320, 1024),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=160,
|
||||
conv_cfg=VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
)
|
||||
model_args = dict(
|
||||
embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg,
|
||||
)
|
||||
model = _create_vitamin('vitamin_large2_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_large2_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(160, 320, 1024),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=160,
|
||||
conv_cfg=VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
)
|
||||
model_args = dict(
|
||||
img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
|
||||
model = _create_vitamin('vitamin_large2_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_large2_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(160, 320, 1024),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=160,
|
||||
conv_cfg=VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
)
|
||||
model_args = dict(
|
||||
img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg
|
||||
)
|
||||
model = _create_vitamin('vitamin_large2_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_large2_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(160, 320, 1024),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=160,
|
||||
conv_cfg=VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
)
|
||||
model_args = dict(
|
||||
img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', embed_cfg=embed_cfg)
|
||||
model = _create_vitamin('vitamin_large2_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
embed_cfg=VitCfg(
|
||||
embed_dim=(192, 384, 1152),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=192,
|
||||
conv_cfg=VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
)
|
||||
model_args = dict(
|
||||
img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
|
||||
model = _create_vitamin(
|
||||
'vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(192, 384, 1152),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=192,
|
||||
conv_cfg=VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
)
|
||||
model_args = dict(
|
||||
img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
|
||||
model = _create_vitamin('vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
embed_cfg = VitCfg(
|
||||
embed_dim=(192, 384, 1152),
|
||||
depths=(2, 4, 1),
|
||||
stem_width=192,
|
||||
conv_cfg=VitConvCfg(
|
||||
norm_layer='layernorm2d',
|
||||
norm_eps=1e-6,
|
||||
),
|
||||
head_type='1d',
|
||||
)
|
||||
model_args = dict(
|
||||
img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2.,
|
||||
class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg)
|
||||
model = _create_vitamin('vitamin_xlarge_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
Loading…
x
Reference in New Issue
Block a user