mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add ConvStem and MobileCLIP hybrid model for B variant. Add full norm disable support to ConvNormAct layers
This commit is contained in:
parent
3c9d8e5b33
commit
cc8a03daac
@ -23,6 +23,7 @@ class ConvNormAct(nn.Module):
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = False,
|
||||
apply_norm: bool = True,
|
||||
apply_act: bool = True,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
@ -48,17 +49,23 @@ class ConvNormAct(nn.Module):
|
||||
**conv_kwargs,
|
||||
)
|
||||
|
||||
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn = norm_act_layer(
|
||||
out_channels,
|
||||
apply_act=apply_act,
|
||||
act_kwargs=act_kwargs,
|
||||
**norm_kwargs,
|
||||
)
|
||||
if apply_norm:
|
||||
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn = norm_act_layer(
|
||||
out_channels,
|
||||
apply_act=apply_act,
|
||||
act_kwargs=act_kwargs,
|
||||
**norm_kwargs,
|
||||
)
|
||||
else:
|
||||
self.bn = nn.Sequential()
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn.add_module('drop', drop_layer())
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
@ -88,6 +95,7 @@ class ConvNormActAa(nn.Module):
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = False,
|
||||
apply_norm: bool = True,
|
||||
apply_act: bool = True,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
@ -113,17 +121,24 @@ class ConvNormActAa(nn.Module):
|
||||
**conv_kwargs,
|
||||
)
|
||||
|
||||
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn = norm_act_layer(
|
||||
out_channels,
|
||||
apply_act=apply_act,
|
||||
act_kwargs=act_kwargs,
|
||||
**norm_kwargs,
|
||||
)
|
||||
if apply_norm:
|
||||
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn = norm_act_layer(
|
||||
out_channels,
|
||||
apply_act=apply_act,
|
||||
act_kwargs=act_kwargs,
|
||||
**norm_kwargs,
|
||||
)
|
||||
else:
|
||||
self.bn = nn.Sequential()
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn.add_module('drop', drop_layer())
|
||||
|
||||
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
|
||||
|
||||
@property
|
||||
|
@ -19,21 +19,18 @@ from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from .create_act import get_act_layer
|
||||
from .create_act import create_act_layer
|
||||
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
|
||||
from .trace_utils import _assert
|
||||
|
||||
|
||||
def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
|
||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||
act_kwargs = act_kwargs or {}
|
||||
if act_layer is not None and apply_act:
|
||||
if inplace:
|
||||
act_kwargs['inplace'] = inplace
|
||||
act = act_layer(**act_kwargs)
|
||||
else:
|
||||
act = nn.Identity()
|
||||
return act
|
||||
act_kwargs.setdefault('inplace', inplace)
|
||||
act = None
|
||||
if apply_act:
|
||||
act = create_act_layer(act_layer, **act_kwargs)
|
||||
return nn.Identity() if act is None else act
|
||||
|
||||
|
||||
class BatchNormAct2d(nn.BatchNorm2d):
|
||||
@ -421,7 +418,6 @@ class LayerNormAct(nn.LayerNorm):
|
||||
):
|
||||
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
|
||||
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
|
||||
|
||||
self._fast_norm = is_fast_norm()
|
||||
|
@ -609,7 +609,7 @@ class VisionTransformer(nn.Module):
|
||||
|
||||
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.pos_embed is None:
|
||||
return x
|
||||
return x.view(x.shape[0], -1, x.shape[-1])
|
||||
|
||||
if self.dynamic_img_size:
|
||||
B, H, W, C = x.shape
|
||||
|
@ -15,14 +15,15 @@ Hacked together by / Copyright 2020, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
|
||||
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to
|
||||
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
from .resnet import resnet26d, resnet50d
|
||||
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
||||
@ -191,8 +192,52 @@ class HybridEmbedWithSize(nn.Module):
|
||||
return x.flatten(2).transpose(1, 2), x.shape[-2:]
|
||||
|
||||
|
||||
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
|
||||
embed_layer = partial(HybridEmbed, backbone=backbone)
|
||||
class ConvStem(nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
depth: int = 3,
|
||||
channels: Union[int, Tuple[int, ...]] = 64,
|
||||
kernel_size: Union[int, Tuple[int, ...]] = 3,
|
||||
stride: Union[int, Tuple[int, ...]] = (2, 2, 2),
|
||||
padding: Union[str, int, Tuple[int, ...]] = "",
|
||||
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||
act_layer: Type[nn.Module] = nn.ReLU,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(channels, int):
|
||||
if depth == 4:
|
||||
channels = (channels // 8, channels // 4, channels // 2, channels)
|
||||
elif depth == 3:
|
||||
channels = (channels // 4, channels // 2, channels)
|
||||
else:
|
||||
channels = to_ntuple(depth)(channels)
|
||||
|
||||
kernel_size = to_ntuple(depth)(kernel_size)
|
||||
padding = to_ntuple(depth)(padding)
|
||||
assert depth == len(stride) == len(kernel_size) == len(channels)
|
||||
|
||||
in_chs = in_chans
|
||||
for i in range(len(channels)):
|
||||
last_conv = i == len(channels) - 1
|
||||
self.add_module(f'{i}', ConvNormAct(
|
||||
in_chs,
|
||||
channels[i],
|
||||
kernel_size=kernel_size[i],
|
||||
stride=stride[i],
|
||||
padding=padding[i],
|
||||
bias=last_conv,
|
||||
apply_norm=not last_conv,
|
||||
apply_act=not last_conv,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
))
|
||||
in_chs = channels[i]
|
||||
|
||||
|
||||
def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs):
|
||||
embed_args = embed_args or {}
|
||||
embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args)
|
||||
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
||||
return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
|
||||
|
||||
@ -433,6 +478,25 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
|
||||
"""
|
||||
backbone = ConvStem(
|
||||
channels=(768//4, 768//4, 768),
|
||||
stride=(4, 2, 2),
|
||||
kernel_size=(4, 2, 2),
|
||||
padding=0,
|
||||
act_layer=nn.GELU,
|
||||
)
|
||||
model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_base_resnet50d_224', backbone=backbone, embed_args=dict(proj=False),
|
||||
pretrained=pretrained, **dict(model_args, **kwargs)
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'vit_tiny_r_s16_p8_224_in21k': 'vit_tiny_r_s16_p8_224.augreg_in21k',
|
||||
'vit_small_r26_s32_224_in21k': 'vit_small_r26_s32_224.augreg_in21k',
|
||||
|
Loading…
x
Reference in New Issue
Block a user