From 34b41b143cf82743abb8281cd9d366c1b25ec6a7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 22 Mar 2024 17:55:02 -0700 Subject: [PATCH 01/27] Fiddling with efficientnet x/h defs, is it worth adding & training any? --- timm/models/_efficientnet_blocks.py | 149 +++++++++++++++++++++++---- timm/models/_efficientnet_builder.py | 46 ++++++++- timm/models/efficientnet.py | 107 +++++++++++++++++++ 3 files changed, 280 insertions(+), 22 deletions(-) diff --git a/timm/models/_efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py index a5a6f30b..b519b230 100644 --- a/timm/models/_efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -35,8 +35,15 @@ class SqueezeExcite(nn.Module): """ def __init__( - self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU, - gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None): + self, + in_chs, + rd_ratio=0.25, + rd_channels=None, + act_layer=nn.ReLU, + gate_layer=nn.Sigmoid, + force_act_layer=None, + rd_round_fn=None, + ): super(SqueezeExcite, self).__init__() if rd_channels is None: rd_round_fn = rd_round_fn or round @@ -59,8 +66,19 @@ class ConvBnAct(nn.Module): """ Conv + Norm Layer + Activation w/ optional skip connection """ def __init__( - self, in_chs, out_chs, kernel_size, stride=1, dilation=1, group_size=0, pad_type='', - skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.): + self, + in_chs, + out_chs, + kernel_size, + stride=1, + dilation=1, + group_size=0, + pad_type='', + skip=False, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + drop_path_rate=0., + ): super(ConvBnAct, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) groups = num_groups(group_size, in_chs) @@ -92,17 +110,45 @@ class DepthwiseSeparableConv(nn.Module): (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. """ def __init__( - self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', - noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - se_layer=None, drop_path_rate=0.): + self, + in_chs, + out_chs, + dw_kernel_size=3, + stride=1, + dilation=1, + group_size=1, + pad_type='', + noskip=False, + pw_kernel_size=1, + pw_act=False, + s2d=0, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + se_layer=None, + drop_path_rate=0., + ): super(DepthwiseSeparableConv, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) - groups = num_groups(group_size, in_chs) self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv + # Space to depth + if s2d == 1: + sd_chs = int(in_chs * 4) + #sd_pad_type = 'sam' + self.conv_s2d = create_conv2d( + in_chs, sd_chs, kernel_size=2, stride=2, padding=0) #'same') + self.bn_s2d = norm_act_layer(sd_chs, sd_chs) + in_chs = sd_chs + else: + self.conv_s2d = None + self.bn_s2d = None + + groups = num_groups(group_size, in_chs) + + dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type self.conv_dw = create_conv2d( - in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, groups=groups) + in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=dw_pad_type, groups=groups) self.bn1 = norm_act_layer(in_chs, inplace=True) # Squeeze-and-excitation @@ -120,7 +166,13 @@ class DepthwiseSeparableConv(nn.Module): def forward(self, x): shortcut = x + #print('ii', x.shape) + if self.conv_s2d is not None: + x = self.conv_s2d(x) + x = self.bn_s2d(x) + #print('id', x.shape) x = self.conv_dw(x) + #print('od', x.shape) x = self.bn1(x) x = self.se(x) x = self.conv_pw(x) @@ -141,15 +193,42 @@ class InvertedResidual(nn.Module): """ def __init__( - self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', - noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.): + self, + in_chs, + out_chs, + dw_kernel_size=3, + stride=1, + dilation=1, + group_size=1, + pad_type='', + noskip=False, + exp_ratio=1.0, + exp_kernel_size=1, + pw_kernel_size=1, + s2d=0, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + se_layer=None, + conv_kwargs=None, + drop_path_rate=0., + ): super(InvertedResidual, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) conv_kwargs = conv_kwargs or {} + self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + + # Space to depth + if s2d == 1: + sd_chs = int(in_chs * 4) + self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding=pad_type) + self.bn_s2d = norm_act_layer(sd_chs, sd_chs) + in_chs = sd_chs + else: + self.conv_s2d = None + self.bn_s2d = None + mid_chs = make_divisible(in_chs * exp_ratio) groups = num_groups(group_size, mid_chs) - self.has_skip = (in_chs == out_chs and stride == 1) and not noskip # Point-wise expansion self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) @@ -177,6 +256,9 @@ class InvertedResidual(nn.Module): def forward(self, x): shortcut = x + if self.conv_s2d is not None: + x = self.conv_s2d(x) + x = self.bn_s2d(x) x = self.conv_pw(x) x = self.bn1(x) x = self.conv_dw(x) @@ -193,9 +275,24 @@ class CondConvResidual(InvertedResidual): """ Inverted residual block w/ CondConv routing""" def __init__( - self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', - noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.): + self, + in_chs, + out_chs, + dw_kernel_size=3, + stride=1, + dilation=1, + group_size=1, + pad_type='', + noskip=False, + exp_ratio=1.0, + exp_kernel_size=1, + pw_kernel_size=1, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + se_layer=None, + num_experts=0, + drop_path_rate=0., + ): self.num_experts = num_experts conv_kwargs = dict(num_experts=self.num_experts) @@ -237,9 +334,23 @@ class EdgeResidual(nn.Module): """ def __init__( - self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, group_size=0, pad_type='', - force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): + self, + in_chs, + out_chs, + exp_kernel_size=3, + stride=1, + dilation=1, + group_size=0, + pad_type='', + force_in_chs=0, + noskip=False, + exp_ratio=1.0, + pw_kernel_size=1, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + se_layer=None, + drop_path_rate=0., + ): super(EdgeResidual, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) if force_in_chs > 0: diff --git a/timm/models/_efficientnet_builder.py b/timm/models/_efficientnet_builder.py index 1e3161d6..aedd8b39 100644 --- a/timm/models/_efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -24,7 +24,7 @@ __all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights" _logger = logging.getLogger(__name__) -_DEBUG_BUILDER = False +_DEBUG_BUILDER = True # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per # papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) @@ -143,6 +143,7 @@ def _decode_block_str(block_str): pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def num_repeat = int(options['r']) + s2d = int(options['d']) if 'd' in options else 0 # each type of block has different valid arguments, fill accordingly block_args = dict( @@ -159,6 +160,7 @@ def _decode_block_str(block_str): exp_ratio=float(options['e']), se_ratio=float(options['se']) if 'se' in options else 0., noskip=skip is False, + s2d=s2d > 0, )) if 'cc' in options: block_args['num_experts'] = int(options['cc']) @@ -169,6 +171,7 @@ def _decode_block_str(block_str): se_ratio=float(options['se']) if 'se' in options else 0., pw_act=block_type == 'dsa', noskip=block_type == 'dsa' or skip is False, + s2d=s2d > 0, )) elif block_type == 'er': block_args.update(dict( @@ -285,8 +288,18 @@ class EfficientNetBuilder: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py """ - def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=False, - act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''): + def __init__( + self, + output_stride=32, + pad_type='', + round_chs_fn=round_channels, + se_from_exp=False, + act_layer=None, + norm_layer=None, + se_layer=None, + drop_path_rate=0., + feature_location='', + ): self.output_stride = output_stride self.pad_type = pad_type self.round_chs_fn = round_chs_fn @@ -317,6 +330,11 @@ class EfficientNetBuilder: bt = ba.pop('block_type') ba['in_chs'] = self.in_chs ba['out_chs'] = self.round_chs_fn(ba['out_chs']) + s2d = ba.get('s2d', 0) + if s2d: + ba['out_chs'] *= 4 + if s2d == 1: + ba['dw_kernel_size'] = (ba['dw_kernel_size'] + 1) // 2 if 'force_in_chs' in ba and ba['force_in_chs']: # NOTE this is a hack to work around mismatch in TF EdgeEffNet impl ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs']) @@ -332,6 +350,9 @@ class EfficientNetBuilder: if not self.se_from_exp: # adjust se_ratio by expansion ratio if calculating se channels from block input se_ratio /= ba.get('exp_ratio', 1.0) + # adjust space2depth + if s2d == 1: + se_ratio /= 4 if self.se_has_ratio: ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio) else: @@ -377,6 +398,7 @@ class EfficientNetBuilder: self.features.append(feature_info) # outer list of block_args defines the stacks + space2depth = 0 for stack_idx, stack_args in enumerate(model_block_args): last_stack = stack_idx + 1 == len(model_block_args) _log_info_if('Stack: {}'.format(stack_idx), self.verbose) @@ -392,6 +414,21 @@ class EfficientNetBuilder: if block_idx >= 1: # only the first block in any stack can have a stride > 1 block_args['stride'] = 1 + if not space2depth and block_args.pop('s2d', False): + assert block_args['stride'] == 1 + space2depth = 1 + + if space2depth > 0: + if space2depth == 2 and block_args['stride'] == 2: + space2depth = 0 + block_args['stride'] = 1 + # to end s2d region, need to correct expansion and se ratio relative to input + # FIXME unify with _make_block logic? this is rather meh + block_args['exp_ratio'] /= 4 + #block_args['se_ratio'] /= 4 + else: + block_args['s2d'] = space2depth + extract_features = False if last_block: next_stack_idx = stack_idx + 1 @@ -416,6 +453,9 @@ class EfficientNetBuilder: block = self._make_block(block_args, total_block_idx, total_block_count) blocks.append(block) + if space2depth == 1: + space2depth = 2 + # stash feature module name and channel info for model feature extraction if extract_features: feature_info = dict( diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 6e61d1bf..09ab005b 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -799,6 +799,88 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0 return model +def _gen_efficientnet_x( + variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8, + group_size=None, version=1, pretrained=False, **kwargs): + """Creates an EfficientNet model. + + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-x-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-x-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-x-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-x-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-x-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-x-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-x-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-x-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-x-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + + """ + """ + if version == 1: + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25_d1_a0', + 'r2_k3_s22_e6_i16_o24_se0.25_f1_d2_a1', + 'r2_k5_s22_e6_i24_o40_se0.25_f1_a1', + 'r3_k3_s22_e6_i40_o80_se0.25_a0', + 'r3_k5_s11_e6_i80_o112_se0.25_a0', + 'r4_k5_s22_e6_i112_o192_se0.25_a0', + 'r1_k3_s11_e6_i192_o320_se0.25_a0', + ] + elif version == 2: + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25_d1_a0', + 'r2_k3_s22_e4_i16_o24_se0.25_f1_d2_a1', + 'r2_k5_s22_e4_i24_o40_se0.25_f1_a1', + 'r3_k3_s22_e4_i40_o80_se0.25_a0', + 'r3_k5_s11_e6_i80_o112_se0.25_a0', + 'r4_k5_s22_e6_i112_o192_se0.25_a0', + 'r1_k3_s11_e6_i192_o320_se0.25_a0', + ] + """ + if version == 1: + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25_d1'], + ['er_r2_k3_s2_e6_c24_se0.25_nre'], + ['er_r2_k5_s2_e6_c40_se0.25_nre'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + else: + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25_d1'], + ['er_r2_k3_s2_e4_c24_se0.25_nre'], + ['er_r2_k5_s2_e4_c40_se0.25_nre'], + ['ir_r3_k3_s2_e4_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier, divisor=channel_divisor) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size), + num_features=round_chs_fn(1280), + stem_size=32, + round_chs_fn=round_chs_fn, + act_layer=resolve_act_layer(kwargs, 'silu'), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a MixNet Small model. @@ -2197,6 +2279,31 @@ def tf_efficientnetv2_b3(pretrained=False, **kwargs) -> EfficientNet: return model +@register_model +def efficientnet_x_b3(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet_x( + 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_x_b5(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B5 """ + model = _gen_efficientnet_x( + 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_h_b5(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B5 """ + model = _gen_efficientnet_x( + 'efficientnet_b5', channel_multiplier=1.92, depth_multiplier=2.2, version=2, pretrained=pretrained, **kwargs) + return model + + @register_model def mixnet_s(pretrained=False, **kwargs) -> EfficientNet: """Creates a MixNet Small model. From 99d4c7d2027b67653d8e5fb15aec210283c2551d Mon Sep 17 00:00:00 2001 From: Beckschen Date: Sun, 5 May 2024 02:50:14 -0400 Subject: [PATCH 02/27] add ViTamin models --- timm/models/__init__.py | 1 + timm/models/vitamin.py | 561 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 562 insertions(+) create mode 100644 timm/models/vitamin.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 9d09efac..9c7bee6f 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -70,6 +70,7 @@ from .vision_transformer import * from .vision_transformer_hybrid import * from .vision_transformer_relpos import * from .vision_transformer_sam import * +from .vitamin import * from .volo import * from .vovnet import * from .xception import * diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py new file mode 100644 index 00000000..3eecb8db --- /dev/null +++ b/timm/models/vitamin.py @@ -0,0 +1,561 @@ +""" ViTamin + +Paper: Designing Scalable Vison Models in the Vision-Language Era + +@misc{chen2023designing, + title={Designing Scalable Vison Models in the Vision-Language Era}, + author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen}, + year={2023}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} + +Based on Apache 2.0 licensed code at https://github.com/ViTamin/ViTamin + +Modifications and timm support by Jieneng Chen 2023 + +Reference: +https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py +https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py +""" + +from functools import partial +from typing import List, Tuple +from dataclasses import dataclass, replace +from typing import Callable, Optional, Union, Tuple, List, Sequence +import math, time +from torch.jit import Final +import torch +import torch.nn as nn +import torch.nn.functional as F +import timm +from timm.layers import to_2tuple +from torch.utils.checkpoint import checkpoint +from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_ + +from timm.models._registry import register_model +from timm.layers import DropPath +from timm.layers.norm_act import _create_act + +from timm.models._manipulate import named_apply, checkpoint_seq +from timm.models._builder import build_model_with_cfg +from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn +from timm.models.vision_transformer_hybrid import HybridEmbed + + +@dataclass +class VitConvCfg: + expand_ratio: float = 4.0 + expand_output: bool = True # calculate expansion channels from output (vs input chs) + kernel_size: int = 3 + group_size: int = 1 # 1 == depthwise + pre_norm_act: bool = False # activation after pre-norm + stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw' + pool_type: str = 'avg2' + downsample_pool_type: str = 'avg2' + act_layer: str = 'gelu' # stem & stage 1234 + act_layer1: str = 'gelu' # stage 1234 + act_layer2: str = 'gelu' # stage 1234 + norm_layer: str = '' + norm_layer_cl: str = '' + norm_eps: Optional[float] = None + down_shortcut: Optional[bool] = True + mlp: str = 'mlp' + + def __post_init__(self): + # mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args + use_mbconv = True + if not self.norm_layer: + self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d' + if not self.norm_layer_cl and not use_mbconv: + self.norm_layer_cl = 'layernorm' + if self.norm_eps is None: + self.norm_eps = 1e-5 if use_mbconv else 1e-6 + self.downsample_pool_type = self.downsample_pool_type or self.pool_type + +@dataclass +class VitCfg: + # embed_dim: Tuple[int, ...] = (96, 192, 384, 768) + embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768) + depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2) + stem_width: int = 64 + conv_cfg: VitConvCfg = VitConvCfg() + weight_init: str = 'vit_eff' + head_type: str = "" + stem_type: str = "stem" + ln2d_permute: bool = True + # memory_format: str="" + + +def _init_conv(module, name, scheme=''): + if isinstance(module, nn.Conv2d): + fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels + fan_out //= module.groups + nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out)) + if module.bias is not None: + nn.init.zeros_(module.bias) + +class Stem(nn.Module): + def __init__( + self, + in_chs: int, + out_chs: int, + act_layer: str = 'gelu', + norm_layer: str = 'layernorm2d', + norm_eps: float = 1e-6, + bias: bool = True, + ): + super().__init__() + self.grad_checkpointing=False + norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) + self.out_chs = out_chs + self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias) + self.norm1 = norm_act_layer(out_chs) + self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias) + named_apply(_init_conv, self) + + def forward(self, x): + if self.grad_checkpointing: + x = checkpoint(self.conv1, x) + x = self.norm1(x) + x = checkpoint(self.conv2, x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) + + return x + +class Downsample2d(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + pool_type: str = 'avg2', + bias: bool = True, + ): + super().__init__() + self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) + + + if dim != dim_out: + self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) # 1x1 conv + else: + self.expand = nn.Identity() + + def forward(self, x): + x = self.pool(x) # spatial downsample + x = self.expand(x) # expand chs + return x + + +class StridedConv(nn.Module): + """ downsample 2d as well + """ + def __init__( + self, + kernel_size=3, + stride=2, + padding=1, + in_chans=3, + embed_dim=768, + ln2d_permute=True + ): + super().__init__() + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) + self.permute = ln2d_permute # TODO: disable + norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6) + self.norm = norm_layer(in_chans) # affine over C + + def forward(self, x): + x = self.norm(x) + x = self.proj(x) + return x + + +class MbConvLNBlock(nn.Module): + """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand) + """ + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 1, + drop_path: float = 0., + kernel_size: int = 3, + norm_layer: str = 'layernorm2d', + norm_eps: float = 1e-6, + act_layer: str = 'gelu', + expand_ratio: float = 4.0, + ): + super(MbConvLNBlock, self).__init__() + self.stride, self.in_chs, self.out_chs = stride, in_chs, out_chs + mid_chs = make_divisible(out_chs * expand_ratio) + prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) + + if stride == 2: + self.shortcut = Downsample2d(in_chs, out_chs, pool_type='avg', bias=True) + elif in_chs != out_chs: + self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True) + else: + self.shortcut = nn.Identity() + + self.pre_norm = prenorm_act_layer(in_chs, apply_act=False) + self.down = nn.Identity() + self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True) + self.act1 = _create_act(act_layer, inplace=True) + self.act2 = _create_act(act_layer, inplace=True) + + self.conv2_kxk = create_conv2d(mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True) + self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + + def init_weights(self, scheme=''): + named_apply(partial(_init_conv, scheme=scheme), self) + + def forward(self, x): + shortcut = self.shortcut(x) + + x = self.pre_norm(x) + x = self.down(x) # nn.Identity() + + # 1x1 expansion conv & act + x = self.conv1_1x1(x) + x = self.act1(x) + + # (strided) depthwise 3x3 conv & act + x = self.conv2_kxk(x) + x = self.act2(x) + + # 1x1 linear projection to output width + x = self.conv3_1x1(x) + x = self.drop_path(x) + shortcut + + return x + + +class MbConvStages(nn.Module): + """ MobileConv for stage 1 and stage 2 of ViTamin + """ + def __init__( + self, + cfg: VitCfg, + img_size: Union[int, Tuple[int, int]] = 224, # place holder + in_chans: int = 3, + ): + super().__init__() + self.grad_checkpointing = False + self.stem = Stem( + in_chs=in_chans, + out_chs=cfg.stem_width, + ) + stages = [] + self.num_stages = len(cfg.embed_dim) + for s, dim in enumerate(cfg.embed_dim[:2]): # stage + blocks = [] + stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width + for d in range(cfg.depths[s]): + blocks += [MbConvLNBlock( + in_chs = stage_in_chs if d==0 else dim, + out_chs = dim, + stride = 2 if d == 0 else 1, + # cfg = cfg.conv_cfg, + )] + blocks = nn.Sequential(*blocks) + stages += [blocks] + + self.stages = nn.ModuleList(stages) + self.pool = StridedConv( + stride=2, + in_chans=cfg.embed_dim[1], + embed_dim=cfg.embed_dim[2] + ) + + def forward(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + for stage in self.stages: + x = checkpoint_seq(stage, x) + x = checkpoint(self.pool, x) + else: + for stage in self.stages: + x = stage(x) + x = self.pool(x) + + return x + +class GeGluMlp(nn.Module): + def __init__( + self, + in_features, + hidden_features, + act_layer = None, + drop = 0.0, + ): + super().__init__() + norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6) + self.norm = norm_layer(in_features) + self.act = nn.GELU() + self.w0 = nn.Linear(in_features, hidden_features) + self.w1 = nn.Linear(in_features, hidden_features) + self.w2 = nn.Linear(hidden_features, in_features) + + def forward(self, x): + x = self.norm(x) + x = self.act(self.w0(x)) * self.w1(x) + x = self.w2(x) + return x + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__( + self, + backbone, + img_size=224, + patch_size=1, + feature_size=None, + in_chans=3, + embed_dim=1024, + bias=True, + dynamic_img_pad=False, + ): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.backbone = backbone + with torch.no_grad(): + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + o = o[-1] # last feature if backbone outputs list/tuple of features + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + + assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 + self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.proj = nn.Identity() + + def forward(self, x): + x = self.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) + return x + +def _create_vision_transformer(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + if 'flexi' in variant: + # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed + # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. + _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False) + else: + _filter_fn = checkpoint_filter_fn + + return build_model_with_cfg( + VisionTransformer, + variant, + pretrained, + pretrained_filter_fn=_filter_fn, + **kwargs, + ) + + +def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): + embed_layer = partial(HybridEmbed, backbone=backbone) + kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set + return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs) + + +@register_model +def vitamin_small(pretrained=False, **kwargs) -> VisionTransformer: + stage_1_2 = MbConvStages(cfg=VitCfg( + embed_dim=(64, 128, 384), + depths=(2, 4, 1), + stem_width=64, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + stage3_args = dict(embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid('vitamin_small', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) + return model + + +@register_model +def vitamin_base(pretrained=False, **kwargs) -> VisionTransformer: + stage_1_2 = MbConvStages(cfg=VitCfg( + embed_dim=(128, 256, 768), + depths=(2, 4, 1), + stem_width=128, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + stage3_args = dict(embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid('vitamin_base', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) + return model + + +@register_model +def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer: + stage_1_2 = MbConvStages(cfg=VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + stage3_args = dict(embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid( + 'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) + return model + +@register_model +def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: + backbone = MbConvStages(cfg=VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + model_args = dict(img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid( + 'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: + backbone = MbConvStages(cfg=VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + model_args = dict(img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid( + 'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: + backbone = MbConvStages(cfg=VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + model_args = dict(img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid( + 'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: + backbone = MbConvStages(cfg=VitCfg( + embed_dim=(192, 384, 1152), + depths=(2, 4, 1), + stem_width=192, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + model_args = dict(img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid( + 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: + backbone = MbConvStages(cfg=VitCfg( + embed_dim=(192, 384, 1152), + depths=(2, 4, 1), + stem_width=192, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + model_args = dict(img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid( + 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: + backbone = MbConvStages(cfg=VitCfg( + embed_dim=(192, 384, 1152), + depths=(2, 4, 1), + stem_width=192, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid( + 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +def count_params(model: nn.Module): + return sum([m.numel() for m in model.parameters()]) + +def count_stage_params(model: nn.Module, prefix='none'): + collections = [] + for name, m in model.named_parameters(): + print(name) + if name.startswith(prefix): + collections.append(m.numel()) + return sum(collections) + + +if __name__ == "__main__": + model = timm.create_model('vitamin_large', num_classes=10).cuda() + # x = torch.rand([2,3,224,224]).cuda() + check_keys(model) From df304ffbf24114c2faf62eb0e6faae4c18320256 Mon Sep 17 00:00:00 2001 From: Beckschen Date: Tue, 14 May 2024 15:10:05 -0400 Subject: [PATCH 03/27] the dataclass init needs to use the default factory pattern, according to Ross --- timm/models/vitamin.py | 62 ++++++++++++------------------------------ 1 file changed, 17 insertions(+), 45 deletions(-) diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index 3eecb8db..ad1b6883 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -21,7 +21,7 @@ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision from functools import partial from typing import List, Tuple -from dataclasses import dataclass, replace +from dataclasses import dataclass, replace, field from typing import Callable, Optional, Union, Tuple, List, Sequence import math, time from torch.jit import Final @@ -29,16 +29,17 @@ import torch import torch.nn as nn import torch.nn.functional as F import timm -from timm.layers import to_2tuple + from torch.utils.checkpoint import checkpoint from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_ -from timm.models._registry import register_model +from timm.layers import to_2tuple from timm.layers import DropPath from timm.layers.norm_act import _create_act from timm.models._manipulate import named_apply, checkpoint_seq from timm.models._builder import build_model_with_cfg +from timm.models._registry import register_model from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn from timm.models.vision_transformer_hybrid import HybridEmbed @@ -54,37 +55,19 @@ class VitConvCfg: pool_type: str = 'avg2' downsample_pool_type: str = 'avg2' act_layer: str = 'gelu' # stem & stage 1234 - act_layer1: str = 'gelu' # stage 1234 - act_layer2: str = 'gelu' # stage 1234 norm_layer: str = '' - norm_layer_cl: str = '' - norm_eps: Optional[float] = None + norm_eps: float = 1e-5 down_shortcut: Optional[bool] = True mlp: str = 'mlp' - def __post_init__(self): - # mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args - use_mbconv = True - if not self.norm_layer: - self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d' - if not self.norm_layer_cl and not use_mbconv: - self.norm_layer_cl = 'layernorm' - if self.norm_eps is None: - self.norm_eps = 1e-5 if use_mbconv else 1e-6 - self.downsample_pool_type = self.downsample_pool_type or self.pool_type @dataclass class VitCfg: - # embed_dim: Tuple[int, ...] = (96, 192, 384, 768) embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768) depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2) stem_width: int = 64 - conv_cfg: VitConvCfg = VitConvCfg() - weight_init: str = 'vit_eff' + conv_cfg: VitConvCfg = field(default_factory=VitConvCfg) head_type: str = "" - stem_type: str = "stem" - ln2d_permute: bool = True - # memory_format: str="" def _init_conv(module, name, scheme=''): @@ -95,6 +78,7 @@ def _init_conv(module, name, scheme=''): if module.bias is not None: nn.init.zeros_(module.bias) + class Stem(nn.Module): def __init__( self, @@ -126,6 +110,7 @@ class Stem(nn.Module): return x + class Downsample2d(nn.Module): def __init__( self, @@ -158,12 +143,10 @@ class StridedConv(nn.Module): stride=2, padding=1, in_chans=3, - embed_dim=768, - ln2d_permute=True + embed_dim=768 ): super().__init__() self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) - self.permute = ln2d_permute # TODO: disable norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6) self.norm = norm_layer(in_chans) # affine over C @@ -354,6 +337,7 @@ class HybridEmbed(nn.Module): x = x.flatten(2).transpose(1, 2) return x + def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') @@ -434,6 +418,7 @@ def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) return model + @register_model def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -452,6 +437,7 @@ def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -470,6 +456,7 @@ def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -488,6 +475,7 @@ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -506,6 +494,7 @@ def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -524,6 +513,7 @@ def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -540,22 +530,4 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') model = _create_vision_transformer_hybrid( 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) - return model - - -def count_params(model: nn.Module): - return sum([m.numel() for m in model.parameters()]) - -def count_stage_params(model: nn.Module, prefix='none'): - collections = [] - for name, m in model.named_parameters(): - print(name) - if name.startswith(prefix): - collections.append(m.numel()) - return sum(collections) - - -if __name__ == "__main__": - model = timm.create_model('vitamin_large', num_classes=10).cuda() - # x = torch.rand([2,3,224,224]).cuda() - check_keys(model) + return model \ No newline at end of file From 530fb49e7e96fc90d4620baf5fc7de3c6edd12c9 Mon Sep 17 00:00:00 2001 From: Beckschen Date: Fri, 17 May 2024 06:48:59 -0400 Subject: [PATCH 04/27] Add link to model weights on Hugging Face --- timm/models/vitamin.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index ad1b6883..75022c5f 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -1,18 +1,18 @@ """ ViTamin Paper: Designing Scalable Vison Models in the Vision-Language Era +Model Weights on Huggingface: https://huggingface.co/collections/jienengchen/vitamin-family-661048126b72debdaca060bf -@misc{chen2023designing, - title={Designing Scalable Vison Models in the Vision-Language Era}, - author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen}, - year={2023}, - archivePrefix={arXiv}, - primaryClass={cs.CV} +@inproceedings{chen2024vitamin, + title={ViTamin: Designing Scalable Vision Models in the Vision-language Era}, + author={Chen, Jieneng and Yu, Qihang and Shen, Xiaohui and Yuille, Alan and Chen, Liang-Chieh}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + year={2024} } Based on Apache 2.0 licensed code at https://github.com/ViTamin/ViTamin -Modifications and timm support by Jieneng Chen 2023 +Modifications and timm support by Jieneng Chen 2024 Reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py @@ -122,7 +122,6 @@ class Downsample2d(nn.Module): super().__init__() self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) - if dim != dim_out: self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) # 1x1 conv else: @@ -530,4 +529,11 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') model = _create_vision_transformer_hybrid( 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) - return model \ No newline at end of file + return model + + +if __name__ == "__main__": + model = timm.create_model('vitamin_large', num_classes=10).cuda() + x = torch.rand([2,3,224,224]).cuda() + y = model(x) + print(y.shape) \ No newline at end of file From 7a2ad6bce1a4a3e230bd5fa7d27431a644030b4c Mon Sep 17 00:00:00 2001 From: Beckschen Date: Fri, 17 May 2024 06:51:35 -0400 Subject: [PATCH 05/27] Add link to model weights on Hugging Face --- timm/models/vitamin.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index 75022c5f..f84a59d6 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -1,7 +1,7 @@ """ ViTamin Paper: Designing Scalable Vison Models in the Vision-Language Era -Model Weights on Huggingface: https://huggingface.co/collections/jienengchen/vitamin-family-661048126b72debdaca060bf +A family of model weights on Huggingface: https://huggingface.co/collections/jienengchen/vitamin-family-661048126b72debdaca060bf @inproceedings{chen2024vitamin, title={ViTamin: Designing Scalable Vision Models in the Vision-language Era}, @@ -529,11 +529,4 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') model = _create_vision_transformer_hybrid( 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) - return model - - -if __name__ == "__main__": - model = timm.create_model('vitamin_large', num_classes=10).cuda() - x = torch.rand([2,3,224,224]).cuda() - y = model(x) - print(y.shape) \ No newline at end of file + return model \ No newline at end of file From 6a8bb03330528143d95881273a65702d7845a54b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 May 2024 10:49:18 -0700 Subject: [PATCH 06/27] Initial MobileNetV4 pass --- timm/layers/__init__.py | 1 + timm/models/_efficientnet_blocks.py | 314 +++++++++++++++++++++++++-- timm/models/_efficientnet_builder.py | 101 ++++++--- timm/models/mobilenetv3.py | 309 ++++++++++++++++++++++++-- 4 files changed, 655 insertions(+), 70 deletions(-) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index de077797..b44e1161 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -1,6 +1,7 @@ from .activations import * from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2 from .attention_pool import AttentionPoolLatent from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding from .blur_pool import BlurPool2d diff --git a/timm/models/_efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py index b519b230..41f3182d 100644 --- a/timm/models/_efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -2,15 +2,19 @@ Hacked together by / Copyright 2019, Ross Wightman """ +from typing import Optional import torch import torch.nn as nn from torch.nn import functional as F -from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer +from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, to_2tuple,\ + get_norm_act_layer, MultiQueryAttention2d, MultiQueryAttentionV2, Attention2d __all__ = [ - 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual'] + 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', + 'UniversalInvertedResidual', 'MobileAttention' +] def num_groups(group_size, channels): @@ -85,7 +89,8 @@ class ConvBnAct(nn.Module): self.has_skip = skip and stride == 1 and in_chs == out_chs self.conv = create_conv2d( - in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type) + in_chs, out_chs, kernel_size, + stride=stride, dilation=dilation, groups=groups, padding=pad_type) self.bn1 = norm_act_layer(out_chs, inplace=True) self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() @@ -105,7 +110,7 @@ class ConvBnAct(nn.Module): class DepthwiseSeparableConv(nn.Module): - """ DepthwiseSeparable block + """ Depthwise-separable block Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. """ @@ -139,16 +144,19 @@ class DepthwiseSeparableConv(nn.Module): self.conv_s2d = create_conv2d( in_chs, sd_chs, kernel_size=2, stride=2, padding=0) #'same') self.bn_s2d = norm_act_layer(sd_chs, sd_chs) + dw_kernel_size = (dw_kernel_size + 1) // 2 + dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type in_chs = sd_chs else: self.conv_s2d = None self.bn_s2d = None + dw_pad_type = pad_type groups = num_groups(group_size, in_chs) - dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type self.conv_dw = create_conv2d( - in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=dw_pad_type, groups=groups) + in_chs, in_chs, dw_kernel_size, + stride=stride, dilation=dilation, padding=dw_pad_type, groups=groups) self.bn1 = norm_act_layer(in_chs, inplace=True) # Squeeze-and-excitation @@ -222,10 +230,13 @@ class InvertedResidual(nn.Module): sd_chs = int(in_chs * 4) self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding=pad_type) self.bn_s2d = norm_act_layer(sd_chs, sd_chs) + dw_kernel_size = (dw_kernel_size + 1) // 2 + dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type in_chs = sd_chs else: self.conv_s2d = None self.bn_s2d = None + dw_pad_type = pad_type mid_chs = make_divisible(in_chs * exp_ratio) groups = num_groups(group_size, mid_chs) @@ -236,8 +247,8 @@ class InvertedResidual(nn.Module): # Depth-wise convolution self.conv_dw = create_conv2d( - mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, - groups=groups, padding=pad_type, **conv_kwargs) + mid_chs, mid_chs, dw_kernel_size, + stride=stride, dilation=dilation, groups=groups, padding=dw_pad_type, **conv_kwargs) self.bn2 = norm_act_layer(mid_chs, inplace=True) # Squeeze-and-excitation @@ -271,6 +282,267 @@ class InvertedResidual(nn.Module): return x +class LayerScale2d(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma.view(1, -1, 1, 1) + return x.mul_(gamma) if self.inplace else x * gamma + + +class UniversalInvertedResidual(nn.Module): + """ Universal Inverted Residual Block + + For MobileNetV4 - https://arxiv.org/abs/ + """ + + def __init__( + self, + in_chs, + out_chs, + dw_kernel_size_start: int = 0, + dw_kernel_size_mid: int = 3, + dw_kernel_size_end: int = 0, + stride=1, + dilation=1, + group_size=1, + pad_type='', + noskip=False, + exp_ratio=1.0, + act_layer=nn.ReLU, + dw_act_layer=None, + norm_layer=nn.BatchNorm2d, + se_layer=None, + conv_kwargs=None, + drop_path_rate=0., + layer_scale_init_value: Optional[float] = 1e-5, + ): + super(UniversalInvertedResidual, self).__init__() + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + dw_act_layer = dw_act_layer or act_layer + dw_norm_act_layer = get_norm_act_layer(norm_layer, dw_act_layer) + conv_kwargs = conv_kwargs or {} + self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + + # FIXME dilation isn't right w/ extra ks > 1 convs + if dw_kernel_size_start: + self.conv_dw_start = create_conv2d( + in_chs, in_chs, dw_kernel_size_start, + dilation=dilation, # FIXME + depthwise=True, + padding=pad_type, + **conv_kwargs, + ) + self.norm_dw_start = dw_norm_act_layer(in_chs, apply_act=False) + else: + self.conv_dw_start = nn.Identity() + self.norm_dw_start = nn.Identity() + + # Point-wise expansion + mid_chs = make_divisible(in_chs * exp_ratio) + self.conv_pw = create_conv2d(in_chs, mid_chs, 1, padding=pad_type, **conv_kwargs) + self.norm_pw = norm_act_layer(mid_chs, inplace=True) + + # Depth-wise convolution + if dw_kernel_size_mid: + groups = num_groups(group_size, mid_chs) + self.conv_dw_mid = create_conv2d( + mid_chs, mid_chs, dw_kernel_size_mid, + stride=stride, + dilation=dilation, # FIXME + groups=groups, + padding=pad_type, + **conv_kwargs, + ) + self.norm_dw_mid = dw_norm_act_layer(mid_chs, inplace=True) + else: + self.conv_dw_mid = nn.Identity() + self.norm_dw_mid = nn.Identity() + + # Squeeze-and-excitation + self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() + + # Point-wise linear projection + self.conv_pwl = create_conv2d(mid_chs, out_chs, 1, padding=pad_type, **conv_kwargs) + self.norm_pwl = norm_act_layer(out_chs, apply_act=False) + + if dw_kernel_size_end: + self.conv_dw_end = create_conv2d( + out_chs, out_chs, dw_kernel_size_end, + dilation=dilation, + depthwise=True, + padding=pad_type, + **conv_kwargs, + ) + self.norm_dw_end = dw_norm_act_layer(out_chs, apply_act=False) + else: + # dw_end rarely used so keeping it out of repr by not using None instead of nn.Identitty() + self.conv_dw_end = None + self.norm_dw_end = None + + if layer_scale_init_value is not None: + self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value) + else: + self.layer_scale = nn.Identity() + self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() + + def feature_info(self, location): + if location == 'expansion': # after SE, input to PWL + return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) + else: # location == 'bottleneck', block output + return dict(module='', num_chs=self.conv_pwl.out_channels) + + def forward(self, x): + shortcut = x + x = self.conv_dw_start(x) + x = self.norm_dw_start(x) + x = self.conv_pw(x) + x = self.norm_pw(x) + x = self.conv_dw_mid(x) + x = self.norm_dw_mid(x) + x = self.se(x) + x = self.conv_pwl(x) + x = self.norm_pwl(x) + if self.conv_dw_end is not None: + x = self.conv_dw_end(x) + x = self.norm_dw_end(x) + x = self.layer_scale(x) + if self.has_skip: + x = self.drop_path(x) + shortcut + return x + + +class MobileAttention(nn.Module): + """ Mobile Attention Block + + For MobileNetV4 - https://arxiv.org/abs/ + """ + def __init__( + self, + in_chs, + out_chs, + stride=1, + dw_kernel_size=3, + dilation=1, + group_size=1, + pad_type='', + num_heads: int = 8, + key_dim: int = 64, + value_dim: int = 64, + use_multi_query: bool = False, + query_strides: int = (1, 1), + kv_stride: int = 1, + cpe_dw_kernel_size=3, + noskip=False, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + drop_path_rate=0., + attn_drop=0.0, + proj_drop=0.0, + layer_scale_init_value: Optional[float] = 1e-5, + use_bias=False, + use_cpe=False, + ): + super(MobileAttention, self).__init__() + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip + self.query_strides = to_2tuple(query_strides) + self.kv_stride = kv_stride + self.has_query_stride = any([s > 1 for s in self.query_strides]) + + # This CPE is different than the one suggested in the original paper. + # https://arxiv.org/abs/2102.10882 + # 1. Rather than adding one CPE before the attention blocks, we add a CPE + # into every attention block. + # 2. We replace the expensive Conv2D by a Seperable DW Conv. + if use_cpe: + self.conv_cpe_dw = create_conv2d( + in_chs, in_chs, + kernel_size=cpe_dw_kernel_size, + dilation=dilation, + depthwise=True, + bias=True, + ) + else: + self.conv_cpe_dw = None + + self.norm = norm_act_layer(in_chs, apply_act=False) + + if num_heads is None: + assert in_chs % key_dim == 0 + num_heads = in_chs // key_dim + + if use_multi_query: + #if self.has_query_stride or self.kv_stride > 1: + self.attn = ( + MultiQueryAttention2d( + in_chs, + dim_out=out_chs, + num_heads=num_heads, + key_dim=key_dim, + value_dim=value_dim, + query_strides=query_strides, + kv_stride=kv_stride, + dilation=dilation, + padding=pad_type, + dw_kernel_size=dw_kernel_size, + attn_drop=attn_drop, + proj_drop=proj_drop, + #bias=use_bias, # why not here if used w/ mhsa? + ) + ) + # else: + # self.attn = MultiQueryAttentionV2( + # in_chs, + # dim_out=out_chs, + # num_heads=num_heads, + # key_dim=key_dim, + # value_dim=value_dim, + # attn_drop=attn_drop, + # proj_drop=proj_drop, + # ) + else: + self.attn = Attention2d( + in_chs, + dim_out=out_chs, + num_heads=num_heads, + attn_drop=attn_drop, + proj_drop=proj_drop, + bias=use_bias, + ) + + if layer_scale_init_value is not None: + self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value) + else: + self.layer_scale = nn.Identity() + + self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() + + + def feature_info(self, location): + if location == 'expansion': # after SE, input to PW + return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) + else: # location == 'bottleneck', block output + return dict(module='', num_chs=self.conv_pw.out_channels) + + def forward(self, x): + if self.conv_cpe_dw is not None: + x_cpe = self.conv_cpe_dw(x) + x = x + x_cpe + + shortcut = x + x = self.norm(x) + x = self.attn(x) + x = self.layer_scale(x) + if self.has_skip: + x = self.drop_path(x) + shortcut + + return x + + class CondConvResidual(InvertedResidual): """ Inverted residual block w/ CondConv routing""" @@ -296,13 +568,24 @@ class CondConvResidual(InvertedResidual): self.num_experts = num_experts conv_kwargs = dict(num_experts=self.num_experts) - super(CondConvResidual, self).__init__( - in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, group_size=group_size, - pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, - pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs, - drop_path_rate=drop_path_rate) - + in_chs, + out_chs, + dw_kernel_size=dw_kernel_size, + stride=stride, + dilation=dilation, + group_size=group_size, + pad_type=pad_type, + act_layer=act_layer, + noskip=noskip, + exp_ratio=exp_ratio, + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + se_layer=se_layer, + norm_layer=norm_layer, + conv_kwargs=conv_kwargs, + drop_path_rate=drop_path_rate, + ) self.routing_fn = nn.Linear(in_chs, self.num_experts) def forward(self, x): @@ -362,7 +645,8 @@ class EdgeResidual(nn.Module): # Expansion convolution self.conv_exp = create_conv2d( - in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type) + in_chs, mid_chs, exp_kernel_size, + stride=stride, dilation=dilation, groups=groups, padding=pad_type) self.bn1 = norm_act_layer(mid_chs, inplace=True) # Squeeze-and-excitation diff --git a/timm/models/_efficientnet_builder.py b/timm/models/_efficientnet_builder.py index aedd8b39..4cbd6342 100644 --- a/timm/models/_efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -139,11 +139,10 @@ def _decode_block_str(block_str): # if act_layer is None, the model default (passed to model init) will be used act_layer = options['n'] if 'n' in options else None - exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 - pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 + start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 + end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def num_repeat = int(options['r']) - s2d = int(options['d']) if 'd' in options else 0 # each type of block has different valid arguments, fill accordingly block_args = dict( @@ -155,31 +154,31 @@ def _decode_block_str(block_str): if block_type == 'ir': block_args.update(dict( dw_kernel_size=_parse_ksize(options['k']), - exp_kernel_size=exp_kernel_size, - pw_kernel_size=pw_kernel_size, + exp_kernel_size=start_kernel_size, + pw_kernel_size=end_kernel_size, exp_ratio=float(options['e']), - se_ratio=float(options['se']) if 'se' in options else 0., + se_ratio=float(options.get('se', 0.)), noskip=skip is False, - s2d=s2d > 0, + s2d=int(options.get('d', 0)) > 0, )) if 'cc' in options: block_args['num_experts'] = int(options['cc']) elif block_type == 'ds' or block_type == 'dsa': block_args.update(dict( dw_kernel_size=_parse_ksize(options['k']), - pw_kernel_size=pw_kernel_size, - se_ratio=float(options['se']) if 'se' in options else 0., + pw_kernel_size=end_kernel_size, + se_ratio=float(options.get('se', 0.)), pw_act=block_type == 'dsa', noskip=block_type == 'dsa' or skip is False, - s2d=s2d > 0, + s2d=int(options.get('d', 0)) > 0, )) elif block_type == 'er': block_args.update(dict( exp_kernel_size=_parse_ksize(options['k']), - pw_kernel_size=pw_kernel_size, + pw_kernel_size=end_kernel_size, exp_ratio=float(options['e']), force_in_chs=force_in_chs, - se_ratio=float(options['se']) if 'se' in options else 0., + se_ratio=float(options.get('se', 0.)), noskip=skip is False, )) elif block_type == 'cn': @@ -187,6 +186,38 @@ def _decode_block_str(block_str): kernel_size=int(options['k']), skip=skip is True, )) + elif block_type == 'uir': + # override exp / proj kernels for start/end in uir block + start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 0 + end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 0 + block_args.update(dict( + dw_kernel_size_start=start_kernel_size, # overload exp ks arg for dw start + dw_kernel_size_mid=_parse_ksize(options['k']), + dw_kernel_size_end=end_kernel_size, # overload pw ks arg for dw end + exp_ratio=float(options['e']), + se_ratio=float(options.get('se', 0.)), + noskip=skip is False, + )) + elif block_type == 'mha': + kv_dim = int(options['d']) + block_args.update(dict( + dw_kernel_size=_parse_ksize(options['k']), + num_heads=int(options['h']), + key_dim=kv_dim, + value_dim=kv_dim, + kv_stride=int(options.get('v', 1)), + noskip=skip is False, + )) + elif block_type == 'mqa': + kv_dim = int(options['d']) + block_args.update(dict( + dw_kernel_size=_parse_ksize(options['k']), + num_heads=int(options['h']), + key_dim=kv_dim, + value_dim=kv_dim, + kv_stride=int(options.get('v', 1)), + noskip=skip is False, + )) else: assert False, 'Unknown block type (%s)' % block_type if 'gs' in options: @@ -331,10 +362,9 @@ class EfficientNetBuilder: ba['in_chs'] = self.in_chs ba['out_chs'] = self.round_chs_fn(ba['out_chs']) s2d = ba.get('s2d', 0) - if s2d: + if s2d > 0: + # adjust while space2depth active ba['out_chs'] *= 4 - if s2d == 1: - ba['dw_kernel_size'] = (ba['dw_kernel_size'] + 1) // 2 if 'force_in_chs' in ba and ba['force_in_chs']: # NOTE this is a hack to work around mismatch in TF EdgeEffNet impl ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs']) @@ -344,19 +374,19 @@ class EfficientNetBuilder: assert ba['act_layer'] is not None ba['norm_layer'] = self.norm_layer ba['drop_path_rate'] = drop_path_rate - if bt != 'cn': - se_ratio = ba.pop('se_ratio') - if se_ratio and self.se_layer is not None: - if not self.se_from_exp: - # adjust se_ratio by expansion ratio if calculating se channels from block input - se_ratio /= ba.get('exp_ratio', 1.0) - # adjust space2depth - if s2d == 1: - se_ratio /= 4 - if self.se_has_ratio: - ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio) - else: - ba['se_layer'] = self.se_layer + + se_ratio = ba.pop('se_ratio', None) + if se_ratio and self.se_layer is not None: + if not self.se_from_exp: + # adjust se_ratio by expansion ratio if calculating se channels from block input + se_ratio /= ba.get('exp_ratio', 1.0) + if s2d == 1: + # adjust for start of space2depth + se_ratio /= 4 + if self.se_has_ratio: + ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio) + else: + ba['se_layer'] = self.se_layer if bt == 'ir': _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) @@ -370,8 +400,17 @@ class EfficientNetBuilder: elif bt == 'cn': _log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose) block = ConvBnAct(**ba) + elif bt == 'uir': + _log_info_if(' UniversalInvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = UniversalInvertedResidual(**ba) + elif bt == 'mqa': + _log_info_if(' MobileMultiQueryAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = MobileAttention(**ba, use_multi_query=True) + elif bt == 'mha': + _log_info_if(' MobileMultiHeadAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = MobileAttention(**ba) else: - assert False, 'Uknkown block type (%s) while building model.' % bt + assert False, 'Unknown block type (%s) while building model.' % bt self.in_chs = ba['out_chs'] # update in_chs for arg of next block return block @@ -420,12 +459,10 @@ class EfficientNetBuilder: if space2depth > 0: if space2depth == 2 and block_args['stride'] == 2: - space2depth = 0 block_args['stride'] = 1 # to end s2d region, need to correct expansion and se ratio relative to input - # FIXME unify with _make_block logic? this is rather meh block_args['exp_ratio'] /= 4 - #block_args['se_ratio'] /= 4 + space2depth = 0 else: block_args['s2d'] = space2depth diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index a9e3a1a8..0846e191 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -622,38 +622,215 @@ def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = return model -def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs): - """ LCNet - Essentially a MobileNet-V3 crossed with a MobileNet-V1 - Paper: `PP-LCNet: A Lightweight CPU Convolutional Neural Network` - https://arxiv.org/abs/2109.15099 + +def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3: + """Creates a MobileNet-V4 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 Args: channel_multiplier: multiplier to number of channels per layer. """ - arch_def = [ - # stage 0, 112x112 in - ['dsa_r1_k3_s1_c32'], - # stage 1, 112x112 in - ['dsa_r2_k3_s2_c64'], - # stage 2, 56x56 in - ['dsa_r2_k3_s2_c128'], - # stage 3, 28x28 in - ['dsa_r1_k3_s2_c256', 'dsa_r1_k5_s1_c256'], - # stage 4, 14x14in - ['dsa_r4_k5_s1_c256'], - # stage 5, 14x14in - ['dsa_r2_k5_s2_c512_se0.25'], - # 7x7 - ] + if 'hybrid' in variant: + if 'medium' in variant: + stem_size = 32 + num_features = 1280 + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + ['er_r1_k3_s2_e4_c48'], + # stage 1, 56x56 in + ['uir_r1_a3_k5_s2_e4_c80', 'uir_r1_a3_k3_s1_e2_c80'], + # stage 2, 28x28 in + [ + 'uir_r1_a3_k5_s2_e6_c160', + 'uir_r1_a0_k0_s1_e2_c160', + 'uir_r1_a3_k3_s1_e4_c160', + 'uir_r1_a3_k5_s1_e4_c160', + 'mqa_r1_k3_h4_s1_v2_d64_c160', + 'uir_r1_a3_k3_s1_e4_c160', + 'mqa_r1_k3_h4_s1_v2_d64_c160', + 'uir_r1_a3_k0_s1_e4_c160', # convnext + 'mqa_r1_k3_h4_s1_v2_d64_c160', + 'uir_r1_a3_k3_s1_e4_c160', + 'mqa_r1_k3_h4_s1_v2_d64_c160', + 'uir_r1_a3_k0_s1_e4_c160', # convnext + ], + # stage 3, 14x14in + [ + 'uir_r1_a5_k5_s2_e6_c256', + 'uir_r1_a5_k5_s1_e4_c256', + 'uir_r2_a3_k5_s1_e4_c256', + 'uir_r1_a0_k0_s1_e2_c256', + 'uir_r1_a3_k5_s1_e2_c256', + 'uir_r1_a0_k0_s1_e2_c256', + 'uir_r1_a0_k0_s1_e4_c256', + 'mqa_r1_k3_h4_s1_d64_c256', + 'uir_r1_a3_k0_s1_e4_c256', # convnext + 'mqa_r1_k3_h4_s1_d64_c256', + 'uir_r1_a5_k5_s1_e4_c256', + 'mqa_r1_k3_h4_s1_d64_c256', + 'uir_r1_a5_k0_s1_e4_c256', # convnext4 + 'mqa_r1_k3_h4_s1_d64_c256', + 'uir_r1_a5_k0_s1_e4_c256', # convnext4 + ], + # stage 4, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + elif 'large' in variant: + stem_size = 24 + num_features = 1280 + act_layer = resolve_act_layer(kwargs, 'gelu') + arch_def = [ + # stage 0, 112x112 in + ['er_r1_k3_s2_e4_c48'], + # stage 1, 56x56 in + ['uir_r1_a3_k5_s2_e4_c96', 'uir_r1_a3_k3_s1_e4_c96'], + # stage 2, 28x28 in + [ + 'uir_r1_a3_k5_s2_e4_c192', + 'uir_r3_a3_k3_s1_e4_c192', + 'uir_r1_a3_k5_s1_e4_c192', + 'uir_r2_a5_k3_s1_e4_c192', + 'mqa_r1_k3_h8_s1_v2_d48_c192', + 'uir_r1_a5_k3_s1_e4_c192', + 'mqa_r1_k3_h8_s1_v2_d48_c192', + 'uir_r1_a5_k3_s1_e4_c192', + 'mqa_r1_k3_h8_s1_v2_d48_c192', + 'uir_r1_a5_k3_s1_e4_c192', + 'mqa_r1_k3_h8_s1_v2_d48_c192', + 'uir_r1_a3_k0_s1_e4_c192', # convnext + ], + # stage 3, 14x14in + [ + 'uir_r4_a5_k5_s2_e4_c512', + 'uir_r1_a5_k0_s1_e4_c512', # convnext + 'uir_r1_a5_k3_s1_e4_c512', + 'uir_r2_a5_k0_s1_e4_c512', # convnext + 'uir_r1_a5_k3_s1_e4_c512', + 'uir_r1_a5_k5_s1_e4_c512', + 'mqa_r1_k3_h8_s1_d64_c512', + 'uir_r3_a5_k0_s1_e4_c512', # convnext + 'mqa_r1_k3_h8_s1_d64_c512', + 'uir_r3_a5_k0_s1_e4_c512', # convnext + 'mqa_r1_k3_h8_s1_d64_c512', + 'uir_r3_a5_k0_s1_e4_c512', # convnext + 'mqa_r1_k3_h8_s1_d64_c512', + 'uir_r3_a5_k0_s1_e4_c512', # convnext + ], + # stage 4, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + else: + assert False, f'Unknown variant {variant}.' + else: + if 'small' in variant: + stem_size = 32 + num_features = 1280 + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + ['cn_r1_k3_s2_e1_c32', 'cn_r1_k1_s1_e1_c32'], + # stage 1, 56x56 in + ['cn_r1_k3_s2_e1_c96', 'cn_r1_k1_s1_e1_c64'], + # stage 2, 28x28 in + [ + 'uir_r1_a5_k5_s2_e3_c96', # start dw + 'uir_r4_a0_k3_s1_e2_c96', # ir + 'uir_r1_a3_k0_s1_e4_c96', # convnext + ], + # stage 3, 14x14 in + [ + 'uir_r1_a3_k3_s2_e6_c128', # start dw + 'uir_r1_a5_k5_s1_e4_c128', # start dw + 'uir_r1_a0_k5_s1_e4_c128', # ir + 'uir_r1_a0_k5_s1_e3_c128', # ir + 'uir_r2_a0_k5_s1_e4_c128', # ir + ], + # stage 4, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + elif 'medium' in variant: + stem_size = 32 + num_features = 1280 + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + ['er_r1_k3_s2_e4_c48'], + # stage 1, 56x56 in + ['uir_r1_a3_k5_s2_e4_c80', 'uir_r1_a3_k3_s1_e2_c80'], + # stage 2, 28x28 in + [ + 'uir_r1_a5_k3_s2_e6_c160', + 'uir_r2_a3_k3_s1_e4_c160', + 'uir_r1_a3_k3_s1_e4_c160', + 'uir_r1_a3_k3_s1_e4_c160', + 'uir_r1_a3_k0_s1_e4_c160', # convnext + 'uir_r2_a0_k0_s1_e2_c160', + 'uir_r1_a3_k0_s1_e4_c160', # convnext + ], + # stage 3, 14x14in + [ + 'uir_r1_a5_k5_s2_e6_c256', + 'uir_r1_a5_k5_s1_e4_c256', + 'uir_r2_a3_k5_s1_e4_c256', + 'uir_r1_a0_k0_s1_e4_c256', + 'uir_r1_a3_k0_s1_e4_c256', # convnext + 'uir_r1_a3_k0_s1_e4_c256', # convnext + 'uir_r1_a3_k5_s1_e2_c256', + 'uir_r1_a5_k5_s1_e4_c256', + 'uir_r2_a0_k0_s1_e4_c256', + 'uir_r1_a5_k0_s1_e2_c256', # convnext + ], + # stage 4, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + elif 'large' in variant: + stem_size = 24 + num_features = 1280 + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + ['er_r1_k3_s2_e4_c48'], + # stage 1, 56x56 in + ['uir_r1_a3_k5_s2_e4_c96', 'uir_r1_a3_k3_s1_e4_c96'], + # stage 2, 28x28 in + [ + 'uir_r1_a3_k5_s2_e4_c192', + 'uir_r3_a3_k3_s1_e4_c192', + 'uir_r1_a3_k5_s1_e4_c192', + 'uir_r5_a5_k3_s1_e4_c192', + 'uir_r1_a3_k0_s1_e4_c192', # convnext + ], + # stage 3, 14x14in + [ + 'uir_r4_a5_k5_s2_e4_c512', + 'uir_r1_a5_k0_s1_e4_c512', # convnext + 'uir_r1_a5_k3_s1_e4_c512', + 'uir_r2_a5_k0_s1_e4_c512', # convnext + 'uir_r1_a5_k3_s1_e4_c512', + 'uir_r1_a5_k5_s1_e4_c512', + 'uir_r3_a5_k0_s1_e4_c512', # convnext + + ], + # stage 4, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + else: + assert False, f'Unknown variant {variant}.' + + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) model_kwargs = dict( block_args=decode_arch_def(arch_def), - stem_size=16, + num_features=num_features, + stem_size=stem_size, + fix_stem=channel_multiplier < 0.75, round_chs_fn=partial(round_channels, multiplier=channel_multiplier), norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), - act_layer=resolve_act_layer(kwargs, 'hard_swish'), - se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU), - num_features=1280, + act_layer=act_layer, + se_layer=se_layer, **kwargs, ) model = _create_mnv3(variant, pretrained, **model_kwargs) @@ -688,6 +865,9 @@ default_cfgs = generate_default_cfgs({ origin_url='https://github.com/Alibaba-MIIL/ImageNet21K', paper_ids='arXiv:2104.10972v4', interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221), + 'mobilenetv3_large_150.untrained': _cfg( + interpolation='bicubic'), + 'mobilenetv3_small_050.lamb_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth', @@ -762,6 +942,32 @@ default_cfgs = generate_default_cfgs({ interpolation='bicubic', ), "lcnet_150.untrained": _cfg(), + + 'mobilenetv4_conv_small': _cfg( + # hf_hub_id='timm/', + interpolation='bicubic'), + 'mobilenetv4_conv_medium': _cfg( + #hf_hub_id='timm/', + interpolation='bicubic'), + 'mobilenetv4_conv_large': _cfg( + # hf_hub_id='timm/', + interpolation='bicubic'), + + 'mobilenetv4_hybrid_small': _cfg( + # hf_hub_id='timm/', + interpolation='bicubic'), + 'mobilenetv4_hybrid_medium': _cfg( + # hf_hub_id='timm/', + interpolation='bicubic'), + 'mobilenetv4_hybrid_large': _cfg( + # hf_hub_id='timm/', + interpolation='bicubic'), + 'mobilenetv4_hybrid_medium_075': _cfg( + # hf_hub_id='timm/', + interpolation='bicubic'), + 'mobilenetv4_hybrid_medium_150': _cfg( + # hf_hub_id='timm/', + interpolation='bicubic'), }) @@ -779,6 +985,13 @@ def mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3: return model +@register_model +def mobilenetv3_large_150(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.5, pretrained=pretrained, **kwargs) + return model + + @register_model def mobilenetv3_small_050(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ @@ -918,6 +1131,56 @@ def lcnet_150(pretrained: bool = False, **kwargs) -> MobileNetV3: return model +@register_model +def mobilenetv4_conv_small(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 """ + model = _gen_mobilenet_v4('mobilenetv4_conv_small', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv4_conv_medium(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 """ + model = _gen_mobilenet_v4('mobilenetv4_conv_medium', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv4_conv_large(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 """ + model = _gen_mobilenet_v4('mobilenetv4_conv_large', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv4_hybrid_medium_075(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 Hybrid """ + model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv4_hybrid_medium(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 Hybrid """ + model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv4_hybrid_medium_150(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 Hybrid """ + model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium_150', 1.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv4_hybrid_large(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 Hybrid""" + model = _gen_mobilenet_v4('mobilenetv4_hybrid_large', 1.0, pretrained=pretrained, **kwargs) + return model + + + register_model_deprecations(__name__, { 'mobilenetv3_large_100_miil': 'mobilenetv3_large_100.miil_in21k_ft_in1k', 'mobilenetv3_large_100_miil_in21k': 'mobilenetv3_large_100.miil_in21k', From 2a1a6b12366fa0c975c038ca4829c8200f026ec0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 May 2024 11:06:32 -0700 Subject: [PATCH 07/27] Adding missing attention2d.py --- timm/layers/attention2d.py | 354 +++++++++++++++++++++++++++++++++++++ 1 file changed, 354 insertions(+) create mode 100644 timm/layers/attention2d.py diff --git a/timm/layers/attention2d.py b/timm/layers/attention2d.py new file mode 100644 index 00000000..3213a9f8 --- /dev/null +++ b/timm/layers/attention2d.py @@ -0,0 +1,354 @@ +from typing import List, Optional, Union + +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .config import use_fused_attn +from .create_conv2d import create_conv2d +from .helpers import to_2tuple +from .pool2d_same import create_pool2d + + +class MultiQueryAttentionV2(nn.Module): + """Multi Query Attention. + + Fast Transformer Decoding: One Write-Head is All You Need + https://arxiv.org/pdf/1911.02150.pdf + + This is an acceletor optimized version - removing multiple unneccessary + tensor transpose by re-arranging indices according to the following rules: 1) + contracted indices are at the end, 2) other indices have the same order in the + input and output tensores. + + Compared to V1, this gives 3x speed up. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + num_heads: int = 8, + key_dim: int = 64, + value_dim: int = 64, + attn_drop: float = 0., + proj_drop: float = 0., + ): + """Initializer.""" + super().__init__() + dim_out = dim_out or dim + self.num_heads = num_heads + self.key_dim = key_dim + self.value_dim = value_dim + self.scale = key_dim ** -0.5 + + self.query_proj = nn.Parameter(torch.randn([self.num_heads, self.key_dim, dim])) + self.key_proj = nn.Parameter(torch.randn([dim, self.key_dim])) + self.value_proj = nn.Parameter(torch.randn([dim, self.value_dim])) + self.attn_drop = nn.Dropout(attn_drop) + self.out_proj = nn.Parameter(torch.randn([dim_out, self.num_heads, self.value_dim])) + self.proj_drop = nn.Dropout(proj_drop) + + def _reshape_input(self, t): + """Reshapes a tensor to three dimensions, keeping the first and last.""" + s = t.shape + # Propagate the shape statically where possible. + #num = t.shape[1:-1].numel() + #return t.reshape(s[0], num, s[-1]) + return t.reshape(s[0], s[1], -1).transpose(1, 2) + + def forward(self, x, m: Optional[torch.Tensor] = None): + """Run layer computation.""" + s = x.shape + m = m or x + + reshaped_x = self._reshape_input(x) + reshaped_m = self._reshape_input(m) + + q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj) + k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj) + + attn = torch.einsum('bnhk,bmk->bnhm', q, k) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj) + o = torch.einsum('bnhm,bmv->bnhv', attn, v) + result = torch.einsum('bnhv,dhv->bnd', o, self.out_proj) + result = self.proj_drop(result) + return result.reshape(s) + + +class MultiQueryAttention2d(nn.Module): + """Multi Query Attention with spatial downsampling. + + 3 parameters are introduced for the spatial downsampling: + 1. kv_stride: downsampling factor on Key and Values only. + 2. query_strides: horizontal & vertical strides on Query only. + + This is an optimized version. + 1. Projections in Attention is explict written out as 1x1 Conv2D. + 2. Additional reshapes are introduced to bring a up to 3x speed up. + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + num_heads: int = 8, + key_dim: Optional[int] = None, + value_dim: Optional[int] = None, + query_strides: int = 1, + kv_stride: int = 1, + dw_kernel_size: int = 3, + dilation: int = 1, + padding: Union[str, int, List[int]] = '', + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = nn.BatchNorm2d, + ): + """Initializer. + + Args: + num_heads: Number of attention heads. + key_dim: Size of the attention key dimension. + value_dim: Size of the attention value dimension. + query_strides: Vertical stride size for query only. + kv_stride: Key and value stride size. + dw_kernel_size: Spatial dimension of the depthwise kernel. + """ + super().__init__() + dim_out = dim_out or dim + self.num_heads = num_heads + self.key_dim = key_dim or dim // num_heads + self.value_dim = value_dim or dim // num_heads + self.query_strides = to_2tuple(query_strides) + self.kv_stride = kv_stride + self.has_query_strides = any([s > 1 for s in self.query_strides]) + self.scale = self.key_dim ** -0.5 + self.fused_attn = use_fused_attn() + self.drop = attn_drop + + if self.has_query_strides: + # FIXME dilation + self.query_down_pool = create_pool2d( + 'avg', + kernel_size=self.query_strides, + padding=padding, + ) + self.query_down_norm = norm_layer(dim) + else: + self.query_down_pool = nn.Identity() + self.query_down_norm = nn.Identity() + + self.query_proj = create_conv2d( + dim, + self.num_heads * self.key_dim, + kernel_size=1, + ) + + if kv_stride > 1: + self.key_down_conv = create_conv2d( + dim, + dim, + kernel_size=dw_kernel_size, + stride=kv_stride, + dilation=dilation, + padding=padding, + depthwise=True, + ) + self.key_down_norm = norm_layer(dim) + else: + self.key_down_conv = nn.Identity() + self.key_down_norm = nn.Identity() + + self.key_proj = create_conv2d( + dim, + self.key_dim, + kernel_size=1, + padding=padding, + ) + + if kv_stride > 1: + self.value_down_conv = create_conv2d( + dim, + dim, + kernel_size=dw_kernel_size, + stride=kv_stride, + dilation=dilation, + padding=padding, + depthwise=True, + ) + self.value_down_norm = norm_layer(dim) + else: + self.value_down_conv = nn.Identity() + self.value_down_norm = nn.Identity() + + self.value_proj = create_conv2d( + dim, + self.value_dim, + kernel_size=1, + ) + + self.attn_drop = nn.Dropout(attn_drop) + + if self.has_query_strides: + self.upsampling = nn.Upsample(self.query_strides, mode='bilinear', align_corners=False) + else: + self.upsampling = nn.Identity() + + self.out_proj = create_conv2d( + self.value_dim * self.num_heads, + dim_out, + kernel_size=1, + ) + + self.proj_drop = nn.Dropout(proj_drop) + self.einsum = False + + def _reshape_input(self, t): + """Reshapes a tensor to three dimensions, keeping the batch and channels.""" + s = t.shape + t = t.reshape(s[0], s[1], -1).transpose(1, 2) + if self.einsum: + return t + else: + return t.unsqueeze(1).contiguous() + + def _reshape_projected_query(self, t, num_heads, key_dim): + """Reshapes projected query: [b, n, n, h x k] -> [b, n x n, h, k].""" + s = t.shape + t = t.reshape(s[0], num_heads, key_dim, -1) + if self.einsum: + return t.permute(0, 3, 1, 2).contiguous() + else: + return t.transpose(-1, -2).contiguous() + + def _reshape_output(self, t, num_heads, h_px, w_px): + """Reshape output:[b, n x n x h, k] -> [b, n, n, hk].""" + s = t.shape + feat_dim = s[-1] * num_heads + if not self.einsum: + t = t.transpose(1, 2) + return t.reshape(s[0], h_px, w_px, feat_dim).permute(0, 3, 1, 2).contiguous() + + + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + """Run layer computation.""" + B, C, H, W = s = x.shape + + q = self.query_down_pool(x) + q = self.query_down_norm(q) + q = self.query_proj(q) + # desired q shape: [b, h, k, n x n] - [b, l, h, k] + q = self._reshape_projected_query(q, self.num_heads, self.key_dim) + + k = self.key_down_conv(x) + k = self.key_down_norm(k) + k = self.key_proj(k) + # output shape of k: [b, k, p], p = m x m + k = self._reshape_input(k) + + v = self.value_down_conv(x) + v = self.value_down_norm(v) + v = self.value_proj(v) + # output shape of v: [ b, p, k], p = m x m + v = self._reshape_input(v) + + # desired q shape: [b, n x n, h, k] + # desired k shape: [b, m x m, k] + # desired logits shape: [b, n x n, h, m x m] + if self.einsum: + attn = torch.einsum('blhk,bpk->blhp', q, k) * self.scale + if attn_mask is not None: + # NOTE: assumes mask is float and in correct shape + attn = attn + attn_mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + o = torch.einsum('blhp,bpk->blhk', attn, v) + else: + if self.fused_attn: + o = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0 + ) + else: + q = q * self.scale + attn = q @ k.transpose(-1, -2) + if attn_mask is not None: + # NOTE: assumes mask is float and in correct shape + attn = attn + attn_mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + o = attn @ v + + # reshape o into [b, hk, n, n,] + o = self._reshape_output(o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1]) + o = self.upsampling(o) + + x = self.out_proj(o) + x = self.proj_drop(x) + return x + + +class Attention2d(nn.Module): + fused_attn: torch.jit.Final[bool] + + """ multi-head attention for 2D NCHW tensors""" + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + num_heads: int = 32, + bias: bool = True, + expand_first: bool = False, + head_first: bool = False, + attn_drop: float = 0., + proj_drop: float = 0. + ): + super().__init__() + dim_out = dim_out or dim + dim_attn = dim_out if expand_first else dim + self.num_heads = num_heads + self.dim_head = dim_attn // num_heads + self.head_first = head_first + self.scale = num_heads ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + B, C, H, W = x.shape + + if self.head_first: + q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2) + else: + q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1) + + if self.fused_attn: + x = torch.nn.functional.scaled_dot_product_attention( + q.transpose(-1, -2).contiguous(), + k.transpose(-1, -2).contiguous(), + v.transpose(-1, -2).contiguous(), + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0., + ).transpose(-1, -2).reshape(B, -1, H, W) + else: + q = q * self.scale + attn = q.transpose(-2, -1) @ k + if attn_mask is not None: + # NOTE: assumes mask is float and in correct shape + attn = attn + attn_mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + + x = self.proj(x) + x = self.proj_drop(x) + return x From 70176a2dae0039d68d8eea3a0b9c8c4564af2e7e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 May 2024 11:43:05 -0700 Subject: [PATCH 08/27] torchscript typing fixes --- timm/layers/attention2d.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/timm/layers/attention2d.py b/timm/layers/attention2d.py index 3213a9f8..3d3f6d01 100644 --- a/timm/layers/attention2d.py +++ b/timm/layers/attention2d.py @@ -207,7 +207,7 @@ class MultiQueryAttention2d(nn.Module): self.proj_drop = nn.Dropout(proj_drop) self.einsum = False - def _reshape_input(self, t): + def _reshape_input(self, t: torch.Tensor): """Reshapes a tensor to three dimensions, keeping the batch and channels.""" s = t.shape t = t.reshape(s[0], s[1], -1).transpose(1, 2) @@ -216,7 +216,7 @@ class MultiQueryAttention2d(nn.Module): else: return t.unsqueeze(1).contiguous() - def _reshape_projected_query(self, t, num_heads, key_dim): + def _reshape_projected_query(self, t: torch.Tensor, num_heads: int, key_dim: int): """Reshapes projected query: [b, n, n, h x k] -> [b, n x n, h, k].""" s = t.shape t = t.reshape(s[0], num_heads, key_dim, -1) @@ -225,7 +225,7 @@ class MultiQueryAttention2d(nn.Module): else: return t.transpose(-1, -2).contiguous() - def _reshape_output(self, t, num_heads, h_px, w_px): + def _reshape_output(self, t: torch.Tensor, num_heads: int, h_px: int, w_px: int): """Reshape output:[b, n x n x h, k] -> [b, n, n, hk].""" s = t.shape feat_dim = s[-1] * num_heads @@ -233,8 +233,6 @@ class MultiQueryAttention2d(nn.Module): t = t.transpose(1, 2) return t.reshape(s[0], h_px, w_px, feat_dim).permute(0, 3, 1, 2).contiguous() - - def forward(self, x, attn_mask: Optional[torch.Tensor] = None): """Run layer computation.""" B, C, H, W = s = x.shape @@ -273,7 +271,7 @@ class MultiQueryAttention2d(nn.Module): o = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, - dropout_p=self.attn_drop.p if self.training else 0 + dropout_p=self.attn_drop.p if self.training else 0. ) else: q = q * self.scale From cb33956b2072c100f08743ecaaf2423547c269c6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 May 2024 14:24:32 -0700 Subject: [PATCH 09/27] Fix some mistakes in mnv4 model defs --- timm/models/mobilenetv3.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index ad6d8a85..1d12f34b 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -640,7 +640,10 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: # stage 0, 112x112 in ['er_r1_k3_s2_e4_c48'], # stage 1, 56x56 in - ['uir_r1_a3_k5_s2_e4_c80', 'uir_r1_a3_k3_s1_e2_c80'], + [ + 'uir_r1_a3_k5_s2_e4_c80', + 'uir_r1_a3_k3_s1_e2_c80', + ], # stage 2, 28x28 in [ 'uir_r1_a3_k5_s2_e6_c160', @@ -685,7 +688,10 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: # stage 0, 112x112 in ['er_r1_k3_s2_e4_c48'], # stage 1, 56x56 in - ['uir_r1_a3_k5_s2_e4_c96', 'uir_r1_a3_k3_s1_e4_c96'], + [ + 'uir_r1_a3_k5_s2_e4_c96', + 'uir_r1_a3_k3_s1_e4_c96', + ], # stage 2, 28x28 in [ 'uir_r1_a3_k5_s2_e4_c192', @@ -710,13 +716,13 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: 'uir_r1_a5_k3_s1_e4_c512', 'uir_r1_a5_k5_s1_e4_c512', 'mqa_r1_k3_h8_s1_d64_c512', - 'uir_r3_a5_k0_s1_e4_c512', # convnext + 'uir_r1_a5_k0_s1_e4_c512', # convnext 'mqa_r1_k3_h8_s1_d64_c512', - 'uir_r3_a5_k0_s1_e4_c512', # convnext + 'uir_r1_a5_k0_s1_e4_c512', # convnext 'mqa_r1_k3_h8_s1_d64_c512', - 'uir_r3_a5_k0_s1_e4_c512', # convnext + 'uir_r1_a5_k0_s1_e4_c512', # convnext 'mqa_r1_k3_h8_s1_d64_c512', - 'uir_r3_a5_k0_s1_e4_c512', # convnext + 'uir_r1_a5_k0_s1_e4_c512', # convnext ], # stage 4, 7x7 in ['cn_r1_k1_s1_c960'], @@ -758,15 +764,18 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: # stage 0, 112x112 in ['er_r1_k3_s2_e4_c48'], # stage 1, 56x56 in - ['uir_r1_a3_k5_s2_e4_c80', 'uir_r1_a3_k3_s1_e2_c80'], + [ + 'uir_r1_a3_k5_s2_e4_c80', + 'uir_r1_a3_k3_s1_e2_c80', + ], # stage 2, 28x28 in [ - 'uir_r1_a5_k3_s2_e6_c160', + 'uir_r1_a3_k5_s2_e6_c160', 'uir_r2_a3_k3_s1_e4_c160', - 'uir_r1_a3_k3_s1_e4_c160', + 'uir_r1_a3_k5_s1_e4_c160', 'uir_r1_a3_k3_s1_e4_c160', 'uir_r1_a3_k0_s1_e4_c160', # convnext - 'uir_r2_a0_k0_s1_e2_c160', + 'uir_r1_a0_k0_s1_e2_c160', 'uir_r1_a3_k0_s1_e4_c160', # convnext ], # stage 3, 14x14in @@ -776,7 +785,6 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: 'uir_r2_a3_k5_s1_e4_c256', 'uir_r1_a0_k0_s1_e4_c256', 'uir_r1_a3_k0_s1_e4_c256', # convnext - 'uir_r1_a3_k0_s1_e4_c256', # convnext 'uir_r1_a3_k5_s1_e2_c256', 'uir_r1_a5_k5_s1_e4_c256', 'uir_r2_a0_k0_s1_e4_c256', @@ -793,7 +801,10 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: # stage 0, 112x112 in ['er_r1_k3_s2_e4_c48'], # stage 1, 56x56 in - ['uir_r1_a3_k5_s2_e4_c96', 'uir_r1_a3_k3_s1_e4_c96'], + [ + 'uir_r1_a3_k5_s2_e4_c96', + 'uir_r1_a3_k3_s1_e4_c96', + ], # stage 2, 28x28 in [ 'uir_r1_a3_k5_s2_e4_c192', @@ -986,7 +997,7 @@ def mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3: @register_model def mobilenetv3_large_150(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ - model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.5, pretrained=pretrained, **kwargs) + model = _gen_mobilenet_v3('mobilenetv3_large_150', 1.5, pretrained=pretrained, **kwargs) return model From 0c6a69e7ef9d3e2ca82ac967e7a5bb132220b0b9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 May 2024 15:54:05 -0700 Subject: [PATCH 10/27] Add comments to MNV4 model defs with block variants --- timm/models/mobilenetv3.py | 220 +++++++++++++++++++------------------ 1 file changed, 113 insertions(+), 107 deletions(-) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 1d12f34b..7b07b8a1 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -638,47 +638,47 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: act_layer = resolve_act_layer(kwargs, 'relu') arch_def = [ # stage 0, 112x112 in - ['er_r1_k3_s2_e4_c48'], + ['er_r1_k3_s2_e4_c48'], # FusedIB (EdgeResidual) # stage 1, 56x56 in [ - 'uir_r1_a3_k5_s2_e4_c80', - 'uir_r1_a3_k3_s1_e2_c80', + 'uir_r1_a3_k5_s2_e4_c80', # ExtraDW + 'uir_r1_a3_k3_s1_e2_c80', # ExtraDW ], # stage 2, 28x28 in [ - 'uir_r1_a3_k5_s2_e6_c160', - 'uir_r1_a0_k0_s1_e2_c160', - 'uir_r1_a3_k3_s1_e4_c160', - 'uir_r1_a3_k5_s1_e4_c160', - 'mqa_r1_k3_h4_s1_v2_d64_c160', - 'uir_r1_a3_k3_s1_e4_c160', - 'mqa_r1_k3_h4_s1_v2_d64_c160', - 'uir_r1_a3_k0_s1_e4_c160', # convnext - 'mqa_r1_k3_h4_s1_v2_d64_c160', - 'uir_r1_a3_k3_s1_e4_c160', - 'mqa_r1_k3_h4_s1_v2_d64_c160', - 'uir_r1_a3_k0_s1_e4_c160', # convnext + 'uir_r1_a3_k5_s2_e6_c160', # ExtraDW + 'uir_r1_a0_k0_s1_e2_c160', # FFN + 'uir_r1_a3_k3_s1_e4_c160', # ExtraDW + 'uir_r1_a3_k5_s1_e4_c160', # ExtraDW + 'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample + 'uir_r1_a3_k3_s1_e4_c160', # ExtraDW + 'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample + 'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt + 'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample + 'uir_r1_a3_k3_s1_e4_c160', # ExtraDW + 'mqa_r1_k3_h4_s1_v2_d64_c160', # MQA w/ KV downsample + 'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt ], # stage 3, 14x14in [ - 'uir_r1_a5_k5_s2_e6_c256', - 'uir_r1_a5_k5_s1_e4_c256', - 'uir_r2_a3_k5_s1_e4_c256', - 'uir_r1_a0_k0_s1_e2_c256', - 'uir_r1_a3_k5_s1_e2_c256', - 'uir_r1_a0_k0_s1_e2_c256', - 'uir_r1_a0_k0_s1_e4_c256', - 'mqa_r1_k3_h4_s1_d64_c256', - 'uir_r1_a3_k0_s1_e4_c256', # convnext - 'mqa_r1_k3_h4_s1_d64_c256', - 'uir_r1_a5_k5_s1_e4_c256', - 'mqa_r1_k3_h4_s1_d64_c256', - 'uir_r1_a5_k0_s1_e4_c256', # convnext4 - 'mqa_r1_k3_h4_s1_d64_c256', - 'uir_r1_a5_k0_s1_e4_c256', # convnext4 + 'uir_r1_a5_k5_s2_e6_c256', # ExtraDW + 'uir_r1_a5_k5_s1_e4_c256', # ExtraDW + 'uir_r2_a3_k5_s1_e4_c256', # ExtraDW + 'uir_r1_a0_k0_s1_e2_c256', # FFN + 'uir_r1_a3_k5_s1_e2_c256', # ExtraDW + 'uir_r1_a0_k0_s1_e2_c256', # FFN + 'uir_r1_a0_k0_s1_e4_c256', # FFN + 'mqa_r1_k3_h4_s1_d64_c256', # MQA + 'uir_r1_a3_k0_s1_e4_c256', # ConvNeXt + 'mqa_r1_k3_h4_s1_d64_c256', # MQA + 'uir_r1_a5_k5_s1_e4_c256', # ExtraDW + 'mqa_r1_k3_h4_s1_d64_c256', # MQA + 'uir_r1_a5_k0_s1_e4_c256', # ConvNeXt + 'mqa_r1_k3_h4_s1_d64_c256', # MQA + 'uir_r1_a5_k0_s1_e4_c256', # ConvNeXt ], # stage 4, 7x7 in - ['cn_r1_k1_s1_c960'], + ['cn_r1_k1_s1_c960'], # Conv ] elif 'large' in variant: stem_size = 24 @@ -686,43 +686,43 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: act_layer = resolve_act_layer(kwargs, 'gelu') arch_def = [ # stage 0, 112x112 in - ['er_r1_k3_s2_e4_c48'], + ['er_r1_k3_s2_e4_c48'], # FusedIB (EdgeResidual) # stage 1, 56x56 in [ - 'uir_r1_a3_k5_s2_e4_c96', - 'uir_r1_a3_k3_s1_e4_c96', + 'uir_r1_a3_k5_s2_e4_c96', # ExtraDW + 'uir_r1_a3_k3_s1_e4_c96', # ExtraDW ], # stage 2, 28x28 in [ - 'uir_r1_a3_k5_s2_e4_c192', - 'uir_r3_a3_k3_s1_e4_c192', - 'uir_r1_a3_k5_s1_e4_c192', - 'uir_r2_a5_k3_s1_e4_c192', - 'mqa_r1_k3_h8_s1_v2_d48_c192', - 'uir_r1_a5_k3_s1_e4_c192', - 'mqa_r1_k3_h8_s1_v2_d48_c192', - 'uir_r1_a5_k3_s1_e4_c192', - 'mqa_r1_k3_h8_s1_v2_d48_c192', - 'uir_r1_a5_k3_s1_e4_c192', - 'mqa_r1_k3_h8_s1_v2_d48_c192', - 'uir_r1_a3_k0_s1_e4_c192', # convnext + 'uir_r1_a3_k5_s2_e4_c192', # ExtraDW + 'uir_r3_a3_k3_s1_e4_c192', # ExtraDW + 'uir_r1_a3_k5_s1_e4_c192', # ExtraDW + 'uir_r2_a5_k3_s1_e4_c192', # ExtraDW + 'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample + 'uir_r1_a5_k3_s1_e4_c192', # ExtraDW + 'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample + 'uir_r1_a5_k3_s1_e4_c192', # ExtraDW + 'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample + 'uir_r1_a5_k3_s1_e4_c192', # ExtraDW + 'mqa_r1_k3_h8_s1_v2_d48_c192', # MQA w/ KV downsample + 'uir_r1_a3_k0_s1_e4_c192', # ConvNeXt ], # stage 3, 14x14in [ - 'uir_r4_a5_k5_s2_e4_c512', - 'uir_r1_a5_k0_s1_e4_c512', # convnext - 'uir_r1_a5_k3_s1_e4_c512', - 'uir_r2_a5_k0_s1_e4_c512', # convnext - 'uir_r1_a5_k3_s1_e4_c512', - 'uir_r1_a5_k5_s1_e4_c512', - 'mqa_r1_k3_h8_s1_d64_c512', - 'uir_r1_a5_k0_s1_e4_c512', # convnext - 'mqa_r1_k3_h8_s1_d64_c512', - 'uir_r1_a5_k0_s1_e4_c512', # convnext - 'mqa_r1_k3_h8_s1_d64_c512', - 'uir_r1_a5_k0_s1_e4_c512', # convnext - 'mqa_r1_k3_h8_s1_d64_c512', - 'uir_r1_a5_k0_s1_e4_c512', # convnext + 'uir_r4_a5_k5_s2_e4_c512', # ExtraDW + 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt + 'uir_r1_a5_k3_s1_e4_c512', # ExtraDW + 'uir_r2_a5_k0_s1_e4_c512', # ConvNeXt + 'uir_r1_a5_k3_s1_e4_c512', # ExtraDW + 'uir_r1_a5_k5_s1_e4_c512', # ExtraDW + 'mqa_r1_k3_h8_s1_d64_c512', # MQA + 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt + 'mqa_r1_k3_h8_s1_d64_c512', # MQA + 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt + 'mqa_r1_k3_h8_s1_d64_c512', # MQA + 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt + 'mqa_r1_k3_h8_s1_d64_c512', # MQA + 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt ], # stage 4, 7x7 in ['cn_r1_k1_s1_c960'], @@ -736,25 +736,31 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: act_layer = resolve_act_layer(kwargs, 'relu') arch_def = [ # stage 0, 112x112 in - ['cn_r1_k3_s2_e1_c32', 'cn_r1_k1_s1_e1_c32'], + [ + 'cn_r1_k3_s2_e1_c32', # Conv + 'cn_r1_k1_s1_e1_c32', # Conv + ], # stage 1, 56x56 in - ['cn_r1_k3_s2_e1_c96', 'cn_r1_k1_s1_e1_c64'], + [ + 'cn_r1_k3_s2_e1_c96', # Conv + 'cn_r1_k1_s1_e1_c64', # Conv + ], # stage 2, 28x28 in [ - 'uir_r1_a5_k5_s2_e3_c96', # start dw - 'uir_r4_a0_k3_s1_e2_c96', # ir - 'uir_r1_a3_k0_s1_e4_c96', # convnext + 'uir_r1_a5_k5_s2_e3_c96', # ExtraDW + 'uir_r4_a0_k3_s1_e2_c96', # IR + 'uir_r1_a3_k0_s1_e4_c96', # ConvNeXt ], # stage 3, 14x14 in [ - 'uir_r1_a3_k3_s2_e6_c128', # start dw - 'uir_r1_a5_k5_s1_e4_c128', # start dw - 'uir_r1_a0_k5_s1_e4_c128', # ir - 'uir_r1_a0_k5_s1_e3_c128', # ir - 'uir_r2_a0_k5_s1_e4_c128', # ir + 'uir_r1_a3_k3_s2_e6_c128', # ExtraDW + 'uir_r1_a5_k5_s1_e4_c128', # ExtraDW + 'uir_r1_a0_k5_s1_e4_c128', # IR + 'uir_r1_a0_k5_s1_e3_c128', # IR + 'uir_r2_a0_k5_s1_e4_c128', # IR ], # stage 4, 7x7 in - ['cn_r1_k1_s1_c960'], # hard-swish + ['cn_r1_k1_s1_c960'], # Conv ] elif 'medium' in variant: stem_size = 32 @@ -762,36 +768,36 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: act_layer = resolve_act_layer(kwargs, 'relu') arch_def = [ # stage 0, 112x112 in - ['er_r1_k3_s2_e4_c48'], + ['er_r1_k3_s2_e4_c48'], # FusedIB (EdgeResidual) # stage 1, 56x56 in [ - 'uir_r1_a3_k5_s2_e4_c80', - 'uir_r1_a3_k3_s1_e2_c80', + 'uir_r1_a3_k5_s2_e4_c80', # ExtraDW + 'uir_r1_a3_k3_s1_e2_c80', # ExtraDW ], # stage 2, 28x28 in [ - 'uir_r1_a3_k5_s2_e6_c160', - 'uir_r2_a3_k3_s1_e4_c160', - 'uir_r1_a3_k5_s1_e4_c160', - 'uir_r1_a3_k3_s1_e4_c160', - 'uir_r1_a3_k0_s1_e4_c160', # convnext - 'uir_r1_a0_k0_s1_e2_c160', - 'uir_r1_a3_k0_s1_e4_c160', # convnext + 'uir_r1_a3_k5_s2_e6_c160', # ExtraDW + 'uir_r2_a3_k3_s1_e4_c160', # ExtraDW + 'uir_r1_a3_k5_s1_e4_c160', # ExtraDW + 'uir_r1_a3_k3_s1_e4_c160', # ExtraDW + 'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt + 'uir_r1_a0_k0_s1_e2_c160', # ExtraDW + 'uir_r1_a3_k0_s1_e4_c160', # ConvNeXt ], # stage 3, 14x14in [ - 'uir_r1_a5_k5_s2_e6_c256', - 'uir_r1_a5_k5_s1_e4_c256', - 'uir_r2_a3_k5_s1_e4_c256', - 'uir_r1_a0_k0_s1_e4_c256', - 'uir_r1_a3_k0_s1_e4_c256', # convnext - 'uir_r1_a3_k5_s1_e2_c256', - 'uir_r1_a5_k5_s1_e4_c256', - 'uir_r2_a0_k0_s1_e4_c256', - 'uir_r1_a5_k0_s1_e2_c256', # convnext + 'uir_r1_a5_k5_s2_e6_c256', # ExtraDW + 'uir_r1_a5_k5_s1_e4_c256', # ExtraDW + 'uir_r2_a3_k5_s1_e4_c256', # ExtraDW + 'uir_r1_a0_k0_s1_e4_c256', # FFN + 'uir_r1_a3_k0_s1_e4_c256', # ConvNeXt + 'uir_r1_a3_k5_s1_e2_c256', # ExtraDW + 'uir_r1_a5_k5_s1_e4_c256', # ExtraDW + 'uir_r2_a0_k0_s1_e4_c256', # FFN + 'uir_r1_a5_k0_s1_e2_c256', # ConvNeXt ], # stage 4, 7x7 in - ['cn_r1_k1_s1_c960'], + ['cn_r1_k1_s1_c960'], # Conv ] elif 'large' in variant: stem_size = 24 @@ -799,33 +805,33 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: act_layer = resolve_act_layer(kwargs, 'relu') arch_def = [ # stage 0, 112x112 in - ['er_r1_k3_s2_e4_c48'], + ['er_r1_k3_s2_e4_c48'], # FusedIB (EdgeResidual) # stage 1, 56x56 in [ - 'uir_r1_a3_k5_s2_e4_c96', - 'uir_r1_a3_k3_s1_e4_c96', + 'uir_r1_a3_k5_s2_e4_c96', # ExtraDW + 'uir_r1_a3_k3_s1_e4_c96', # ExtraDW ], # stage 2, 28x28 in [ - 'uir_r1_a3_k5_s2_e4_c192', - 'uir_r3_a3_k3_s1_e4_c192', - 'uir_r1_a3_k5_s1_e4_c192', - 'uir_r5_a5_k3_s1_e4_c192', - 'uir_r1_a3_k0_s1_e4_c192', # convnext + 'uir_r1_a3_k5_s2_e4_c192', # ExtraDW + 'uir_r3_a3_k3_s1_e4_c192', # ExtraDW + 'uir_r1_a3_k5_s1_e4_c192', # ExtraDW + 'uir_r5_a5_k3_s1_e4_c192', # ExtraDW + 'uir_r1_a3_k0_s1_e4_c192', # ConvNeXt ], # stage 3, 14x14in [ - 'uir_r4_a5_k5_s2_e4_c512', - 'uir_r1_a5_k0_s1_e4_c512', # convnext - 'uir_r1_a5_k3_s1_e4_c512', - 'uir_r2_a5_k0_s1_e4_c512', # convnext - 'uir_r1_a5_k3_s1_e4_c512', - 'uir_r1_a5_k5_s1_e4_c512', - 'uir_r3_a5_k0_s1_e4_c512', # convnext + 'uir_r4_a5_k5_s2_e4_c512', # ExtraDW + 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt + 'uir_r1_a5_k3_s1_e4_c512', # ExtraDW + 'uir_r2_a5_k0_s1_e4_c512', # ConvNeXt + 'uir_r1_a5_k3_s1_e4_c512', # ExtraDW + 'uir_r1_a5_k5_s1_e4_c512', # ExtraDW + 'uir_r3_a5_k0_s1_e4_c512', # ConvNeXt ], # stage 4, 7x7 in - ['cn_r1_k1_s1_c960'], + ['cn_r1_k1_s1_c960'], # Conv ] else: assert False, f'Unknown variant {variant}.' From 28d76a97dbf092cb7c34aa9e5f68e4dedc716c58 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 May 2024 11:50:42 -0700 Subject: [PATCH 11/27] Mixed up kernel size for last blocks in mnv4-conv-small --- timm/models/mobilenetv3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 7b07b8a1..a9c63f28 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -757,7 +757,7 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: 'uir_r1_a5_k5_s1_e4_c128', # ExtraDW 'uir_r1_a0_k5_s1_e4_c128', # IR 'uir_r1_a0_k5_s1_e3_c128', # IR - 'uir_r2_a0_k5_s1_e4_c128', # IR + 'uir_r2_a0_k3_s1_e4_c128', # IR ], # stage 4, 7x7 in ['cn_r1_k1_s1_c960'], # Conv From 7fe96e7a92262dbbc81325b1fc7583446c0996c0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 May 2024 15:09:29 -0700 Subject: [PATCH 12/27] More MobileNet-v4 fixes * missed final norm after post pooling 1x1 PW head conv * improve repr of model by flipping a few modules to None when not used, nn.Sequential for MultiQueryAttention query/key/value/output * allow layer scaling to be enabled/disabled at model variant level, conv variants don't use it --- timm/layers/attention2d.py | 85 ++++++++++++---------------- timm/models/_efficientnet_blocks.py | 65 +++++++++------------ timm/models/_efficientnet_builder.py | 27 +++++---- timm/models/mobilenetv3.py | 27 ++++++++- 4 files changed, 102 insertions(+), 102 deletions(-) diff --git a/timm/layers/attention2d.py b/timm/layers/attention2d.py index 3d3f6d01..d1d38fb3 100644 --- a/timm/layers/attention2d.py +++ b/timm/layers/attention2d.py @@ -107,6 +107,7 @@ class MultiQueryAttention2d(nn.Module): attn_drop: float = 0., proj_drop: float = 0., norm_layer: nn.Module = nn.BatchNorm2d, + use_bias: bool = False, ): """Initializer. @@ -130,26 +131,25 @@ class MultiQueryAttention2d(nn.Module): self.fused_attn = use_fused_attn() self.drop = attn_drop + self.query = nn.Sequential() if self.has_query_strides: # FIXME dilation - self.query_down_pool = create_pool2d( - 'avg', - kernel_size=self.query_strides, - padding=padding, - ) - self.query_down_norm = norm_layer(dim) - else: - self.query_down_pool = nn.Identity() - self.query_down_norm = nn.Identity() - - self.query_proj = create_conv2d( + self.query.add_module('down_pool', create_pool2d( + 'avg', + kernel_size=self.query_strides, + padding=padding, + )) + self.query.add_module('norm', norm_layer(dim)) + self.query.add_module('proj', create_conv2d( dim, self.num_heads * self.key_dim, kernel_size=1, - ) + bias=use_bias, + )) + self.key = nn.Sequential() if kv_stride > 1: - self.key_down_conv = create_conv2d( + self.key.add_module('down_conv', create_conv2d( dim, dim, kernel_size=dw_kernel_size, @@ -157,21 +157,19 @@ class MultiQueryAttention2d(nn.Module): dilation=dilation, padding=padding, depthwise=True, - ) - self.key_down_norm = norm_layer(dim) - else: - self.key_down_conv = nn.Identity() - self.key_down_norm = nn.Identity() - - self.key_proj = create_conv2d( + )) + self.key.add_module('norm', norm_layer(dim)) + self.key.add_module('proj', create_conv2d( dim, self.key_dim, kernel_size=1, padding=padding, - ) + bias=use_bias, + )) + self.value = nn.Sequential() if kv_stride > 1: - self.value_down_conv = create_conv2d( + self.value.add_module('down_conv', create_conv2d( dim, dim, kernel_size=dw_kernel_size, @@ -179,32 +177,28 @@ class MultiQueryAttention2d(nn.Module): dilation=dilation, padding=padding, depthwise=True, - ) - self.value_down_norm = norm_layer(dim) - else: - self.value_down_conv = nn.Identity() - self.value_down_norm = nn.Identity() - - self.value_proj = create_conv2d( + )) + self.value.add_module('norm', norm_layer(dim)) + self.value.add_module('proj', create_conv2d( dim, self.value_dim, kernel_size=1, - ) + bias=use_bias, + )) self.attn_drop = nn.Dropout(attn_drop) + self.output = nn.Sequential() if self.has_query_strides: - self.upsampling = nn.Upsample(self.query_strides, mode='bilinear', align_corners=False) - else: - self.upsampling = nn.Identity() - - self.out_proj = create_conv2d( + self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False)) + self.output.add_module('proj', create_conv2d( self.value_dim * self.num_heads, dim_out, kernel_size=1, - ) + bias=use_bias, + )) + self.output.add_module('drop', nn.Dropout(proj_drop)) - self.proj_drop = nn.Dropout(proj_drop) self.einsum = False def _reshape_input(self, t: torch.Tensor): @@ -237,21 +231,15 @@ class MultiQueryAttention2d(nn.Module): """Run layer computation.""" B, C, H, W = s = x.shape - q = self.query_down_pool(x) - q = self.query_down_norm(q) - q = self.query_proj(q) + q = self.query(x) # desired q shape: [b, h, k, n x n] - [b, l, h, k] q = self._reshape_projected_query(q, self.num_heads, self.key_dim) - k = self.key_down_conv(x) - k = self.key_down_norm(k) - k = self.key_proj(k) + k = self.key(x) # output shape of k: [b, k, p], p = m x m k = self._reshape_input(k) - v = self.value_down_conv(x) - v = self.value_down_norm(v) - v = self.value_proj(v) + v = self.value(x) # output shape of v: [ b, p, k], p = m x m v = self._reshape_input(v) @@ -285,10 +273,7 @@ class MultiQueryAttention2d(nn.Module): # reshape o into [b, hk, n, n,] o = self._reshape_output(o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1]) - o = self.upsampling(o) - - x = self.out_proj(o) - x = self.proj_drop(x) + x = self.output(o) return x diff --git a/timm/models/_efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py index 41f3182d..be00b01c 100644 --- a/timm/models/_efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -174,13 +174,12 @@ class DepthwiseSeparableConv(nn.Module): def forward(self, x): shortcut = x - #print('ii', x.shape) + #print('ii', x.shape) # FIXME debug s2d if self.conv_s2d is not None: x = self.conv_s2d(x) x = self.bn_s2d(x) - #print('id', x.shape) + #print('id', x.shape) # FIXME debug s2d x = self.conv_dw(x) - #print('od', x.shape) x = self.bn1(x) x = self.se(x) x = self.conv_pw(x) @@ -296,7 +295,8 @@ class LayerScale2d(nn.Module): class UniversalInvertedResidual(nn.Module): """ Universal Inverted Residual Block - For MobileNetV4 - https://arxiv.org/abs/ + For MobileNetV4 - https://arxiv.org/abs/, referenced from + https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L778 """ def __init__( @@ -338,8 +338,9 @@ class UniversalInvertedResidual(nn.Module): ) self.norm_dw_start = dw_norm_act_layer(in_chs, apply_act=False) else: - self.conv_dw_start = nn.Identity() - self.norm_dw_start = nn.Identity() + # start is None when not used for cleaner repr + self.conv_dw_start = None + self.norm_dw_start = None # Point-wise expansion mid_chs = make_divisible(in_chs * exp_ratio) @@ -359,6 +360,7 @@ class UniversalInvertedResidual(nn.Module): ) self.norm_dw_mid = dw_norm_act_layer(mid_chs, inplace=True) else: + # keeping mid as identity so it can be hooked more easily for features self.conv_dw_mid = nn.Identity() self.norm_dw_mid = nn.Identity() @@ -379,7 +381,7 @@ class UniversalInvertedResidual(nn.Module): ) self.norm_dw_end = dw_norm_act_layer(out_chs, apply_act=False) else: - # dw_end rarely used so keeping it out of repr by not using None instead of nn.Identitty() + # end is None when not in use for cleaner repr self.conv_dw_end = None self.norm_dw_end = None @@ -397,8 +399,9 @@ class UniversalInvertedResidual(nn.Module): def forward(self, x): shortcut = x - x = self.conv_dw_start(x) - x = self.norm_dw_start(x) + if self.conv_dw_start is not None: + x = self.conv_dw_start(x) + x = self.norm_dw_start(x) x = self.conv_pw(x) x = self.norm_pw(x) x = self.conv_dw_mid(x) @@ -418,7 +421,8 @@ class UniversalInvertedResidual(nn.Module): class MobileAttention(nn.Module): """ Mobile Attention Block - For MobileNetV4 - https://arxiv.org/abs/ + For MobileNetV4 - https://arxiv.org/abs/, referenced from + https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L1504 """ def __init__( self, @@ -476,34 +480,21 @@ class MobileAttention(nn.Module): num_heads = in_chs // key_dim if use_multi_query: - #if self.has_query_stride or self.kv_stride > 1: - self.attn = ( - MultiQueryAttention2d( - in_chs, - dim_out=out_chs, - num_heads=num_heads, - key_dim=key_dim, - value_dim=value_dim, - query_strides=query_strides, - kv_stride=kv_stride, - dilation=dilation, - padding=pad_type, - dw_kernel_size=dw_kernel_size, - attn_drop=attn_drop, - proj_drop=proj_drop, - #bias=use_bias, # why not here if used w/ mhsa? - ) + self.attn = MultiQueryAttention2d( + in_chs, + dim_out=out_chs, + num_heads=num_heads, + key_dim=key_dim, + value_dim=value_dim, + query_strides=query_strides, + kv_stride=kv_stride, + dilation=dilation, + padding=pad_type, + dw_kernel_size=dw_kernel_size, + attn_drop=attn_drop, + proj_drop=proj_drop, + #bias=use_bias, # why not here if used w/ mhsa? ) - # else: - # self.attn = MultiQueryAttentionV2( - # in_chs, - # dim_out=out_chs, - # num_heads=num_heads, - # key_dim=key_dim, - # value_dim=value_dim, - # attn_drop=attn_drop, - # proj_drop=proj_drop, - # ) else: self.attn = Attention2d( in_chs, diff --git a/timm/models/_efficientnet_builder.py b/timm/models/_efficientnet_builder.py index 4cbd6342..7d96216a 100644 --- a/timm/models/_efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -5,6 +5,7 @@ Handles stride, dilation calculations, and selects feature extraction points. Hacked together by / Copyright 2019, Ross Wightman """ +from typing import Callable, Optional import logging import math @@ -321,15 +322,16 @@ class EfficientNetBuilder: """ def __init__( self, - output_stride=32, - pad_type='', - round_chs_fn=round_channels, - se_from_exp=False, - act_layer=None, - norm_layer=None, - se_layer=None, - drop_path_rate=0., - feature_location='', + output_stride: int = 32, + pad_type: str = '', + round_chs_fn: Callable = round_channels, + se_from_exp: bool = False, + act_layer: Optional[Callable] = None, + norm_layer: Optional[Callable] = None, + se_layer: Optional[Callable] = None, + drop_path_rate: float = 0., + layer_scale_init_value: Optional[float] = None, + feature_location: str = '', ): self.output_stride = output_stride self.pad_type = pad_type @@ -344,6 +346,7 @@ class EfficientNetBuilder: except TypeError: self.se_has_ratio = False self.drop_path_rate = drop_path_rate + self.layer_scale_init_value = layer_scale_init_value if feature_location == 'depthwise': # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense _logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'") @@ -402,13 +405,13 @@ class EfficientNetBuilder: block = ConvBnAct(**ba) elif bt == 'uir': _log_info_if(' UniversalInvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) - block = UniversalInvertedResidual(**ba) + block = UniversalInvertedResidual(**ba, layer_scale_init_value=self.layer_scale_init_value) elif bt == 'mqa': _log_info_if(' MobileMultiQueryAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose) - block = MobileAttention(**ba, use_multi_query=True) + block = MobileAttention(**ba, use_multi_query=True, layer_scale_init_value=self.layer_scale_init_value) elif bt == 'mha': _log_info_if(' MobileMultiHeadAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose) - block = MobileAttention(**ba) + block = MobileAttention(**ba, layer_scale_init_value=self.layer_scale_init_value) else: assert False, 'Unknown block type (%s) while building model.' % bt diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index a9c63f28..e90a8df4 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -51,6 +51,7 @@ class MobileNetV3(nn.Module): fix_stem: bool = False, num_features: int = 1280, head_bias: bool = True, + head_norm: bool = False, pad_type: PadType = '', act_layer: Optional[LayerType] = None, norm_layer: Optional[LayerType] = None, @@ -59,6 +60,7 @@ class MobileNetV3(nn.Module): round_chs_fn: Callable = round_channels, drop_rate: float = 0., drop_path_rate: float = 0., + layer_scale_init_value: Optional[float] = None, global_pool: str = 'avg', ): """ @@ -78,6 +80,7 @@ class MobileNetV3(nn.Module): round_chs_fn: Callable to round number of filters based on depth multiplier. drop_rate: Dropout rate. drop_path_rate: Stochastic depth rate. + layer_scale_init_value: Enable layer scale on compatible blocks if not None global_pool: Type of pooling to use for global pooling features of the FC head. """ super(MobileNetV3, self).__init__() @@ -106,6 +109,7 @@ class MobileNetV3(nn.Module): norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features @@ -115,8 +119,16 @@ class MobileNetV3(nn.Module): # Head + Pooling self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) num_pooled_chs = head_chs * self.global_pool.feat_mult() - self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) - self.act2 = act_layer(inplace=True) + if head_norm: + # mobilenet-v4 post-pooling PW conv is followed by a norm+act layer + self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type) # never bias + self.norm_head = norm_act_layer(self.num_features) + self.act2 = nn.Identity() + else: + # mobilenet-v3 and others only have an activation after final PW conv + self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) + self.norm_head = nn.Identity() + self.act2 = act_layer(inplace=True) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() @@ -125,7 +137,7 @@ class MobileNetV3(nn.Module): def as_sequential(self): layers = [self.conv_stem, self.bn1] layers.extend(self.blocks) - layers.extend([self.global_pool, self.conv_head, self.act2]) + layers.extend([self.global_pool, self.conv_head, self.norm_head, self.act2]) layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) return nn.Sequential(*layers) @@ -224,8 +236,10 @@ class MobileNetV3(nn.Module): self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0 if max_index < len(self.blocks): self.conv_head = nn.Identity() + self.norm_head = nn.Identity() if prune_head: self.conv_head = nn.Identity() + self.norm_head = nn.Identity() self.reset_classifier(0, '') return take_indices @@ -241,6 +255,7 @@ class MobileNetV3(nn.Module): def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: x = self.global_pool(x) x = self.conv_head(x) + x = self.norm_head(x) x = self.act2(x) x = self.flatten(x) if pre_logits: @@ -632,6 +647,7 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: channel_multiplier: multiplier to number of channels per layer. """ if 'hybrid' in variant: + layer_scale_init_value = 1e-5 if 'medium' in variant: stem_size = 32 num_features = 1280 @@ -730,6 +746,7 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: else: assert False, f'Unknown variant {variant}.' else: + layer_scale_init_value = None if 'small' in variant: stem_size = 32 num_features = 1280 @@ -836,9 +853,12 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: else: assert False, f'Unknown variant {variant}.' + # NOTE SE not used in initial MobileNet-v4 definitions se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) model_kwargs = dict( block_args=decode_arch_def(arch_def), + head_bias=False, + head_norm=True, num_features=num_features, stem_size=stem_size, fix_stem=channel_multiplier < 0.75, @@ -846,6 +866,7 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=act_layer, se_layer=se_layer, + layer_scale_init_value=layer_scale_init_value, **kwargs, ) model = _create_mnv3(variant, pretrained, **model_kwargs) From a12b72b5c4f4cd28985ff1bae347b886f8f41e6d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 May 2024 15:50:34 -0700 Subject: [PATCH 13/27] Fix missing head_norm arg pop for feature model --- timm/models/mobilenetv3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index e90a8df4..be1ea876 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -385,7 +385,7 @@ def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV if 'feature_cfg' in kwargs or 'feature_cls' in kwargs: features_mode = 'cfg' else: - kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool') + kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'head_norm', 'global_pool') model_cls = MobileNetV3Features features_mode = 'cls' From 4ff7c257665189aec80284d590fddf286a9aaf8a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 May 2024 16:44:50 -0700 Subject: [PATCH 14/27] Pass layer_scale_init_value to Mnv3Features module --- timm/models/mobilenetv3.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index be1ea876..40f201b9 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -80,7 +80,7 @@ class MobileNetV3(nn.Module): round_chs_fn: Callable to round number of filters based on depth multiplier. drop_rate: Dropout rate. drop_path_rate: Stochastic depth rate. - layer_scale_init_value: Enable layer scale on compatible blocks if not None + layer_scale_init_value: Enable layer scale on compatible blocks if not None. global_pool: Type of pooling to use for global pooling features of the FC head. """ super(MobileNetV3, self).__init__() @@ -294,6 +294,7 @@ class MobileNetV3Features(nn.Module): se_layer: Optional[LayerType] = None, drop_rate: float = 0., drop_path_rate: float = 0., + layer_scale_init_value: Optional[float] = None, ): """ Args: @@ -312,6 +313,7 @@ class MobileNetV3Features(nn.Module): se_layer: Type of Squeeze-and-Excite layer. drop_rate: Dropout rate. drop_path_rate: Stochastic depth rate. + layer_scale_init_value: Enable layer scale on compatible blocks if not None. """ super(MobileNetV3Features, self).__init__() act_layer = act_layer or nn.ReLU @@ -337,6 +339,7 @@ class MobileNetV3Features(nn.Module): norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, feature_location=feature_location, ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) From 5fa6efa158e5c6a7edc3a322c33a4deba8badfce Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 27 May 2024 22:06:22 -0700 Subject: [PATCH 15/27] Add anti-aliasing support to mobilenetv3 and efficientnet family models. Update MobileNetV4 model defs, resolutions. Fix #599 * create_aa helper function centralized for all timm uses (resnet, convbnact helper) * allow BlurPool w/ pre-defined channels (expand) * mobilenetv4 UIB block using ConvNormAct layers for improved clarity, esp with AA added * improve more mobilenetv3 and efficientnet related type annotations --- timm/layers/__init__.py | 2 +- timm/layers/blur_pool.py | 57 +++- timm/layers/conv_bn_act.py | 110 ++++---- timm/models/_efficientnet_blocks.py | 372 +++++++++++++++------------ timm/models/_efficientnet_builder.py | 14 +- timm/models/efficientnet.py | 80 +++--- timm/models/mobilenetv3.py | 162 +++++++----- timm/models/resnet.py | 11 +- 8 files changed, 479 insertions(+), 329 deletions(-) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index b44e1161..3f023572 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -4,7 +4,7 @@ from .adaptive_avgmax_pool import \ from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2 from .attention_pool import AttentionPoolLatent from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding -from .blur_pool import BlurPool2d +from .blur_pool import BlurPool2d, create_aa from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead from .cond_conv2d import CondConv2d, get_condconv_initializer from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \ diff --git a/timm/layers/blur_pool.py b/timm/layers/blur_pool.py index e73d8863..6a4b668c 100644 --- a/timm/layers/blur_pool.py +++ b/timm/layers/blur_pool.py @@ -5,12 +5,16 @@ BlurPool layer inspired by Hacked together by Chris Ha and Ross Wightman """ +from functools import partial +from typing import Optional, Type import torch import torch.nn as nn import torch.nn.functional as F import numpy as np + from .padding import get_padding +from .typing import LayerType class BlurPool2d(nn.Module): @@ -26,17 +30,62 @@ class BlurPool2d(nn.Module): Returns: torch.Tensor: the transformed tensor. """ - def __init__(self, channels, filt_size=3, stride=2) -> None: + def __init__( + self, + channels: Optional[int] = None, + filt_size: int = 3, + stride: int = 2, + pad_mode: str = 'reflect', + ) -> None: super(BlurPool2d, self).__init__() assert filt_size > 1 self.channels = channels self.filt_size = filt_size self.stride = stride + self.pad_mode = pad_mode self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 + coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) - blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) + blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :] + if channels is not None: + blur_filter = blur_filter.repeat(self.channels, 1, 1, 1) self.register_buffer('filt', blur_filter, persistent=False) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = F.pad(x, self.padding, 'reflect') - return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels) + x = F.pad(x, self.padding, mode=self.pad_mode) + if self.channels is None: + channels = x.shape[1] + weight = self.filt.expand(channels, 1, self.filt_size, self.filt_size) + else: + channels = self.channels + weight = self.filt + return F.conv2d(x, weight, stride=self.stride, groups=channels) + + +def create_aa( + aa_layer: LayerType, + channels: Optional[int] = None, + stride: int = 2, + enable: bool = True, + noop: Optional[Type[nn.Module]] = nn.Identity +) -> nn.Module: + """ Anti-aliasing """ + if not aa_layer or not enable: + return noop() if noop is not None else None + + if isinstance(aa_layer, str): + aa_layer = aa_layer.lower().replace('_', '').replace('-', '') + if aa_layer == 'avg' or aa_layer == 'avgpool': + aa_layer = nn.AvgPool2d + elif aa_layer == 'blur' or aa_layer == 'blurpool': + aa_layer = BlurPool2d + elif aa_layer == 'blurpc': + aa_layer = partial(BlurPool2d, pad_mode='constant') + + else: + assert False, f"Unknown anti-aliasing layer ({aa_layer})." + + try: + return aa_layer(channels=channels, stride=stride) + except TypeError as e: + return aa_layer(stride) diff --git a/timm/layers/conv_bn_act.py b/timm/layers/conv_bn_act.py index 84aaf4bf..17847d76 100644 --- a/timm/layers/conv_bn_act.py +++ b/timm/layers/conv_bn_act.py @@ -2,9 +2,12 @@ Hacked together by / Copyright 2020 Ross Wightman """ -import functools +from typing import Any, Dict, Optional, Type + from torch import nn as nn +from .typing import LayerType, PadType +from .blur_pool import create_aa from .create_conv2d import create_conv2d from .create_norm_act import get_norm_act_layer @@ -12,28 +15,38 @@ from .create_norm_act import get_norm_act_layer class ConvNormAct(nn.Module): def __init__( self, - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding='', - dilation=1, - groups=1, - bias=False, - apply_act=True, - norm_layer=nn.BatchNorm2d, - norm_kwargs=None, - act_layer=nn.ReLU, - act_kwargs=None, - drop_layer=None, + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + padding: PadType = '', + dilation: int = 1, + groups: int = 1, + bias: bool = False, + apply_act: bool = True, + norm_layer: LayerType = nn.BatchNorm2d, + act_layer: LayerType = nn.ReLU, + drop_layer: Optional[Type[nn.Module]] = None, + conv_kwargs: Optional[Dict[str, Any]] = None, + norm_kwargs: Optional[Dict[str, Any]] = None, + act_kwargs: Optional[Dict[str, Any]] = None, ): super(ConvNormAct, self).__init__() + conv_kwargs = conv_kwargs or {} norm_kwargs = norm_kwargs or {} act_kwargs = act_kwargs or {} self.conv = create_conv2d( - in_channels, out_channels, kernel_size, stride=stride, - padding=padding, dilation=dilation, groups=groups, bias=bias) + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + **conv_kwargs, + ) # NOTE for backwards compatibility with models that use separate norm and act layer definitions norm_act_layer = get_norm_act_layer(norm_layer, act_layer) @@ -64,54 +77,53 @@ class ConvNormAct(nn.Module): ConvBnAct = ConvNormAct -def create_aa(aa_layer, channels, stride=2, enable=True): - if not aa_layer or not enable: - return nn.Identity() - if isinstance(aa_layer, functools.partial): - if issubclass(aa_layer.func, nn.AvgPool2d): - return aa_layer() - else: - return aa_layer(channels) - elif issubclass(aa_layer, nn.AvgPool2d): - return aa_layer(stride) - else: - return aa_layer(channels=channels, stride=stride) - - class ConvNormActAa(nn.Module): def __init__( self, - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding='', - dilation=1, - groups=1, - bias=False, - apply_act=True, - norm_layer=nn.BatchNorm2d, - norm_kwargs=None, - act_layer=nn.ReLU, - act_kwargs=None, - aa_layer=None, - drop_layer=None, + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + padding: PadType = '', + dilation: int = 1, + groups: int = 1, + bias: bool = False, + apply_act: bool = True, + norm_layer: LayerType = nn.BatchNorm2d, + act_layer: LayerType = nn.ReLU, + aa_layer: Optional[LayerType] = None, + drop_layer: Optional[Type[nn.Module]] = None, + conv_kwargs: Optional[Dict[str, Any]] = None, + norm_kwargs: Optional[Dict[str, Any]] = None, + act_kwargs: Optional[Dict[str, Any]] = None, ): super(ConvNormActAa, self).__init__() use_aa = aa_layer is not None and stride == 2 + conv_kwargs = conv_kwargs or {} norm_kwargs = norm_kwargs or {} act_kwargs = act_kwargs or {} self.conv = create_conv2d( - in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, - padding=padding, dilation=dilation, groups=groups, bias=bias) + in_channels, out_channels, kernel_size, + stride=1 if use_aa else stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + **conv_kwargs, + ) # NOTE for backwards compatibility with models that use separate norm and act layer definitions norm_act_layer = get_norm_act_layer(norm_layer, act_layer) # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` if drop_layer: norm_kwargs['drop_layer'] = drop_layer - self.bn = norm_act_layer(out_channels, apply_act=apply_act, act_kwargs=act_kwargs, **norm_kwargs) + self.bn = norm_act_layer( + out_channels, + apply_act=apply_act, + act_kwargs=act_kwargs, + **norm_kwargs, + ) self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) @property diff --git a/timm/models/_efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py index be00b01c..f33dacd5 100644 --- a/timm/models/_efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -2,22 +2,24 @@ Hacked together by / Copyright 2019, Ross Wightman """ -from typing import Optional +from typing import Callable, Dict, Optional, Type import torch import torch.nn as nn from torch.nn import functional as F -from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, to_2tuple,\ - get_norm_act_layer, MultiQueryAttention2d, MultiQueryAttentionV2, Attention2d +from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, create_aa, to_2tuple, LayerType,\ + ConvNormAct, ConvNormActAa, get_norm_act_layer, MultiQueryAttention2d, Attention2d __all__ = [ 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'UniversalInvertedResidual', 'MobileAttention' ] +ModuleType = Type[nn.Module] -def num_groups(group_size, channels): + +def num_groups(group_size: Optional[int], channels: int): if not group_size: # 0 or None return 1 # normal conv with 1 group else: @@ -40,13 +42,13 @@ class SqueezeExcite(nn.Module): def __init__( self, - in_chs, - rd_ratio=0.25, - rd_channels=None, - act_layer=nn.ReLU, - gate_layer=nn.Sigmoid, - force_act_layer=None, - rd_round_fn=None, + in_chs: int, + rd_ratio: float = 0.25, + rd_channels: Optional[int] = None, + act_layer: LayerType = nn.ReLU, + gate_layer: LayerType = nn.Sigmoid, + force_act_layer: Optional[LayerType] = None, + rd_round_fn: Optional[Callable] = None, ): super(SqueezeExcite, self).__init__() if rd_channels is None: @@ -71,27 +73,31 @@ class ConvBnAct(nn.Module): """ def __init__( self, - in_chs, - out_chs, - kernel_size, - stride=1, - dilation=1, - group_size=0, - pad_type='', - skip=False, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - drop_path_rate=0., + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + group_size: int = 0, + pad_type: str = '', + skip: bool = False, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[LayerType] = None, + drop_path_rate: float = 0., ): super(ConvBnAct, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) groups = num_groups(group_size, in_chs) self.has_skip = skip and stride == 1 and in_chs == out_chs + use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation self.conv = create_conv2d( in_chs, out_chs, kernel_size, - stride=stride, dilation=dilation, groups=groups, padding=pad_type) + stride=1 if use_aa else stride, + dilation=dilation, groups=groups, padding=pad_type) self.bn1 = norm_act_layer(out_chs, inplace=True) + self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa) self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() def feature_info(self, location): @@ -104,6 +110,7 @@ class ConvBnAct(nn.Module): shortcut = x x = self.conv(x) x = self.bn1(x) + x = self.aa(x) if self.has_skip: x = self.drop_path(x) + shortcut return x @@ -116,37 +123,38 @@ class DepthwiseSeparableConv(nn.Module): """ def __init__( self, - in_chs, - out_chs, - dw_kernel_size=3, - stride=1, - dilation=1, - group_size=1, - pad_type='', - noskip=False, - pw_kernel_size=1, - pw_act=False, - s2d=0, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - se_layer=None, - drop_path_rate=0., + in_chs: int, + out_chs: int, + dw_kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + group_size: int = 1, + pad_type: str = '', + noskip: bool = False, + pw_kernel_size: int = 1, + pw_act: bool = False, + s2d: int = 0, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[ModuleType] = None, + drop_path_rate: float = 0., ): super(DepthwiseSeparableConv, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv + use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation # Space to depth if s2d == 1: sd_chs = int(in_chs * 4) - #sd_pad_type = 'sam' - self.conv_s2d = create_conv2d( - in_chs, sd_chs, kernel_size=2, stride=2, padding=0) #'same') + self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same') self.bn_s2d = norm_act_layer(sd_chs, sd_chs) dw_kernel_size = (dw_kernel_size + 1) // 2 dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type in_chs = sd_chs + use_aa = False # disable AA else: self.conv_s2d = None self.bn_s2d = None @@ -156,8 +164,10 @@ class DepthwiseSeparableConv(nn.Module): self.conv_dw = create_conv2d( in_chs, in_chs, dw_kernel_size, - stride=stride, dilation=dilation, padding=dw_pad_type, groups=groups) + stride=1 if use_aa else stride, + dilation=dilation, padding=dw_pad_type, groups=groups) self.bn1 = norm_act_layer(in_chs, inplace=True) + self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa) # Squeeze-and-excitation self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity() @@ -174,13 +184,12 @@ class DepthwiseSeparableConv(nn.Module): def forward(self, x): shortcut = x - #print('ii', x.shape) # FIXME debug s2d if self.conv_s2d is not None: x = self.conv_s2d(x) x = self.bn_s2d(x) - #print('id', x.shape) # FIXME debug s2d x = self.conv_dw(x) x = self.bn1(x) + x = self.aa(x) x = self.se(x) x = self.conv_pw(x) x = self.bn2(x) @@ -201,37 +210,40 @@ class InvertedResidual(nn.Module): def __init__( self, - in_chs, - out_chs, - dw_kernel_size=3, - stride=1, - dilation=1, - group_size=1, - pad_type='', - noskip=False, - exp_ratio=1.0, - exp_kernel_size=1, - pw_kernel_size=1, - s2d=0, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - se_layer=None, - conv_kwargs=None, - drop_path_rate=0., + in_chs: int, + out_chs: int, + dw_kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + group_size: int = 1, + pad_type: str = '', + noskip: bool = False, + exp_ratio: float = 1.0, + exp_kernel_size: int = 1, + pw_kernel_size: int = 1, + s2d: int = 0, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[ModuleType] = None, + conv_kwargs: Optional[Dict] = None, + drop_path_rate: float = 0., ): super(InvertedResidual, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) conv_kwargs = conv_kwargs or {} self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation # Space to depth if s2d == 1: sd_chs = int(in_chs * 4) - self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding=pad_type) + self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same') self.bn_s2d = norm_act_layer(sd_chs, sd_chs) dw_kernel_size = (dw_kernel_size + 1) // 2 dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type in_chs = sd_chs + use_aa = False # disable AA else: self.conv_s2d = None self.bn_s2d = None @@ -247,8 +259,10 @@ class InvertedResidual(nn.Module): # Depth-wise convolution self.conv_dw = create_conv2d( mid_chs, mid_chs, dw_kernel_size, - stride=stride, dilation=dilation, groups=groups, padding=dw_pad_type, **conv_kwargs) + stride=1 if use_aa else stride, + dilation=dilation, groups=groups, padding=dw_pad_type, **conv_kwargs) self.bn2 = norm_act_layer(mid_chs, inplace=True) + self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa) # Squeeze-and-excitation self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() @@ -273,6 +287,7 @@ class InvertedResidual(nn.Module): x = self.bn1(x) x = self.conv_dw(x) x = self.bn2(x) + x = self.aa(x) x = self.se(x) x = self.conv_pwl(x) x = self.bn3(x) @@ -282,7 +297,7 @@ class InvertedResidual(nn.Module): class LayerScale2d(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): + def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) @@ -293,7 +308,7 @@ class LayerScale2d(nn.Module): class UniversalInvertedResidual(nn.Module): - """ Universal Inverted Residual Block + """ Universal Inverted Residual Block (aka Universal Inverted Bottleneck, UIB) For MobileNetV4 - https://arxiv.org/abs/, referenced from https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L778 @@ -301,89 +316,109 @@ class UniversalInvertedResidual(nn.Module): def __init__( self, - in_chs, - out_chs, + in_chs: int, + out_chs: int, dw_kernel_size_start: int = 0, dw_kernel_size_mid: int = 3, dw_kernel_size_end: int = 0, - stride=1, - dilation=1, - group_size=1, - pad_type='', - noskip=False, - exp_ratio=1.0, - act_layer=nn.ReLU, - dw_act_layer=None, - norm_layer=nn.BatchNorm2d, - se_layer=None, - conv_kwargs=None, - drop_path_rate=0., + stride: int = 1, + dilation: int = 1, + group_size: int = 1, + pad_type: str = '', + noskip: bool = False, + exp_ratio: float = 1.0, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[ModuleType] = None, + conv_kwargs: Optional[Dict] = None, + drop_path_rate: float = 0., layer_scale_init_value: Optional[float] = 1e-5, ): super(UniversalInvertedResidual, self).__init__() - norm_act_layer = get_norm_act_layer(norm_layer, act_layer) - dw_act_layer = dw_act_layer or act_layer - dw_norm_act_layer = get_norm_act_layer(norm_layer, dw_act_layer) conv_kwargs = conv_kwargs or {} self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + if stride > 1: + assert dw_kernel_size_start or dw_kernel_size_mid or dw_kernel_size_end # FIXME dilation isn't right w/ extra ks > 1 convs if dw_kernel_size_start: - self.conv_dw_start = create_conv2d( + dw_start_stride = stride if not dw_kernel_size_mid else 1 + dw_start_groups = num_groups(group_size, in_chs) + self.dw_start = ConvNormActAa( in_chs, in_chs, dw_kernel_size_start, + stride=dw_start_stride, dilation=dilation, # FIXME - depthwise=True, + groups=dw_start_groups, padding=pad_type, + apply_act=False, + act_layer=act_layer, + norm_layer=norm_layer, + aa_layer=aa_layer, **conv_kwargs, ) - self.norm_dw_start = dw_norm_act_layer(in_chs, apply_act=False) else: - # start is None when not used for cleaner repr - self.conv_dw_start = None - self.norm_dw_start = None + self.dw_start = nn.Identity() # Point-wise expansion mid_chs = make_divisible(in_chs * exp_ratio) - self.conv_pw = create_conv2d(in_chs, mid_chs, 1, padding=pad_type, **conv_kwargs) - self.norm_pw = norm_act_layer(mid_chs, inplace=True) + self.pw_exp = ConvNormAct( + in_chs, mid_chs, 1, + padding=pad_type, + act_layer=act_layer, + norm_layer=norm_layer, + **conv_kwargs, + ) - # Depth-wise convolution + # Middle depth-wise convolution if dw_kernel_size_mid: groups = num_groups(group_size, mid_chs) - self.conv_dw_mid = create_conv2d( + self.dw_mid = ConvNormActAa( mid_chs, mid_chs, dw_kernel_size_mid, stride=stride, dilation=dilation, # FIXME groups=groups, padding=pad_type, + act_layer=act_layer, + norm_layer=norm_layer, + aa_layer=aa_layer, **conv_kwargs, ) - self.norm_dw_mid = dw_norm_act_layer(mid_chs, inplace=True) else: # keeping mid as identity so it can be hooked more easily for features - self.conv_dw_mid = nn.Identity() - self.norm_dw_mid = nn.Identity() + self.dw_mid = nn.Identity() # Squeeze-and-excitation self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() # Point-wise linear projection - self.conv_pwl = create_conv2d(mid_chs, out_chs, 1, padding=pad_type, **conv_kwargs) - self.norm_pwl = norm_act_layer(out_chs, apply_act=False) + self.pw_proj = ConvNormAct( + mid_chs, out_chs, 1, + padding=pad_type, + apply_act=False, + act_layer=act_layer, + norm_layer=norm_layer, + **conv_kwargs, + ) if dw_kernel_size_end: - self.conv_dw_end = create_conv2d( + dw_end_stride = stride if not dw_kernel_size_start and not dw_kernel_size_mid else 1 + dw_end_groups = num_groups(group_size, out_chs) + if dw_end_stride > 1: + assert not aa_layer + self.dw_end = ConvNormAct( out_chs, out_chs, dw_kernel_size_end, + stride=dw_end_stride, dilation=dilation, - depthwise=True, + groups=dw_end_groups, padding=pad_type, + apply_act=False, + act_layer=act_layer, + norm_layer=norm_layer, **conv_kwargs, ) - self.norm_dw_end = dw_norm_act_layer(out_chs, apply_act=False) else: - # end is None when not in use for cleaner repr - self.conv_dw_end = None - self.norm_dw_end = None + self.dw_end = nn.Identity() if layer_scale_init_value is not None: self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value) @@ -393,25 +428,18 @@ class UniversalInvertedResidual(nn.Module): def feature_info(self, location): if location == 'expansion': # after SE, input to PWL - return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) + return dict(module='pw_proj.conv', hook_type='forward_pre', num_chs=self.pw_proj.conv.in_channels) else: # location == 'bottleneck', block output - return dict(module='', num_chs=self.conv_pwl.out_channels) + return dict(module='', num_chs=self.pw_proj.conv.out_channels) def forward(self, x): shortcut = x - if self.conv_dw_start is not None: - x = self.conv_dw_start(x) - x = self.norm_dw_start(x) - x = self.conv_pw(x) - x = self.norm_pw(x) - x = self.conv_dw_mid(x) - x = self.norm_dw_mid(x) + x = self.dw_start(x) + x = self.pw_exp(x) + x = self.dw_mid(x) x = self.se(x) - x = self.conv_pwl(x) - x = self.norm_pwl(x) - if self.conv_dw_end is not None: - x = self.conv_dw_end(x) - x = self.norm_dw_end(x) + x = self.pw_proj(x) + x = self.dw_end(x) x = self.layer_scale(x) if self.has_skip: x = self.drop_path(x) + shortcut @@ -426,29 +454,30 @@ class MobileAttention(nn.Module): """ def __init__( self, - in_chs, - out_chs, - stride=1, - dw_kernel_size=3, - dilation=1, - group_size=1, - pad_type='', + in_chs: int, + out_chs: int, + stride: int = 1, + dw_kernel_size: int = 3, + dilation: int = 1, + group_size: int = 1, + pad_type: str = '', num_heads: int = 8, key_dim: int = 64, value_dim: int = 64, use_multi_query: bool = False, query_strides: int = (1, 1), kv_stride: int = 1, - cpe_dw_kernel_size=3, - noskip=False, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - drop_path_rate=0., - attn_drop=0.0, - proj_drop=0.0, + cpe_dw_kernel_size: int = 3, + noskip: bool = False, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[LayerType] = None, + drop_path_rate: float = 0., + attn_drop: float = 0.0, + proj_drop: float = 0.0, layer_scale_init_value: Optional[float] = 1e-5, - use_bias=False, - use_cpe=False, + use_bias: bool = False, + use_cpe: bool = False, ): super(MobileAttention, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) @@ -512,7 +541,6 @@ class MobileAttention(nn.Module): self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() - def feature_info(self, location): if location == 'expansion': # after SE, input to PW return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) @@ -539,22 +567,23 @@ class CondConvResidual(InvertedResidual): def __init__( self, - in_chs, - out_chs, - dw_kernel_size=3, - stride=1, - dilation=1, - group_size=1, - pad_type='', - noskip=False, - exp_ratio=1.0, - exp_kernel_size=1, - pw_kernel_size=1, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - se_layer=None, - num_experts=0, - drop_path_rate=0., + in_chs: int, + out_chs: int, + dw_kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + group_size: int = 1, + pad_type: str = '', + noskip: bool = False, + exp_ratio: float = 1.0, + exp_kernel_size: int = 1, + pw_kernel_size: int = 1, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[ModuleType] = None, + num_experts: int = 0, + drop_path_rate: float = 0., ): self.num_experts = num_experts @@ -567,13 +596,14 @@ class CondConvResidual(InvertedResidual): dilation=dilation, group_size=group_size, pad_type=pad_type, - act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, pw_kernel_size=pw_kernel_size, - se_layer=se_layer, + act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer, + se_layer=se_layer, conv_kwargs=conv_kwargs, drop_path_rate=drop_path_rate, ) @@ -609,21 +639,22 @@ class EdgeResidual(nn.Module): def __init__( self, - in_chs, - out_chs, - exp_kernel_size=3, - stride=1, - dilation=1, - group_size=0, - pad_type='', - force_in_chs=0, - noskip=False, - exp_ratio=1.0, - pw_kernel_size=1, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - se_layer=None, - drop_path_rate=0., + in_chs: int, + out_chs: int, + exp_kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + group_size: int = 0, + pad_type: str = '', + force_in_chs: int = 0, + noskip: bool = False, + exp_ratio: float = 1.0, + pw_kernel_size: int = 1, + act_layer: LayerType = nn.ReLU, + norm_layer: LayerType = nn.BatchNorm2d, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[ModuleType] = None, + drop_path_rate: float = 0., ): super(EdgeResidual, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) @@ -633,13 +664,17 @@ class EdgeResidual(nn.Module): mid_chs = make_divisible(in_chs * exp_ratio) groups = num_groups(group_size, in_chs) self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation # Expansion convolution self.conv_exp = create_conv2d( in_chs, mid_chs, exp_kernel_size, - stride=stride, dilation=dilation, groups=groups, padding=pad_type) + stride=1 if use_aa else stride, + dilation=dilation, groups=groups, padding=pad_type) self.bn1 = norm_act_layer(mid_chs, inplace=True) + self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa) + # Squeeze-and-excitation self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() @@ -658,6 +693,7 @@ class EdgeResidual(nn.Module): shortcut = x x = self.conv_exp(x) x = self.bn1(x) + x = self.aa(x) x = self.se(x) x = self.conv_pwl(x) x = self.bn2(x) diff --git a/timm/models/_efficientnet_builder.py b/timm/models/_efficientnet_builder.py index 7d96216a..e9b789a4 100644 --- a/timm/models/_efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -17,7 +17,7 @@ from typing import Any, Dict, List import torch.nn as nn from ._efficientnet_blocks import * -from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible +from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible, LayerType __all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights", 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'] @@ -326,9 +326,10 @@ class EfficientNetBuilder: pad_type: str = '', round_chs_fn: Callable = round_channels, se_from_exp: bool = False, - act_layer: Optional[Callable] = None, - norm_layer: Optional[Callable] = None, - se_layer: Optional[Callable] = None, + act_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[LayerType] = None, drop_path_rate: float = 0., layer_scale_init_value: Optional[float] = None, feature_location: str = '', @@ -339,6 +340,7 @@ class EfficientNetBuilder: self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs self.act_layer = act_layer self.norm_layer = norm_layer + self.aa_layer = aa_layer self.se_layer = get_attn(se_layer) try: self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg @@ -378,6 +380,9 @@ class EfficientNetBuilder: ba['norm_layer'] = self.norm_layer ba['drop_path_rate'] = drop_path_rate + if self.aa_layer is not None: + ba['aa_layer'] = self.aa_layer + se_ratio = ba.pop('se_ratio', None) if se_ratio and self.se_layer is not None: if not self.se_from_exp: @@ -461,6 +466,7 @@ class EfficientNetBuilder: space2depth = 1 if space2depth > 0: + # FIXME s2d is a WIP if space2depth == 2 and block_args['stride'] == 2: block_args['stride'] = 1 # to end s2d region, need to correct expansion and se ratio relative to input diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index dcb0db7e..46c4e81e 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -36,7 +36,7 @@ the models and weights open source! Hacked together by / Copyright 2019, Ross Wightman """ from functools import partial -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -44,10 +44,10 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct +from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct, LayerType from ._builder import build_model_with_cfg, pretrained_cfg_for_features from ._efficientnet_blocks import SqueezeExcite -from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ +from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from ._features import FeatureInfo, FeatureHooks, feature_take_indices from ._manipulate import checkpoint_seq @@ -74,21 +74,22 @@ class EfficientNet(nn.Module): def __init__( self, - block_args, - num_classes=1000, - num_features=1280, - in_chans=3, - stem_size=32, - fix_stem=False, - output_stride=32, - pad_type='', - round_chs_fn=round_channels, - act_layer=None, - norm_layer=None, - se_layer=None, - drop_rate=0., - drop_path_rate=0., - global_pool='avg' + block_args: BlockArgs, + num_classes: int = 1000, + num_features: int = 1280, + in_chans: int = 3, + stem_size: int = 32, + fix_stem: bool = False, + output_stride: int = 32, + pad_type: str = '', + act_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[LayerType] = None, + round_chs_fn: Callable = round_channels, + drop_rate: float = 0., + drop_path_rate: float = 0., + global_pool: str = 'avg' ): super(EfficientNet, self).__init__() act_layer = act_layer or nn.ReLU @@ -113,6 +114,7 @@ class EfficientNet(nn.Module): round_chs_fn=round_chs_fn, act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, ) @@ -270,20 +272,21 @@ class EfficientNetFeatures(nn.Module): def __init__( self, - block_args, - out_indices=(0, 1, 2, 3, 4), - feature_location='bottleneck', - in_chans=3, - stem_size=32, - fix_stem=False, - output_stride=32, - pad_type='', - round_chs_fn=round_channels, - act_layer=None, - norm_layer=None, - se_layer=None, - drop_rate=0., - drop_path_rate=0. + block_args: BlockArgs, + out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + feature_location: str = 'bottleneck', + in_chans: int = 3, + stem_size: int = 32, + fix_stem: bool = False, + output_stride: int = 32, + pad_type: str = '', + act_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + aa_layer: Optional[LayerType] = None, + se_layer: Optional[LayerType] = None, + round_chs_fn: Callable = round_channels, + drop_rate: float = 0., + drop_path_rate: float = 0., ): super(EfficientNetFeatures, self).__init__() act_layer = act_layer or nn.ReLU @@ -306,6 +309,7 @@ class EfficientNetFeatures(nn.Module): round_chs_fn=round_chs_fn, act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, feature_location=feature_location, @@ -1154,6 +1158,7 @@ default_cfgs = generate_default_cfgs({ input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), 'efficientnet_b3_g8_gn.untrained': _cfg( input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), + 'efficientnet_blur_b0.untrained': _cfg(), 'efficientnet_es.ra_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth', @@ -1850,6 +1855,17 @@ def efficientnet_b3_g8_gn(pretrained=False, **kwargs) -> EfficientNet: return model +@register_model +def efficientnet_blur_b0(pretrained=False, **kwargs) -> EfficientNet: + """ EfficientNet-B0 w/ BlurPool """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_blur_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, + aa_layer='blurpc', **kwargs + ) + return model + + @register_model def efficientnet_es(pretrained=False, **kwargs) -> EfficientNet: """ EfficientNet-Edge Small. """ diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 40f201b9..b25d87ba 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -40,6 +40,7 @@ class MobileNetV3(nn.Module): * HardCoRe-NAS - https://arxiv.org/abs/2102.11646 (defn in hardcorenas.py uses this class) * FBNet-V3 - https://arxiv.org/abs/2006.02049 * LCNet - https://arxiv.org/abs/2109.15099 + * MobileNet-V4 - https://arxiv.org/abs/2404.10518 """ def __init__( @@ -52,9 +53,10 @@ class MobileNetV3(nn.Module): num_features: int = 1280, head_bias: bool = True, head_norm: bool = False, - pad_type: PadType = '', + pad_type: str = '', act_layer: Optional[LayerType] = None, norm_layer: Optional[LayerType] = None, + aa_layer: Optional[LayerType] = None, se_layer: Optional[LayerType] = None, se_from_exp: bool = True, round_chs_fn: Callable = round_channels, @@ -75,6 +77,7 @@ class MobileNetV3(nn.Module): pad_type: Type of padding to use for convolution layers. act_layer: Type of activation layer. norm_layer: Type of normalization layer. + aa_layer: Type of anti-aliasing layer. se_layer: Type of Squeeze-and-Excite layer. se_from_exp: If True, calculate SE channel reduction from expanded mid channels. round_chs_fn: Callable to round number of filters based on depth multiplier. @@ -107,6 +110,7 @@ class MobileNetV3(nn.Module): se_from_exp=se_from_exp, act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value, @@ -291,6 +295,7 @@ class MobileNetV3Features(nn.Module): se_from_exp: bool = True, act_layer: Optional[LayerType] = None, norm_layer: Optional[LayerType] = None, + aa_layer: Optional[LayerType] = None, se_layer: Optional[LayerType] = None, drop_rate: float = 0., drop_path_rate: float = 0., @@ -337,6 +342,7 @@ class MobileNetV3Features(nn.Module): se_from_exp=se_from_exp, act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value, @@ -649,15 +655,17 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: Args: channel_multiplier: multiplier to number of channels per layer. """ + num_features = 1280 if 'hybrid' in variant: layer_scale_init_value = 1e-5 if 'medium' in variant: stem_size = 32 - num_features = 1280 act_layer = resolve_act_layer(kwargs, 'relu') arch_def = [ # stage 0, 112x112 in - ['er_r1_k3_s2_e4_c48'], # FusedIB (EdgeResidual) + [ + 'er_r1_k3_s2_e4_c48' # FusedIB (EdgeResidual) + ], # stage 1, 56x56 in [ 'uir_r1_a3_k5_s2_e4_c80', # ExtraDW @@ -689,23 +697,26 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: 'uir_r1_a0_k0_s1_e4_c256', # FFN 'mqa_r1_k3_h4_s1_d64_c256', # MQA 'uir_r1_a3_k0_s1_e4_c256', # ConvNeXt - 'mqa_r1_k3_h4_s1_d64_c256', # MQA + 'mqa_r1_k3_h4_s1_d64_c256', # MQA 'uir_r1_a5_k5_s1_e4_c256', # ExtraDW - 'mqa_r1_k3_h4_s1_d64_c256', # MQA + 'mqa_r1_k3_h4_s1_d64_c256', # MQA 'uir_r1_a5_k0_s1_e4_c256', # ConvNeXt 'mqa_r1_k3_h4_s1_d64_c256', # MQA 'uir_r1_a5_k0_s1_e4_c256', # ConvNeXt ], # stage 4, 7x7 in - ['cn_r1_k1_s1_c960'], # Conv + [ + 'cn_r1_k1_s1_c960' # Conv + ], ] elif 'large' in variant: stem_size = 24 - num_features = 1280 act_layer = resolve_act_layer(kwargs, 'gelu') arch_def = [ # stage 0, 112x112 in - ['er_r1_k3_s2_e4_c48'], # FusedIB (EdgeResidual) + [ + 'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual) + ], # stage 1, 56x56 in [ 'uir_r1_a3_k5_s2_e4_c96', # ExtraDW @@ -734,17 +745,19 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: 'uir_r2_a5_k0_s1_e4_c512', # ConvNeXt 'uir_r1_a5_k3_s1_e4_c512', # ExtraDW 'uir_r1_a5_k5_s1_e4_c512', # ExtraDW - 'mqa_r1_k3_h8_s1_d64_c512', # MQA + 'mqa_r1_k3_h8_s1_d64_c512', # MQA 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt - 'mqa_r1_k3_h8_s1_d64_c512', # MQA + 'mqa_r1_k3_h8_s1_d64_c512', # MQA 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt - 'mqa_r1_k3_h8_s1_d64_c512', # MQA + 'mqa_r1_k3_h8_s1_d64_c512', # MQA 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt - 'mqa_r1_k3_h8_s1_d64_c512', # MQA + 'mqa_r1_k3_h8_s1_d64_c512', # MQA 'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt ], # stage 4, 7x7 in - ['cn_r1_k1_s1_c960'], + [ + 'cn_r1_k1_s1_c960', # Conv + ], ] else: assert False, f'Unknown variant {variant}.' @@ -752,7 +765,6 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: layer_scale_init_value = None if 'small' in variant: stem_size = 32 - num_features = 1280 act_layer = resolve_act_layer(kwargs, 'relu') arch_def = [ # stage 0, 112x112 in @@ -780,15 +792,18 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: 'uir_r2_a0_k3_s1_e4_c128', # IR ], # stage 4, 7x7 in - ['cn_r1_k1_s1_c960'], # Conv + [ + 'cn_r1_k1_s1_c960', # Conv + ], ] elif 'medium' in variant: stem_size = 32 - num_features = 1280 act_layer = resolve_act_layer(kwargs, 'relu') arch_def = [ # stage 0, 112x112 in - ['er_r1_k3_s2_e4_c48'], # FusedIB (EdgeResidual) + [ + 'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual) + ], # stage 1, 56x56 in [ 'uir_r1_a3_k5_s2_e4_c80', # ExtraDW @@ -817,15 +832,18 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: 'uir_r1_a5_k0_s1_e2_c256', # ConvNeXt ], # stage 4, 7x7 in - ['cn_r1_k1_s1_c960'], # Conv + [ + 'cn_r1_k1_s1_c960', # Conv + ], ] elif 'large' in variant: stem_size = 24 - num_features = 1280 act_layer = resolve_act_layer(kwargs, 'relu') arch_def = [ # stage 0, 112x112 in - ['er_r1_k3_s2_e4_c48'], # FusedIB (EdgeResidual) + [ + 'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual) + ], # stage 1, 56x56 in [ 'uir_r1_a3_k5_s2_e4_c96', # ExtraDW @@ -851,24 +869,23 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: ], # stage 4, 7x7 in - ['cn_r1_k1_s1_c960'], # Conv + [ + 'cn_r1_k1_s1_c960', # Conv + ], ] else: assert False, f'Unknown variant {variant}.' - # NOTE SE not used in initial MobileNet-v4 definitions - se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) model_kwargs = dict( block_args=decode_arch_def(arch_def), head_bias=False, head_norm=True, num_features=num_features, stem_size=stem_size, - fix_stem=channel_multiplier < 0.75, + fix_stem=channel_multiplier < 1.0, round_chs_fn=partial(round_channels, multiplier=channel_multiplier), norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=act_layer, - se_layer=se_layer, layer_scale_init_value=layer_scale_init_value, **kwargs, ) @@ -904,9 +921,6 @@ default_cfgs = generate_default_cfgs({ origin_url='https://github.com/Alibaba-MIIL/ImageNet21K', paper_ids='arXiv:2104.10972v4', interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221), - 'mobilenetv3_large_150.untrained': _cfg( - interpolation='bicubic'), - 'mobilenetv3_small_050.lamb_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth', @@ -985,28 +999,48 @@ default_cfgs = generate_default_cfgs({ 'mobilenetv4_conv_small': _cfg( # hf_hub_id='timm/', interpolation='bicubic'), - 'mobilenetv4_conv_medium': _cfg( - #hf_hub_id='timm/', - interpolation='bicubic'), - 'mobilenetv4_conv_large': _cfg( + 'mobilenetv4_conv_medium.r224': _cfg( # hf_hub_id='timm/', - interpolation='bicubic'), + crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_conv_medium.r256': _cfg( + # hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_conv_large.r256': _cfg( + # hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_conv_large.r384': _cfg( + # hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=0.95, interpolation='bicubic'), 'mobilenetv4_hybrid_small': _cfg( # hf_hub_id='timm/', interpolation='bicubic'), - 'mobilenetv4_hybrid_medium': _cfg( + 'mobilenetv4_hybrid_medium.r224': _cfg( # hf_hub_id='timm/', - interpolation='bicubic'), - 'mobilenetv4_hybrid_large': _cfg( + crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_hybrid_medium.r256': _cfg( # hf_hub_id='timm/', - interpolation='bicubic'), + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_hybrid_large.r256': _cfg( + # hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_hybrid_large.r384': _cfg( + # hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=0.95, interpolation='bicubic'), + + # experimental + 'mobilenetv4_conv_aa_medium.r256': _cfg( + # hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_conv_blur_medium.r256': _cfg( + # hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'), 'mobilenetv4_hybrid_medium_075': _cfg( # hf_hub_id='timm/', - interpolation='bicubic'), - 'mobilenetv4_hybrid_medium_150': _cfg( + crop_pct=0.95, interpolation='bicubic'), + 'mobilenetv4_hybrid_large_075.r256': _cfg( # hf_hub_id='timm/', - interpolation='bicubic'), + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'), }) @@ -1024,13 +1058,6 @@ def mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3: return model -@register_model -def mobilenetv3_large_150(pretrained: bool = False, **kwargs) -> MobileNetV3: - """ MobileNet V3 """ - model = _gen_mobilenet_v3('mobilenetv3_large_150', 1.5, pretrained=pretrained, **kwargs) - return model - - @register_model def mobilenetv3_small_050(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ @@ -1191,13 +1218,6 @@ def mobilenetv4_conv_large(pretrained: bool = False, **kwargs) -> MobileNetV3: return model -@register_model -def mobilenetv4_hybrid_medium_075(pretrained: bool = False, **kwargs) -> MobileNetV3: - """ MobileNet V4 Hybrid """ - model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium_075', 0.75, pretrained=pretrained, **kwargs) - return model - - @register_model def mobilenetv4_hybrid_medium(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V4 Hybrid """ @@ -1205,13 +1225,6 @@ def mobilenetv4_hybrid_medium(pretrained: bool = False, **kwargs) -> MobileNetV3 return model -@register_model -def mobilenetv4_hybrid_medium_150(pretrained: bool = False, **kwargs) -> MobileNetV3: - """ MobileNet V4 Hybrid """ - model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium_150', 1.5, pretrained=pretrained, **kwargs) - return model - - @register_model def mobilenetv4_hybrid_large(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V4 Hybrid""" @@ -1219,6 +1232,33 @@ def mobilenetv4_hybrid_large(pretrained: bool = False, **kwargs) -> MobileNetV3: return model +@register_model +def mobilenetv4_conv_aa_medium(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 w/ AvgPool AA """ + model = _gen_mobilenet_v4('mobilenetv4_conv_aa_medium', 1.0, pretrained=pretrained, aa_layer='avg', **kwargs) + return model + + +@register_model +def mobilenetv4_conv_blur_medium(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 Conv w/ Blur AA """ + model = _gen_mobilenet_v4('mobilenetv4_conv_blur_medium', 1.0, pretrained=pretrained, aa_layer='blurpc', **kwargs) + return model + + +@register_model +def mobilenetv4_hybrid_medium_075(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 Hybrid """ + model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv4_hybrid_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3: + """ MobileNet V4 Hybrid""" + model = _gen_mobilenet_v4('mobilenetv4_hybrid_large', 0.75, pretrained=pretrained, **kwargs) + return model + register_model_deprecations(__name__, { 'mobilenetv3_large_100_miil': 'mobilenetv3_large_100.miil_in21k_ft_in1k', diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 53dfab9c..15f16997 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \ - get_attn, get_act_layer, get_norm_layer, create_classifier + get_attn, get_act_layer, get_norm_layer, create_classifier, create_aa from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint_seq @@ -31,15 +31,6 @@ def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int: return padding -def create_aa(aa_layer: Type[nn.Module], channels: int, stride: int = 2, enable: bool = True) -> nn.Module: - if not aa_layer or not enable: - return nn.Identity() - if issubclass(aa_layer, nn.AvgPool2d): - return aa_layer(stride) - else: - return aa_layer(channels=channels, stride=stride) - - class BasicBlock(nn.Module): expansion = 1 From a503639bcce2f0d3e8c6a7459cabb7dd6aafa4c2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 30 May 2024 10:17:09 -0700 Subject: [PATCH 16/27] Add mobileclip fastvit model defs, support extra SE. Add forward_intermediates API to fastvit --- timm/models/fastvit.py | 405 +++++++++++++++++++++++++++++++---------- 1 file changed, 307 insertions(+), 98 deletions(-) diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index 74b6cc28..7c918887 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -7,7 +7,7 @@ # import os from functools import partial -from typing import Tuple, Optional, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -16,6 +16,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \ ClassifierHead from ._builder import build_model_with_cfg +from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -40,19 +41,19 @@ class MobileOneBlock(nn.Module): """ def __init__( - self, - in_chs: int, - out_chs: int, - kernel_size: int, - stride: int = 1, - dilation: int = 1, - group_size: int = 0, - inference_mode: bool = False, - use_se: bool = False, - use_act: bool = True, - use_scale_branch: bool = True, - num_conv_branches: int = 1, - act_layer: nn.Module = nn.GELU, + self, + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + group_size: int = 0, + inference_mode: bool = False, + use_se: bool = False, + use_act: bool = True, + use_scale_branch: bool = True, + num_conv_branches: int = 1, + act_layer: nn.Module = nn.GELU, ) -> None: """Construct a MobileOneBlock module. @@ -280,15 +281,16 @@ class ReparamLargeKernelConv(nn.Module): """ def __init__( - self, - in_chs: int, - out_chs: int, - kernel_size: int, - stride: int, - group_size: int, - small_kernel: Optional[int] = None, - inference_mode: bool = False, - act_layer: Optional[nn.Module] = None, + self, + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int, + group_size: int, + small_kernel: Optional[int] = None, + use_se: bool = False, + act_layer: Optional[nn.Module] = None, + inference_mode: bool = False, ) -> None: """Construct a ReparamLargeKernelConv module. @@ -299,8 +301,8 @@ class ReparamLargeKernelConv(nn.Module): stride: Stride size. Default: 1 group_size: Group size. Default: 1 small_kernel: Kernel size of small kernel conv branch. - inference_mode: If True, instantiates model in inference mode. Default: ``False`` act_layer: Activation module. Default: ``nn.GELU`` + inference_mode: If True, instantiates model in inference mode. Default: ``False`` """ super(ReparamLargeKernelConv, self).__init__() self.stride = stride @@ -342,6 +344,7 @@ class ReparamLargeKernelConv(nn.Module): groups=self.groups, apply_act=False, ) + self.se = SqueezeExcite(out_chs, rd_ratio=0.25) if use_se else nn.Identity() # FIXME output of this act was not used in original impl, likely due to bug self.act = act_layer() if act_layer is not None else nn.Identity() @@ -352,6 +355,7 @@ class ReparamLargeKernelConv(nn.Module): out = self.large_conv(x) if self.small_conv is not None: out = out + self.small_conv(x) + out = self.se(out) out = self.act(out) return out @@ -472,12 +476,12 @@ class Attention(nn.Module): fused_attn: torch.jit.Final[bool] def __init__( - self, - dim: int, - head_dim: int = 32, - qkv_bias: bool = False, - attn_drop: float = 0.0, - proj_drop: float = 0.0, + self, + dim: int, + head_dim: int = 32, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, ) -> None: """Build MHSA module that can handle 3D or 4D input tensors. @@ -535,14 +539,15 @@ class PatchEmbed(nn.Module): """Convolutional patch embedding layer.""" def __init__( - self, - patch_size: int, - stride: int, - in_chs: int, - embed_dim: int, - act_layer: nn.Module = nn.GELU, - lkc_use_act: bool = False, - inference_mode: bool = False, + self, + patch_size: int, + stride: int, + in_chs: int, + embed_dim: int, + act_layer: nn.Module = nn.GELU, + lkc_use_act: bool = False, + use_se: bool = False, + inference_mode: bool = False, ) -> None: """Build patch embedding layer. @@ -562,14 +567,16 @@ class PatchEmbed(nn.Module): stride=stride, group_size=1, small_kernel=3, - inference_mode=inference_mode, + use_se=use_se, act_layer=act_layer if lkc_use_act else None, # NOTE original weights didn't use this act + inference_mode=inference_mode, ), MobileOneBlock( in_chs=embed_dim, out_chs=embed_dim, kernel_size=1, stride=1, + use_se=False, act_layer=act_layer, inference_mode=inference_mode, ) @@ -598,11 +605,11 @@ class RepMixer(nn.Module): """ def __init__( - self, - dim, - kernel_size=3, - layer_scale_init_value=1e-5, - inference_mode: bool = False, + self, + dim, + kernel_size=3, + layer_scale_init_value=1e-5, + inference_mode: bool = False, ): """Build RepMixer Module. @@ -648,7 +655,7 @@ class RepMixer(nn.Module): if layer_scale_init_value is not None: self.layer_scale = LayerScale2d(dim, layer_scale_init_value) else: - self.layer_scale = nn.Identity + self.layer_scale = nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: if self.reparam_conv is not None: @@ -706,12 +713,12 @@ class ConvMlp(nn.Module): """Convolutional FFN Module.""" def __init__( - self, - in_chs: int, - hidden_channels: Optional[int] = None, - out_chs: Optional[int] = None, - act_layer: nn.Module = nn.GELU, - drop: float = 0.0, + self, + in_chs: int, + hidden_channels: Optional[int] = None, + out_chs: Optional[int] = None, + act_layer: nn.Module = nn.GELU, + drop: float = 0.0, ) -> None: """Build convolutional FFN module. @@ -764,11 +771,11 @@ class RepConditionalPosEnc(nn.Module): """ def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - spatial_shape: Union[int, Tuple[int, int]] = (7, 7), - inference_mode=False, + self, + dim: int, + dim_out: Optional[int] = None, + spatial_shape: Union[int, Tuple[int, int]] = (7, 7), + inference_mode=False, ) -> None: """Build reparameterizable conditional positional encoding @@ -878,15 +885,15 @@ class RepMixerBlock(nn.Module): """ def __init__( - self, - dim: int, - kernel_size: int = 3, - mlp_ratio: float = 4.0, - act_layer: nn.Module = nn.GELU, - proj_drop: float = 0.0, - drop_path: float = 0.0, - layer_scale_init_value: float = 1e-5, - inference_mode: bool = False, + self, + dim: int, + kernel_size: int = 3, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + proj_drop: float = 0.0, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-5, + inference_mode: bool = False, ): """Build RepMixer Block. @@ -936,14 +943,14 @@ class AttentionBlock(nn.Module): """ def __init__( - self, - dim: int, - mlp_ratio: float = 4.0, - act_layer: nn.Module = nn.GELU, - norm_layer: nn.Module = nn.BatchNorm2d, - proj_drop: float = 0.0, - drop_path: float = 0.0, - layer_scale_init_value: float = 1e-5, + self, + dim: int, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.BatchNorm2d, + proj_drop: float = 0.0, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-5, ): """Build Attention Block. @@ -993,6 +1000,7 @@ class FastVitStage(nn.Module): depth: int, token_mixer_type: str, downsample: bool = True, + se_downsample: bool = False, down_patch_size: int = 7, down_stride: int = 2, pos_emb_layer: Optional[nn.Module] = None, @@ -1030,6 +1038,7 @@ class FastVitStage(nn.Module): stride=down_stride, in_chs=dim, embed_dim=dim_out, + use_se=se_downsample, act_layer=act_layer, lkc_use_act=lkc_use_act, inference_mode=inference_mode, @@ -1090,29 +1099,30 @@ class FastVit(nn.Module): """ def __init__( - self, - in_chans: int = 3, - layers: Tuple[int, ...] = (2, 2, 6, 2), - token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"), - embed_dims: Tuple[int, ...] = (64, 128, 256, 512), - mlp_ratios: Tuple[float, ...] = (4,) * 4, - downsamples: Tuple[bool, ...] = (False, True, True, True), - repmixer_kernel_size: int = 3, - num_classes: int = 1000, - pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4, - down_patch_size: int = 7, - down_stride: int = 2, - drop_rate: float = 0.0, - proj_drop_rate: float = 0.0, - drop_path_rate: float = 0.0, - layer_scale_init_value: float = 1e-5, - fork_feat: bool = False, - cls_ratio: float = 2.0, - global_pool: str = 'avg', - norm_layer: nn.Module = nn.BatchNorm2d, - act_layer: nn.Module = nn.GELU, - lkc_use_act: bool = False, - inference_mode: bool = False, + self, + in_chans: int = 3, + layers: Tuple[int, ...] = (2, 2, 6, 2), + token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"), + embed_dims: Tuple[int, ...] = (64, 128, 256, 512), + mlp_ratios: Tuple[float, ...] = (4,) * 4, + downsamples: Tuple[bool, ...] = (False, True, True, True), + se_downsamples: Tuple[bool, ...] = (False, False, False, False), + repmixer_kernel_size: int = 3, + num_classes: int = 1000, + pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4, + down_patch_size: int = 7, + down_stride: int = 2, + drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + layer_scale_init_value: float = 1e-5, + lkc_use_act: bool = False, + fork_feat: bool = False, + cls_ratio: float = 2.0, + global_pool: str = 'avg', + norm_layer: nn.Module = nn.BatchNorm2d, + act_layer: nn.Module = nn.GELU, + inference_mode: bool = False, ) -> None: super().__init__() self.num_classes = 0 if fork_feat else num_classes @@ -1140,6 +1150,7 @@ class FastVit(nn.Module): dim_out=embed_dims[i], depth=layers[i], downsample=downsample, + se_downsample=se_downsamples[i], down_patch_size=down_patch_size, down_stride=down_stride, pos_emb_layer=pos_embs[i], @@ -1160,6 +1171,7 @@ class FastVit(nn.Module): scale *= 2 self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) + self.num_stages = len(self.stages) self.num_features = prev_dim # For segmentation and detection, extract intermediate output @@ -1236,6 +1248,66 @@ class FastVit(nn.Module): self.num_classes = num_classes self.head.reset(num_classes, global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int], Tuple[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages), indices) + + # forward pass + x = self.stem(x) + last_idx = self.num_stages - 1 + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index + 1] + feat_idx = 0 + for feat_idx, stage in enumerate(stages): + x = stage(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.final_conv(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x: torch.Tensor) -> torch.Tensor: # input embedding x = self.stem(x) @@ -1297,8 +1369,7 @@ default_cfgs = generate_default_cfgs({ "fastvit_ma36.apple_in1k": _cfg( hf_hub_id='timm/', - crop_pct=0.95 - ), + crop_pct=0.95), "fastvit_t8.apple_dist_in1k": _cfg( hf_hub_id='timm/'), @@ -1318,15 +1389,111 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', crop_pct=0.95 ), + + "fastvit_mci0.apple_mclip": _cfg( + #hf_hub_id='timm/', + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s0.pt', + crop_pct=0.95, + num_classes=512, # CLIP proj dim + mean=(0., 0., 0.), std=(1., 1., 1.) + ), + "fastvit_mci1.apple_mclip": _cfg( + # hf_hub_id='timm/', + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt', + crop_pct=0.95, + num_classes=512, # CLIP proj dim + mean=(0., 0., 0.), std=(1., 1., 1.) + ), + "fastvit_mci2.apple_mclip": _cfg( + # hf_hub_id='timm/', + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt', + crop_pct=0.95, + num_classes=512, # CLIP proj dim + mean=(0., 0., 0.), std=(1., 1., 1.) + ), }) +def _checkpoint_filter_fn(state_dict, model): + """ Remap original checkpoints -> timm """ + if 'stem.0.conv_kxk.0.conv.weight' in state_dict: + return state_dict # non-original checkpoint, no remapping needed + + state_dict = state_dict.get('state_dict', state_dict) + if 'image_encoder.model.head.proj' in state_dict: + # remap MobileCLIP checkpoints + prefix = 'image_encoder.model.' + else: + prefix = '' + + import re + import bisect + + # find stage ends by locating downsample layers + stage_ends = [] + for k, v in state_dict.items(): + match = re.match(r'^(.*?)network\.(\d+)\.proj.*', k) + if match: + stage_ends.append(int(match.group(2))) + stage_ends = list(sorted(set(stage_ends))) + + out_dict = {} + for k, v in state_dict.items(): + if prefix: + if prefix not in k: + continue + k = k.replace(prefix, '') + + # remap renamed layers + k = k.replace('patch_embed', 'stem') + k = k.replace('rbr_conv', 'conv_kxk') + k = k.replace('rbr_scale', 'conv_scale') + k = k.replace('rbr_skip', 'identity') + k = k.replace('conv_exp', 'final_conv') # to match byobnet, regnet, nfnet + k = k.replace('lkb_origin', 'large_conv') + k = k.replace('convffn', 'mlp') + k = k.replace('se.reduce', 'se.fc1') + k = k.replace('se.expand', 'se.fc2') + k = re.sub(r'layer_scale_([0-9])', r'layer_scale_\1.gamma', k) + if k.endswith('layer_scale'): + k = k.replace('layer_scale', 'layer_scale.gamma') + k = k.replace('dist_head', 'head_dist') + if k.startswith('head.'): + if k == 'head.proj' and hasattr(model.head, 'fc') and isinstance(model.head.fc, nn.Linear): + # if CLIP projection, map to head.fc w/ bias = zeros + k = k.replace('head.proj', 'head.fc.weight') + v = v.T + out_dict['head.fc.bias'] = torch.zeros(v.shape[0]) + else: + k = k.replace('head.', 'head.fc.') + + # remap flat sequential network to stages + match = re.match(r'^network\.(\d+)', k) + stage_idx, net_idx = None, None + if match: + net_idx = int(match.group(1)) + stage_idx = bisect.bisect_right(stage_ends, net_idx) + if stage_idx is not None: + net_prefix = f'network.{net_idx}' + stage_prefix = f'stages.{stage_idx}' + if net_prefix + '.proj' in k: + k = k.replace(net_prefix + '.proj', stage_prefix + '.downsample.proj') + elif net_prefix + '.pe' in k: + k = k.replace(net_prefix + '.pe', stage_prefix + '.pos_emb.pos_enc') + else: + k = k.replace(net_prefix, stage_prefix + '.blocks') + + out_dict[k] = v + return out_dict + + def _create_fastvit(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) model = build_model_with_cfg( FastVit, variant, pretrained, + pretrained_filter_fn=_checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs ) @@ -1419,3 +1586,45 @@ def fastvit_ma36(pretrained=False, **kwargs): token_mixers=("repmixer", "repmixer", "repmixer", "attention") ) return _create_fastvit('fastvit_ma36', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_mci0(pretrained=False, **kwargs): + """Instantiate MCi0 model variant.""" + model_args = dict( + layers=(2, 6, 10, 2), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(3, 3, 3, 3), + se_downsamples=(False, False, True, True), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + ) + return _create_fastvit('fastvit_mci0', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_mci1(pretrained=False, **kwargs): + """Instantiate MCi1 model variant.""" + model_args = dict( + layers=(4, 12, 20, 4), + embed_dims=(64, 128, 256, 512), + mlp_ratios=(3, 3, 3, 3), + se_downsamples=(False, False, True, True), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + ) + return _create_fastvit('fastvit_mci1', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def fastvit_mci2(pretrained=False, **kwargs): + """Instantiate MCi2 model variant.""" + model_args = dict( + layers=(4, 12, 24, 4), + embed_dims=(80, 160, 320, 640), + mlp_ratios=(3, 3, 3, 3), + se_downsamples=(False, False, True, True), + pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), + token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + ) + return _create_fastvit('fastvit_mci2', pretrained=pretrained, **dict(model_args, **kwargs)) \ No newline at end of file From ce637771dc2abc5ae81a96af14724300815edd15 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 30 May 2024 10:18:24 -0700 Subject: [PATCH 17/27] Add fastvit to forward_intermediates test --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 652ea355..7d9c041e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -52,7 +52,7 @@ FEAT_INTER_FILTERS = [ 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos', 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet', - 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', + 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. From 7f96538052eb411003c2c38a7013353aa31ffc62 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 31 May 2024 11:59:51 -0700 Subject: [PATCH 18/27] Add missing lkc act for mobileclip fastvits --- timm/models/fastvit.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index 7c918887..66142105 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -1420,7 +1420,7 @@ def _checkpoint_filter_fn(state_dict, model): return state_dict # non-original checkpoint, no remapping needed state_dict = state_dict.get('state_dict', state_dict) - if 'image_encoder.model.head.proj' in state_dict: + if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict: # remap MobileCLIP checkpoints prefix = 'image_encoder.model.' else: @@ -1598,6 +1598,7 @@ def fastvit_mci0(pretrained=False, **kwargs): se_downsamples=(False, False, True, True), pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + lkc_use_act=True, ) return _create_fastvit('fastvit_mci0', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -1612,6 +1613,7 @@ def fastvit_mci1(pretrained=False, **kwargs): se_downsamples=(False, False, True, True), pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + lkc_use_act=True, ) return _create_fastvit('fastvit_mci1', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -1626,5 +1628,6 @@ def fastvit_mci2(pretrained=False, **kwargs): se_downsamples=(False, False, True, True), pos_embs=(None, None, None, partial(RepConditionalPosEnc, spatial_shape=(7, 7))), token_mixers=("repmixer", "repmixer", "repmixer", "attention"), + lkc_use_act=True, ) - return _create_fastvit('fastvit_mci2', pretrained=pretrained, **dict(model_args, **kwargs)) \ No newline at end of file + return _create_fastvit('fastvit_mci2', pretrained=pretrained, **dict(model_args, **kwargs)) From 1b66ec7cf3cbd14390a2bc4f78904ba04608614d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 3 Jun 2024 17:14:03 -0700 Subject: [PATCH 19/27] Fixup ViTamin, add hub weight reference --- timm/models/vision_transformer.py | 22 +- timm/models/vision_transformer_hybrid.py | 44 ++- timm/models/vitamin.py | 343 +++++++++++------------ 3 files changed, 219 insertions(+), 190 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 2dd4754a..f25db6b5 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -409,6 +409,7 @@ class VisionTransformer(nn.Module): qk_norm: bool = False, init_values: Optional[float] = None, class_token: bool = True, + pos_embed: str = 'learn', no_embed_class: bool = False, reg_tokens: int = 0, pre_norm: bool = False, @@ -460,6 +461,7 @@ class VisionTransformer(nn.Module): super().__init__() assert global_pool in ('', 'avg', 'token', 'map') assert class_token or global_pool != 'token' + assert pos_embed in ('', 'none', 'learn') use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) act_layer = get_act_layer(act_layer) or nn.GELU @@ -494,7 +496,10 @@ class VisionTransformer(nn.Module): self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens - self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) + if not pos_embed or pos_embed == 'none': + self.pos_embed = None + else: + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_drop = nn.Dropout(p=pos_drop_rate) if patch_drop_rate > 0: self.patch_drop = PatchDropout( @@ -556,7 +561,8 @@ class VisionTransformer(nn.Module): def init_weights(self, mode: str = '') -> None: assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. - trunc_normal_(self.pos_embed, std=.02) + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) if self.cls_token is not None: nn.init.normal_(self.cls_token, std=1e-6) named_apply(get_init_weights_vit(mode, head_bias), self) @@ -583,6 +589,8 @@ class VisionTransformer(nn.Module): @torch.jit.ignore def set_grad_checkpointing(self, enable: bool = True) -> None: self.grad_checkpointing = enable + if hasattr(self.patch_embed, 'set_grad_checkpointing'): + self.patch_embed.set_grad_checkpointing(enable) @torch.jit.ignore def get_classifier(self) -> nn.Module: @@ -600,6 +608,9 @@ class VisionTransformer(nn.Module): self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + if self.pos_embed is None: + return x + if self.dynamic_img_size: B, H, W, C = x.shape pos_embed = resample_abs_pos_embed( @@ -1066,10 +1077,13 @@ def checkpoint_filter_fn( # IJEPA, vit in an 'encoder' submodule state_dict = state_dict['encoder'] prefix = 'module.' - elif 'visual.trunk.pos_embed' in state_dict: + elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict: # OpenCLIP model with timm vision encoder - # FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) prefix = 'visual.trunk.' + if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear): + # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) + out_dict['head.weight'] = state_dict['visual.head.proj.weight'] + out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0]) if prefix: # filter on & remove prefix string from keys diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 25dd9c27..c2dd1e59 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -38,14 +38,15 @@ class HybridEmbed(nn.Module): def __init__( self, - backbone, - img_size=224, - patch_size=1, - feature_size=None, - feature_ratio=None, - in_chans=3, - embed_dim=768, - bias=True, + backbone: nn.Module, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 1, + feature_size: Optional[Union[int, Tuple[int, int]]] = None, + feature_ratio: Optional[Union[int, Tuple[int, int]]] = None, + in_chans: int = 3, + embed_dim: int = 768, + bias: bool = True, + proj: bool = True, flatten: bool = True, output_fmt: Optional[str] = None, strict_img_size: bool = True, @@ -95,7 +96,18 @@ class HybridEmbed(nn.Module): self.strict_img_size = strict_img_size self.dynamic_img_pad = dynamic_img_pad - self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + if proj: + self.proj = nn.Conv2d( + feature_dim, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + ) + else: + assert feature_dim == embed_dim,\ + f'The feature dim ({feature_dim} must match embed dim ({embed_dim}) when projection disabled.' + self.proj = nn.Identity() def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]: total_reduction = ( @@ -116,6 +128,13 @@ class HybridEmbed(nn.Module): else: return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1] + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + if hasattr(self.backbone, 'set_grad_checkpointing'): + self.backbone.set_grad_checkpointing(enable=enable) + elif hasattr(self.backbone, 'grad_checkpointing'): + self.backbone.grad_checkpointing = enable + def forward(self, x): x = self.backbone(x) if isinstance(x, (list, tuple)): @@ -157,6 +176,13 @@ class HybridEmbedWithSize(nn.Module): bias=bias, ) + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + if hasattr(self.backbone, 'set_grad_checkpointing'): + self.backbone.set_grad_checkpointing(enable=enable) + elif hasattr(self.backbone, 'grad_checkpointing'): + self.backbone.grad_checkpointing = enable + def forward(self, x) -> Tuple[torch.Tensor, List[int]]: x = self.backbone(x) if isinstance(x, (list, tuple)): diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index f84a59d6..71d3b674 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -19,29 +19,22 @@ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py """ +import math +from dataclasses import dataclass, field from functools import partial -from typing import List, Tuple -from dataclasses import dataclass, replace, field -from typing import Callable, Optional, Union, Tuple, List, Sequence -import math, time -from torch.jit import Final +from typing import Optional, Union, Tuple + import torch import torch.nn as nn -import torch.nn.functional as F -import timm -from torch.utils.checkpoint import checkpoint -from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_ - -from timm.layers import to_2tuple -from timm.layers import DropPath -from timm.layers.norm_act import _create_act - -from timm.models._manipulate import named_apply, checkpoint_seq -from timm.models._builder import build_model_with_cfg -from timm.models._registry import register_model -from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn -from timm.models.vision_transformer_hybrid import HybridEmbed +from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from timm.layers import create_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, \ + make_divisible, DropPath +from ._builder import build_model_with_cfg +from ._manipulate import named_apply, checkpoint_seq +from ._registry import register_model, generate_default_cfgs +from .vision_transformer import VisionTransformer, checkpoint_filter_fn +from .vision_transformer_hybrid import HybridEmbed @dataclass @@ -90,24 +83,19 @@ class Stem(nn.Module): bias: bool = True, ): super().__init__() - self.grad_checkpointing=False norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) self.out_chs = out_chs + self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias) self.norm1 = norm_act_layer(out_chs) self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias) + named_apply(_init_conv, self) def forward(self, x): - if self.grad_checkpointing: - x = checkpoint(self.conv1, x) - x = self.norm1(x) - x = checkpoint(self.conv2, x) - else: - x = self.conv1(x) - x = self.norm1(x) - x = self.conv2(x) - + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) return x @@ -145,8 +133,9 @@ class StridedConv(nn.Module): embed_dim=768 ): super().__init__() - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6) + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) self.norm = norm_layer(in_chans) # affine over C def forward(self, x): @@ -185,10 +174,10 @@ class MbConvLNBlock(nn.Module): self.pre_norm = prenorm_act_layer(in_chs, apply_act=False) self.down = nn.Identity() self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True) - self.act1 = _create_act(act_layer, inplace=True) - self.act2 = _create_act(act_layer, inplace=True) - - self.conv2_kxk = create_conv2d(mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True) + self.act1 = create_act_layer(act_layer, inplace=True) + self.conv2_kxk = create_conv2d( + mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True) + self.act2 = create_act_layer(act_layer, inplace=True) self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -228,58 +217,57 @@ class MbConvStages(nn.Module): ): super().__init__() self.grad_checkpointing = False + self.stem = Stem( in_chs=in_chans, out_chs=cfg.stem_width, ) + stages = [] self.num_stages = len(cfg.embed_dim) for s, dim in enumerate(cfg.embed_dim[:2]): # stage - blocks = [] stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width - for d in range(cfg.depths[s]): - blocks += [MbConvLNBlock( - in_chs = stage_in_chs if d==0 else dim, - out_chs = dim, - stride = 2 if d == 0 else 1, - # cfg = cfg.conv_cfg, - )] - blocks = nn.Sequential(*blocks) - stages += [blocks] + blocks = [ + MbConvLNBlock( + in_chs = stage_in_chs if d==0 else dim, + out_chs = dim, + stride = 2 if d == 0 else 1, + ) + for d in range(cfg.depths[s]) + ] + stages += [nn.Sequential(*blocks)] + self.stages = nn.Sequential(*stages) - self.stages = nn.ModuleList(stages) self.pool = StridedConv( - stride=2, - in_chans=cfg.embed_dim[1], - embed_dim=cfg.embed_dim[2] - ) + stride=2, + in_chans=cfg.embed_dim[1], + embed_dim=cfg.embed_dim[2] + ) def forward(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - for stage in self.stages: - x = checkpoint_seq(stage, x) - x = checkpoint(self.pool, x) + x = checkpoint_seq(self.stages, x) else: - for stage in self.stages: - x = stage(x) - x = self.pool(x) - + x = self.stages(x) + x = self.pool(x) return x + class GeGluMlp(nn.Module): def __init__( self, in_features, hidden_features, - act_layer = None, + act_layer = 'gelu', drop = 0.0, ): super().__init__() norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6) + self.norm = norm_layer(in_features) - self.act = nn.GELU() self.w0 = nn.Linear(in_features, hidden_features) + self.act = create_act_layer(act_layer) self.w1 = nn.Linear(in_features, hidden_features) self.w2 = nn.Linear(hidden_features, in_features) @@ -290,118 +278,116 @@ class GeGluMlp(nn.Module): return x -class HybridEmbed(nn.Module): - """ CNN Feature Map Embedding - Extract feature map from CNN, flatten, project to embedding dim. - """ - def __init__( - self, - backbone, - img_size=224, - patch_size=1, - feature_size=None, - in_chans=3, - embed_dim=1024, - bias=True, - dynamic_img_pad=False, - ): - super().__init__() - assert isinstance(backbone, nn.Module) - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.backbone = backbone - with torch.no_grad(): - training = backbone.training - if training: - backbone.eval() - o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) - if isinstance(o, (list, tuple)): - o = o[-1] # last feature if backbone outputs list/tuple of features - feature_size = o.shape[-2:] - feature_dim = o.shape[1] - backbone.train(training) - - assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 - self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - self.proj = nn.Identity() - - def forward(self, x): - x = self.backbone(x) - if isinstance(x, (list, tuple)): - x = x[-1] # last feature if backbone outputs list/tuple of features - x = self.proj(x) - x = x.flatten(2).transpose(1, 2) - return x - - -def _create_vision_transformer(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - - if 'flexi' in variant: - # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed - # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. - _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False) - else: - _filter_fn = checkpoint_filter_fn +def _create_vitamin(variant, pretrained=False, embed_cfg=None, **kwargs): + assert embed_cfg is not None + backbone = MbConvStages(cfg=embed_cfg) + kwargs['embed_layer'] = partial(HybridEmbed, backbone=backbone, proj=False) + kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set return build_model_with_cfg( VisionTransformer, variant, pretrained, - pretrained_filter_fn=_filter_fn, + pretrained_filter_fn=checkpoint_filter_fn, **kwargs, ) -def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): - embed_layer = partial(HybridEmbed, backbone=backbone) - kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set - return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs) +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD, + 'first_conv': 'patch_embed.backbone.stem.conv1', + 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'vitamin_small.datacomp1b_clip_ltt': _cfg( + hf_hub_id='jienengchen/ViTamin-S-LTT', num_classes=384), + 'vitamin_small.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-S', num_classes=384), + 'vitamin_base.datacomp1b_clip_ltt': _cfg( + hf_hub_id='jienengchen/ViTamin-B-LTT', num_classes=768), + 'vitamin_base.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-B', num_classes=768), + 'vitamin_large.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-224px', num_classes=1024), + 'vitamin_large_256.datacomp1b_clip_l2': _cfg( + hf_hub_id='jienengchen/ViTamin-L2-256px', num_classes=1024, + input_size=(3, 256, 256), crop_pct=1.0), + 'vitamin_large_256.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-256px', num_classes=1024, + input_size=(3, 256, 256), crop_pct=1.0), + 'vitamin_large_336.datacomp1b_clip_l2': _cfg( + hf_hub_id='jienengchen/ViTamin-L2-336px', num_classes=1024, + input_size=(3, 336, 336), crop_pct=1.0), + 'vitamin_large_336.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-336px', num_classes=1024, + input_size=(3, 336, 336), crop_pct=1.0), + 'vitamin_large_384.datacomp1b_clip_l2': _cfg( + hf_hub_id='jienengchen/ViTamin-L2-384px', num_classes=1024, + input_size=(3, 384, 384), crop_pct=1.0), + 'vitamin_large_384.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-384px', num_classes=1024, + input_size=(3, 384, 384), crop_pct=1.0), + 'vitamin_xlarge_256.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-XL-256px', num_classes=1152, + input_size=(3, 256, 256), crop_pct=1.0), + 'vitamin_xlarge_336.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-XL-336px', num_classes=1152, + input_size=(3, 336, 336), crop_pct=1.0), + 'vitamin_xlarge_384.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-XL-384px', num_classes=1152, + input_size=(3, 384, 384), crop_pct=1.0), +}) @register_model def vitamin_small(pretrained=False, **kwargs) -> VisionTransformer: - stage_1_2 = MbConvStages(cfg=VitCfg( - embed_dim=(64, 128, 384), - depths=(2, 4, 1), - stem_width=64, - conv_cfg = VitConvCfg( - norm_layer='layernorm2d', - norm_eps=1e-6, - ), - head_type='1d', + embed_cfg = VitCfg( + embed_dim=(64, 128, 384), + depths=(2, 4, 1), + stem_width=64, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, ), + head_type='1d', ) - stage3_args = dict(embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid('vitamin_small', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) + model_args = dict( + embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg + ) + model = _create_vitamin('vitamin_small', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_base(pretrained=False, **kwargs) -> VisionTransformer: - stage_1_2 = MbConvStages(cfg=VitCfg( - embed_dim=(128, 256, 768), - depths=(2, 4, 1), - stem_width=128, - conv_cfg = VitConvCfg( - norm_layer='layernorm2d', - norm_eps=1e-6, - ), - head_type='1d', + embed_cfg = VitCfg( + embed_dim=(128, 256, 768), + depths=(2, 4, 1), + stem_width=128, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, ), + head_type='1d', ) - stage3_args = dict(embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid('vitamin_base', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) + model_args = dict( + embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_base', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer: - stage_1_2 = MbConvStages(cfg=VitCfg( + embed_cfg = VitCfg( embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, @@ -410,17 +396,18 @@ def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer: norm_eps=1e-6, ), head_type='1d', - ), ) - stage3_args = dict(embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid( - 'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) + model_args = dict( + embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg, + ) + model = _create_vitamin('vitamin_large', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: - backbone = MbConvStages(cfg=VitCfg( + embed_cfg = VitCfg( embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, @@ -429,17 +416,17 @@ def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: norm_eps=1e-6, ), head_type='1d', - ), ) - model_args = dict(img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid( - 'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + model_args = dict( + img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_large_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: - backbone = MbConvStages(cfg=VitCfg( + embed_cfg = VitCfg( embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, @@ -448,17 +435,18 @@ def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: norm_eps=1e-6, ), head_type='1d', - ), ) - model_args = dict(img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid( - 'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + model_args = dict( + img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg + ) + model = _create_vitamin('vitamin_large_336', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: - backbone = MbConvStages(cfg=VitCfg( + embed_cfg = VitCfg( embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, @@ -467,17 +455,17 @@ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: norm_eps=1e-6, ), head_type='1d', - ), ) - model_args = dict(img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid( - 'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + model_args = dict( + img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_large_384', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: - backbone = MbConvStages(cfg=VitCfg( + embed_cfg=VitCfg( embed_dim=(192, 384, 1152), depths=(2, 4, 1), stem_width=192, @@ -486,17 +474,18 @@ def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: norm_eps=1e-6, ), head_type='1d', - ), ) - model_args = dict(img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid( - 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + model_args = dict( + img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg) + model = _create_vitamin( + 'vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: - backbone = MbConvStages(cfg=VitCfg( + embed_cfg = VitCfg( embed_dim=(192, 384, 1152), depths=(2, 4, 1), stem_width=192, @@ -505,17 +494,17 @@ def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: norm_eps=1e-6, ), head_type='1d', - ), ) - model_args = dict(img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid( - 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + model_args = dict( + img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_xlarge_336', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: - backbone = MbConvStages(cfg=VitCfg( + embed_cfg = VitCfg( embed_dim=(192, 384, 1152), depths=(2, 4, 1), stem_width=192, @@ -524,9 +513,9 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: norm_eps=1e-6, ), head_type='1d', - ), ) - model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') - model = _create_vision_transformer_hybrid( - 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + model_args = dict( + img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_xlarge_384', pretrained=pretrained, **dict(model_args, **kwargs)) return model \ No newline at end of file From 58591a97f7607f9b860bd7a77cc69f29a008057d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 4 Jun 2024 16:57:16 -0700 Subject: [PATCH 20/27] Enable features_only properly --- timm/models/vitamin.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index 71d3b674..7c2c8735 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -279,6 +279,7 @@ class GeGluMlp(nn.Module): def _create_vitamin(variant, pretrained=False, embed_cfg=None, **kwargs): + out_indices = kwargs.pop('out_indices', 3) assert embed_cfg is not None backbone = MbConvStages(cfg=embed_cfg) kwargs['embed_layer'] = partial(HybridEmbed, backbone=backbone, proj=False) @@ -289,6 +290,7 @@ def _create_vitamin(variant, pretrained=False, embed_cfg=None, **kwargs): variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) From 0e77c95ed7a45b14eb43c274d38a10be66491cb0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 5 Jun 2024 00:20:00 -0700 Subject: [PATCH 21/27] Add vitamin to non-std testing models --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 652ea355..1d1cc733 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -60,7 +60,7 @@ NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*' + 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*' ] NUM_NON_STD = len(NON_STD_FILTERS) From cc8a03daacbb710aeb344c19ec84568cb8fc14b9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 6 Jun 2024 09:15:27 -0700 Subject: [PATCH 22/27] Add ConvStem and MobileCLIP hybrid model for B variant. Add full norm disable support to ConvNormAct layers --- timm/layers/conv_bn_act.py | 59 +++++++++++-------- timm/layers/norm_act.py | 16 ++---- timm/models/vision_transformer.py | 2 +- timm/models/vision_transformer_hybrid.py | 72 ++++++++++++++++++++++-- 4 files changed, 112 insertions(+), 37 deletions(-) diff --git a/timm/layers/conv_bn_act.py b/timm/layers/conv_bn_act.py index 17847d76..de738045 100644 --- a/timm/layers/conv_bn_act.py +++ b/timm/layers/conv_bn_act.py @@ -23,6 +23,7 @@ class ConvNormAct(nn.Module): dilation: int = 1, groups: int = 1, bias: bool = False, + apply_norm: bool = True, apply_act: bool = True, norm_layer: LayerType = nn.BatchNorm2d, act_layer: LayerType = nn.ReLU, @@ -48,17 +49,23 @@ class ConvNormAct(nn.Module): **conv_kwargs, ) - # NOTE for backwards compatibility with models that use separate norm and act layer definitions - norm_act_layer = get_norm_act_layer(norm_layer, act_layer) - # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` - if drop_layer: - norm_kwargs['drop_layer'] = drop_layer - self.bn = norm_act_layer( - out_channels, - apply_act=apply_act, - act_kwargs=act_kwargs, - **norm_kwargs, - ) + if apply_norm: + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` + if drop_layer: + norm_kwargs['drop_layer'] = drop_layer + self.bn = norm_act_layer( + out_channels, + apply_act=apply_act, + act_kwargs=act_kwargs, + **norm_kwargs, + ) + else: + self.bn = nn.Sequential() + if drop_layer: + norm_kwargs['drop_layer'] = drop_layer + self.bn.add_module('drop', drop_layer()) @property def in_channels(self): @@ -88,6 +95,7 @@ class ConvNormActAa(nn.Module): dilation: int = 1, groups: int = 1, bias: bool = False, + apply_norm: bool = True, apply_act: bool = True, norm_layer: LayerType = nn.BatchNorm2d, act_layer: LayerType = nn.ReLU, @@ -113,17 +121,24 @@ class ConvNormActAa(nn.Module): **conv_kwargs, ) - # NOTE for backwards compatibility with models that use separate norm and act layer definitions - norm_act_layer = get_norm_act_layer(norm_layer, act_layer) - # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` - if drop_layer: - norm_kwargs['drop_layer'] = drop_layer - self.bn = norm_act_layer( - out_channels, - apply_act=apply_act, - act_kwargs=act_kwargs, - **norm_kwargs, - ) + if apply_norm: + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` + if drop_layer: + norm_kwargs['drop_layer'] = drop_layer + self.bn = norm_act_layer( + out_channels, + apply_act=apply_act, + act_kwargs=act_kwargs, + **norm_kwargs, + ) + else: + self.bn = nn.Sequential() + if drop_layer: + norm_kwargs['drop_layer'] = drop_layer + self.bn.add_module('drop', drop_layer()) + self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) @property diff --git a/timm/layers/norm_act.py b/timm/layers/norm_act.py index 49505c58..496efcfd 100644 --- a/timm/layers/norm_act.py +++ b/timm/layers/norm_act.py @@ -19,21 +19,18 @@ from torch import nn as nn from torch.nn import functional as F from torchvision.ops.misc import FrozenBatchNorm2d -from .create_act import get_act_layer +from .create_act import create_act_layer from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm from .trace_utils import _assert def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True): - act_layer = get_act_layer(act_layer) # string -> nn.Module act_kwargs = act_kwargs or {} - if act_layer is not None and apply_act: - if inplace: - act_kwargs['inplace'] = inplace - act = act_layer(**act_kwargs) - else: - act = nn.Identity() - return act + act_kwargs.setdefault('inplace', inplace) + act = None + if apply_act: + act = create_act_layer(act_layer, **act_kwargs) + return nn.Identity() if act is None else act class BatchNormAct2d(nn.BatchNorm2d): @@ -421,7 +418,6 @@ class LayerNormAct(nn.LayerNorm): ): super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() - act_layer = get_act_layer(act_layer) # string -> nn.Module self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act) self._fast_norm = is_fast_norm() diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index f25db6b5..e3f1b8f2 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -609,7 +609,7 @@ class VisionTransformer(nn.Module): def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: if self.pos_embed is None: - return x + return x.view(x.shape[0], -1, x.shape[-1]) if self.dynamic_img_size: B, H, W, C = x.shape diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index c2dd1e59..af51fa98 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -15,14 +15,15 @@ Hacked together by / Copyright 2020, Ross Wightman """ import math from functools import partial -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, Union import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to +from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to + from ._registry import generate_default_cfgs, register_model, register_model_deprecations from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, create_resnetv2_stem @@ -191,8 +192,52 @@ class HybridEmbedWithSize(nn.Module): return x.flatten(2).transpose(1, 2), x.shape[-2:] -def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): - embed_layer = partial(HybridEmbed, backbone=backbone) +class ConvStem(nn.Sequential): + def __init__( + self, + in_chans: int = 3, + depth: int = 3, + channels: Union[int, Tuple[int, ...]] = 64, + kernel_size: Union[int, Tuple[int, ...]] = 3, + stride: Union[int, Tuple[int, ...]] = (2, 2, 2), + padding: Union[str, int, Tuple[int, ...]] = "", + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + act_layer: Type[nn.Module] = nn.ReLU, + ): + super().__init__() + if isinstance(channels, int): + if depth == 4: + channels = (channels // 8, channels // 4, channels // 2, channels) + elif depth == 3: + channels = (channels // 4, channels // 2, channels) + else: + channels = to_ntuple(depth)(channels) + + kernel_size = to_ntuple(depth)(kernel_size) + padding = to_ntuple(depth)(padding) + assert depth == len(stride) == len(kernel_size) == len(channels) + + in_chs = in_chans + for i in range(len(channels)): + last_conv = i == len(channels) - 1 + self.add_module(f'{i}', ConvNormAct( + in_chs, + channels[i], + kernel_size=kernel_size[i], + stride=stride[i], + padding=padding[i], + bias=last_conv, + apply_norm=not last_conv, + apply_act=not last_conv, + norm_layer=norm_layer, + act_layer=act_layer, + )) + in_chs = channels[i] + + +def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs): + embed_args = embed_args or {} + embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args) kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs) @@ -433,6 +478,25 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs) -> VisionTransformer: return model +@register_model +def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer: + """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights. + """ + backbone = ConvStem( + channels=(768//4, 768//4, 768), + stride=(4, 2, 2), + kernel_size=(4, 2, 2), + padding=0, + act_layer=nn.GELU, + ) + model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True) + model = _create_vision_transformer_hybrid( + 'vit_base_resnet50d_224', backbone=backbone, embed_args=dict(proj=False), + pretrained=pretrained, **dict(model_args, **kwargs) + ) + return model + + register_model_deprecations(__name__, { 'vit_tiny_r_s16_p8_224_in21k': 'vit_tiny_r_s16_p8_224.augreg_in21k', 'vit_small_r26_s32_224_in21k': 'vit_small_r26_s32_224.augreg_in21k', From 7d4ada6d16760a3815a50c5e7f1b724955429d07 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 6 Jun 2024 09:16:43 -0700 Subject: [PATCH 23/27] Update ViTamin model defs --- timm/models/vitamin.py | 148 +++++++++++++++++++++++++++++++---------- 1 file changed, 114 insertions(+), 34 deletions(-) diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index 7c2c8735..6e0c28f0 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -308,34 +308,36 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ - 'vitamin_small.datacomp1b_clip_ltt': _cfg( + 'vitamin_small_224.datacomp1b_clip_ltt': _cfg( hf_hub_id='jienengchen/ViTamin-S-LTT', num_classes=384), - 'vitamin_small.datacomp1b_clip': _cfg( + 'vitamin_small_224.datacomp1b_clip': _cfg( hf_hub_id='jienengchen/ViTamin-S', num_classes=384), - 'vitamin_base.datacomp1b_clip_ltt': _cfg( + 'vitamin_base_224.datacomp1b_clip_ltt': _cfg( hf_hub_id='jienengchen/ViTamin-B-LTT', num_classes=768), - 'vitamin_base.datacomp1b_clip': _cfg( + 'vitamin_base_224.datacomp1b_clip': _cfg( hf_hub_id='jienengchen/ViTamin-B', num_classes=768), - 'vitamin_large.datacomp1b_clip': _cfg( - hf_hub_id='jienengchen/ViTamin-L-224px', num_classes=1024), - 'vitamin_large_256.datacomp1b_clip_l2': _cfg( + 'vitamin_large_224.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-224px', num_classes=768), + 'vitamin_large_256.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-256px', num_classes=768, + input_size=(3, 256, 256), crop_pct=1.0), + 'vitamin_large_336.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-336px', num_classes=768, + input_size=(3, 336, 336), crop_pct=1.0), + 'vitamin_large_384.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L-384px', num_classes=768, + input_size=(3, 384, 384), crop_pct=1.0), + 'vitamin_large2_224.datacomp1b_clip': _cfg( + hf_hub_id='jienengchen/ViTamin-L2-224px', num_classes=1024), + 'vitamin_large2_256.datacomp1b_clip': _cfg( hf_hub_id='jienengchen/ViTamin-L2-256px', num_classes=1024, input_size=(3, 256, 256), crop_pct=1.0), - 'vitamin_large_256.datacomp1b_clip': _cfg( - hf_hub_id='jienengchen/ViTamin-L-256px', num_classes=1024, - input_size=(3, 256, 256), crop_pct=1.0), - 'vitamin_large_336.datacomp1b_clip_l2': _cfg( + 'vitamin_large2_336.datacomp1b_clip': _cfg( hf_hub_id='jienengchen/ViTamin-L2-336px', num_classes=1024, input_size=(3, 336, 336), crop_pct=1.0), - 'vitamin_large_336.datacomp1b_clip': _cfg( - hf_hub_id='jienengchen/ViTamin-L-336px', num_classes=1024, - input_size=(3, 336, 336), crop_pct=1.0), - 'vitamin_large_384.datacomp1b_clip_l2': _cfg( + 'vitamin_large2_384.datacomp1b_clip': _cfg( hf_hub_id='jienengchen/ViTamin-L2-384px', num_classes=1024, input_size=(3, 384, 384), crop_pct=1.0), - 'vitamin_large_384.datacomp1b_clip': _cfg( - hf_hub_id='jienengchen/ViTamin-L-384px', num_classes=1024, - input_size=(3, 384, 384), crop_pct=1.0), 'vitamin_xlarge_256.datacomp1b_clip': _cfg( hf_hub_id='jienengchen/ViTamin-XL-256px', num_classes=1152, input_size=(3, 256, 256), crop_pct=1.0), @@ -349,12 +351,12 @@ default_cfgs = generate_default_cfgs({ @register_model -def vitamin_small(pretrained=False, **kwargs) -> VisionTransformer: +def vitamin_small_224(pretrained=False, **kwargs) -> VisionTransformer: embed_cfg = VitCfg( embed_dim=(64, 128, 384), depths=(2, 4, 1), stem_width=64, - conv_cfg = VitConvCfg( + conv_cfg=VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), @@ -364,17 +366,17 @@ def vitamin_small(pretrained=False, **kwargs) -> VisionTransformer: embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg', embed_cfg=embed_cfg ) - model = _create_vitamin('vitamin_small', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_vitamin('vitamin_small_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vitamin_base(pretrained=False, **kwargs) -> VisionTransformer: +def vitamin_base_224(pretrained=False, **kwargs) -> VisionTransformer: embed_cfg = VitCfg( embed_dim=(128, 256, 768), depths=(2, 4, 1), stem_width=128, - conv_cfg = VitConvCfg( + conv_cfg=VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), @@ -383,17 +385,17 @@ def vitamin_base(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict( embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg', embed_cfg=embed_cfg) - model = _create_vitamin('vitamin_base', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_vitamin('vitamin_base_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer: +def vitamin_large_224(pretrained=False, **kwargs) -> VisionTransformer: embed_cfg = VitCfg( embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, - conv_cfg = VitConvCfg( + conv_cfg=VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), @@ -403,7 +405,7 @@ def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer: embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg', embed_cfg=embed_cfg, ) - model = _create_vitamin('vitamin_large', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_vitamin('vitamin_large_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -413,7 +415,7 @@ def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, - conv_cfg = VitConvCfg( + conv_cfg=VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), @@ -432,7 +434,7 @@ def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, - conv_cfg = VitConvCfg( + conv_cfg=VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), @@ -452,7 +454,7 @@ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: embed_dim=(160, 320, 1024), depths=(2, 4, 1), stem_width=160, - conv_cfg = VitConvCfg( + conv_cfg=VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), @@ -465,13 +467,91 @@ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: return model +@register_model +def vitamin_large2_224(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg, + ) + model = _create_vitamin('vitamin_large2_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vitamin_large2_256(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_large2_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vitamin_large2_336(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg + ) + model = _create_vitamin('vitamin_large2_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vitamin_large2_384(pretrained=False, **kwargs) -> VisionTransformer: + embed_cfg = VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg=VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ) + model_args = dict( + img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., + class_token=False, global_pool='avg', embed_cfg=embed_cfg) + model = _create_vitamin('vitamin_large2_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: embed_cfg=VitCfg( embed_dim=(192, 384, 1152), depths=(2, 4, 1), stem_width=192, - conv_cfg = VitConvCfg( + conv_cfg=VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), @@ -491,7 +571,7 @@ def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: embed_dim=(192, 384, 1152), depths=(2, 4, 1), stem_width=192, - conv_cfg = VitConvCfg( + conv_cfg=VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), @@ -500,7 +580,7 @@ def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict( img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg', pos_embed='none', embed_cfg=embed_cfg) - model = _create_vitamin('vitamin_xlarge_336', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_vitamin('vitamin_xlarge_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -510,7 +590,7 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: embed_dim=(192, 384, 1152), depths=(2, 4, 1), stem_width=192, - conv_cfg = VitConvCfg( + conv_cfg=VitConvCfg( norm_layer='layernorm2d', norm_eps=1e-6, ), From 88a1006e025c1a4e39fb2b4db7f8ad8cb85ae88f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 6 Jun 2024 12:38:52 -0700 Subject: [PATCH 24/27] checkpoint filter fns with consistent name, add mobileclip-b pretrained cfgs --- timm/models/beit.py | 4 +- timm/models/efficientformer.py | 4 +- timm/models/fastvit.py | 4 +- timm/models/pvt_v2.py | 4 +- timm/models/vision_transformer_hybrid.py | 109 ++++++++++++++++++----- 5 files changed, 95 insertions(+), 30 deletions(-) diff --git a/timm/models/beit.py b/timm/models/beit.py index 63b6db54..922d15e7 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -591,7 +591,7 @@ default_cfgs = generate_default_cfgs({ }) -def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True): +def checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True): state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('module', state_dict) # beit v2 didn't strip module @@ -637,7 +637,7 @@ def _create_beit(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( Beit, variant, pretrained, - pretrained_filter_fn=_beit_checkpoint_filter_fn, + pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index c28538bc..32630683 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -556,7 +556,7 @@ class EfficientFormer(nn.Module): return x -def _checkpoint_filter_fn(state_dict, model): +def checkpoint_filter_fn(state_dict, model): """ Remap original checkpoints -> timm """ if 'stem.0.weight' in state_dict: return state_dict # non-original checkpoint, no remapping needed @@ -611,7 +611,7 @@ def _create_efficientformer(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', 4) model = build_model_with_cfg( EfficientFormer, variant, pretrained, - pretrained_filter_fn=_checkpoint_filter_fn, + pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index 66142105..ef7ec3c9 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -1414,7 +1414,7 @@ default_cfgs = generate_default_cfgs({ }) -def _checkpoint_filter_fn(state_dict, model): +def checkpoint_filter_fn(state_dict, model): """ Remap original checkpoints -> timm """ if 'stem.0.conv_kxk.0.conv.weight' in state_dict: return state_dict # non-original checkpoint, no remapping needed @@ -1493,7 +1493,7 @@ def _create_fastvit(variant, pretrained=False, **kwargs): FastVit, variant, pretrained, - pretrained_filter_fn=_checkpoint_filter_fn, + pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs ) diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 1d9c6842..90ebfe7a 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -403,7 +403,7 @@ class PyramidVisionTransformerV2(nn.Module): return x -def _checkpoint_filter_fn(state_dict, model): +def checkpoint_filter_fn(state_dict, model): """ Remap original checkpoints -> timm """ if 'patch_embed.proj.weight' in state_dict: return state_dict # non-original checkpoint, no remapping needed @@ -430,7 +430,7 @@ def _create_pvt2(variant, pretrained=False, **kwargs): PyramidVisionTransformerV2, variant, pretrained, - pretrained_filter_fn=_checkpoint_filter_fn, + pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs, ) diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index af51fa98..3501565c 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -15,7 +15,7 @@ Hacked together by / Copyright 2020, Ross Wightman """ import math from functools import partial -from typing import List, Optional, Tuple, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -24,10 +24,11 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to +from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model, register_model_deprecations from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, create_resnetv2_stem -from .vision_transformer import _create_vision_transformer, VisionTransformer +from .vision_transformer import VisionTransformer class HybridEmbed(nn.Module): @@ -159,22 +160,26 @@ class HybridEmbedWithSize(nn.Module): """ def __init__( self, - backbone, - img_size=224, - patch_size=1, - feature_size=None, - in_chans=3, - embed_dim=768, + backbone: nn.Module, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 1, + feature_size: Optional[Union[int, Tuple[int, int]]] = None, + feature_ratio: Optional[Union[int, Tuple[int, int]]] = None, + in_chans: int = 3, + embed_dim: int = 768, bias=True, + proj=True, ): super().__init__( backbone=backbone, img_size=img_size, patch_size=patch_size, feature_size=feature_size, + feature_ratio=feature_ratio, in_chans=in_chans, embed_dim=embed_dim, bias=bias, + proj=proj, ) @torch.jit.ignore @@ -206,12 +211,8 @@ class ConvStem(nn.Sequential): ): super().__init__() if isinstance(channels, int): - if depth == 4: - channels = (channels // 8, channels // 4, channels // 2, channels) - elif depth == 3: - channels = (channels // 4, channels // 2, channels) - else: - channels = to_ntuple(depth)(channels) + # a default tiered channel strategy + channels = tuple([channels // 2**i for i in range(depth)][::-1]) kernel_size = to_ntuple(depth)(kernel_size) padding = to_ntuple(depth)(padding) @@ -235,13 +236,6 @@ class ConvStem(nn.Sequential): in_chs = channels[i] -def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs): - embed_args = embed_args or {} - embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args) - kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set - return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs) - - def _resnetv2(layers=(3, 4, 9), **kwargs): """ ResNet-V2 backbone helper""" padding_same = kwargs.get('padding_same', True) @@ -257,6 +251,66 @@ def _resnetv2(layers=(3, 4, 9), **kwargs): return backbone +def _convert_mobileclip(state_dict, model, prefix='image_encoder.model.'): + out = {} + for k, v in state_dict.items(): + if not k.startswith(prefix): + continue + k = k.replace(prefix, '') + k = k.replace('patch_emb.', 'patch_embed.backbone.') + k = k.replace('block.conv', 'conv') + k = k.replace('block.norm', 'bn') + k = k.replace('post_transformer_norm.', 'norm.') + k = k.replace('pre_norm_mha.0', 'norm1') + k = k.replace('pre_norm_mha.1', 'attn') + k = k.replace('pre_norm_ffn.0', 'norm2') + k = k.replace('pre_norm_ffn.1', 'mlp.fc1') + k = k.replace('pre_norm_ffn.4', 'mlp.fc2') + k = k.replace('qkv_proj.', 'qkv.') + k = k.replace('out_proj.', 'proj.') + k = k.replace('transformer.', 'blocks.') + if k == 'pos_embed.pos_embed.pos_embed': + k = 'pos_embed' + v = v.squeeze(0) + if 'classifier.proj' in k: + bias_k = k.replace('classifier.proj', 'head.bias') + k = k.replace('classifier.proj', 'head.weight') + v = v.T + out[bias_k] = torch.zeros(v.shape[0]) + out[k] = v + return out + + +def checkpoint_filter_fn( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, + interpolation: str = 'bicubic', + antialias: bool = True, +) -> Dict[str, torch.Tensor]: + from .vision_transformer import checkpoint_filter_fn as _filter_fn + + if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict: + state_dict = _convert_mobileclip(state_dict, model) + + return _filter_fn(state_dict, model, interpolation=interpolation, antialias=antialias) + + +def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', 3) + embed_args = embed_args or {} + embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args) + kwargs.setdefault('embed_layer', embed_layer) + kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set + return build_model_with_cfg( + VisionTransformer, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) + + def _cfg(url='', **kwargs): return { 'url': url, @@ -331,6 +385,17 @@ default_cfgs = generate_default_cfgs({ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), 'vit_base_resnet50d_224.untrained': _cfg( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + + 'vit_base_mci_224.apple_mclip': _cfg( + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt', + num_classes=512, + mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.conv1.0', + ), + 'vit_base_mci_224.apple_mclip_lt': _cfg( + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt', + num_classes=512, + mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.conv1.0', + ), }) @@ -491,7 +556,7 @@ def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer: ) model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True) model = _create_vision_transformer_hybrid( - 'vit_base_resnet50d_224', backbone=backbone, embed_args=dict(proj=False), + 'vit_base_mci_224', backbone=backbone, embed_args=dict(proj=False), pretrained=pretrained, **dict(model_args, **kwargs) ) return model From fc1b66a51d35201ecc13c1c5697c4b2e28c9c49d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 6 Jun 2024 13:42:26 -0700 Subject: [PATCH 25/27] Fix first conv name for mci vit-b --- timm/models/vision_transformer_hybrid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 3501565c..c16e7c78 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -389,12 +389,12 @@ default_cfgs = generate_default_cfgs({ 'vit_base_mci_224.apple_mclip': _cfg( url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt', num_classes=512, - mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.conv1.0', + mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv.weight', ), 'vit_base_mci_224.apple_mclip_lt': _cfg( url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt', num_classes=512, - mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.conv1.0', + mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv.weight', ), }) From ad026e6e33b8db4d5640f16bfc7aedf2f7359931 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 6 Jun 2024 17:56:14 -0700 Subject: [PATCH 26/27] Fix in_chans switching on create --- timm/models/vision_transformer_hybrid.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index c16e7c78..0c690c35 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -389,12 +389,12 @@ default_cfgs = generate_default_cfgs({ 'vit_base_mci_224.apple_mclip': _cfg( url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt', num_classes=512, - mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv.weight', + mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv', ), 'vit_base_mci_224.apple_mclip_lt': _cfg( url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt', num_classes=512, - mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv.weight', + mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv', ), }) @@ -552,6 +552,7 @@ def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer: stride=(4, 2, 2), kernel_size=(4, 2, 2), padding=0, + in_chans=kwargs.get('in_chans', 3), act_layer=nn.GELU, ) model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True) From 7ccb10ebfff6a57002047fd4fabcd911ec4f9604 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 6 Jun 2024 21:50:27 -0700 Subject: [PATCH 27/27] Disable efficient_builder debug flag --- timm/models/_efficientnet_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/_efficientnet_builder.py b/timm/models/_efficientnet_builder.py index e9b789a4..57bde323 100644 --- a/timm/models/_efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -25,7 +25,7 @@ __all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights" _logger = logging.getLogger(__name__) -_DEBUG_BUILDER = True +_DEBUG_BUILDER = False # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per # papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)