diff --git a/timm/layers/conv_bn_act.py b/timm/layers/conv_bn_act.py index 17847d76..de738045 100644 --- a/timm/layers/conv_bn_act.py +++ b/timm/layers/conv_bn_act.py @@ -23,6 +23,7 @@ class ConvNormAct(nn.Module): dilation: int = 1, groups: int = 1, bias: bool = False, + apply_norm: bool = True, apply_act: bool = True, norm_layer: LayerType = nn.BatchNorm2d, act_layer: LayerType = nn.ReLU, @@ -48,17 +49,23 @@ class ConvNormAct(nn.Module): **conv_kwargs, ) - # NOTE for backwards compatibility with models that use separate norm and act layer definitions - norm_act_layer = get_norm_act_layer(norm_layer, act_layer) - # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` - if drop_layer: - norm_kwargs['drop_layer'] = drop_layer - self.bn = norm_act_layer( - out_channels, - apply_act=apply_act, - act_kwargs=act_kwargs, - **norm_kwargs, - ) + if apply_norm: + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` + if drop_layer: + norm_kwargs['drop_layer'] = drop_layer + self.bn = norm_act_layer( + out_channels, + apply_act=apply_act, + act_kwargs=act_kwargs, + **norm_kwargs, + ) + else: + self.bn = nn.Sequential() + if drop_layer: + norm_kwargs['drop_layer'] = drop_layer + self.bn.add_module('drop', drop_layer()) @property def in_channels(self): @@ -88,6 +95,7 @@ class ConvNormActAa(nn.Module): dilation: int = 1, groups: int = 1, bias: bool = False, + apply_norm: bool = True, apply_act: bool = True, norm_layer: LayerType = nn.BatchNorm2d, act_layer: LayerType = nn.ReLU, @@ -113,17 +121,24 @@ class ConvNormActAa(nn.Module): **conv_kwargs, ) - # NOTE for backwards compatibility with models that use separate norm and act layer definitions - norm_act_layer = get_norm_act_layer(norm_layer, act_layer) - # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` - if drop_layer: - norm_kwargs['drop_layer'] = drop_layer - self.bn = norm_act_layer( - out_channels, - apply_act=apply_act, - act_kwargs=act_kwargs, - **norm_kwargs, - ) + if apply_norm: + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` + if drop_layer: + norm_kwargs['drop_layer'] = drop_layer + self.bn = norm_act_layer( + out_channels, + apply_act=apply_act, + act_kwargs=act_kwargs, + **norm_kwargs, + ) + else: + self.bn = nn.Sequential() + if drop_layer: + norm_kwargs['drop_layer'] = drop_layer + self.bn.add_module('drop', drop_layer()) + self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) @property diff --git a/timm/layers/norm_act.py b/timm/layers/norm_act.py index 49505c58..496efcfd 100644 --- a/timm/layers/norm_act.py +++ b/timm/layers/norm_act.py @@ -19,21 +19,18 @@ from torch import nn as nn from torch.nn import functional as F from torchvision.ops.misc import FrozenBatchNorm2d -from .create_act import get_act_layer +from .create_act import create_act_layer from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm from .trace_utils import _assert def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True): - act_layer = get_act_layer(act_layer) # string -> nn.Module act_kwargs = act_kwargs or {} - if act_layer is not None and apply_act: - if inplace: - act_kwargs['inplace'] = inplace - act = act_layer(**act_kwargs) - else: - act = nn.Identity() - return act + act_kwargs.setdefault('inplace', inplace) + act = None + if apply_act: + act = create_act_layer(act_layer, **act_kwargs) + return nn.Identity() if act is None else act class BatchNormAct2d(nn.BatchNorm2d): @@ -421,7 +418,6 @@ class LayerNormAct(nn.LayerNorm): ): super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() - act_layer = get_act_layer(act_layer) # string -> nn.Module self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index f25db6b5..e3f1b8f2 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -609,7 +609,7 @@ class VisionTransformer(nn.Module): def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: if self.pos_embed is None: - return x + return x.view(x.shape[0], -1, x.shape[-1]) if self.dynamic_img_size: B, H, W, C = x.shape diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index c2dd1e59..af51fa98 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -15,14 +15,15 @@ Hacked together by / Copyright 2020, Ross Wightman """ import math from functools import partial -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, Union import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to +from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to + from ._registry import generate_default_cfgs, register_model, register_model_deprecations from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, create_resnetv2_stem @@ -191,8 +192,52 @@ class HybridEmbedWithSize(nn.Module): return x.flatten(2).transpose(1, 2), x.shape[-2:] -def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): - embed_layer = partial(HybridEmbed, backbone=backbone) +class ConvStem(nn.Sequential): + def __init__( + self, + in_chans: int = 3, + depth: int = 3, + channels: Union[int, Tuple[int, ...]] = 64, + kernel_size: Union[int, Tuple[int, ...]] = 3, + stride: Union[int, Tuple[int, ...]] = (2, 2, 2), + padding: Union[str, int, Tuple[int, ...]] = "", + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.ReLU, + ): + super().__init__() + if isinstance(channels, int): + if depth == 4: + channels = (channels // 8, channels // 4, channels // 2, channels) + elif depth == 3: + channels = (channels // 4, channels // 2, channels) + else: + channels = to_ntuple(depth)(channels) + + kernel_size = to_ntuple(depth)(kernel_size) + padding = to_ntuple(depth)(padding) + assert depth == len(stride) == len(kernel_size) == len(channels) + + in_chs = in_chans + for i in range(len(channels)): + last_conv = i == len(channels) - 1 + self.add_module(f'{i}', ConvNormAct( + in_chs, + channels[i], + kernel_size=kernel_size[i], + stride=stride[i], + padding=padding[i], + bias=last_conv, + apply_norm=not last_conv, + apply_act=not last_conv, + norm_layer=norm_layer, + act_layer=act_layer, + )) + in_chs = channels[i] + + +def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs): + embed_args = embed_args or {} + embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args) kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs) @@ -433,6 +478,25 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs) -> VisionTransformer: return model +@register_model +def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer: + """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights. + """ + backbone = ConvStem( + channels=(768//4, 768//4, 768), + stride=(4, 2, 2), + kernel_size=(4, 2, 2), + padding=0, + act_layer=nn.GELU, + ) + model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True) + model = _create_vision_transformer_hybrid( + 'vit_base_resnet50d_224', backbone=backbone, embed_args=dict(proj=False), + pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + register_model_deprecations(__name__, { 'vit_tiny_r_s16_p8_224_in21k': 'vit_tiny_r_s16_p8_224.augreg_in21k', 'vit_small_r26_s32_224_in21k': 'vit_small_r26_s32_224.augreg_in21k',