Merge pull request #1789 from huggingface/mw-final
Final push to get remaining models using multi-weight pretrained configs and HF hub weightsmetaformer_baselines_for_vision
commit
bd5f9a341f
|
@ -58,6 +58,8 @@ else:
|
|||
EXCLUDE_FILTERS = ['*enormous*']
|
||||
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*']
|
||||
|
||||
EXCLUDE_JIT_FILTERS = []
|
||||
|
||||
TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
|
||||
TARGET_BWD_SIZE = 128
|
||||
MAX_BWD_SIZE = 320
|
||||
|
@ -277,7 +279,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
|
|||
|
||||
|
||||
if 'GITHUB_ACTIONS' not in os.environ:
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.timeout(240)
|
||||
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_load_pretrained(model_name, batch_size):
|
||||
|
@ -286,19 +288,13 @@ if 'GITHUB_ACTIONS' not in os.environ:
|
|||
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=5)
|
||||
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=0)
|
||||
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.timeout(240)
|
||||
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_features_pretrained(model_name, batch_size):
|
||||
"""Create that pretrained weights load when features_only==True."""
|
||||
create_model(model_name, pretrained=True, features_only=True)
|
||||
|
||||
EXCLUDE_JIT_FILTERS = [
|
||||
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
|
||||
'dla*', 'hrnet*', 'ghostnet*' # hopefully fix at some point
|
||||
'vit_large_*', 'vit_huge_*', 'vit_gi*',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.torchscript
|
||||
@pytest.mark.timeout(120)
|
||||
|
|
|
@ -52,6 +52,7 @@ def create_classifier(
|
|||
pool_type: str = 'avg',
|
||||
use_conv: bool = False,
|
||||
input_fmt: str = 'NCHW',
|
||||
drop_rate: Optional[float] = None,
|
||||
):
|
||||
global_pool, num_pooled_features = _create_pool(
|
||||
num_features,
|
||||
|
@ -65,6 +66,9 @@ def create_classifier(
|
|||
num_classes,
|
||||
use_conv=use_conv,
|
||||
)
|
||||
if drop_rate is not None:
|
||||
dropout = nn.Dropout(drop_rate)
|
||||
return global_pool, dropout, fc
|
||||
return global_pool, fc
|
||||
|
||||
|
||||
|
|
|
@ -11,9 +11,26 @@ from .create_norm_act import get_norm_act_layer
|
|||
|
||||
class ConvNormAct(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
|
||||
bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None):
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding='',
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
apply_act=True,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
norm_kwargs=None,
|
||||
act_layer=nn.ReLU,
|
||||
act_kwargs=None,
|
||||
drop_layer=None,
|
||||
):
|
||||
super(ConvNormAct, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
act_kwargs = act_kwargs or {}
|
||||
|
||||
self.conv = create_conv2d(
|
||||
in_channels, out_channels, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
|
@ -21,8 +38,14 @@ class ConvNormAct(nn.Module):
|
|||
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
||||
norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
|
||||
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn = norm_act_layer(
|
||||
out_channels,
|
||||
apply_act=apply_act,
|
||||
act_kwargs=act_kwargs,
|
||||
**norm_kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
|
@ -57,10 +80,27 @@ def create_aa(aa_layer, channels, stride=2, enable=True):
|
|||
|
||||
class ConvNormActAa(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
|
||||
bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None):
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding='',
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
apply_act=True,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
norm_kwargs=None,
|
||||
act_layer=nn.ReLU,
|
||||
act_kwargs=None,
|
||||
aa_layer=None,
|
||||
drop_layer=None,
|
||||
):
|
||||
super(ConvNormActAa, self).__init__()
|
||||
use_aa = aa_layer is not None and stride == 2
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
act_kwargs = act_kwargs or {}
|
||||
|
||||
self.conv = create_conv2d(
|
||||
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
|
||||
|
@ -69,8 +109,9 @@ class ConvNormActAa(nn.Module):
|
|||
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
||||
norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
|
||||
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn = norm_act_layer(out_channels, apply_act=apply_act, act_kwargs=act_kwargs, **norm_kwargs)
|
||||
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
|
||||
|
||||
@property
|
||||
|
|
|
@ -24,6 +24,18 @@ from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
|
|||
from .trace_utils import _assert
|
||||
|
||||
|
||||
def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
|
||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||
act_kwargs = act_kwargs or {}
|
||||
if act_layer is not None and apply_act:
|
||||
if inplace:
|
||||
act_kwargs['inplace'] = inplace
|
||||
act = act_layer(**act_kwargs)
|
||||
else:
|
||||
act = nn.Identity()
|
||||
return act
|
||||
|
||||
|
||||
class BatchNormAct2d(nn.BatchNorm2d):
|
||||
"""BatchNorm + Activation
|
||||
|
||||
|
@ -40,31 +52,33 @@ class BatchNormAct2d(nn.BatchNorm2d):
|
|||
track_running_stats=True,
|
||||
apply_act=True,
|
||||
act_layer=nn.ReLU,
|
||||
act_params=None, # FIXME not the final approach
|
||||
act_kwargs=None,
|
||||
inplace=True,
|
||||
drop_layer=None,
|
||||
device=None,
|
||||
dtype=None
|
||||
dtype=None,
|
||||
):
|
||||
try:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(BatchNormAct2d, self).__init__(
|
||||
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats,
|
||||
**factory_kwargs
|
||||
num_features,
|
||||
eps=eps,
|
||||
momentum=momentum,
|
||||
affine=affine,
|
||||
track_running_stats=track_running_stats,
|
||||
**factory_kwargs,
|
||||
)
|
||||
except TypeError:
|
||||
# NOTE for backwards compat with old PyTorch w/o factory device/dtype support
|
||||
super(BatchNormAct2d, self).__init__(
|
||||
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
|
||||
num_features,
|
||||
eps=eps,
|
||||
momentum=momentum,
|
||||
affine=affine,
|
||||
track_running_stats=track_running_stats,
|
||||
)
|
||||
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||
if act_layer is not None and apply_act:
|
||||
act_args = dict(inplace=True) if inplace else {}
|
||||
if act_params is not None:
|
||||
act_args['negative_slope'] = act_params
|
||||
self.act = act_layer(**act_args)
|
||||
else:
|
||||
self.act = nn.Identity()
|
||||
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
|
||||
|
||||
def forward(self, x):
|
||||
# cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing
|
||||
|
@ -188,6 +202,7 @@ class FrozenBatchNormAct2d(torch.nn.Module):
|
|||
eps: float = 1e-5,
|
||||
apply_act=True,
|
||||
act_layer=nn.ReLU,
|
||||
act_kwargs=None,
|
||||
inplace=True,
|
||||
drop_layer=None,
|
||||
):
|
||||
|
@ -199,12 +214,7 @@ class FrozenBatchNormAct2d(torch.nn.Module):
|
|||
self.register_buffer("running_var", torch.ones(num_features))
|
||||
|
||||
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||
if act_layer is not None and apply_act:
|
||||
act_args = dict(inplace=True) if inplace else {}
|
||||
self.act = act_layer(**act_args)
|
||||
else:
|
||||
self.act = nn.Identity()
|
||||
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
|
@ -344,6 +354,7 @@ class GroupNormAct(nn.GroupNorm):
|
|||
group_size=None,
|
||||
apply_act=True,
|
||||
act_layer=nn.ReLU,
|
||||
act_kwargs=None,
|
||||
inplace=True,
|
||||
drop_layer=None,
|
||||
):
|
||||
|
@ -354,12 +365,8 @@ class GroupNormAct(nn.GroupNorm):
|
|||
affine=affine,
|
||||
)
|
||||
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||
if act_layer is not None and apply_act:
|
||||
act_args = dict(inplace=True) if inplace else {}
|
||||
self.act = act_layer(**act_args)
|
||||
else:
|
||||
self.act = nn.Identity()
|
||||
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
|
||||
|
||||
self._fast_norm = is_fast_norm()
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -380,17 +387,14 @@ class GroupNorm1Act(nn.GroupNorm):
|
|||
affine=True,
|
||||
apply_act=True,
|
||||
act_layer=nn.ReLU,
|
||||
act_kwargs=None,
|
||||
inplace=True,
|
||||
drop_layer=None,
|
||||
):
|
||||
super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine)
|
||||
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||
if act_layer is not None and apply_act:
|
||||
act_args = dict(inplace=True) if inplace else {}
|
||||
self.act = act_layer(**act_args)
|
||||
else:
|
||||
self.act = nn.Identity()
|
||||
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
|
||||
|
||||
self._fast_norm = is_fast_norm()
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -411,17 +415,15 @@ class LayerNormAct(nn.LayerNorm):
|
|||
affine=True,
|
||||
apply_act=True,
|
||||
act_layer=nn.ReLU,
|
||||
act_kwargs=None,
|
||||
inplace=True,
|
||||
drop_layer=None,
|
||||
):
|
||||
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
|
||||
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||
if act_layer is not None and apply_act:
|
||||
act_args = dict(inplace=True) if inplace else {}
|
||||
self.act = act_layer(**act_args)
|
||||
else:
|
||||
self.act = nn.Identity()
|
||||
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
|
||||
|
||||
self._fast_norm = is_fast_norm()
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -442,17 +444,13 @@ class LayerNormAct2d(nn.LayerNorm):
|
|||
affine=True,
|
||||
apply_act=True,
|
||||
act_layer=nn.ReLU,
|
||||
act_kwargs=None,
|
||||
inplace=True,
|
||||
drop_layer=None,
|
||||
):
|
||||
super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
|
||||
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||
if act_layer is not None and apply_act:
|
||||
act_args = dict(inplace=True) if inplace else {}
|
||||
self.act = act_layer(**act_args)
|
||||
else:
|
||||
self.act = nn.Identity()
|
||||
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
|
||||
self._fast_norm = is_fast_norm()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -29,7 +29,7 @@ class PatchEmbed(nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int = 224,
|
||||
img_size: Optional[int] = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
|
@ -39,12 +39,16 @@ class PatchEmbed(nn.Module):
|
|||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
if img_size is not None:
|
||||
self.img_size = to_2tuple(img_size)
|
||||
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
else:
|
||||
self.img_size = None
|
||||
self.grid_size = None
|
||||
self.num_patches = None
|
||||
|
||||
if output_fmt is not None:
|
||||
self.flatten = False
|
||||
self.output_fmt = Format(output_fmt)
|
||||
|
@ -58,8 +62,10 @@ class PatchEmbed(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||
if self.img_size is not None:
|
||||
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
|
|
|
@ -3,6 +3,8 @@ import torch.nn as nn
|
|||
|
||||
|
||||
class SpaceToDepth(nn.Module):
|
||||
bs: torch.jit.Final[int]
|
||||
|
||||
def __init__(self, block_size=4):
|
||||
super().__init__()
|
||||
assert block_size == 4
|
||||
|
@ -12,7 +14,7 @@ class SpaceToDepth(nn.Module):
|
|||
N, C, H, W = x.size()
|
||||
x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
|
||||
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
|
||||
x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
|
||||
x = x.view(N, C * self.bs * self.bs, H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@ from .eva import *
|
|||
from .focalnet import *
|
||||
from .gcvit import *
|
||||
from .ghostnet import *
|
||||
from .gluon_xception import *
|
||||
from .hardcorenas import *
|
||||
from .hrnet import *
|
||||
from .inception_resnet_v2 import *
|
||||
|
|
|
@ -15,43 +15,13 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, LayerNorm
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['CoaT']
|
||||
|
||||
|
||||
def _cfg_coat(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed1.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'coat_tiny': _cfg_coat(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_tiny-473c2a20.pth'
|
||||
),
|
||||
'coat_mini': _cfg_coat(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_mini-2c6baf49.pth'
|
||||
),
|
||||
'coat_lite_tiny': _cfg_coat(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth'
|
||||
),
|
||||
'coat_lite_mini': _cfg_coat(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth'
|
||||
),
|
||||
'coat_lite_small': _cfg_coat(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_small-fea1d5a1.pth'
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class ConvRelPosEnc(nn.Module):
|
||||
""" Convolutional relative position encoding. """
|
||||
def __init__(self, head_chs, num_heads, window):
|
||||
|
@ -147,7 +117,7 @@ class FactorAttnConvRelPosEnc(nn.Module):
|
|||
|
||||
# Generate Q, K, V.
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # [B, h, N, Ch]
|
||||
q, k, v = qkv.unbind(0) # [B, h, N, Ch]
|
||||
|
||||
# Factorized attention.
|
||||
k_softmax = k.softmax(dim=2)
|
||||
|
@ -334,7 +304,12 @@ class ParallelBlock(nn.Module):
|
|||
|
||||
img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
|
||||
img_tokens = F.interpolate(
|
||||
img_tokens, scale_factor=scale_factor, recompute_scale_factor=False, mode='bilinear', align_corners=False)
|
||||
img_tokens,
|
||||
scale_factor=scale_factor,
|
||||
recompute_scale_factor=False,
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
)
|
||||
img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2)
|
||||
|
||||
out = torch.cat((cls_token, img_tokens), dim=1)
|
||||
|
@ -384,17 +359,17 @@ class CoaT(nn.Module):
|
|||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dims=(0, 0, 0, 0),
|
||||
serial_depths=(0, 0, 0, 0),
|
||||
embed_dims=(64, 128, 320, 512),
|
||||
serial_depths=(3, 4, 6, 3),
|
||||
parallel_depth=0,
|
||||
num_heads=0,
|
||||
mlp_ratios=(0, 0, 0, 0),
|
||||
num_heads=8,
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
norm_layer=LayerNorm,
|
||||
return_interm_layers=False,
|
||||
out_features=None,
|
||||
crpe_window=None,
|
||||
|
@ -711,6 +686,7 @@ def remove_cls(x):
|
|||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
out_dict = {}
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
for k, v in state_dict.items():
|
||||
# original model had unused norm layers, removing them requires filtering pretrained checkpoints
|
||||
if k.startswith('norm1') or \
|
||||
|
@ -726,52 +702,100 @@ def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs):
|
|||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
model = build_model_with_cfg(
|
||||
CoaT, variant, pretrained,
|
||||
CoaT,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def _cfg_coat(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed1.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'coat_tiny.in1k': _cfg_coat(hf_hub_id='timm/'),
|
||||
'coat_mini.in1k': _cfg_coat(hf_hub_id='timm/'),
|
||||
'coat_small.in1k': _cfg_coat(hf_hub_id='timm/'),
|
||||
'coat_lite_tiny.in1k': _cfg_coat(hf_hub_id='timm/'),
|
||||
'coat_lite_mini.in1k': _cfg_coat(hf_hub_id='timm/'),
|
||||
'coat_lite_small.in1k': _cfg_coat(hf_hub_id='timm/'),
|
||||
'coat_lite_medium.in1k': _cfg_coat(hf_hub_id='timm/'),
|
||||
'coat_lite_medium_384.in1k': _cfg_coat(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash',
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def coat_tiny(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6,
|
||||
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
|
||||
model = _create_coat('coat_tiny', pretrained=pretrained, **model_cfg)
|
||||
patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6)
|
||||
model = _create_coat('coat_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def coat_mini(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6,
|
||||
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
|
||||
model = _create_coat('coat_mini', pretrained=pretrained, **model_cfg)
|
||||
patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6)
|
||||
model = _create_coat('coat_mini', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def coat_small(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
patch_size=4, embed_dims=[152, 320, 320, 320], serial_depths=[2, 2, 2, 2], parallel_depth=6, **kwargs)
|
||||
model = _create_coat('coat_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def coat_lite_tiny(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0,
|
||||
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
|
||||
model = _create_coat('coat_lite_tiny', pretrained=pretrained, **model_cfg)
|
||||
patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], mlp_ratios=[8, 8, 4, 4])
|
||||
model = _create_coat('coat_lite_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def coat_lite_mini(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0,
|
||||
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
|
||||
model = _create_coat('coat_lite_mini', pretrained=pretrained, **model_cfg)
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], mlp_ratios=[8, 8, 4, 4])
|
||||
model = _create_coat('coat_lite_mini', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def coat_lite_small(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0,
|
||||
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
|
||||
model = _create_coat('coat_lite_small', pretrained=pretrained, **model_cfg)
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], mlp_ratios=[8, 8, 4, 4])
|
||||
model = _create_coat('coat_lite_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def coat_lite_medium(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
patch_size=4, embed_dims=[128, 256, 320, 512], serial_depths=[3, 6, 10, 8])
|
||||
model = _create_coat('coat_lite_medium', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def coat_lite_medium_384(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
img_size=384, patch_size=4, embed_dims=[128, 256, 320, 512], serial_depths=[3, 6, 10, 8])
|
||||
model = _create_coat('coat_lite_medium_384', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
||||
return model
|
|
@ -28,37 +28,16 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp
|
||||
from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .vision_transformer_hybrid import HybridEmbed
|
||||
|
||||
|
||||
__all__ = ['ConViT']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# ConViT
|
||||
'convit_tiny': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convit/convit_tiny.pth"),
|
||||
'convit_small': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convit/convit_small.pth"),
|
||||
'convit_base': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convit/convit_base.pth")
|
||||
}
|
||||
|
||||
|
||||
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
|
||||
class GPSA(nn.Module):
|
||||
def __init__(
|
||||
|
@ -218,7 +197,7 @@ class Block(nn.Module):
|
|||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
norm_layer=LayerNorm,
|
||||
use_gpsa=True,
|
||||
locality_strength=1.,
|
||||
):
|
||||
|
@ -280,7 +259,7 @@ class ConViT(nn.Module):
|
|||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
hybrid_backbone=None,
|
||||
norm_layer=nn.LayerNorm,
|
||||
norm_layer=LayerNorm,
|
||||
local_up_to_layer=3,
|
||||
locality_strength=1.,
|
||||
use_pos_embed=True,
|
||||
|
@ -300,7 +279,11 @@ class ConViT(nn.Module):
|
|||
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
else:
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.num_patches = num_patches
|
||||
|
||||
|
@ -405,28 +388,43 @@ def _create_convit(variant, pretrained=False, **kwargs):
|
|||
return build_model_with_cfg(ConViT, variant, pretrained, **kwargs)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
# ConViT
|
||||
'convit_tiny.fb_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'convit_small.fb_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'convit_base.fb_in1k': _cfg(hf_hub_id='timm/')
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def convit_tiny(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
|
||||
num_heads=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model = _create_convit(variant='convit_tiny', pretrained=pretrained, **model_args)
|
||||
local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=4)
|
||||
model = _create_convit(variant='convit_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convit_small(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
|
||||
num_heads=9, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model = _create_convit(variant='convit_small', pretrained=pretrained, **model_args)
|
||||
local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=9)
|
||||
model = _create_convit(variant='convit_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convit_base(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
|
||||
num_heads=16, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model = _create_convit(variant='convit_base', pretrained=pretrained, **model_args)
|
||||
local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=16)
|
||||
model = _create_convit(variant='convit_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
|
|
@ -6,31 +6,13 @@ import torch.nn as nn
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import SelectAdaptivePool2d
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
|
||||
__all__ = ['ConvMixer']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .96, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
|
||||
'first_conv': 'stem.0',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'convmixer_1536_20': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1536_20_ks9_p7.pth.tar'),
|
||||
'convmixer_768_32': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_768_32_ks7_p7_relu.pth.tar'),
|
||||
'convmixer_1024_20_ks9_p14': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1024_20_ks9_p14.pth.tar')
|
||||
}
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
|
@ -122,6 +104,25 @@ def _create_convmixer(variant, pretrained=False, **kwargs):
|
|||
return build_model_with_cfg(ConvMixer, variant, pretrained, **kwargs)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .96, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
|
||||
'first_conv': 'stem.0',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'convmixer_1536_20.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'convmixer_768_32.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'convmixer_1024_20_ks9_p14.in1k': _cfg(hf_hub_id='timm/')
|
||||
})
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def convmixer_1536_20(pretrained=False, **kwargs):
|
||||
model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs)
|
||||
|
|
|
@ -36,56 +36,12 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||
from timm.layers import DropPath, to_2tuple, trunc_normal_, _assert
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .vision_transformer import Block
|
||||
|
||||
__all__ = ['CrossViT'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
|
||||
'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'),
|
||||
'classifier': ('head.0', 'head.1'),
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'crossvit_15_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_224.pth'),
|
||||
'crossvit_15_dagger_240': _cfg(
|
||||
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_224.pth',
|
||||
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
||||
),
|
||||
'crossvit_15_dagger_408': _cfg(
|
||||
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth',
|
||||
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
|
||||
),
|
||||
'crossvit_18_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth'),
|
||||
'crossvit_18_dagger_240': _cfg(
|
||||
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_224.pth',
|
||||
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
||||
),
|
||||
'crossvit_18_dagger_408': _cfg(
|
||||
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth',
|
||||
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
|
||||
),
|
||||
'crossvit_9_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth'),
|
||||
'crossvit_9_dagger_240': _cfg(
|
||||
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_dagger_224.pth',
|
||||
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
||||
),
|
||||
'crossvit_base_240': _cfg(
|
||||
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_base_224.pth'),
|
||||
'crossvit_small_240': _cfg(
|
||||
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_small_224.pth'),
|
||||
'crossvit_tiny_240': _cfg(
|
||||
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_tiny_224.pth'),
|
||||
}
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
|
@ -531,6 +487,47 @@ def _create_crossvit(variant, pretrained=False, **kwargs):
|
|||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
|
||||
'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'),
|
||||
'classifier': ('head.0', 'head.1'),
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'crossvit_15_240.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'crossvit_15_dagger_240.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
||||
),
|
||||
'crossvit_15_dagger_408.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
|
||||
),
|
||||
'crossvit_18_240.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'crossvit_18_dagger_240.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
||||
),
|
||||
'crossvit_18_dagger_408.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
|
||||
),
|
||||
'crossvit_9_240.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'crossvit_9_dagger_240.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
||||
),
|
||||
'crossvit_base_240.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'crossvit_small_240.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'crossvit_tiny_240.in1k': _cfg(hf_hub_id='timm/'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def crossvit_tiny_240(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
""" Deep Layer Aggregation and DLA w/ Res2Net
|
||||
DLA original adapted from Official Pytorch impl at:
|
||||
DLA original adapted from Official Pytorch impl at: https://github.com/ucbdrive/dla
|
||||
DLA Paper: `Deep Layer Aggregation` - https://arxiv.org/abs/1707.06484
|
||||
|
||||
Res2Net additions from: https://github.com/gasvn/Res2Net/
|
||||
|
@ -15,55 +15,28 @@ import torch.nn.functional as F
|
|||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import create_classifier
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['DLA']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'base_layer.0', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'dla34': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla34-2b83ff04.pth'),
|
||||
'dla46_c': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla46_c-9b68d685.pth'),
|
||||
'dla46x_c': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla46x_c-6bc5b5c8.pth'),
|
||||
'dla60x_c': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla60x_c-a38e054a.pth'),
|
||||
'dla60': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla60-9e91bd4d.pth'),
|
||||
'dla60x': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla60x-6818f6bb.pth'),
|
||||
'dla102': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla102-21f57b54.pth'),
|
||||
'dla102x': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla102x-7ec0aa2a.pth'),
|
||||
'dla102x2': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla102x2-ac4239c4.pth'),
|
||||
'dla169': _cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dla169-7c767967.pth'),
|
||||
'dla60_res2net': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net_dla60_4s-d88db7f9.pth'),
|
||||
'dla60_res2next': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next_dla60_4s-d327927b.pth'),
|
||||
}
|
||||
|
||||
|
||||
class DlaBasic(nn.Module):
|
||||
"""DLA Basic"""
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, dilation=1, **_):
|
||||
super(DlaBasic, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
inplanes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation)
|
||||
inplanes, planes, kernel_size=3,
|
||||
stride=stride, padding=dilation, bias=False, dilation=dilation)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, stride=1, padding=dilation, bias=False, dilation=dilation)
|
||||
planes, planes, kernel_size=3,
|
||||
stride=1, padding=dilation, bias=False, dilation=dilation)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x, shortcut=None, children: Optional[List[torch.Tensor]] = None):
|
||||
def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
|
||||
if shortcut is None:
|
||||
shortcut = x
|
||||
|
||||
|
@ -93,8 +66,8 @@ class DlaBottleneck(nn.Module):
|
|||
self.conv1 = nn.Conv2d(inplanes, mid_planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(mid_planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
mid_planes, mid_planes, kernel_size=3, stride=stride, padding=dilation,
|
||||
bias=False, dilation=dilation, groups=cardinality)
|
||||
mid_planes, mid_planes, kernel_size=3,
|
||||
stride=stride, padding=dilation, bias=False, dilation=dilation, groups=cardinality)
|
||||
self.bn2 = nn.BatchNorm2d(mid_planes)
|
||||
self.conv3 = nn.Conv2d(mid_planes, outplanes, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(outplanes)
|
||||
|
@ -143,8 +116,8 @@ class DlaBottle2neck(nn.Module):
|
|||
bns = []
|
||||
for _ in range(num_scale_convs):
|
||||
convs.append(nn.Conv2d(
|
||||
mid_planes, mid_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, dilation=dilation, groups=cardinality, bias=False))
|
||||
mid_planes, mid_planes, kernel_size=3,
|
||||
stride=stride, padding=dilation, dilation=dilation, groups=cardinality, bias=False))
|
||||
bns.append(nn.BatchNorm2d(mid_planes))
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
|
@ -211,8 +184,20 @@ class DlaRoot(nn.Module):
|
|||
|
||||
class DlaTree(nn.Module):
|
||||
def __init__(
|
||||
self, levels, block, in_channels, out_channels, stride=1, dilation=1, cardinality=1,
|
||||
base_width=64, level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False):
|
||||
self,
|
||||
levels,
|
||||
block,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
cardinality=1,
|
||||
base_width=64,
|
||||
level_root=False,
|
||||
root_dim=0,
|
||||
root_kernel_size=1,
|
||||
root_shortcut=False,
|
||||
):
|
||||
super(DlaTree, self).__init__()
|
||||
if root_dim == 0:
|
||||
root_dim = 2 * out_channels
|
||||
|
@ -235,9 +220,22 @@ class DlaTree(nn.Module):
|
|||
else:
|
||||
cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut))
|
||||
self.tree1 = DlaTree(
|
||||
levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs)
|
||||
levels - 1,
|
||||
block,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
root_dim=0,
|
||||
**cargs,
|
||||
)
|
||||
self.tree2 = DlaTree(
|
||||
levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs)
|
||||
levels - 1,
|
||||
block,
|
||||
out_channels,
|
||||
out_channels,
|
||||
root_dim=root_dim + out_channels,
|
||||
**cargs,
|
||||
)
|
||||
self.root = None
|
||||
self.level_root = level_root
|
||||
self.root_dim = root_dim
|
||||
|
@ -262,20 +260,31 @@ class DlaTree(nn.Module):
|
|||
|
||||
class DLA(nn.Module):
|
||||
def __init__(
|
||||
self, levels, channels, output_stride=32, num_classes=1000, in_chans=3, global_pool='avg',
|
||||
cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False, drop_rate=0.0):
|
||||
self,
|
||||
levels,
|
||||
channels,
|
||||
output_stride=32,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
global_pool='avg',
|
||||
cardinality=1,
|
||||
base_width=64,
|
||||
block=DlaBottle2neck,
|
||||
shortcut_root=False,
|
||||
drop_rate=0.0,
|
||||
):
|
||||
super(DLA, self).__init__()
|
||||
self.channels = channels
|
||||
self.num_classes = num_classes
|
||||
self.cardinality = cardinality
|
||||
self.base_width = base_width
|
||||
self.drop_rate = drop_rate
|
||||
assert output_stride == 32 # FIXME support dilation
|
||||
|
||||
self.base_layer = nn.Sequential(
|
||||
nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False),
|
||||
nn.BatchNorm2d(channels[0]),
|
||||
nn.ReLU(inplace=True))
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
|
||||
self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
|
||||
cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root)
|
||||
|
@ -293,8 +302,13 @@ class DLA(nn.Module):
|
|||
]
|
||||
|
||||
self.num_features = channels[-1]
|
||||
self.global_pool, self.fc = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
|
||||
self.global_pool, self.head_drop, self.fc = create_classifier(
|
||||
self.num_features,
|
||||
self.num_classes,
|
||||
pool_type=global_pool,
|
||||
use_conv=True,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
|
||||
for m in self.modules():
|
||||
|
@ -310,7 +324,8 @@ class DLA(nn.Module):
|
|||
for i in range(convs):
|
||||
modules.extend([
|
||||
nn.Conv2d(
|
||||
inplanes, planes, kernel_size=3, stride=stride if i == 0 else 1,
|
||||
inplanes, planes, kernel_size=3,
|
||||
stride=stride if i == 0 else 1,
|
||||
padding=dilation, bias=False, dilation=dilation),
|
||||
nn.BatchNorm2d(planes),
|
||||
nn.ReLU(inplace=True)])
|
||||
|
@ -356,8 +371,7 @@ class DLA(nn.Module):
|
|||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.head_drop(x)
|
||||
if pre_logits:
|
||||
return self.flatten(x)
|
||||
x = self.fc(x)
|
||||
|
@ -371,103 +385,131 @@ class DLA(nn.Module):
|
|||
|
||||
def _create_dla(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
DLA, variant, pretrained,
|
||||
DLA,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_strict=False,
|
||||
feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)),
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'base_layer.0', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'dla34.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'dla46_c.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'dla46x_c.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'dla60x_c.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'dla60.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'dla60x.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'dla102.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'dla102x.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'dla102x2.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'dla169': _cfg(hf_hub_id='timm/'),
|
||||
'dla60_res2net.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'dla60_res2next.in1k': _cfg(hf_hub_id='timm/'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def dla60_res2net(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
|
||||
block=DlaBottle2neck, cardinality=1, base_width=28, **kwargs)
|
||||
return _create_dla('dla60_res2net', pretrained, **model_kwargs)
|
||||
block=DlaBottle2neck, cardinality=1, base_width=28)
|
||||
return _create_dla('dla60_res2net', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def dla60_res2next(pretrained=False,**kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
|
||||
block=DlaBottle2neck, cardinality=8, base_width=4, **kwargs)
|
||||
return _create_dla('dla60_res2next', pretrained, **model_kwargs)
|
||||
block=DlaBottle2neck, cardinality=8, base_width=4)
|
||||
return _create_dla('dla60_res2next', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def dla34(pretrained=False, **kwargs): # DLA-34
|
||||
model_kwargs = dict(
|
||||
levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 128, 256, 512],
|
||||
block=DlaBasic, **kwargs)
|
||||
return _create_dla('dla34', pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 128, 256, 512], block=DlaBasic)
|
||||
return _create_dla('dla34', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def dla46_c(pretrained=False, **kwargs): # DLA-46-C
|
||||
model_kwargs = dict(
|
||||
levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256],
|
||||
block=DlaBottleneck, **kwargs)
|
||||
return _create_dla('dla46_c', pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], block=DlaBottleneck)
|
||||
return _create_dla('dla46_c', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def dla46x_c(pretrained=False, **kwargs): # DLA-X-46-C
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256],
|
||||
block=DlaBottleneck, cardinality=32, base_width=4, **kwargs)
|
||||
return _create_dla('dla46x_c', pretrained, **model_kwargs)
|
||||
block=DlaBottleneck, cardinality=32, base_width=4)
|
||||
return _create_dla('dla46x_c', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def dla60x_c(pretrained=False, **kwargs): # DLA-X-60-C
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 64, 64, 128, 256],
|
||||
block=DlaBottleneck, cardinality=32, base_width=4, **kwargs)
|
||||
return _create_dla('dla60x_c', pretrained, **model_kwargs)
|
||||
block=DlaBottleneck, cardinality=32, base_width=4)
|
||||
return _create_dla('dla60x_c', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def dla60(pretrained=False, **kwargs): # DLA-60
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024],
|
||||
block=DlaBottleneck, **kwargs)
|
||||
return _create_dla('dla60', pretrained, **model_kwargs)
|
||||
block=DlaBottleneck)
|
||||
return _create_dla('dla60', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def dla60x(pretrained=False, **kwargs): # DLA-X-60
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024],
|
||||
block=DlaBottleneck, cardinality=32, base_width=4, **kwargs)
|
||||
return _create_dla('dla60x', pretrained, **model_kwargs)
|
||||
block=DlaBottleneck, cardinality=32, base_width=4)
|
||||
return _create_dla('dla60x', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def dla102(pretrained=False, **kwargs): # DLA-102
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
|
||||
block=DlaBottleneck, shortcut_root=True, **kwargs)
|
||||
return _create_dla('dla102', pretrained, **model_kwargs)
|
||||
block=DlaBottleneck, shortcut_root=True)
|
||||
return _create_dla('dla102', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def dla102x(pretrained=False, **kwargs): # DLA-X-102
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
|
||||
block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True, **kwargs)
|
||||
return _create_dla('dla102x', pretrained, **model_kwargs)
|
||||
block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True)
|
||||
return _create_dla('dla102x', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
|
||||
block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True, **kwargs)
|
||||
return _create_dla('dla102x2', pretrained, **model_kwargs)
|
||||
block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True)
|
||||
return _create_dla('dla102x2', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def dla169(pretrained=False, **kwargs): # DLA-169
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024],
|
||||
block=DlaBottleneck, shortcut_root=True, **kwargs)
|
||||
return _create_dla('dla169', pretrained, **model_kwargs)
|
||||
block=DlaBottleneck, shortcut_root=True)
|
||||
return _create_dla('dla169', pretrained, **dict(model_args, **kwargs))
|
||||
|
|
|
@ -17,7 +17,8 @@ import torch.nn.functional as F
|
|||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d
|
||||
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d, \
|
||||
use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
|
@ -37,14 +38,16 @@ class PositionalEncodingFourier(nn.Module):
|
|||
self.dim = dim
|
||||
|
||||
def forward(self, shape: Tuple[int, int, int]):
|
||||
inv_mask = ~torch.zeros(shape).to(device=self.token_projection.weight.device, dtype=torch.bool)
|
||||
y_embed = inv_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = inv_mask.cumsum(2, dtype=torch.float32)
|
||||
device = self.token_projection.weight.device
|
||||
dtype = self.token_projection.weight.dtype
|
||||
inv_mask = ~torch.zeros(shape).to(device=device, dtype=torch.bool)
|
||||
y_embed = inv_mask.cumsum(1, dtype=dtype)
|
||||
x_embed = inv_mask.cumsum(2, dtype=dtype)
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=inv_mask.device)
|
||||
dim_t = torch.arange(self.hidden_dim, dtype=dtype, device=device)
|
||||
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
|
@ -129,9 +132,9 @@ class CrossCovarianceAttn(nn.Module):
|
|||
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = (attn @ v)
|
||||
|
||||
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
|
||||
|
||||
x = x.permute(0, 3, 1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
@ -494,25 +497,25 @@ def _cfg(url='', **kwargs):
|
|||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'edgenext_xx_small.in1k': _cfg(
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'edgenext_x_small.in1k': _cfg(
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'edgenext_small.usi_in1k': _cfg( # USI weights
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||
),
|
||||
'edgenext_base.usi_in1k': _cfg( # USI weights
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth",
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||
),
|
||||
'edgenext_base.in21k_ft_in1k': _cfg( # USI weights
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.21/edgenext_base_IN21K.pth",
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||
),
|
||||
'edgenext_small_rw.sw_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth',
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||
),
|
||||
})
|
||||
|
|
|
@ -961,10 +961,10 @@ default_cfgs = generate_default_cfgs({
|
|||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b4_ra2_320-7eb33cd5.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 320, 320), pool_size=(10, 10), test_input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'efficientnet_b5.in12k_ft_in1k': _cfg(
|
||||
'efficientnet_b5.sw_in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, crop_mode='squash'),
|
||||
'efficientnet_b5.in12k': _cfg(
|
||||
'efficientnet_b5.sw_in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.95, num_classes=11821),
|
||||
'efficientnet_b6.untrained': _cfg(
|
||||
|
@ -1149,6 +1149,19 @@ default_cfgs = generate_default_cfgs({
|
|||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
|
||||
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
|
||||
|
||||
'tf_efficientnet_b5.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
|
||||
'tf_efficientnet_b7.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
|
||||
'tf_efficientnet_b8.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
|
||||
|
||||
'tf_efficientnet_b0.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
|
||||
hf_hub_id='timm/',
|
||||
|
@ -1169,22 +1182,44 @@ default_cfgs = generate_default_cfgs({
|
|||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
|
||||
'tf_efficientnet_b5.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
|
||||
'tf_efficientnet_b5.aa_in1k': _cfg(
|
||||
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_aa-99018a74.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
|
||||
'tf_efficientnet_b6.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
|
||||
'tf_efficientnet_b7.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
|
||||
'tf_efficientnet_b7.aa_in1k': _cfg(
|
||||
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_aa-076e3472.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
|
||||
'tf_efficientnet_b8.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
|
||||
|
||||
'tf_efficientnet_b0.in1k': _cfg(
|
||||
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth',
|
||||
#hf_hub_id='timm/',
|
||||
input_size=(3, 224, 224)),
|
||||
'tf_efficientnet_b1.in1k': _cfg(
|
||||
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4.pth',
|
||||
#hf_hub_id='timm/',
|
||||
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
|
||||
'tf_efficientnet_b2.in1k': _cfg(
|
||||
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth',
|
||||
#hf_hub_id='timm/',
|
||||
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
|
||||
'tf_efficientnet_b3.in1k': _cfg(
|
||||
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth',
|
||||
#hf_hub_id='timm/',
|
||||
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
|
||||
'tf_efficientnet_b4.in1k': _cfg(
|
||||
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth',
|
||||
#hf_hub_id='timm/',
|
||||
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
|
||||
'tf_efficientnet_b5.in1k': _cfg(
|
||||
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth',
|
||||
#hf_hub_id='timm/',
|
||||
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
|
||||
|
||||
|
||||
'tf_efficientnet_es.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
|
||||
|
|
|
@ -16,63 +16,62 @@ from timm.layers import SelectAdaptivePool2d, Linear, make_divisible
|
|||
from ._builder import build_model_with_cfg
|
||||
from ._efficientnet_blocks import SqueezeExcite, ConvBnAct
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['GhostNet']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'ghostnet_050': _cfg(url=''),
|
||||
'ghostnet_100': _cfg(
|
||||
url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'),
|
||||
'ghostnet_130': _cfg(url=''),
|
||||
}
|
||||
|
||||
|
||||
_SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4))
|
||||
|
||||
|
||||
class GhostModule(nn.Module):
|
||||
def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True):
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=1,
|
||||
ratio=2,
|
||||
dw_size=3,
|
||||
stride=1,
|
||||
relu=True,
|
||||
):
|
||||
super(GhostModule, self).__init__()
|
||||
self.oup = oup
|
||||
init_channels = math.ceil(oup / ratio)
|
||||
new_channels = init_channels * (ratio - 1)
|
||||
self.out_chs = out_chs
|
||||
init_chs = math.ceil(out_chs / ratio)
|
||||
new_chs = init_chs * (ratio - 1)
|
||||
|
||||
self.primary_conv = nn.Sequential(
|
||||
nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False),
|
||||
nn.BatchNorm2d(init_channels),
|
||||
nn.ReLU(inplace=True) if relu else nn.Sequential(),
|
||||
nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False),
|
||||
nn.BatchNorm2d(init_chs),
|
||||
nn.ReLU(inplace=True) if relu else nn.Identity(),
|
||||
)
|
||||
|
||||
self.cheap_operation = nn.Sequential(
|
||||
nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False),
|
||||
nn.BatchNorm2d(new_channels),
|
||||
nn.ReLU(inplace=True) if relu else nn.Sequential(),
|
||||
nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size//2, groups=init_chs, bias=False),
|
||||
nn.BatchNorm2d(new_chs),
|
||||
nn.ReLU(inplace=True) if relu else nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.primary_conv(x)
|
||||
x2 = self.cheap_operation(x1)
|
||||
out = torch.cat([x1, x2], dim=1)
|
||||
return out[:, :self.oup, :, :]
|
||||
return out[:, :self.out_chs, :, :]
|
||||
|
||||
|
||||
class GhostBottleneck(nn.Module):
|
||||
""" Ghost bottleneck w/ optional SE"""
|
||||
|
||||
def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, act_layer=nn.ReLU, se_ratio=0.):
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
mid_chs,
|
||||
out_chs,
|
||||
dw_kernel_size=3,
|
||||
stride=1,
|
||||
act_layer=nn.ReLU,
|
||||
se_ratio=0.,
|
||||
):
|
||||
super(GhostBottleneck, self).__init__()
|
||||
has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.stride = stride
|
||||
|
@ -133,7 +132,15 @@ class GhostBottleneck(nn.Module):
|
|||
|
||||
class GhostNet(nn.Module):
|
||||
def __init__(
|
||||
self, cfgs, num_classes=1000, width=1.0, in_chans=3, output_stride=32, global_pool='avg', drop_rate=0.2):
|
||||
self,
|
||||
cfgs,
|
||||
num_classes=1000,
|
||||
width=1.0,
|
||||
in_chans=3,
|
||||
output_stride=32,
|
||||
global_pool='avg',
|
||||
drop_rate=0.2,
|
||||
):
|
||||
super(GhostNet, self).__init__()
|
||||
# setting of inverted residual blocks
|
||||
assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
|
||||
|
@ -275,9 +282,30 @@ def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs):
|
|||
**kwargs,
|
||||
)
|
||||
return build_model_with_cfg(
|
||||
GhostNet, variant, pretrained,
|
||||
GhostNet,
|
||||
variant,
|
||||
pretrained,
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
**model_kwargs)
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'ghostnet_050.untrained': _cfg(),
|
||||
'ghostnet_100.in1k': _cfg(
|
||||
url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'),
|
||||
'ghostnet_130.untrained': _cfg(),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -1,267 +0,0 @@
|
|||
"""Pytorch impl of Gluon Xception
|
||||
This is a port of the Gluon Xception code and weights, itself ported from a PyTorch DeepLab impl.
|
||||
|
||||
Gluon model: (https://gluon-cv.mxnet.io/_modules/gluoncv/model_zoo/xception.html)
|
||||
Original PyTorch DeepLab impl: https://github.com/jfzhang95/pytorch-deeplab-xception
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import create_classifier, get_padding
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
|
||||
__all__ = ['Xception65']
|
||||
|
||||
default_cfgs = {
|
||||
'gluon_xception65': {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth',
|
||||
'input_size': (3, 299, 299),
|
||||
'crop_pct': 0.903,
|
||||
'pool_size': (10, 10),
|
||||
'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN,
|
||||
'std': IMAGENET_DEFAULT_STD,
|
||||
'num_classes': 1000,
|
||||
'first_conv': 'conv1',
|
||||
'classifier': 'fc'
|
||||
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
||||
},
|
||||
}
|
||||
|
||||
""" PADDING NOTES
|
||||
The original PyTorch and Gluon impl of these models dutifully reproduced the
|
||||
aligned padding added to Tensorflow models for Deeplab. This padding was compensating
|
||||
for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd want it to.
|
||||
"""
|
||||
|
||||
|
||||
class SeparableConv2d(nn.Module):
|
||||
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None):
|
||||
super(SeparableConv2d, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation = dilation
|
||||
|
||||
# depthwise convolution
|
||||
padding = get_padding(kernel_size, stride, dilation)
|
||||
self.conv_dw = nn.Conv2d(
|
||||
inplanes, inplanes, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=inplanes, bias=bias)
|
||||
self.bn = norm_layer(num_features=inplanes)
|
||||
# pointwise convolution
|
||||
self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn(x)
|
||||
x = self.conv_pw(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, norm_layer=None):
|
||||
super(Block, self).__init__()
|
||||
if isinstance(planes, (list, tuple)):
|
||||
assert len(planes) == 3
|
||||
else:
|
||||
planes = (planes,) * 3
|
||||
outplanes = planes[-1]
|
||||
|
||||
if outplanes != inplanes or stride != 1:
|
||||
self.skip = nn.Sequential()
|
||||
self.skip.add_module('conv1', nn.Conv2d(
|
||||
inplanes, outplanes, 1, stride=stride, bias=False)),
|
||||
self.skip.add_module('bn1', norm_layer(num_features=outplanes))
|
||||
else:
|
||||
self.skip = None
|
||||
|
||||
rep = OrderedDict()
|
||||
for i in range(3):
|
||||
rep['act%d' % (i + 1)] = nn.ReLU(inplace=True)
|
||||
rep['conv%d' % (i + 1)] = SeparableConv2d(
|
||||
inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, norm_layer=norm_layer)
|
||||
rep['bn%d' % (i + 1)] = norm_layer(planes[i])
|
||||
inplanes = planes[i]
|
||||
|
||||
if not start_with_relu:
|
||||
del rep['act1']
|
||||
else:
|
||||
rep['act1'] = nn.ReLU(inplace=False)
|
||||
self.rep = nn.Sequential(rep)
|
||||
|
||||
def forward(self, x):
|
||||
skip = x
|
||||
if self.skip is not None:
|
||||
skip = self.skip(skip)
|
||||
x = self.rep(x) + skip
|
||||
return x
|
||||
|
||||
|
||||
class Xception65(nn.Module):
|
||||
"""Modified Aligned Xception.
|
||||
|
||||
NOTE: only the 65 layer version is included here, the 71 layer variant
|
||||
was not correct and had no pretrained weights
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
|
||||
drop_rate=0., global_pool='avg'):
|
||||
super(Xception65, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
if output_stride == 32:
|
||||
entry_block3_stride = 2
|
||||
exit_block20_stride = 2
|
||||
middle_dilation = 1
|
||||
exit_dilation = (1, 1)
|
||||
elif output_stride == 16:
|
||||
entry_block3_stride = 2
|
||||
exit_block20_stride = 1
|
||||
middle_dilation = 1
|
||||
exit_dilation = (1, 2)
|
||||
elif output_stride == 8:
|
||||
entry_block3_stride = 1
|
||||
exit_block20_stride = 1
|
||||
middle_dilation = 2
|
||||
exit_dilation = (2, 4)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Entry flow
|
||||
self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = norm_layer(num_features=32)
|
||||
self.act1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = norm_layer(num_features=64)
|
||||
self.act2 = nn.ReLU(inplace=True)
|
||||
|
||||
self.block1 = Block(64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer)
|
||||
self.block1_act = nn.ReLU(inplace=True)
|
||||
self.block2 = Block(128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer)
|
||||
self.block3 = Block(256, 728, stride=entry_block3_stride, norm_layer=norm_layer)
|
||||
|
||||
# Middle flow
|
||||
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
|
||||
728, 728, stride=1, dilation=middle_dilation, norm_layer=norm_layer)) for i in range(4, 20)]))
|
||||
|
||||
# Exit flow
|
||||
self.block20 = Block(
|
||||
728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_dilation[0], norm_layer=norm_layer)
|
||||
self.block20_act = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
|
||||
self.bn3 = norm_layer(num_features=1536)
|
||||
self.act3 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
|
||||
self.bn4 = norm_layer(num_features=1536)
|
||||
self.act4 = nn.ReLU(inplace=True)
|
||||
|
||||
self.num_features = 2048
|
||||
self.conv5 = SeparableConv2d(
|
||||
1536, self.num_features, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
|
||||
self.bn5 = norm_layer(num_features=self.num_features)
|
||||
self.act5 = nn.ReLU(inplace=True)
|
||||
self.feature_info = [
|
||||
dict(num_chs=64, reduction=2, module='act2'),
|
||||
dict(num_chs=128, reduction=4, module='block1_act'),
|
||||
dict(num_chs=256, reduction=8, module='block3.rep.act1'),
|
||||
dict(num_chs=728, reduction=16, module='block20.rep.act1'),
|
||||
dict(num_chs=2048, reduction=32, module='act5'),
|
||||
]
|
||||
|
||||
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
matcher = dict(
|
||||
stem=r'^conv[12]|bn[12]',
|
||||
blocks=[
|
||||
(r'^mid\.block(\d+)', None),
|
||||
(r'^block(\d+)', None),
|
||||
(r'^conv[345]|bn[345]', (99,)),
|
||||
],
|
||||
)
|
||||
return matcher
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
assert not enable, "gradient checkpointing not supported"
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
# Entry flow
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.act2(x)
|
||||
|
||||
x = self.block1(x)
|
||||
x = self.block1_act(x)
|
||||
# c1 = x
|
||||
x = self.block2(x)
|
||||
# c2 = x
|
||||
x = self.block3(x)
|
||||
|
||||
# Middle flow
|
||||
x = self.mid(x)
|
||||
# c3 = x
|
||||
|
||||
# Exit flow
|
||||
x = self.block20(x)
|
||||
x = self.block20_act(x)
|
||||
x = self.conv3(x)
|
||||
x = self.bn3(x)
|
||||
x = self.act3(x)
|
||||
|
||||
x = self.conv4(x)
|
||||
x = self.bn4(x)
|
||||
x = self.act4(x)
|
||||
|
||||
x = self.conv5(x)
|
||||
x = self.bn5(x)
|
||||
x = self.act5(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x):
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate:
|
||||
F.dropout(x, self.drop_rate, training=self.training)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _create_gluon_xception(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
Xception65, variant, pretrained,
|
||||
feature_cfg=dict(feature_cls='hook'),
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_xception65(pretrained=False, **kwargs):
|
||||
""" Modified Aligned Xception-65
|
||||
"""
|
||||
return _create_gluon_xception('gluon_xception65', pretrained, **kwargs)
|
File diff suppressed because it is too large
Load Diff
|
@ -2,75 +2,41 @@
|
|||
Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
|
||||
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
|
||||
"""
|
||||
from functools import partial
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import create_classifier
|
||||
from timm.layers import create_classifier, ConvNormAct
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import flatten_modules
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||
|
||||
__all__ = ['InceptionResnetV2']
|
||||
|
||||
default_cfgs = {
|
||||
# ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
|
||||
'inception_resnet_v2': {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth',
|
||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.8975, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
|
||||
'label_offset': 1, # 1001 classes in pretrained weights
|
||||
},
|
||||
# ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
|
||||
'ens_adv_inception_resnet_v2': {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth',
|
||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.8975, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
|
||||
'label_offset': 1, # 1001 classes in pretrained weights
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BasicConv2d(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
|
||||
super(BasicConv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
||||
self.bn = nn.BatchNorm2d(out_planes, eps=.001)
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Mixed_5b(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, conv_block=None):
|
||||
super(Mixed_5b, self).__init__()
|
||||
conv_block = conv_block or ConvNormAct
|
||||
|
||||
self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
|
||||
self.branch0 = conv_block(192, 96, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(192, 48, kernel_size=1, stride=1),
|
||||
BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
|
||||
conv_block(192, 48, kernel_size=1, stride=1),
|
||||
conv_block(48, 64, kernel_size=5, stride=1, padding=2)
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
BasicConv2d(192, 64, kernel_size=1, stride=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
||||
conv_block(192, 64, kernel_size=1, stride=1),
|
||||
conv_block(64, 96, kernel_size=3, stride=1, padding=1),
|
||||
conv_block(96, 96, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
self.branch3 = nn.Sequential(
|
||||
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
||||
BasicConv2d(192, 64, kernel_size=1, stride=1)
|
||||
conv_block(192, 64, kernel_size=1, stride=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -83,26 +49,26 @@ class Mixed_5b(nn.Module):
|
|||
|
||||
|
||||
class Block35(nn.Module):
|
||||
def __init__(self, scale=1.0):
|
||||
def __init__(self, scale=1.0, conv_block=None):
|
||||
super(Block35, self).__init__()
|
||||
|
||||
self.scale = scale
|
||||
conv_block = conv_block or ConvNormAct
|
||||
|
||||
self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
|
||||
self.branch0 = conv_block(320, 32, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
||||
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
|
||||
conv_block(320, 32, kernel_size=1, stride=1),
|
||||
conv_block(32, 32, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
||||
BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
|
||||
conv_block(320, 32, kernel_size=1, stride=1),
|
||||
conv_block(32, 48, kernel_size=3, stride=1, padding=1),
|
||||
conv_block(48, 64, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
self.act = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
|
@ -111,20 +77,21 @@ class Block35(nn.Module):
|
|||
out = torch.cat((x0, x1, x2), 1)
|
||||
out = self.conv2d(out)
|
||||
out = out * self.scale + x
|
||||
out = self.relu(out)
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
|
||||
class Mixed_6a(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, conv_block=None):
|
||||
super(Mixed_6a, self).__init__()
|
||||
conv_block = conv_block or ConvNormAct
|
||||
|
||||
self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
|
||||
self.branch0 = conv_block(320, 384, kernel_size=3, stride=2)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(320, 256, kernel_size=1, stride=1),
|
||||
BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
||||
conv_block(320, 256, kernel_size=1, stride=1),
|
||||
conv_block(256, 256, kernel_size=3, stride=1, padding=1),
|
||||
conv_block(256, 384, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch2 = nn.MaxPool2d(3, stride=2)
|
||||
|
@ -138,21 +105,21 @@ class Mixed_6a(nn.Module):
|
|||
|
||||
|
||||
class Block17(nn.Module):
|
||||
def __init__(self, scale=1.0):
|
||||
def __init__(self, scale=1.0, conv_block=None):
|
||||
super(Block17, self).__init__()
|
||||
|
||||
self.scale = scale
|
||||
conv_block = conv_block or ConvNormAct
|
||||
|
||||
self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
|
||||
self.branch0 = conv_block(1088, 192, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(1088, 128, kernel_size=1, stride=1),
|
||||
BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
||||
BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0))
|
||||
conv_block(1088, 128, kernel_size=1, stride=1),
|
||||
conv_block(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
||||
conv_block(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0))
|
||||
)
|
||||
|
||||
self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
self.act = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
|
@ -160,28 +127,29 @@ class Block17(nn.Module):
|
|||
out = torch.cat((x0, x1), 1)
|
||||
out = self.conv2d(out)
|
||||
out = out * self.scale + x
|
||||
out = self.relu(out)
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
|
||||
class Mixed_7a(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, conv_block=None):
|
||||
super(Mixed_7a, self).__init__()
|
||||
conv_block = conv_block or ConvNormAct
|
||||
|
||||
self.branch0 = nn.Sequential(
|
||||
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
||||
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
||||
conv_block(1088, 256, kernel_size=1, stride=1),
|
||||
conv_block(256, 384, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
||||
BasicConv2d(256, 288, kernel_size=3, stride=2)
|
||||
conv_block(1088, 256, kernel_size=1, stride=1),
|
||||
conv_block(256, 288, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
||||
BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(288, 320, kernel_size=3, stride=2)
|
||||
conv_block(1088, 256, kernel_size=1, stride=1),
|
||||
conv_block(256, 288, kernel_size=3, stride=1, padding=1),
|
||||
conv_block(288, 320, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch3 = nn.MaxPool2d(3, stride=2)
|
||||
|
@ -197,21 +165,21 @@ class Mixed_7a(nn.Module):
|
|||
|
||||
class Block8(nn.Module):
|
||||
|
||||
def __init__(self, scale=1.0, no_relu=False):
|
||||
def __init__(self, scale=1.0, no_relu=False, conv_block=None):
|
||||
super(Block8, self).__init__()
|
||||
|
||||
self.scale = scale
|
||||
conv_block = conv_block or ConvNormAct
|
||||
|
||||
self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
|
||||
self.branch0 = conv_block(2080, 192, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(2080, 192, kernel_size=1, stride=1),
|
||||
BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)),
|
||||
BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
||||
conv_block(2080, 192, kernel_size=1, stride=1),
|
||||
conv_block(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)),
|
||||
conv_block(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
||||
)
|
||||
|
||||
self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
|
||||
self.relu = None if no_relu else nn.ReLU(inplace=False)
|
||||
self.relu = None if no_relu else nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
|
@ -225,81 +193,58 @@ class Block8(nn.Module):
|
|||
|
||||
|
||||
class InceptionResnetV2(nn.Module):
|
||||
def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
drop_rate=0.,
|
||||
output_stride=32,
|
||||
global_pool='avg',
|
||||
norm_layer='batchnorm2d',
|
||||
norm_eps=1e-3,
|
||||
act_layer='relu',
|
||||
):
|
||||
super(InceptionResnetV2, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.num_classes = num_classes
|
||||
self.num_features = 1536
|
||||
assert output_stride == 32
|
||||
conv_block = partial(
|
||||
ConvNormAct,
|
||||
padding=0,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
norm_kwargs=dict(eps=norm_eps),
|
||||
act_kwargs=dict(inplace=True),
|
||||
)
|
||||
|
||||
self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
|
||||
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
|
||||
self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2d_1a = conv_block(in_chans, 32, kernel_size=3, stride=2)
|
||||
self.conv2d_2a = conv_block(32, 32, kernel_size=3, stride=1)
|
||||
self.conv2d_2b = conv_block(32, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')]
|
||||
|
||||
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
|
||||
self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
|
||||
self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
|
||||
self.conv2d_3b = conv_block(64, 80, kernel_size=1, stride=1)
|
||||
self.conv2d_4a = conv_block(80, 192, kernel_size=3, stride=1)
|
||||
self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')]
|
||||
|
||||
self.maxpool_5a = nn.MaxPool2d(3, stride=2)
|
||||
self.mixed_5b = Mixed_5b()
|
||||
self.repeat = nn.Sequential(
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17)
|
||||
)
|
||||
self.mixed_5b = Mixed_5b(conv_block=conv_block)
|
||||
self.repeat = nn.Sequential(*[Block35(scale=0.17, conv_block=conv_block) for _ in range(10)])
|
||||
self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')]
|
||||
|
||||
self.mixed_6a = Mixed_6a()
|
||||
self.repeat_1 = nn.Sequential(
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10)
|
||||
)
|
||||
self.mixed_6a = Mixed_6a(conv_block=conv_block)
|
||||
self.repeat_1 = nn.Sequential(*[Block17(scale=0.10, conv_block=conv_block) for _ in range(20)])
|
||||
self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')]
|
||||
|
||||
self.mixed_7a = Mixed_7a()
|
||||
self.repeat_2 = nn.Sequential(
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20)
|
||||
)
|
||||
self.block8 = Block8(no_relu=True)
|
||||
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
|
||||
self.mixed_7a = Mixed_7a(conv_block=conv_block)
|
||||
self.repeat_2 = nn.Sequential(*[Block8(scale=0.20, conv_block=conv_block) for _ in range(9)])
|
||||
|
||||
self.block8 = Block8(no_relu=True, conv_block=conv_block)
|
||||
self.conv2d_7b = conv_block(2080, self.num_features, kernel_size=1, stride=1)
|
||||
self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')]
|
||||
|
||||
self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
self.global_pool, self.head_drop, self.classif = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
|
@ -352,8 +297,7 @@ class InceptionResnetV2(nn.Module):
|
|||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate > 0:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.classif(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -366,18 +310,36 @@ def _create_inception_resnet_v2(variant, pretrained=False, **kwargs):
|
|||
return build_model_with_cfg(InceptionResnetV2, variant, pretrained, **kwargs)
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
# ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
|
||||
'inception_resnet_v2.tf_in1k': {
|
||||
'hf_hub_id': 'timm/',
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth',
|
||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.8975, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
|
||||
'label_offset': 1, # 1001 classes in pretrained weights
|
||||
},
|
||||
# As per https://arxiv.org/abs/1705.07204 and
|
||||
# ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
|
||||
'inception_resnet_v2.tf_ens_adv_in1k': {
|
||||
'hf_hub_id': 'timm/',
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth',
|
||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.8975, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
|
||||
'label_offset': 1, # 1001 classes in pretrained weights
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def inception_resnet_v2(pretrained=False, **kwargs):
|
||||
r"""InceptionResnetV2 model architecture from the
|
||||
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
|
||||
"""
|
||||
return _create_inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def ens_adv_inception_resnet_v2(pretrained=False, **kwargs):
|
||||
r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
|
||||
As per https://arxiv.org/abs/1705.07204 and
|
||||
https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
|
||||
"""
|
||||
return _create_inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs)
|
||||
register_model_deprecations(__name__, {
|
||||
'ens_adv_inception_resnet_v2': 'inception_resnet_v2.tf_ens_adv_in1k',
|
||||
})
|
|
@ -3,61 +3,27 @@
|
|||
Originally from torchvision Inception3 model
|
||||
Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import trunc_normal_, create_classifier, Linear
|
||||
from timm.layers import trunc_normal_, create_classifier, Linear, ConvNormAct
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._builder import resolve_pretrained_cfg
|
||||
from ._manipulate import flatten_modules
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||
|
||||
__all__ = ['InceptionV3', 'InceptionV3Aux'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'Conv2d_1a_3x3.conv', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# original PyTorch weights, ported from Tensorflow but modified
|
||||
'inception_v3': _cfg(
|
||||
# NOTE checkpoint has aux logit layer weights
|
||||
url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'),
|
||||
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
||||
'tf_inception_v3': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth',
|
||||
num_classes=1000, label_offset=1),
|
||||
# my port of Tensorflow adversarially trained Inception V3 from
|
||||
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
||||
'adv_inception_v3': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth',
|
||||
num_classes=1000, label_offset=1),
|
||||
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
||||
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
||||
'gluon_inception_v3': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults
|
||||
std=IMAGENET_DEFAULT_STD, # also works well with inception defaults
|
||||
)
|
||||
}
|
||||
__all__ = ['InceptionV3'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
class InceptionA(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, pool_features, conv_block=None):
|
||||
super(InceptionA, self).__init__()
|
||||
if conv_block is None:
|
||||
conv_block = BasicConv2d
|
||||
conv_block = conv_block or ConvNormAct
|
||||
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
|
||||
|
||||
self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
|
||||
|
@ -94,8 +60,7 @@ class InceptionB(nn.Module):
|
|||
|
||||
def __init__(self, in_channels, conv_block=None):
|
||||
super(InceptionB, self).__init__()
|
||||
if conv_block is None:
|
||||
conv_block = BasicConv2d
|
||||
conv_block = conv_block or ConvNormAct
|
||||
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
|
||||
|
||||
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
||||
|
@ -123,8 +88,7 @@ class InceptionC(nn.Module):
|
|||
|
||||
def __init__(self, in_channels, channels_7x7, conv_block=None):
|
||||
super(InceptionC, self).__init__()
|
||||
if conv_block is None:
|
||||
conv_block = BasicConv2d
|
||||
conv_block = conv_block or ConvNormAct
|
||||
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
|
||||
|
||||
c7 = channels_7x7
|
||||
|
@ -168,8 +132,7 @@ class InceptionD(nn.Module):
|
|||
|
||||
def __init__(self, in_channels, conv_block=None):
|
||||
super(InceptionD, self).__init__()
|
||||
if conv_block is None:
|
||||
conv_block = BasicConv2d
|
||||
conv_block = conv_block or ConvNormAct
|
||||
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
||||
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
|
||||
|
||||
|
@ -200,8 +163,7 @@ class InceptionE(nn.Module):
|
|||
|
||||
def __init__(self, in_channels, conv_block=None):
|
||||
super(InceptionE, self).__init__()
|
||||
if conv_block is None:
|
||||
conv_block = BasicConv2d
|
||||
conv_block = conv_block or ConvNormAct
|
||||
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
|
||||
|
||||
self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
|
||||
|
@ -248,8 +210,7 @@ class InceptionAux(nn.Module):
|
|||
|
||||
def __init__(self, in_channels, num_classes, conv_block=None):
|
||||
super(InceptionAux, self).__init__()
|
||||
if conv_block is None:
|
||||
conv_block = BasicConv2d
|
||||
conv_block = conv_block or ConvNormAct
|
||||
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
|
||||
self.conv1 = conv_block(128, 768, kernel_size=5)
|
||||
self.conv1.stddev = 0.01
|
||||
|
@ -274,52 +235,56 @@ class InceptionAux(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class BasicConv2d(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, **kwargs):
|
||||
super(BasicConv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
||||
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return F.relu(x, inplace=True)
|
||||
|
||||
|
||||
class InceptionV3(nn.Module):
|
||||
"""Inception-V3 with no AuxLogits
|
||||
FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns
|
||||
"""Inception-V3
|
||||
"""
|
||||
aux_logits: torch.jit.Final[bool]
|
||||
|
||||
def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=False):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
drop_rate=0.,
|
||||
global_pool='avg',
|
||||
aux_logits=False,
|
||||
norm_layer='batchnorm2d',
|
||||
norm_eps=1e-3,
|
||||
act_layer='relu',
|
||||
):
|
||||
super(InceptionV3, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
self.aux_logits = aux_logits
|
||||
conv_block = partial(
|
||||
ConvNormAct,
|
||||
padding=0,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
norm_kwargs=dict(eps=norm_eps),
|
||||
act_kwargs=dict(inplace=True),
|
||||
)
|
||||
|
||||
self.Conv2d_1a_3x3 = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
|
||||
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
|
||||
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
|
||||
self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2)
|
||||
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
|
||||
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
|
||||
self.Pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
|
||||
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
|
||||
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
|
||||
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
|
||||
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
|
||||
self.Pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
|
||||
self.Mixed_5b = InceptionA(192, pool_features=32)
|
||||
self.Mixed_5c = InceptionA(256, pool_features=64)
|
||||
self.Mixed_5d = InceptionA(288, pool_features=64)
|
||||
self.Mixed_6a = InceptionB(288)
|
||||
self.Mixed_6b = InceptionC(768, channels_7x7=128)
|
||||
self.Mixed_6c = InceptionC(768, channels_7x7=160)
|
||||
self.Mixed_6d = InceptionC(768, channels_7x7=160)
|
||||
self.Mixed_6e = InceptionC(768, channels_7x7=192)
|
||||
self.Mixed_5b = InceptionA(192, pool_features=32, conv_block=conv_block)
|
||||
self.Mixed_5c = InceptionA(256, pool_features=64, conv_block=conv_block)
|
||||
self.Mixed_5d = InceptionA(288, pool_features=64, conv_block=conv_block)
|
||||
self.Mixed_6a = InceptionB(288, conv_block=conv_block)
|
||||
self.Mixed_6b = InceptionC(768, channels_7x7=128, conv_block=conv_block)
|
||||
self.Mixed_6c = InceptionC(768, channels_7x7=160, conv_block=conv_block)
|
||||
self.Mixed_6d = InceptionC(768, channels_7x7=160, conv_block=conv_block)
|
||||
self.Mixed_6e = InceptionC(768, channels_7x7=192, conv_block=conv_block)
|
||||
if aux_logits:
|
||||
self.AuxLogits = InceptionAux(768, num_classes)
|
||||
self.AuxLogits = InceptionAux(768, num_classes, conv_block=conv_block)
|
||||
else:
|
||||
self.AuxLogits = None
|
||||
self.Mixed_7a = InceptionD(768)
|
||||
self.Mixed_7b = InceptionE(1280)
|
||||
self.Mixed_7c = InceptionE(2048)
|
||||
self.Mixed_7a = InceptionD(768, conv_block=conv_block)
|
||||
self.Mixed_7b = InceptionE(1280, conv_block=conv_block)
|
||||
self.Mixed_7c = InceptionE(2048, conv_block=conv_block)
|
||||
self.feature_info = [
|
||||
dict(num_chs=64, reduction=2, module='Conv2d_2b_3x3'),
|
||||
dict(num_chs=192, reduction=4, module='Conv2d_4a_3x3'),
|
||||
|
@ -329,7 +294,12 @@ class InceptionV3(nn.Module):
|
|||
]
|
||||
|
||||
self.num_features = 2048
|
||||
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
self.global_pool, self.head_drop, self.fc = create_classifier(
|
||||
self.num_features,
|
||||
self.num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
||||
|
@ -394,85 +364,99 @@ class InceptionV3(nn.Module):
|
|||
|
||||
def forward_features(self, x):
|
||||
x = self.forward_preaux(x)
|
||||
if self.aux_logits:
|
||||
aux = self.AuxLogits(x)
|
||||
x = self.forward_postaux(x)
|
||||
return x, aux
|
||||
x = self.forward_postaux(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x):
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate > 0:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.head_drop(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
if self.aux_logits:
|
||||
x, aux = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x, aux
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
class InceptionV3Aux(InceptionV3):
|
||||
"""InceptionV3 with AuxLogits
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=True):
|
||||
super(InceptionV3Aux, self).__init__(
|
||||
num_classes, in_chans, drop_rate, global_pool, aux_logits)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.forward_preaux(x)
|
||||
aux = self.AuxLogits(x) if self.training else None
|
||||
x = self.forward_postaux(x)
|
||||
return x, aux
|
||||
|
||||
def forward(self, x):
|
||||
x, aux = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x, aux
|
||||
|
||||
|
||||
def _create_inception_v3(variant, pretrained=False, **kwargs):
|
||||
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
|
||||
aux_logits = kwargs.pop('aux_logits', False)
|
||||
aux_logits = kwargs.get('aux_logits', False)
|
||||
has_aux_logits = False
|
||||
if pretrained_cfg:
|
||||
# only torchvision pretrained weights have aux logits
|
||||
has_aux_logits = pretrained_cfg.tag == 'tv_in1k'
|
||||
if aux_logits:
|
||||
assert not kwargs.pop('features_only', False)
|
||||
model_cls = InceptionV3Aux
|
||||
load_strict = variant == 'inception_v3'
|
||||
load_strict = has_aux_logits
|
||||
else:
|
||||
model_cls = InceptionV3
|
||||
load_strict = variant != 'inception_v3'
|
||||
load_strict = not has_aux_logits
|
||||
|
||||
return build_model_with_cfg(
|
||||
model_cls, variant, pretrained,
|
||||
InceptionV3,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
pretrained_strict=load_strict,
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'Conv2d_1a_3x3.conv', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
# original PyTorch weights, ported from Tensorflow but modified
|
||||
'inception_v3.tv_in1k': _cfg(
|
||||
# NOTE checkpoint has aux logit layer weights
|
||||
hf_hub_id='timm/',
|
||||
url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'),
|
||||
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
||||
'inception_v3.tf_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth',
|
||||
num_classes=1000, label_offset=1),
|
||||
# my port of Tensorflow adversarially trained Inception V3 from
|
||||
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
||||
'inception_v3.tf_adv_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth',
|
||||
num_classes=1000, label_offset=1),
|
||||
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
||||
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
||||
'inception_v3.gluon_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults
|
||||
std=IMAGENET_DEFAULT_STD, # also works well with inception defaults
|
||||
)
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def inception_v3(pretrained=False, **kwargs):
|
||||
# original PyTorch weights, ported from Tensorflow but modified
|
||||
model = _create_inception_v3('inception_v3', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_inception_v3(pretrained=False, **kwargs):
|
||||
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
||||
model = _create_inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def adv_inception_v3(pretrained=False, **kwargs):
|
||||
# my port of Tensorflow adversarially trained Inception V3 from
|
||||
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
||||
model = _create_inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_inception_v3(pretrained=False, **kwargs):
|
||||
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
||||
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
||||
model = _create_inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
register_model_deprecations(__name__, {
|
||||
'tf_inception_v3': 'inception_v3.tf_in1k',
|
||||
'adv_inception_v3': 'inception_v3.tf_adv_in1k',
|
||||
'gluon_inception_v3': 'inception_v3.gluon_in1k',
|
||||
})
|
|
@ -2,49 +2,24 @@
|
|||
Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
|
||||
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import create_classifier
|
||||
from timm.layers import create_classifier, ConvNormAct
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['InceptionV4']
|
||||
|
||||
default_cfgs = {
|
||||
'inception_v4': {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth',
|
||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'features.0.conv', 'classifier': 'last_linear',
|
||||
'label_offset': 1, # 1001 classes in pretrained weights
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BasicConv2d(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
|
||||
super(BasicConv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
||||
self.bn = nn.BatchNorm2d(out_planes, eps=0.001)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Mixed3a(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, conv_block=ConvNormAct):
|
||||
super(Mixed3a, self).__init__()
|
||||
self.maxpool = nn.MaxPool2d(3, stride=2)
|
||||
self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2)
|
||||
self.conv = conv_block(64, 96, kernel_size=3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.maxpool(x)
|
||||
|
@ -54,19 +29,19 @@ class Mixed3a(nn.Module):
|
|||
|
||||
|
||||
class Mixed4a(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, conv_block=ConvNormAct):
|
||||
super(Mixed4a, self).__init__()
|
||||
|
||||
self.branch0 = nn.Sequential(
|
||||
BasicConv2d(160, 64, kernel_size=1, stride=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, stride=1)
|
||||
conv_block(160, 64, kernel_size=1, stride=1),
|
||||
conv_block(64, 96, kernel_size=3, stride=1)
|
||||
)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(160, 64, kernel_size=1, stride=1),
|
||||
BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
||||
BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
||||
BasicConv2d(64, 96, kernel_size=(3, 3), stride=1)
|
||||
conv_block(160, 64, kernel_size=1, stride=1),
|
||||
conv_block(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
||||
conv_block(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
||||
conv_block(64, 96, kernel_size=(3, 3), stride=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -77,9 +52,9 @@ class Mixed4a(nn.Module):
|
|||
|
||||
|
||||
class Mixed5a(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, conv_block=ConvNormAct):
|
||||
super(Mixed5a, self).__init__()
|
||||
self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2)
|
||||
self.conv = conv_block(192, 192, kernel_size=3, stride=2)
|
||||
self.maxpool = nn.MaxPool2d(3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -90,24 +65,24 @@ class Mixed5a(nn.Module):
|
|||
|
||||
|
||||
class InceptionA(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, conv_block=ConvNormAct):
|
||||
super(InceptionA, self).__init__()
|
||||
self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1)
|
||||
self.branch0 = conv_block(384, 96, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(384, 64, kernel_size=1, stride=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1)
|
||||
conv_block(384, 64, kernel_size=1, stride=1),
|
||||
conv_block(64, 96, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
BasicConv2d(384, 64, kernel_size=1, stride=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
||||
conv_block(384, 64, kernel_size=1, stride=1),
|
||||
conv_block(64, 96, kernel_size=3, stride=1, padding=1),
|
||||
conv_block(96, 96, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
self.branch3 = nn.Sequential(
|
||||
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
||||
BasicConv2d(384, 96, kernel_size=1, stride=1)
|
||||
conv_block(384, 96, kernel_size=1, stride=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -120,14 +95,14 @@ class InceptionA(nn.Module):
|
|||
|
||||
|
||||
class ReductionA(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, conv_block=ConvNormAct):
|
||||
super(ReductionA, self).__init__()
|
||||
self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2)
|
||||
self.branch0 = conv_block(384, 384, kernel_size=3, stride=2)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(384, 192, kernel_size=1, stride=1),
|
||||
BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(224, 256, kernel_size=3, stride=2)
|
||||
conv_block(384, 192, kernel_size=1, stride=1),
|
||||
conv_block(192, 224, kernel_size=3, stride=1, padding=1),
|
||||
conv_block(224, 256, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch2 = nn.MaxPool2d(3, stride=2)
|
||||
|
@ -141,27 +116,27 @@ class ReductionA(nn.Module):
|
|||
|
||||
|
||||
class InceptionB(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, conv_block=ConvNormAct):
|
||||
super(InceptionB, self).__init__()
|
||||
self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1)
|
||||
self.branch0 = conv_block(1024, 384, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(1024, 192, kernel_size=1, stride=1),
|
||||
BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
||||
BasicConv2d(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0))
|
||||
conv_block(1024, 192, kernel_size=1, stride=1),
|
||||
conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
||||
conv_block(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0))
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
BasicConv2d(1024, 192, kernel_size=1, stride=1),
|
||||
BasicConv2d(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
||||
BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
||||
BasicConv2d(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
||||
BasicConv2d(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3))
|
||||
conv_block(1024, 192, kernel_size=1, stride=1),
|
||||
conv_block(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
||||
conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
||||
conv_block(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
||||
conv_block(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3))
|
||||
)
|
||||
|
||||
self.branch3 = nn.Sequential(
|
||||
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
||||
BasicConv2d(1024, 128, kernel_size=1, stride=1)
|
||||
conv_block(1024, 128, kernel_size=1, stride=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -174,19 +149,19 @@ class InceptionB(nn.Module):
|
|||
|
||||
|
||||
class ReductionB(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, conv_block=ConvNormAct):
|
||||
super(ReductionB, self).__init__()
|
||||
|
||||
self.branch0 = nn.Sequential(
|
||||
BasicConv2d(1024, 192, kernel_size=1, stride=1),
|
||||
BasicConv2d(192, 192, kernel_size=3, stride=2)
|
||||
conv_block(1024, 192, kernel_size=1, stride=1),
|
||||
conv_block(192, 192, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(1024, 256, kernel_size=1, stride=1),
|
||||
BasicConv2d(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
||||
BasicConv2d(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
||||
BasicConv2d(320, 320, kernel_size=3, stride=2)
|
||||
conv_block(1024, 256, kernel_size=1, stride=1),
|
||||
conv_block(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
||||
conv_block(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
||||
conv_block(320, 320, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch2 = nn.MaxPool2d(3, stride=2)
|
||||
|
@ -200,24 +175,24 @@ class ReductionB(nn.Module):
|
|||
|
||||
|
||||
class InceptionC(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, conv_block=ConvNormAct):
|
||||
super(InceptionC, self).__init__()
|
||||
|
||||
self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1)
|
||||
self.branch0 = conv_block(1536, 256, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
|
||||
self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
|
||||
self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
||||
self.branch1_0 = conv_block(1536, 384, kernel_size=1, stride=1)
|
||||
self.branch1_1a = conv_block(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
|
||||
self.branch1_1b = conv_block(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
||||
|
||||
self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
|
||||
self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
||||
self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1))
|
||||
self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
|
||||
self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
||||
self.branch2_0 = conv_block(1536, 384, kernel_size=1, stride=1)
|
||||
self.branch2_1 = conv_block(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
||||
self.branch2_2 = conv_block(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1))
|
||||
self.branch2_3a = conv_block(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
|
||||
self.branch2_3b = conv_block(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
||||
|
||||
self.branch3 = nn.Sequential(
|
||||
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
||||
BasicConv2d(1536, 256, kernel_size=1, stride=1)
|
||||
conv_block(1536, 256, kernel_size=1, stride=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -242,37 +217,44 @@ class InceptionC(nn.Module):
|
|||
|
||||
|
||||
class InceptionV4(nn.Module):
|
||||
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
output_stride=32,
|
||||
drop_rate=0.,
|
||||
global_pool='avg',
|
||||
norm_layer='batchnorm2d',
|
||||
norm_eps=1e-3,
|
||||
act_layer='relu',
|
||||
):
|
||||
super(InceptionV4, self).__init__()
|
||||
assert output_stride == 32
|
||||
self.drop_rate = drop_rate
|
||||
self.num_classes = num_classes
|
||||
self.num_features = 1536
|
||||
|
||||
self.features = nn.Sequential(
|
||||
BasicConv2d(in_chans, 32, kernel_size=3, stride=2),
|
||||
BasicConv2d(32, 32, kernel_size=3, stride=1),
|
||||
BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
||||
Mixed3a(),
|
||||
Mixed4a(),
|
||||
Mixed5a(),
|
||||
InceptionA(),
|
||||
InceptionA(),
|
||||
InceptionA(),
|
||||
InceptionA(),
|
||||
ReductionA(), # Mixed6a
|
||||
InceptionB(),
|
||||
InceptionB(),
|
||||
InceptionB(),
|
||||
InceptionB(),
|
||||
InceptionB(),
|
||||
InceptionB(),
|
||||
InceptionB(),
|
||||
ReductionB(), # Mixed7a
|
||||
InceptionC(),
|
||||
InceptionC(),
|
||||
InceptionC(),
|
||||
conv_block = partial(
|
||||
ConvNormAct,
|
||||
padding=0,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
norm_kwargs=dict(eps=norm_eps),
|
||||
act_kwargs=dict(inplace=True),
|
||||
)
|
||||
|
||||
features = [
|
||||
conv_block(in_chans, 32, kernel_size=3, stride=2),
|
||||
conv_block(32, 32, kernel_size=3, stride=1),
|
||||
conv_block(32, 64, kernel_size=3, stride=1, padding=1),
|
||||
Mixed3a(conv_block),
|
||||
Mixed4a(conv_block),
|
||||
Mixed5a(conv_block),
|
||||
]
|
||||
features += [InceptionA(conv_block) for _ in range(4)]
|
||||
features += [ReductionA(conv_block)] # Mixed6a
|
||||
features += [InceptionB(conv_block) for _ in range(7)]
|
||||
features += [ReductionB(conv_block)] # Mixed7a
|
||||
features += [InceptionC(conv_block) for _ in range(3)]
|
||||
self.features = nn.Sequential(*features)
|
||||
self.feature_info = [
|
||||
dict(num_chs=64, reduction=2, module='features.2'),
|
||||
dict(num_chs=160, reduction=4, module='features.3'),
|
||||
|
@ -280,8 +262,8 @@ class InceptionV4(nn.Module):
|
|||
dict(num_chs=1024, reduction=16, module='features.17'),
|
||||
dict(num_chs=1536, reduction=32, module='features.21'),
|
||||
]
|
||||
self.global_pool, self.last_linear = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool)
|
||||
self.global_pool, self.head_drop, self.last_linear = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
|
@ -308,8 +290,7 @@ class InceptionV4(nn.Module):
|
|||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate > 0:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.last_linear(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -320,9 +301,25 @@ class InceptionV4(nn.Module):
|
|||
|
||||
def _create_inception_v4(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
InceptionV4, variant, pretrained,
|
||||
InceptionV4,
|
||||
variant,
|
||||
pretrained,
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'inception_v4.tf_in1k': {
|
||||
'hf_hub_id': 'timm/',
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth',
|
||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'features.0.conv', 'classifier': 'last_linear',
|
||||
'label_offset': 1, # 1001 classes in pretrained weights
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -23,77 +23,13 @@ from torch import nn
|
|||
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath, is_exportable
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||
from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups
|
||||
from .vision_transformer import Block as TransformerBlock
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.9, 'interpolation': 'bicubic',
|
||||
'mean': (0., 0., 0.), 'std': (1., 1., 1.),
|
||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||
'fixed_input_size': False,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'mobilevit_xxs': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_xxs-ad385b40.pth'),
|
||||
'mobilevit_xs': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_xs-8fbd6366.pth'),
|
||||
'mobilevit_s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'),
|
||||
'semobilevit_s': _cfg(),
|
||||
|
||||
'mobilevitv2_050': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_050-49951ee2.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_075': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_075-b5556ef6.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_100': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_100-e464ef3b.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_125': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_125-0ae35027.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_150': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150-737c5019.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_175': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175-16462ee2.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_200': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200-b3422f67.pth',
|
||||
crop_pct=0.888),
|
||||
|
||||
'mobilevitv2_150_in22ft1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_in22ft1k-0b555d7b.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_175_in22ft1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_in22ft1k-4117fa1f.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_200_in22ft1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_in22ft1k-1d7c8927.pth',
|
||||
crop_pct=0.888),
|
||||
|
||||
'mobilevitv2_150_384_in22ft1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_384_in22ft1k-9e142854.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
'mobilevitv2_175_384_in22ft1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_384_in22ft1k-059cbe56.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
'mobilevitv2_200_384_in22ft1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_384_in22ft1k-32c87503.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
}
|
||||
|
||||
|
||||
def _inverted_residual_block(d, c, s, br=4.0):
|
||||
# inverted residual is a bottleneck block with bottle_ratio > 1 applied to in_chs, linear output, gs=1 (depthwise)
|
||||
return ByoBlockCfg(
|
||||
|
@ -600,7 +536,6 @@ class MobileVitV2Block(nn.Module):
|
|||
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
|
||||
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
|
||||
|
||||
|
||||
x = self.conv_proj(x)
|
||||
return x
|
||||
|
||||
|
@ -625,6 +560,66 @@ def _create_mobilevit2(variant, cfg_variant=None, pretrained=False, **kwargs):
|
|||
**kwargs)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.9, 'interpolation': 'bicubic',
|
||||
'mean': (0., 0., 0.), 'std': (1., 1., 1.),
|
||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||
'fixed_input_size': False,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'mobilevit_xxs.cvnets_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'mobilevit_xs.cvnets_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'mobilevit_s.cvnets_in1k': _cfg(hf_hub_id='timm/'),
|
||||
|
||||
'mobilevitv2_050.cvnets_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_075.cvnets_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_100.cvnets_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_125.cvnets_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_150.cvnets_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_175.cvnets_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_200.cvnets_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.888),
|
||||
|
||||
'mobilevitv2_150.cvnets_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_175.cvnets_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_200.cvnets_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.888),
|
||||
|
||||
'mobilevitv2_150.cvnets_in22k_ft_in1k_384': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
'mobilevitv2_175.cvnets_in22k_ft_in1k_384': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
'mobilevitv2_200.cvnets_in22k_ft_in1k_384': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevit_xxs(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs)
|
||||
|
@ -640,11 +635,6 @@ def mobilevit_s(pretrained=False, **kwargs):
|
|||
return _create_mobilevit('mobilevit_s', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def semobilevit_s(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_050(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('mobilevitv2_050', pretrained=pretrained, **kwargs)
|
||||
|
@ -680,37 +670,12 @@ def mobilevitv2_200(pretrained=False, **kwargs):
|
|||
return _create_mobilevit('mobilevitv2_200', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_150_in22ft1k(pretrained=False, **kwargs):
|
||||
return _create_mobilevit(
|
||||
'mobilevitv2_150_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs)
|
||||
register_model_deprecations(__name__, {
|
||||
'mobilevitv2_150_in22ft1k': 'mobilevitv2_150.cvnets_in22k_ft_in1k',
|
||||
'mobilevitv2_175_in22ft1k': 'mobilevitv2_175.cvnets_in22k_ft_in1k',
|
||||
'mobilevitv2_200_in22ft1k': 'mobilevitv2_200.cvnets_in22k_ft_in1k',
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_175_in22ft1k(pretrained=False, **kwargs):
|
||||
return _create_mobilevit(
|
||||
'mobilevitv2_175_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_200_in22ft1k(pretrained=False, **kwargs):
|
||||
return _create_mobilevit(
|
||||
'mobilevitv2_200_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_150_384_in22ft1k(pretrained=False, **kwargs):
|
||||
return _create_mobilevit(
|
||||
'mobilevitv2_150_384_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_175_384_in22ft1k(pretrained=False, **kwargs):
|
||||
return _create_mobilevit(
|
||||
'mobilevitv2_175_384_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_200_384_in22ft1k(pretrained=False, **kwargs):
|
||||
return _create_mobilevit(
|
||||
'mobilevitv2_200_384_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs)
|
||||
'mobilevitv2_150_384_in22ft1k': 'mobilevitv2_150.cvnets_in22k_ft_in1k_384',
|
||||
'mobilevitv2_175_384_in22ft1k': 'mobilevitv2_175.cvnets_in22k_ft_in1k_384',
|
||||
'mobilevitv2_200_384_in22ft1k': 'mobilevitv2_200.cvnets_in22k_ft_in1k_384',
|
||||
})
|
|
@ -10,25 +10,10 @@ import torch.nn.functional as F
|
|||
|
||||
from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['NASNetALarge']
|
||||
|
||||
default_cfgs = {
|
||||
'nasnetalarge': {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nasnetalarge-dc4a7b8b.pth',
|
||||
'input_size': (3, 331, 331),
|
||||
'pool_size': (11, 11),
|
||||
'crop_pct': 0.911,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': (0.5, 0.5, 0.5),
|
||||
'std': (0.5, 0.5, 0.5),
|
||||
'num_classes': 1000,
|
||||
'first_conv': 'conv0.conv',
|
||||
'classifier': 'last_linear',
|
||||
'label_offset': 1, # 1001 classes in pretrained weights
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ActConvBn(nn.Module):
|
||||
|
@ -408,14 +393,22 @@ class NASNetALarge(nn.Module):
|
|||
"""NASNetALarge (6 @ 4032) """
|
||||
|
||||
def __init__(
|
||||
self, num_classes=1000, in_chans=3, stem_size=96, channel_multiplier=2,
|
||||
num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'):
|
||||
self,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
stem_size=96,
|
||||
channel_multiplier=2,
|
||||
num_features=4032,
|
||||
output_stride=32,
|
||||
drop_rate=0.,
|
||||
global_pool='avg',
|
||||
pad_type='same',
|
||||
):
|
||||
super(NASNetALarge, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.stem_size = stem_size
|
||||
self.num_features = num_features
|
||||
self.channel_multiplier = channel_multiplier
|
||||
self.drop_rate = drop_rate
|
||||
assert output_stride == 32
|
||||
|
||||
channels = self.num_features // 24
|
||||
|
@ -501,8 +494,8 @@ class NASNetALarge(nn.Module):
|
|||
dict(num_chs=4032, reduction=32, module='act'),
|
||||
]
|
||||
|
||||
self.global_pool, self.last_linear = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool)
|
||||
self.global_pool, self.head_drop, self.last_linear = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
|
@ -562,8 +555,7 @@ class NASNetALarge(nn.Module):
|
|||
|
||||
def forward_head(self, x):
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate > 0:
|
||||
x = F.dropout(x, self.drop_rate, training=self.training)
|
||||
x = self.head_drop(x)
|
||||
x = self.last_linear(x)
|
||||
return x
|
||||
|
||||
|
@ -575,9 +567,30 @@ class NASNetALarge(nn.Module):
|
|||
|
||||
def _create_nasnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
NASNetALarge, variant, pretrained,
|
||||
NASNetALarge,
|
||||
variant,
|
||||
pretrained,
|
||||
feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'nasnetalarge.tf_in1k': {
|
||||
'hf_hub_id': 'timm/',
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nasnetalarge-dc4a7b8b.pth',
|
||||
'input_size': (3, 331, 331),
|
||||
'pool_size': (11, 11),
|
||||
'crop_pct': 0.911,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': (0.5, 0.5, 0.5),
|
||||
'std': (0.5, 0.5, 0.5),
|
||||
'num_classes': 1000,
|
||||
'first_conv': 'conv0.conv',
|
||||
'classifier': 'last_linear',
|
||||
'label_offset': 1, # 1001 classes in pretrained weights
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -26,52 +26,30 @@ from torch import nn
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_, _assert
|
||||
from timm.layers import create_conv2d, create_pool2d, to_ntuple
|
||||
from timm.layers import create_conv2d, create_pool2d, to_ntuple, use_fused_attn, LayerNorm
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq, named_apply
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||
|
||||
__all__ = ['Nest'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': [14, 14],
|
||||
'crop_pct': .875, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# (weights from official Google JAX impl)
|
||||
'nest_base': _cfg(),
|
||||
'nest_small': _cfg(),
|
||||
'nest_tiny': _cfg(),
|
||||
'jx_nest_base': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_base-8bc41011.pth'),
|
||||
'jx_nest_small': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_small-422eaded.pth'),
|
||||
'jx_nest_tiny': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_tiny-e3428fb9.pth'),
|
||||
}
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with
|
||||
an extra "image block" dim
|
||||
"""
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Linear(dim, 3*dim, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
|
@ -87,12 +65,17 @@ class Attention(nn.Module):
|
|||
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1) # (B, H, T, N, N)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
# (B, H, T, N, C'), permute -> (B, T, N, C', H)
|
||||
x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C)
|
||||
x = x.permute(0, 2, 3, 4, 1).reshape(B, T, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x # (B, T, N, C)
|
||||
|
@ -118,11 +101,22 @@ class TransformerLayer(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=proj_drop,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.norm1(x)
|
||||
|
@ -317,7 +311,7 @@ class Nest(nn.Module):
|
|||
self.num_classes = num_classes
|
||||
self.num_features = embed_dims[-1]
|
||||
self.feature_info = []
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
norm_layer = norm_layer or LayerNorm
|
||||
act_layer = act_layer or nn.GELU
|
||||
self.drop_rate = drop_rate
|
||||
self.num_levels = num_levels
|
||||
|
@ -490,14 +484,39 @@ def checkpoint_filter_fn(state_dict, model):
|
|||
|
||||
def _create_nest(variant, pretrained=False, **kwargs):
|
||||
model = build_model_with_cfg(
|
||||
Nest, variant, pretrained,
|
||||
Nest,
|
||||
variant,
|
||||
pretrained,
|
||||
feature_cfg=dict(out_indices=(0, 1, 2), flatten_sequential=True),
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': [14, 14],
|
||||
'crop_pct': .875, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'nest_base.untrained': _cfg(),
|
||||
'nest_small.untrained': _cfg(),
|
||||
'nest_tiny.untrained': _cfg(),
|
||||
# (weights from official Google JAX impl, require 'SAME' padding)
|
||||
'nest_base_jx.goog_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'nest_small_jx.goog_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'nest_tiny_jx.goog_in1k': _cfg(hf_hub_id='timm/'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def nest_base(pretrained=False, **kwargs):
|
||||
""" Nest-B @ 224x224
|
||||
|
@ -527,30 +546,38 @@ def nest_tiny(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
@register_model
|
||||
def jx_nest_base(pretrained=False, **kwargs):
|
||||
""" Nest-B @ 224x224, Pretrained weights converted from official Jax impl.
|
||||
def nest_base_jx(pretrained=False, **kwargs):
|
||||
""" Nest-B @ 224x224
|
||||
"""
|
||||
kwargs['pad_type'] = 'same'
|
||||
model_kwargs = dict(embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs)
|
||||
model = _create_nest('jx_nest_base', pretrained=pretrained, **model_kwargs)
|
||||
kwargs.setdefault('pad_type', 'same')
|
||||
model_kwargs = dict(
|
||||
embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs)
|
||||
model = _create_nest('nest_base_jx', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def jx_nest_small(pretrained=False, **kwargs):
|
||||
""" Nest-S @ 224x224, Pretrained weights converted from official Jax impl.
|
||||
def nest_small_jx(pretrained=False, **kwargs):
|
||||
""" Nest-S @ 224x224
|
||||
"""
|
||||
kwargs['pad_type'] = 'same'
|
||||
kwargs.setdefault('pad_type', 'same')
|
||||
model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs)
|
||||
model = _create_nest('jx_nest_small', pretrained=pretrained, **model_kwargs)
|
||||
model = _create_nest('nest_small_jx', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def jx_nest_tiny(pretrained=False, **kwargs):
|
||||
""" Nest-T @ 224x224, Pretrained weights converted from official Jax impl.
|
||||
def nest_tiny_jx(pretrained=False, **kwargs):
|
||||
""" Nest-T @ 224x224
|
||||
"""
|
||||
kwargs['pad_type'] = 'same'
|
||||
kwargs.setdefault('pad_type', 'same')
|
||||
model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs)
|
||||
model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs)
|
||||
model = _create_nest('nest_tiny_jx', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'jx_nest_base': 'nest_base_jx',
|
||||
'jx_nest_small': 'nest_small_jx',
|
||||
'jx_nest_tiny': 'nest_tiny_jx',
|
||||
})
|
|
@ -14,57 +14,21 @@ Modifications for timm by / Copyright 2020 Ross Wightman
|
|||
import math
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_, to_2tuple
|
||||
from timm.layers import trunc_normal_, to_2tuple, LayerNorm
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .vision_transformer import Block
|
||||
|
||||
|
||||
__all__ = ['PoolingVisionTransformer'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.conv', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# deit models (FB weights)
|
||||
'pit_ti_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_730.pth'),
|
||||
'pit_xs_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_781.pth'),
|
||||
'pit_s_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_809.pth'),
|
||||
'pit_b_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'),
|
||||
'pit_ti_distilled_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'pit_xs_distilled_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'pit_s_distilled_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'pit_b_distilled_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
}
|
||||
|
||||
|
||||
class SequentialTuple(nn.Sequential):
|
||||
""" This module exists to work around torchscript typing issues list -> list"""
|
||||
def __init__(self, *args):
|
||||
|
@ -87,11 +51,13 @@ class Transformer(nn.Module):
|
|||
proj_drop=.0,
|
||||
attn_drop=.0,
|
||||
drop_path_prob=None,
|
||||
norm_layer=None,
|
||||
):
|
||||
super(Transformer, self).__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
embed_dim = base_dim * heads
|
||||
|
||||
self.pool = pool
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
self.blocks = nn.Sequential(*[
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
|
@ -105,30 +71,29 @@ class Transformer(nn.Module):
|
|||
)
|
||||
for i in range(depth)])
|
||||
|
||||
self.pool = pool
|
||||
|
||||
def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x, cls_tokens = x
|
||||
B, C, H, W = x.shape
|
||||
token_length = cls_tokens.shape[1]
|
||||
if self.pool is not None:
|
||||
x, cls_tokens = self.pool(x, cls_tokens)
|
||||
|
||||
B, C, H, W = x.shape
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.blocks(x)
|
||||
|
||||
cls_tokens = x[:, :token_length]
|
||||
x = x[:, token_length:]
|
||||
x = x.transpose(1, 2).reshape(B, C, H, W)
|
||||
|
||||
if self.pool is not None:
|
||||
x, cls_tokens = self.pool(x, cls_tokens)
|
||||
return x, cls_tokens
|
||||
|
||||
|
||||
class ConvHeadPooling(nn.Module):
|
||||
class Pooling(nn.Module):
|
||||
def __init__(self, in_feature, out_feature, stride, padding_mode='zeros'):
|
||||
super(ConvHeadPooling, self).__init__()
|
||||
super(Pooling, self).__init__()
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_feature,
|
||||
|
@ -148,10 +113,26 @@ class ConvHeadPooling(nn.Module):
|
|||
|
||||
|
||||
class ConvEmbedding(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, patch_size, stride, padding):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
stride: int = 8,
|
||||
padding: int = 0,
|
||||
):
|
||||
super(ConvEmbedding, self).__init__()
|
||||
padding = padding
|
||||
self.img_size = to_2tuple(img_size)
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
self.height = math.floor((self.img_size[0] + 2 * padding - self.patch_size[0]) / stride + 1)
|
||||
self.width = math.floor((self.img_size[1] + 2 * padding - self.patch_size[1]) / stride + 1)
|
||||
self.grid_size = (self.height, self.width)
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=patch_size, stride=stride, padding=padding, bias=True)
|
||||
in_channels, out_channels, kernel_size=patch_size,
|
||||
stride=stride, padding=padding, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
|
@ -166,13 +147,14 @@ class PoolingVisionTransformer(nn.Module):
|
|||
"""
|
||||
def __init__(
|
||||
self,
|
||||
img_size,
|
||||
patch_size,
|
||||
stride,
|
||||
base_dims,
|
||||
depth,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
stride: int = 8,
|
||||
stem_type: str = 'overlap',
|
||||
base_dims: Sequence[int] = (48, 48, 48),
|
||||
depth: Sequence[int] = (2, 6, 4),
|
||||
heads: Sequence[int] = (2, 4, 8),
|
||||
mlp_ratio: float = 4,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
global_pool='token',
|
||||
|
@ -186,50 +168,48 @@ class PoolingVisionTransformer(nn.Module):
|
|||
super(PoolingVisionTransformer, self).__init__()
|
||||
assert global_pool in ('token',)
|
||||
|
||||
padding = 0
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
height = math.floor((img_size[0] + 2 * padding - patch_size[0]) / stride + 1)
|
||||
width = math.floor((img_size[1] + 2 * padding - patch_size[1]) / stride + 1)
|
||||
|
||||
self.base_dims = base_dims
|
||||
self.heads = heads
|
||||
embed_dim = base_dims[0] * heads[0]
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_tokens = 2 if distilled else 1
|
||||
self.feature_info = []
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, base_dims[0] * heads[0], height, width))
|
||||
self.patch_embed = ConvEmbedding(in_chans, base_dims[0] * heads[0], patch_size, stride, padding)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, base_dims[0] * heads[0]))
|
||||
self.patch_embed = ConvEmbedding(in_chans, embed_dim, img_size, patch_size, stride)
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_dim, self.patch_embed.height, self.patch_embed.width))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim))
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_drate)
|
||||
|
||||
transformers = []
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depth)).split(depth)]
|
||||
for stage in range(len(depth)):
|
||||
prev_dim = embed_dim
|
||||
for i in range(len(depth)):
|
||||
pool = None
|
||||
if stage < len(heads) - 1:
|
||||
pool = ConvHeadPooling(
|
||||
base_dims[stage] * heads[stage],
|
||||
base_dims[stage + 1] * heads[stage + 1],
|
||||
embed_dim = base_dims[i] * heads[i]
|
||||
if i > 0:
|
||||
pool = Pooling(
|
||||
prev_dim,
|
||||
embed_dim,
|
||||
stride=2,
|
||||
)
|
||||
transformers += [Transformer(
|
||||
base_dims[stage],
|
||||
depth[stage],
|
||||
heads[stage],
|
||||
base_dims[i],
|
||||
depth[i],
|
||||
heads[i],
|
||||
mlp_ratio,
|
||||
pool=pool,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path_prob=dpr[stage],
|
||||
)
|
||||
]
|
||||
drop_path_prob=dpr[i],
|
||||
)]
|
||||
prev_dim = embed_dim
|
||||
self.feature_info += [dict(num_chs=prev_dim, reduction=(stride - 1) * 2**i, module=f'transformers.{i}')]
|
||||
|
||||
self.transformers = SequentialTuple(*transformers)
|
||||
self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)
|
||||
self.num_features = self.embed_dim = base_dims[-1] * heads[-1]
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
|
||||
# Classifier head
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
|
@ -318,25 +298,58 @@ def checkpoint_filter_fn(state_dict, model):
|
|||
# if k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
||||
# # To resize pos embedding when using model at different size from pretrained weights
|
||||
# v = resize_pos_embed(v, model.pos_embed)
|
||||
k = p_blocks.sub(lambda exp: f'transformers.{int(exp.group(1))}.pool.', k)
|
||||
k = p_blocks.sub(lambda exp: f'transformers.{int(exp.group(1)) + 1}.pool.', k)
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_pit(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
default_out_indices = tuple(range(3))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
|
||||
model = build_model_with_cfg(
|
||||
PoolingVisionTransformer,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(feature_cls='hook', no_rewrite=True, out_indices=out_indices),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.conv', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
# deit models (FB weights)
|
||||
'pit_ti_224.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'pit_xs_224.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'pit_s_224.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'pit_b_224.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'pit_ti_distilled_224.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
classifier=('head', 'head_dist')),
|
||||
'pit_xs_distilled_224.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
classifier=('head', 'head_dist')),
|
||||
'pit_s_distilled_224.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
classifier=('head', 'head_dist')),
|
||||
'pit_b_distilled_224.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
classifier=('head', 'head_dist')),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_b_224(pretrained, **kwargs):
|
||||
model_args = dict(
|
||||
|
|
|
@ -14,26 +14,10 @@ import torch.nn.functional as F
|
|||
|
||||
from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['PNASNet5Large']
|
||||
|
||||
default_cfgs = {
|
||||
'pnasnet5large': {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth',
|
||||
'input_size': (3, 331, 331),
|
||||
'pool_size': (11, 11),
|
||||
'crop_pct': 0.911,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': (0.5, 0.5, 0.5),
|
||||
'std': (0.5, 0.5, 0.5),
|
||||
'num_classes': 1000,
|
||||
'first_conv': 'conv_0.conv',
|
||||
'classifier': 'last_linear',
|
||||
'label_offset': 1, # 1001 classes in pretrained weights
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SeparableConv2d(nn.Module):
|
||||
|
||||
|
@ -185,8 +169,16 @@ class CellStem0(CellBase):
|
|||
|
||||
class Cell(CellBase):
|
||||
|
||||
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type='',
|
||||
is_reduction=False, match_prev_layer_dims=False):
|
||||
def __init__(
|
||||
self,
|
||||
in_chs_left,
|
||||
out_chs_left,
|
||||
in_chs_right,
|
||||
out_chs_right,
|
||||
pad_type='',
|
||||
is_reduction=False,
|
||||
match_prev_layer_dims=False,
|
||||
):
|
||||
super(Cell, self).__init__()
|
||||
|
||||
# If `is_reduction` is set to `True` stride 2 is used for
|
||||
|
@ -236,10 +228,17 @@ class Cell(CellBase):
|
|||
|
||||
|
||||
class PNASNet5Large(nn.Module):
|
||||
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
output_stride=32,
|
||||
drop_rate=0.,
|
||||
global_pool='avg',
|
||||
pad_type='',
|
||||
):
|
||||
super(PNASNet5Large, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
self.num_features = 4320
|
||||
assert output_stride == 32
|
||||
|
||||
|
@ -293,8 +292,8 @@ class PNASNet5Large(nn.Module):
|
|||
dict(num_chs=4320, reduction=32, module='act'),
|
||||
]
|
||||
|
||||
self.global_pool, self.last_linear = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool)
|
||||
self.global_pool, self.head_drop, self.last_linear = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
|
@ -334,8 +333,7 @@ class PNASNet5Large(nn.Module):
|
|||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate > 0:
|
||||
x = F.dropout(x, self.drop_rate, training=self.training)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.last_linear(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -346,9 +344,30 @@ class PNASNet5Large(nn.Module):
|
|||
|
||||
def _create_pnasnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
PNASNet5Large, variant, pretrained,
|
||||
PNASNet5Large,
|
||||
variant,
|
||||
pretrained,
|
||||
feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'pnasnet5large.tf_in1k': {
|
||||
'hf_hub_id': 'timm/',
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth',
|
||||
'input_size': (3, 331, 331),
|
||||
'pool_size': (11, 11),
|
||||
'crop_pct': 0.911,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': (0.5, 0.5, 0.5),
|
||||
'std': (0.5, 0.5, 0.5),
|
||||
'num_classes': 1000,
|
||||
'first_conv': 'conv_0.conv',
|
||||
'classifier': 'last_linear',
|
||||
'label_offset': 1, # 1001 classes in pretrained weights
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -16,42 +16,21 @@ Modifications and timm support by / Copyright 2022, Ross Wightman
|
|||
"""
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Tuple, List, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_
|
||||
from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['PyramidVisionTransformerV2']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.9, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': False,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'pvt_v2_b0': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b0.pth'),
|
||||
'pvt_v2_b1': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b1.pth'),
|
||||
'pvt_v2_b2': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b2.pth'),
|
||||
'pvt_v2_b3': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b3.pth'),
|
||||
'pvt_v2_b4': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b4.pth'),
|
||||
'pvt_v2_b5': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b5.pth'),
|
||||
'pvt_v2_b2_li': _cfg(url='https://github.com/whai362/PVT/releases/download/v2/pvt_v2_b2_li.pth')
|
||||
}
|
||||
|
||||
|
||||
class MlpWithDepthwiseConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -87,6 +66,8 @@ class MlpWithDepthwiseConv(nn.Module):
|
|||
|
||||
|
||||
class Attention(nn.Module):
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
|
@ -104,6 +85,7 @@ class Attention(nn.Module):
|
|||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
||||
|
@ -132,26 +114,31 @@ class Attention(nn.Module):
|
|||
q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
||||
|
||||
if self.pool is not None:
|
||||
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
|
||||
x_ = self.norm(x_)
|
||||
x_ = self.act(x_)
|
||||
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
x = x.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
x = self.sr(self.pool(x)).reshape(B, C, -1).permute(0, 2, 1)
|
||||
x = self.norm(x)
|
||||
x = self.act(x)
|
||||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
else:
|
||||
if self.sr is not None:
|
||||
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
||||
x_ = self.norm(x_)
|
||||
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
x = x.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
|
||||
x = self.norm(x)
|
||||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
else:
|
||||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv.unbind(0)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
@ -171,7 +158,7 @@ class Block(nn.Module):
|
|||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
norm_layer=LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
@ -184,7 +171,8 @@ class Block(nn.Module):
|
|||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = MlpWithDepthwiseConv(
|
||||
in_features=dim,
|
||||
|
@ -193,10 +181,11 @@ class Block(nn.Module):
|
|||
drop=proj_drop,
|
||||
extra_relu=linear_attn,
|
||||
)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_size: List[int]):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), feat_size))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x), feat_size))
|
||||
x = x + self.drop_path1(self.attn(self.norm1(x), feat_size))
|
||||
x = x + self.drop_path2(self.mlp(self.norm2(x), feat_size))
|
||||
|
||||
return x
|
||||
|
||||
|
@ -216,10 +205,9 @@ class OverlapPatchEmbed(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
feat_size = x.shape[-2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = self.norm(x)
|
||||
return x, feat_size
|
||||
return x
|
||||
|
||||
|
||||
class PyramidVisionTransformerStage(nn.Module):
|
||||
|
@ -237,7 +225,7 @@ class PyramidVisionTransformerStage(nn.Module):
|
|||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: Union[List[float], float] = 0.0,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
norm_layer: Callable = LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
self.grad_checkpointing = False
|
||||
|
@ -247,7 +235,8 @@ class PyramidVisionTransformerStage(nn.Module):
|
|||
patch_size=3,
|
||||
stride=2,
|
||||
in_chans=dim,
|
||||
embed_dim=dim_out)
|
||||
embed_dim=dim_out,
|
||||
)
|
||||
else:
|
||||
assert dim == dim_out
|
||||
self.downsample = None
|
||||
|
@ -267,23 +256,27 @@ class PyramidVisionTransformerStage(nn.Module):
|
|||
|
||||
self.norm = norm_layer(dim_out)
|
||||
|
||||
def forward(self, x, feat_size: List[int]) -> Tuple[torch.Tensor, List[int]]:
|
||||
def forward(self, x):
|
||||
# x is either B, C, H, W (if downsample) or B, H, W, C if not
|
||||
if self.downsample is not None:
|
||||
x, feat_size = self.downsample(x)
|
||||
# input to downsample is B, C, H, W
|
||||
x = self.downsample(x) # output B, H, W, C
|
||||
B, H, W, C = x.shape
|
||||
feat_size = (H, W)
|
||||
x = x.reshape(B, -1, C)
|
||||
for blk in self.blocks:
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint.checkpoint(blk, x, feat_size)
|
||||
else:
|
||||
x = blk(x, feat_size)
|
||||
x = self.norm(x)
|
||||
x = x.reshape(x.shape[0], feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
||||
return x, feat_size
|
||||
x = x.reshape(B, feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
||||
return x
|
||||
|
||||
|
||||
class PyramidVisionTransformerV2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=None,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='avg',
|
||||
|
@ -298,7 +291,7 @@ class PyramidVisionTransformerV2(nn.Module):
|
|||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
norm_layer=LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
|
@ -310,19 +303,21 @@ class PyramidVisionTransformerV2(nn.Module):
|
|||
num_heads = to_ntuple(num_stages)(num_heads)
|
||||
sr_ratios = to_ntuple(num_stages)(sr_ratios)
|
||||
assert(len(embed_dims)) == num_stages
|
||||
self.feature_info = []
|
||||
|
||||
self.patch_embed = OverlapPatchEmbed(
|
||||
patch_size=7,
|
||||
stride=4,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dims[0])
|
||||
embed_dim=embed_dims[0],
|
||||
)
|
||||
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
cur = 0
|
||||
prev_dim = embed_dims[0]
|
||||
self.stages = nn.ModuleList()
|
||||
stages = []
|
||||
for i in range(num_stages):
|
||||
self.stages.append(PyramidVisionTransformerStage(
|
||||
stages += [PyramidVisionTransformerStage(
|
||||
dim=prev_dim,
|
||||
dim_out=embed_dims[i],
|
||||
depth=depths[i],
|
||||
|
@ -336,9 +331,11 @@ class PyramidVisionTransformerV2(nn.Module):
|
|||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
))
|
||||
)]
|
||||
prev_dim = embed_dims[i]
|
||||
cur += depths[i]
|
||||
self.feature_info += [dict(num_chs=prev_dim, reduction=4 * 2**i, module=f'stages.{i}')]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
# classification head
|
||||
self.num_features = embed_dims[-1]
|
||||
|
@ -390,9 +387,8 @@ class PyramidVisionTransformerV2(nn.Module):
|
|||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x, feat_size = self.patch_embed(x)
|
||||
for stage in self.stages:
|
||||
x, feat_size = stage(x, feat_size=feat_size)
|
||||
x = self.patch_embed(x)
|
||||
x = self.stages(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
|
@ -428,69 +424,80 @@ def _checkpoint_filter_fn(state_dict, model):
|
|||
|
||||
|
||||
def _create_pvt2(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
default_out_indices = tuple(range(4))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
model = build_model_with_cfg(
|
||||
PyramidVisionTransformerV2, variant, pretrained,
|
||||
PyramidVisionTransformerV2,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=_checkpoint_filter_fn,
|
||||
**kwargs
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.9, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': False,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'pvt_v2_b0': _cfg(hf_hub_id='timm/'),
|
||||
'pvt_v2_b1': _cfg(hf_hub_id='timm/'),
|
||||
'pvt_v2_b2': _cfg(hf_hub_id='timm/'),
|
||||
'pvt_v2_b3': _cfg(hf_hub_id='timm/'),
|
||||
'pvt_v2_b4': _cfg(hf_hub_id='timm/'),
|
||||
'pvt_v2_b5': _cfg(hf_hub_id='timm/'),
|
||||
'pvt_v2_b2_li': _cfg(hf_hub_id='timm/'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def pvt_v2_b0(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
depths=(2, 2, 2, 2), embed_dims=(32, 64, 160, 256), num_heads=(1, 2, 5, 8),
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
return _create_pvt2('pvt_v2_b0', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(depths=(2, 2, 2, 2), embed_dims=(32, 64, 160, 256), num_heads=(1, 2, 5, 8))
|
||||
return _create_pvt2('pvt_v2_b0', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pvt_v2_b1(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
depths=(2, 2, 2, 2), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8),
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
return _create_pvt2('pvt_v2_b1', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(depths=(2, 2, 2, 2), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8))
|
||||
return _create_pvt2('pvt_v2_b1', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pvt_v2_b2(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8),
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
return _create_pvt2('pvt_v2_b2', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8))
|
||||
return _create_pvt2('pvt_v2_b2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pvt_v2_b3(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
depths=(3, 4, 18, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8),
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
return _create_pvt2('pvt_v2_b3', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(depths=(3, 4, 18, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8))
|
||||
return _create_pvt2('pvt_v2_b3', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pvt_v2_b4(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
depths=(3, 8, 27, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8),
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
return _create_pvt2('pvt_v2_b4', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(depths=(3, 8, 27, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8))
|
||||
return _create_pvt2('pvt_v2_b4', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pvt_v2_b5(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
depths=(3, 6, 40, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8),
|
||||
mlp_ratios=(4, 4, 4, 4), norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs)
|
||||
return _create_pvt2('pvt_v2_b5', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
depths=(3, 6, 40, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), mlp_ratios=(4, 4, 4, 4))
|
||||
return _create_pvt2('pvt_v2_b5', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pvt_v2_b2_li(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8),
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), linear=True, **kwargs)
|
||||
return _create_pvt2('pvt_v2_b2_li', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), linear=True)
|
||||
return _create_pvt2('pvt_v2_b2_li', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
|
|
@ -9,41 +9,12 @@ import torch.nn as nn
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .resnet import ResNet
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv1', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'res2net50_26w_4s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth'),
|
||||
'res2net50_48w_2s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth'),
|
||||
'res2net50_14w_8s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth'),
|
||||
'res2net50_26w_6s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth'),
|
||||
'res2net50_26w_8s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth'),
|
||||
'res2net101_26w_4s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth'),
|
||||
'res2next50': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth'),
|
||||
}
|
||||
|
||||
|
||||
class Bottle2neck(nn.Module):
|
||||
""" Res2Net/Res2NeXT Bottleneck
|
||||
Adapted from https://github.com/gasvn/Res2Net/blob/master/res2net.py
|
||||
|
@ -149,11 +120,33 @@ def _create_res2net(variant, pretrained=False, **kwargs):
|
|||
return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv1', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'res2net50_26w_4s.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'res2net50_48w_2s.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'res2net50_14w_8s.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'res2net50_26w_6s.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'res2net50_26w_8s.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'res2net101_26w_4s.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'res2next50.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'res2net50d.in1k': _cfg(hf_hub_id='timm/', first_conv='conv1.0'),
|
||||
'res2net101d.in1k': _cfg(hf_hub_id='timm/', first_conv='conv1.0'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def res2net50_26w_4s(pretrained=False, **kwargs):
|
||||
"""Constructs a Res2Net-50 26w4s model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4))
|
||||
|
@ -163,8 +156,6 @@ def res2net50_26w_4s(pretrained=False, **kwargs):
|
|||
@register_model
|
||||
def res2net101_26w_4s(pretrained=False, **kwargs):
|
||||
"""Constructs a Res2Net-101 26w4s model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4))
|
||||
|
@ -174,8 +165,6 @@ def res2net101_26w_4s(pretrained=False, **kwargs):
|
|||
@register_model
|
||||
def res2net50_26w_6s(pretrained=False, **kwargs):
|
||||
"""Constructs a Res2Net-50 26w6s model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6))
|
||||
|
@ -185,8 +174,6 @@ def res2net50_26w_6s(pretrained=False, **kwargs):
|
|||
@register_model
|
||||
def res2net50_26w_8s(pretrained=False, **kwargs):
|
||||
"""Constructs a Res2Net-50 26w8s model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8))
|
||||
|
@ -196,8 +183,6 @@ def res2net50_26w_8s(pretrained=False, **kwargs):
|
|||
@register_model
|
||||
def res2net50_48w_2s(pretrained=False, **kwargs):
|
||||
"""Constructs a Res2Net-50 48w2s model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2))
|
||||
|
@ -207,8 +192,6 @@ def res2net50_48w_2s(pretrained=False, **kwargs):
|
|||
@register_model
|
||||
def res2net50_14w_8s(pretrained=False, **kwargs):
|
||||
"""Constructs a Res2Net-50 14w8s model.
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8))
|
||||
|
@ -218,9 +201,27 @@ def res2net50_14w_8s(pretrained=False, **kwargs):
|
|||
@register_model
|
||||
def res2next50(pretrained=False, **kwargs):
|
||||
"""Construct Res2NeXt-50 4s
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4))
|
||||
return _create_res2net('res2next50', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def res2net50d(pretrained=False, **kwargs):
|
||||
"""Construct Res2Net-50
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, stem_type='deep',
|
||||
avg_down=True, stem_width=32, block_args=dict(scale=4))
|
||||
return _create_res2net('res2net50d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def res2net101d(pretrained=False, **kwargs):
|
||||
"""Construct Res2Net-50
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, stem_type='deep',
|
||||
avg_down=True, stem_width=32, block_args=dict(scale=4))
|
||||
return _create_res2net('res2net101d', pretrained, **dict(model_args, **kwargs))
|
||||
|
|
|
@ -11,45 +11,10 @@ from torch import nn
|
|||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import SplitAttn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .resnet import ResNet
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv1.0', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
default_cfgs = {
|
||||
'resnest14d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth'),
|
||||
'resnest26d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth'),
|
||||
'resnest50d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth'),
|
||||
'resnest101e': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'resnest200e': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth',
|
||||
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=0.909, interpolation='bicubic'),
|
||||
'resnest269e': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth',
|
||||
input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.928, interpolation='bicubic'),
|
||||
'resnest50d_4s2x40d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth',
|
||||
interpolation='bicubic'),
|
||||
'resnest50d_1s4x24d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth',
|
||||
interpolation='bicubic')
|
||||
}
|
||||
|
||||
|
||||
class ResNestBottleneck(nn.Module):
|
||||
"""ResNet Bottleneck
|
||||
"""
|
||||
|
@ -153,7 +118,45 @@ class ResNestBottleneck(nn.Module):
|
|||
|
||||
|
||||
def _create_resnest(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
|
||||
return build_model_with_cfg(
|
||||
ResNet,
|
||||
variant,
|
||||
pretrained,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv1.0', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'resnest14d.gluon_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'resnest26d.gluon_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'resnest50d.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'resnest101e.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'resnest200e.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=0.909, interpolation='bicubic'),
|
||||
'resnest269e.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.928, interpolation='bicubic'),
|
||||
'resnest50d_4s2x40d.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'resnest50d_1s4x24d.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
interpolation='bicubic')
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -18,41 +18,11 @@ import torch.nn.functional as F
|
|||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import create_classifier
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.0', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'selecsls42': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic'),
|
||||
'selecsls42b': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls42b-8af30141.pth',
|
||||
interpolation='bicubic'),
|
||||
'selecsls60': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls60-bbf87526.pth',
|
||||
interpolation='bicubic'),
|
||||
'selecsls60b': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls60b-94e619b5.pth',
|
||||
interpolation='bicubic'),
|
||||
'selecsls84': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic'),
|
||||
}
|
||||
|
||||
|
||||
class SequentialList(nn.Sequential):
|
||||
|
||||
def __init__(self, *args):
|
||||
|
@ -155,7 +125,6 @@ class SelecSLS(nn.Module):
|
|||
|
||||
def __init__(self, cfg, num_classes=1000, in_chans=3, drop_rate=0.0, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
super(SelecSLS, self).__init__()
|
||||
|
||||
self.stem = conv_bn(in_chans, 32, stride=2)
|
||||
|
@ -165,14 +134,16 @@ class SelecSLS(nn.Module):
|
|||
self.num_features = cfg['num_features']
|
||||
self.feature_info = cfg['feature_info']
|
||||
|
||||
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
self.global_pool, self.head_drop, self.fc = create_classifier(
|
||||
self.num_features,
|
||||
self.num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1.)
|
||||
nn.init.constant_(m.bias, 0.)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
|
@ -202,8 +173,7 @@ class SelecSLS(nn.Module):
|
|||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.fc(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -336,10 +306,41 @@ def _create_selecsls(variant, pretrained, **kwargs):
|
|||
|
||||
# this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises?
|
||||
return build_model_with_cfg(
|
||||
SelecSLS, variant, pretrained,
|
||||
SelecSLS,
|
||||
variant,
|
||||
pretrained,
|
||||
model_cfg=cfg,
|
||||
feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True),
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.0', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'selecsls42.untrained': _cfg(
|
||||
interpolation='bicubic'),
|
||||
'selecsls42b.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'selecsls60.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'selecsls60b.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'selecsls84.untrained': _cfg(
|
||||
interpolation='bicubic'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -21,45 +21,11 @@ import torch.nn.functional as F
|
|||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import create_classifier
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['SENet']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'layer0.conv1', 'classifier': 'last_linear',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'legacy_senet154': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_senet154-e9eb9fe6.pth'),
|
||||
'legacy_seresnet18': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet18-4bb0ce65.pth',
|
||||
interpolation='bicubic'),
|
||||
'legacy_seresnet34': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet34-a4004e63.pth'),
|
||||
'legacy_seresnet50': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'),
|
||||
'legacy_seresnet101': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'),
|
||||
'legacy_seresnet152': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'),
|
||||
'legacy_seresnext26_32x4d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth',
|
||||
interpolation='bicubic'),
|
||||
'legacy_seresnext50_32x4d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_se_resnext50_32x4d-f3651bad.pth'),
|
||||
'legacy_seresnext101_32x4d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_se_resnext101_32x4d-37725eac.pth'),
|
||||
}
|
||||
|
||||
|
||||
def _weight_init(m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
|
@ -401,6 +367,40 @@ def _create_senet(variant, pretrained=False, **kwargs):
|
|||
return build_model_with_cfg(SENet, variant, pretrained, **kwargs)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'layer0.conv1', 'classifier': 'last_linear',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'legacy_senet154.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_senet154-e9eb9fe6.pth'),
|
||||
'legacy_seresnet18.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet18-4bb0ce65.pth',
|
||||
interpolation='bicubic'),
|
||||
'legacy_seresnet34.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet34-a4004e63.pth'),
|
||||
'legacy_seresnet50.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'),
|
||||
'legacy_seresnet101.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'),
|
||||
'legacy_seresnet152.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'),
|
||||
'legacy_seresnext26_32x4d.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth',
|
||||
interpolation='bicubic'),
|
||||
'legacy_seresnext50_32x4d.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_se_resnext50_32x4d-f3651bad.pth'),
|
||||
'legacy_seresnext101_32x4d.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/legacy_se_resnext101_32x4d-37725eac.pth'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def legacy_seresnet18(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
|
|
|
@ -8,36 +8,19 @@ Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2
|
|||
|
||||
import math
|
||||
from functools import partial
|
||||
from itertools import accumulate
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
|
||||
from timm.layers import lecun_normal_, DropPath, Mlp, PatchEmbed as TimmPatchEmbed
|
||||
from timm.layers import lecun_normal_, DropPath, Mlp, PatchEmbed, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['Sequencer2D'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': DEFAULT_CROP_PCT, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
sequencer2d_s=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_s.pth"),
|
||||
sequencer2d_m=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_m.pth"),
|
||||
sequencer2d_l=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_l.pth"),
|
||||
)
|
||||
__all__ = ['Sequencer2d'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False):
|
||||
|
@ -73,27 +56,6 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=Fals
|
|||
module.init_weights()
|
||||
|
||||
|
||||
def get_stage(
|
||||
index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer,
|
||||
norm_layer, act_layer, num_layers, bidirectional, union,
|
||||
with_fc, drop=0., drop_path_rate=0., **kwargs):
|
||||
assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios)
|
||||
blocks = []
|
||||
for block_idx in range(layers[index]):
|
||||
drop_path = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
|
||||
blocks.append(block_layer(
|
||||
embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index],
|
||||
rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer,
|
||||
num_layers=num_layers, bidirectional=bidirectional, union=union, with_fc=with_fc,
|
||||
drop=drop, drop_path=drop_path))
|
||||
|
||||
if index < len(embed_dims) - 1:
|
||||
blocks.append(Downsample2D(embed_dims[index], embed_dims[index + 1], patch_sizes[index + 1]))
|
||||
|
||||
blocks = nn.Sequential(*blocks)
|
||||
return blocks
|
||||
|
||||
|
||||
class RNNIdentity(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RNNIdentity, self).__init__()
|
||||
|
@ -102,12 +64,18 @@ class RNNIdentity(nn.Module):
|
|||
return x, None
|
||||
|
||||
|
||||
class RNN2DBase(nn.Module):
|
||||
class RNN2dBase(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, input_size: int, hidden_size: int,
|
||||
num_layers: int = 1, bias: bool = True, bidirectional: bool = True,
|
||||
union="cat", with_fc=True):
|
||||
self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
num_layers: int = 1,
|
||||
bias: bool = True,
|
||||
bidirectional: bool = True,
|
||||
union="cat",
|
||||
with_fc=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
|
@ -190,29 +158,67 @@ class RNN2DBase(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class LSTM2D(RNN2DBase):
|
||||
class LSTM2d(RNN2dBase):
|
||||
|
||||
def __init__(
|
||||
self, input_size: int, hidden_size: int,
|
||||
num_layers: int = 1, bias: bool = True, bidirectional: bool = True,
|
||||
union="cat", with_fc=True):
|
||||
self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
num_layers: int = 1,
|
||||
bias: bool = True,
|
||||
bidirectional: bool = True,
|
||||
union="cat",
|
||||
with_fc=True,
|
||||
):
|
||||
super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc)
|
||||
if self.with_vertical:
|
||||
self.rnn_v = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional)
|
||||
self.rnn_v = nn.LSTM(
|
||||
input_size,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
batch_first=True,
|
||||
bias=bias,
|
||||
bidirectional=bidirectional,
|
||||
)
|
||||
if self.with_horizontal:
|
||||
self.rnn_h = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional)
|
||||
self.rnn_h = nn.LSTM(
|
||||
input_size,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
batch_first=True,
|
||||
bias=bias,
|
||||
bidirectional=bidirectional,
|
||||
)
|
||||
|
||||
|
||||
class Sequencer2DBlock(nn.Module):
|
||||
class Sequencer2dBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU,
|
||||
num_layers=1, bidirectional=True, union="cat", with_fc=True, drop=0., drop_path=0.):
|
||||
self,
|
||||
dim,
|
||||
hidden_size,
|
||||
mlp_ratio=3.0,
|
||||
rnn_layer=LSTM2d,
|
||||
mlp_layer=Mlp,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
act_layer=nn.GELU,
|
||||
num_layers=1,
|
||||
bidirectional=True,
|
||||
union="cat",
|
||||
with_fc=True,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
):
|
||||
super().__init__()
|
||||
channels_dim = int(mlp_ratio * dim)
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.rnn_tokens = rnn_layer(dim, hidden_size, num_layers=num_layers, bidirectional=bidirectional,
|
||||
union=union, with_fc=with_fc)
|
||||
self.rnn_tokens = rnn_layer(
|
||||
dim,
|
||||
hidden_size,
|
||||
num_layers=num_layers,
|
||||
bidirectional=bidirectional,
|
||||
union=union,
|
||||
with_fc=with_fc,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop)
|
||||
|
@ -223,17 +229,6 @@ class Sequencer2DBlock(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class PatchEmbed(TimmPatchEmbed):
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
else:
|
||||
x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class Shuffle(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -247,7 +242,7 @@ class Shuffle(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class Downsample2D(nn.Module):
|
||||
class Downsample2d(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, patch_size):
|
||||
super().__init__()
|
||||
self.down = nn.Conv2d(input_dim, output_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
@ -259,20 +254,74 @@ class Downsample2D(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class Sequencer2D(nn.Module):
|
||||
class Sequencer2dStage(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
depth,
|
||||
patch_size,
|
||||
hidden_size,
|
||||
mlp_ratio,
|
||||
downsample=False,
|
||||
block_layer=Sequencer2dBlock,
|
||||
rnn_layer=LSTM2d,
|
||||
mlp_layer=Mlp,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
act_layer=nn.GELU,
|
||||
num_layers=1,
|
||||
bidirectional=True,
|
||||
union="cat",
|
||||
with_fc=True,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
):
|
||||
super().__init__()
|
||||
if downsample:
|
||||
self.downsample = Downsample2d(dim, dim_out, patch_size)
|
||||
else:
|
||||
assert dim == dim_out
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
blocks = []
|
||||
for block_idx in range(depth):
|
||||
blocks.append(block_layer(
|
||||
dim_out,
|
||||
hidden_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
rnn_layer=rnn_layer,
|
||||
mlp_layer=mlp_layer,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
num_layers=num_layers,
|
||||
bidirectional=bidirectional,
|
||||
union=union,
|
||||
with_fc=with_fc,
|
||||
drop=drop,
|
||||
drop_path=drop_path[block_idx] if isinstance(drop_path, (list, tuple)) else drop_path,
|
||||
))
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downsample(x)
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
|
||||
class Sequencer2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes=1000,
|
||||
img_size=224,
|
||||
in_chans=3,
|
||||
global_pool='avg',
|
||||
layers=[4, 3, 8, 3],
|
||||
patch_sizes=[7, 2, 1, 1],
|
||||
embed_dims=[192, 384, 384, 384],
|
||||
hidden_sizes=[48, 96, 96, 96],
|
||||
mlp_ratios=[3.0, 3.0, 3.0, 3.0],
|
||||
block_layer=Sequencer2DBlock,
|
||||
rnn_layer=LSTM2D,
|
||||
layers=(4, 3, 8, 3),
|
||||
patch_sizes=(7, 2, 2, 1),
|
||||
embed_dims=(192, 384, 384, 384),
|
||||
hidden_sizes=(48, 96, 96, 96),
|
||||
mlp_ratios=(3.0, 3.0, 3.0, 3.0),
|
||||
block_layer=Sequencer2dBlock,
|
||||
rnn_layer=LSTM2d,
|
||||
mlp_layer=Mlp,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
act_layer=nn.GELU,
|
||||
|
@ -291,23 +340,56 @@ class Sequencer2D(nn.Module):
|
|||
self.global_pool = global_pool
|
||||
self.num_features = embed_dims[-1] # num_features for consistency with other models
|
||||
self.feature_dim = -1 # channel dim index for feature outputs (rank 4, NHWC)
|
||||
self.embed_dims = embed_dims
|
||||
self.output_fmt = 'NHWC'
|
||||
self.feature_info = []
|
||||
|
||||
self.stem = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans,
|
||||
embed_dim=embed_dims[0], norm_layer=norm_layer if stem_norm else None,
|
||||
flatten=False)
|
||||
img_size=None,
|
||||
patch_size=patch_sizes[0],
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dims[0],
|
||||
norm_layer=norm_layer if stem_norm else None,
|
||||
flatten=False,
|
||||
output_fmt='NHWC',
|
||||
)
|
||||
|
||||
self.blocks = nn.Sequential(*[
|
||||
get_stage(
|
||||
i, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer=block_layer,
|
||||
rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer,
|
||||
num_layers=num_rnn_layers, bidirectional=bidirectional,
|
||||
union=union, with_fc=with_fc, drop=drop_rate, drop_path_rate=drop_path_rate,
|
||||
)
|
||||
for i, _ in enumerate(embed_dims)])
|
||||
assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios)
|
||||
reductions = list(accumulate(patch_sizes, lambda x, y: x * y))
|
||||
stages = []
|
||||
prev_dim = embed_dims[0]
|
||||
for i, _ in enumerate(embed_dims):
|
||||
stages += [Sequencer2dStage(
|
||||
prev_dim,
|
||||
embed_dims[i],
|
||||
depth=layers[i],
|
||||
downsample=i > 0,
|
||||
patch_size=patch_sizes[i],
|
||||
hidden_size=hidden_sizes[i],
|
||||
mlp_ratio=mlp_ratios[i],
|
||||
block_layer=block_layer,
|
||||
rnn_layer=rnn_layer,
|
||||
mlp_layer=mlp_layer,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
num_layers=num_rnn_layers,
|
||||
bidirectional=bidirectional,
|
||||
union=union,
|
||||
with_fc=with_fc,
|
||||
drop=drop_rate,
|
||||
drop_path=drop_path_rate,
|
||||
)]
|
||||
prev_dim = embed_dims[i]
|
||||
self.feature_info += [dict(num_chs=prev_dim, reduction=reductions[i], module=f'stages.{i}')]
|
||||
|
||||
self.stages = nn.Sequential(*stages)
|
||||
self.norm = norm_layer(embed_dims[-1])
|
||||
self.head = nn.Linear(embed_dims[-1], self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head = ClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
input_fmt=self.output_fmt,
|
||||
)
|
||||
|
||||
self.init_weights(nlhb=nlhb)
|
||||
|
||||
|
@ -320,8 +402,11 @@ class Sequencer2D(nn.Module):
|
|||
return dict(
|
||||
stem=r'^stem',
|
||||
blocks=[
|
||||
(r'^blocks\.(\d+)\..*\.down', (99999,)),
|
||||
(r'^blocks\.(\d+)', None) if coarse else (r'^blocks\.(\d+)\.(\d+)', None),
|
||||
(r'^stages\.(\d+)', None),
|
||||
(r'^norm', (99999,))
|
||||
] if coarse else [
|
||||
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
|
||||
(r'^stages\.(\d+)\.downsample', (0,)),
|
||||
(r'^norm', (99999,))
|
||||
]
|
||||
)
|
||||
|
@ -336,21 +421,16 @@ class Sequencer2D(nn.Module):
|
|||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'avg')
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.blocks(x)
|
||||
x = self.stages(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = x.mean(dim=(1, 2))
|
||||
return x if pre_logits else self.head(x)
|
||||
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
@ -358,15 +438,56 @@ class Sequencer2D(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def _create_sequencer2d(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Sequencer2D models.')
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" Remap original checkpoints -> timm """
|
||||
if 'stages.0.blocks.0.norm1.weight' in state_dict:
|
||||
return state_dict # already translated checkpoint
|
||||
if 'model' in state_dict:
|
||||
state_dict = state_dict['model']
|
||||
|
||||
model = build_model_with_cfg(Sequencer2D, variant, pretrained, **kwargs)
|
||||
import re
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
k = re.sub(r'blocks.([0-9]+).([0-9]+).down', lambda x: f'stages.{int(x.group(1)) + 1}.downsample.down', k)
|
||||
k = re.sub(r'blocks.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
out_dict[k] = v
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_sequencer2d(variant, pretrained=False, **kwargs):
|
||||
default_out_indices = tuple(range(3))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
|
||||
model = build_model_with_cfg(
|
||||
Sequencer2d,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
# main
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': DEFAULT_CROP_PCT, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.proj', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'sequencer2d_s.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'sequencer2d_m.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'sequencer2d_l.in1k': _cfg(hf_hub_id='timm/'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def sequencer2d_s(pretrained=False, **kwargs):
|
||||
|
@ -376,12 +497,12 @@ def sequencer2d_s(pretrained=False, **kwargs):
|
|||
embed_dims=[192, 384, 384, 384],
|
||||
hidden_sizes=[48, 96, 96, 96],
|
||||
mlp_ratios=[3.0, 3.0, 3.0, 3.0],
|
||||
rnn_layer=LSTM2D,
|
||||
rnn_layer=LSTM2d,
|
||||
bidirectional=True,
|
||||
union="cat",
|
||||
with_fc=True,
|
||||
**kwargs)
|
||||
model = _create_sequencer2d('sequencer2d_s', pretrained=pretrained, **model_args)
|
||||
)
|
||||
model = _create_sequencer2d('sequencer2d_s', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -393,12 +514,12 @@ def sequencer2d_m(pretrained=False, **kwargs):
|
|||
embed_dims=[192, 384, 384, 384],
|
||||
hidden_sizes=[48, 96, 96, 96],
|
||||
mlp_ratios=[3.0, 3.0, 3.0, 3.0],
|
||||
rnn_layer=LSTM2D,
|
||||
rnn_layer=LSTM2d,
|
||||
bidirectional=True,
|
||||
union="cat",
|
||||
with_fc=True,
|
||||
**kwargs)
|
||||
model = _create_sequencer2d('sequencer2d_m', pretrained=pretrained, **model_args)
|
||||
model = _create_sequencer2d('sequencer2d_m', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -410,10 +531,10 @@ def sequencer2d_l(pretrained=False, **kwargs):
|
|||
embed_dims=[192, 384, 384, 384],
|
||||
hidden_sizes=[48, 96, 96, 96],
|
||||
mlp_ratios=[3.0, 3.0, 3.0, 3.0],
|
||||
rnn_layer=LSTM2D,
|
||||
rnn_layer=LSTM2d,
|
||||
bidirectional=True,
|
||||
union="cat",
|
||||
with_fc=True,
|
||||
**kwargs)
|
||||
model = _create_sequencer2d('sequencer2d_l', pretrained=pretrained, **model_args)
|
||||
model = _create_sequencer2d('sequencer2d_l', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
|
|
@ -15,34 +15,10 @@ from torch import nn as nn
|
|||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import SelectiveKernel, ConvNormAct, create_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .resnet import ResNet
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv1', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'skresnet18': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'),
|
||||
'skresnet34': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'),
|
||||
'skresnet50': _cfg(),
|
||||
'skresnet50d': _cfg(
|
||||
first_conv='conv1.0'),
|
||||
'skresnext50_32x4d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth'),
|
||||
}
|
||||
|
||||
|
||||
class SelectiveKernelBasic(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
|
@ -166,7 +142,33 @@ class SelectiveKernelBottleneck(nn.Module):
|
|||
|
||||
|
||||
def _create_skresnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
|
||||
return build_model_with_cfg(
|
||||
ResNet,
|
||||
variant,
|
||||
pretrained,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv1', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'skresnet18.ra_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'skresnet34.ra_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'skresnet50.untrained': _cfg(),
|
||||
'skresnet50d.untrained': _cfg(
|
||||
first_conv='conv1.0'),
|
||||
'skresnext50_32x4d.ra_in1k': _cfg(hf_hub_id='timm/'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -23,44 +23,11 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||
from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .vision_transformer import Attention
|
||||
|
||||
__all__ = ['Twins'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embeds.0.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'twins_pcpvt_small': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_small-e70e7e7a.pth',
|
||||
),
|
||||
'twins_pcpvt_base': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_base-e5ecb09b.pth',
|
||||
),
|
||||
'twins_pcpvt_large': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_large-d273f802.pth',
|
||||
),
|
||||
'twins_svt_small': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_small-42e5f78c.pth',
|
||||
),
|
||||
'twins_svt_base': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_base-c2265010.pth',
|
||||
),
|
||||
'twins_svt_large': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_large-90f6aaa9.pth',
|
||||
),
|
||||
}
|
||||
|
||||
Size_ = Tuple[int, int]
|
||||
|
||||
|
||||
|
@ -469,6 +436,27 @@ def _create_twins(variant, pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embeds.0.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'twins_pcpvt_small.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'twins_pcpvt_base.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'twins_pcpvt_large.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'twins_svt_small.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'twins_svt_base.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'twins_svt_large.in1k': _cfg(hf_hub_id='timm/'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def twins_pcpvt_small(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
|
|
|
@ -15,34 +15,11 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||
from timm.layers import ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['VGG']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'features.0', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'vgg11': _cfg(url='https://download.pytorch.org/models/vgg11-bbd30ac9.pth'),
|
||||
'vgg13': _cfg(url='https://download.pytorch.org/models/vgg13-c768596a.pth'),
|
||||
'vgg16': _cfg(url='https://download.pytorch.org/models/vgg16-397923af.pth'),
|
||||
'vgg19': _cfg(url='https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'),
|
||||
'vgg11_bn': _cfg(url='https://download.pytorch.org/models/vgg11_bn-6002323d.pth'),
|
||||
'vgg13_bn': _cfg(url='https://download.pytorch.org/models/vgg13_bn-abd245e5.pth'),
|
||||
'vgg16_bn': _cfg(url='https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'),
|
||||
'vgg19_bn': _cfg(url='https://download.pytorch.org/models/vgg19_bn-c79401a0.pth'),
|
||||
}
|
||||
|
||||
|
||||
cfgs: Dict[str, List[Union[str, int]]] = {
|
||||
'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
|
@ -55,8 +32,15 @@ cfgs: Dict[str, List[Union[str, int]]] = {
|
|||
class ConvMlp(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,
|
||||
drop_rate: float = 0.2, act_layer: nn.Module = None, conv_layer: nn.Module = None):
|
||||
self,
|
||||
in_features=512,
|
||||
out_features=4096,
|
||||
kernel_size=7,
|
||||
mlp_ratio=1.0,
|
||||
drop_rate: float = 0.2,
|
||||
act_layer: nn.Module = None,
|
||||
conv_layer: nn.Module = None,
|
||||
):
|
||||
super(ConvMlp, self).__init__()
|
||||
self.input_kernel_size = kernel_size
|
||||
mid_features = int(out_features * mlp_ratio)
|
||||
|
@ -124,10 +108,20 @@ class VGG(nn.Module):
|
|||
self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{len(layers) - 1}'))
|
||||
|
||||
self.pre_logits = ConvMlp(
|
||||
prev_chs, self.num_features, 7, mlp_ratio=mlp_ratio,
|
||||
drop_rate=drop_rate, act_layer=act_layer, conv_layer=conv_layer)
|
||||
prev_chs,
|
||||
self.num_features,
|
||||
7,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop_rate=drop_rate,
|
||||
act_layer=act_layer,
|
||||
conv_layer=conv_layer,
|
||||
)
|
||||
self.head = ClassifierHead(
|
||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
|
@ -147,7 +141,11 @@ class VGG(nn.Module):
|
|||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.head = ClassifierHead(
|
||||
self.num_features, self.num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||
self.num_features,
|
||||
self.num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=self.drop_rate,
|
||||
)
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.features(x)
|
||||
|
@ -197,14 +195,40 @@ def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG:
|
|||
# NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5]
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5))
|
||||
model = build_model_with_cfg(
|
||||
VGG, variant, pretrained,
|
||||
VGG,
|
||||
variant,
|
||||
pretrained,
|
||||
model_cfg=cfgs[cfg],
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
pretrained_filter_fn=_filter_fn,
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'features.0', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'vgg11.tv_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'vgg13.tv_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'vgg16.tv_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'vgg19.tv_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'vgg11_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'vgg13_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'vgg16_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'vgg19_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def vgg11(pretrained: bool = False, **kwargs: Any) -> VGG:
|
||||
r"""VGG 11-layer model (configuration "A") from
|
||||
|
|
|
@ -14,30 +14,11 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||
from timm.layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['Visformer']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.0', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
visformer_tiny=_cfg(),
|
||||
visformer_small=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/visformer_small-839e1f5b.pth'
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SpatialMlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -464,6 +445,23 @@ def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.0', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'visformer_tiny.in1k': _cfg(hf_hub_id='timm/'),
|
||||
'visformer_small.in1k': _cfg(hf_hub_id='timm/'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def visformer_tiny(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
|
|
|
@ -61,11 +61,12 @@ class RelPosAttention(nn.Module):
|
|||
k = self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
attn_bias = None
|
||||
if self.rel_pos is not None:
|
||||
attn_bias = self.rel_pos.get_bias()
|
||||
elif shared_rel_pos is not None:
|
||||
attn_bias = shared_rel_pos
|
||||
else:
|
||||
attn_bias = None
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
|
|
|
@ -27,26 +27,10 @@ import torch.nn.functional as F
|
|||
|
||||
from timm.layers import create_classifier
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||
|
||||
__all__ = ['Xception']
|
||||
|
||||
default_cfgs = {
|
||||
'xception': {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth',
|
||||
'input_size': (3, 299, 299),
|
||||
'pool_size': (10, 10),
|
||||
'crop_pct': 0.8975,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': (0.5, 0.5, 0.5),
|
||||
'std': (0.5, 0.5, 0.5),
|
||||
'num_classes': 1000,
|
||||
'first_conv': 'conv1',
|
||||
'classifier': 'fc'
|
||||
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class SeparableConv2d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1):
|
||||
|
@ -244,6 +228,29 @@ def _xception(variant, pretrained=False, **kwargs):
|
|||
**kwargs)
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'legacy_xception.tf_in1k': {
|
||||
'hf_hub_id': 'timm/',
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth',
|
||||
'input_size': (3, 299, 299),
|
||||
'pool_size': (10, 10),
|
||||
'crop_pct': 0.8975,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': (0.5, 0.5, 0.5),
|
||||
'std': (0.5, 0.5, 0.5),
|
||||
'num_classes': 1000,
|
||||
'first_conv': 'conv1',
|
||||
'classifier': 'fc'
|
||||
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def xception(pretrained=False, **kwargs):
|
||||
return _xception('xception', pretrained=pretrained, **kwargs)
|
||||
def legacy_xception(pretrained=False, **kwargs):
|
||||
return _xception('legacy_xception', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'xception': 'legacy_xception',
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue