Add ConvStem and MobileCLIP hybrid model for B variant. Add full norm disable support to ConvNormAct layers

This commit is contained in:
Ross Wightman 2024-06-06 09:15:27 -07:00
parent 3c9d8e5b33
commit cc8a03daac
4 changed files with 112 additions and 37 deletions

View File

@ -23,6 +23,7 @@ class ConvNormAct(nn.Module):
dilation: int = 1,
groups: int = 1,
bias: bool = False,
apply_norm: bool = True,
apply_act: bool = True,
norm_layer: LayerType = nn.BatchNorm2d,
act_layer: LayerType = nn.ReLU,
@ -48,6 +49,7 @@ class ConvNormAct(nn.Module):
**conv_kwargs,
)
if apply_norm:
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
@ -59,6 +61,11 @@ class ConvNormAct(nn.Module):
act_kwargs=act_kwargs,
**norm_kwargs,
)
else:
self.bn = nn.Sequential()
if drop_layer:
norm_kwargs['drop_layer'] = drop_layer
self.bn.add_module('drop', drop_layer())
@property
def in_channels(self):
@ -88,6 +95,7 @@ class ConvNormActAa(nn.Module):
dilation: int = 1,
groups: int = 1,
bias: bool = False,
apply_norm: bool = True,
apply_act: bool = True,
norm_layer: LayerType = nn.BatchNorm2d,
act_layer: LayerType = nn.ReLU,
@ -113,6 +121,7 @@ class ConvNormActAa(nn.Module):
**conv_kwargs,
)
if apply_norm:
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
@ -124,6 +133,12 @@ class ConvNormActAa(nn.Module):
act_kwargs=act_kwargs,
**norm_kwargs,
)
else:
self.bn = nn.Sequential()
if drop_layer:
norm_kwargs['drop_layer'] = drop_layer
self.bn.add_module('drop', drop_layer())
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
@property

View File

@ -19,21 +19,18 @@ from torch import nn as nn
from torch.nn import functional as F
from torchvision.ops.misc import FrozenBatchNorm2d
from .create_act import get_act_layer
from .create_act import create_act_layer
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
from .trace_utils import _assert
def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
act_layer = get_act_layer(act_layer) # string -> nn.Module
act_kwargs = act_kwargs or {}
if act_layer is not None and apply_act:
if inplace:
act_kwargs['inplace'] = inplace
act = act_layer(**act_kwargs)
else:
act = nn.Identity()
return act
act_kwargs.setdefault('inplace', inplace)
act = None
if apply_act:
act = create_act_layer(act_layer, **act_kwargs)
return nn.Identity() if act is None else act
class BatchNormAct2d(nn.BatchNorm2d):
@ -421,7 +418,6 @@ class LayerNormAct(nn.LayerNorm):
):
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
self._fast_norm = is_fast_norm()

View File

@ -609,7 +609,7 @@ class VisionTransformer(nn.Module):
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
if self.pos_embed is None:
return x
return x.view(x.shape[0], -1, x.shape[-1])
if self.dynamic_img_size:
B, H, W, C = x.shape

View File

@ -15,14 +15,15 @@ Hacked together by / Copyright 2020, Ross Wightman
"""
import math
from functools import partial
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, create_resnetv2_stem
@ -191,8 +192,52 @@ class HybridEmbedWithSize(nn.Module):
return x.flatten(2).transpose(1, 2), x.shape[-2:]
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
embed_layer = partial(HybridEmbed, backbone=backbone)
class ConvStem(nn.Sequential):
def __init__(
self,
in_chans: int = 3,
depth: int = 3,
channels: Union[int, Tuple[int, ...]] = 64,
kernel_size: Union[int, Tuple[int, ...]] = 3,
stride: Union[int, Tuple[int, ...]] = (2, 2, 2),
padding: Union[str, int, Tuple[int, ...]] = "",
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
act_layer: Type[nn.Module] = nn.ReLU,
):
super().__init__()
if isinstance(channels, int):
if depth == 4:
channels = (channels // 8, channels // 4, channels // 2, channels)
elif depth == 3:
channels = (channels // 4, channels // 2, channels)
else:
channels = to_ntuple(depth)(channels)
kernel_size = to_ntuple(depth)(kernel_size)
padding = to_ntuple(depth)(padding)
assert depth == len(stride) == len(kernel_size) == len(channels)
in_chs = in_chans
for i in range(len(channels)):
last_conv = i == len(channels) - 1
self.add_module(f'{i}', ConvNormAct(
in_chs,
channels[i],
kernel_size=kernel_size[i],
stride=stride[i],
padding=padding[i],
bias=last_conv,
apply_norm=not last_conv,
apply_act=not last_conv,
norm_layer=norm_layer,
act_layer=act_layer,
))
in_chs = channels[i]
def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs):
embed_args = embed_args or {}
embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
@ -433,6 +478,25 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs) -> VisionTransformer:
return model
@register_model
def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer:
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
"""
backbone = ConvStem(
channels=(768//4, 768//4, 768),
stride=(4, 2, 2),
kernel_size=(4, 2, 2),
padding=0,
act_layer=nn.GELU,
)
model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
model = _create_vision_transformer_hybrid(
'vit_base_resnet50d_224', backbone=backbone, embed_args=dict(proj=False),
pretrained=pretrained, **dict(model_args, **kwargs)
)
return model
register_model_deprecations(__name__, {
'vit_tiny_r_s16_p8_224_in21k': 'vit_tiny_r_s16_p8_224.augreg_in21k',
'vit_small_r26_s32_224_in21k': 'vit_small_r26_s32_224.augreg_in21k',