From 7fe96e7a92262dbbc81325b1fc7583446c0996c0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 May 2024 15:09:29 -0700 Subject: [PATCH] More MobileNet-v4 fixes * missed final norm after post pooling 1x1 PW head conv * improve repr of model by flipping a few modules to None when not used, nn.Sequential for MultiQueryAttention query/key/value/output * allow layer scaling to be enabled/disabled at model variant level, conv variants don't use it --- timm/layers/attention2d.py | 85 ++++++++++++---------------- timm/models/_efficientnet_blocks.py | 65 +++++++++------------ timm/models/_efficientnet_builder.py | 27 +++++---- timm/models/mobilenetv3.py | 27 ++++++++- 4 files changed, 102 insertions(+), 102 deletions(-) diff --git a/timm/layers/attention2d.py b/timm/layers/attention2d.py index 3d3f6d01..d1d38fb3 100644 --- a/timm/layers/attention2d.py +++ b/timm/layers/attention2d.py @@ -107,6 +107,7 @@ class MultiQueryAttention2d(nn.Module): attn_drop: float = 0., proj_drop: float = 0., norm_layer: nn.Module = nn.BatchNorm2d, + use_bias: bool = False, ): """Initializer. @@ -130,26 +131,25 @@ class MultiQueryAttention2d(nn.Module): self.fused_attn = use_fused_attn() self.drop = attn_drop + self.query = nn.Sequential() if self.has_query_strides: # FIXME dilation - self.query_down_pool = create_pool2d( - 'avg', - kernel_size=self.query_strides, - padding=padding, - ) - self.query_down_norm = norm_layer(dim) - else: - self.query_down_pool = nn.Identity() - self.query_down_norm = nn.Identity() - - self.query_proj = create_conv2d( + self.query.add_module('down_pool', create_pool2d( + 'avg', + kernel_size=self.query_strides, + padding=padding, + )) + self.query.add_module('norm', norm_layer(dim)) + self.query.add_module('proj', create_conv2d( dim, self.num_heads * self.key_dim, kernel_size=1, - ) + bias=use_bias, + )) + self.key = nn.Sequential() if kv_stride > 1: - self.key_down_conv = create_conv2d( + self.key.add_module('down_conv', create_conv2d( dim, dim, kernel_size=dw_kernel_size, @@ -157,21 +157,19 @@ class MultiQueryAttention2d(nn.Module): dilation=dilation, padding=padding, depthwise=True, - ) - self.key_down_norm = norm_layer(dim) - else: - self.key_down_conv = nn.Identity() - self.key_down_norm = nn.Identity() - - self.key_proj = create_conv2d( + )) + self.key.add_module('norm', norm_layer(dim)) + self.key.add_module('proj', create_conv2d( dim, self.key_dim, kernel_size=1, padding=padding, - ) + bias=use_bias, + )) + self.value = nn.Sequential() if kv_stride > 1: - self.value_down_conv = create_conv2d( + self.value.add_module('down_conv', create_conv2d( dim, dim, kernel_size=dw_kernel_size, @@ -179,32 +177,28 @@ class MultiQueryAttention2d(nn.Module): dilation=dilation, padding=padding, depthwise=True, - ) - self.value_down_norm = norm_layer(dim) - else: - self.value_down_conv = nn.Identity() - self.value_down_norm = nn.Identity() - - self.value_proj = create_conv2d( + )) + self.value.add_module('norm', norm_layer(dim)) + self.value.add_module('proj', create_conv2d( dim, self.value_dim, kernel_size=1, - ) + bias=use_bias, + )) self.attn_drop = nn.Dropout(attn_drop) + self.output = nn.Sequential() if self.has_query_strides: - self.upsampling = nn.Upsample(self.query_strides, mode='bilinear', align_corners=False) - else: - self.upsampling = nn.Identity() - - self.out_proj = create_conv2d( + self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False)) + self.output.add_module('proj', create_conv2d( self.value_dim * self.num_heads, dim_out, kernel_size=1, - ) + bias=use_bias, + )) + self.output.add_module('drop', nn.Dropout(proj_drop)) - self.proj_drop = nn.Dropout(proj_drop) self.einsum = False def _reshape_input(self, t: torch.Tensor): @@ -237,21 +231,15 @@ class MultiQueryAttention2d(nn.Module): """Run layer computation.""" B, C, H, W = s = x.shape - q = self.query_down_pool(x) - q = self.query_down_norm(q) - q = self.query_proj(q) + q = self.query(x) # desired q shape: [b, h, k, n x n] - [b, l, h, k] q = self._reshape_projected_query(q, self.num_heads, self.key_dim) - k = self.key_down_conv(x) - k = self.key_down_norm(k) - k = self.key_proj(k) + k = self.key(x) # output shape of k: [b, k, p], p = m x m k = self._reshape_input(k) - v = self.value_down_conv(x) - v = self.value_down_norm(v) - v = self.value_proj(v) + v = self.value(x) # output shape of v: [ b, p, k], p = m x m v = self._reshape_input(v) @@ -285,10 +273,7 @@ class MultiQueryAttention2d(nn.Module): # reshape o into [b, hk, n, n,] o = self._reshape_output(o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1]) - o = self.upsampling(o) - - x = self.out_proj(o) - x = self.proj_drop(x) + x = self.output(o) return x diff --git a/timm/models/_efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py index 41f3182d..be00b01c 100644 --- a/timm/models/_efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -174,13 +174,12 @@ class DepthwiseSeparableConv(nn.Module): def forward(self, x): shortcut = x - #print('ii', x.shape) + #print('ii', x.shape) # FIXME debug s2d if self.conv_s2d is not None: x = self.conv_s2d(x) x = self.bn_s2d(x) - #print('id', x.shape) + #print('id', x.shape) # FIXME debug s2d x = self.conv_dw(x) - #print('od', x.shape) x = self.bn1(x) x = self.se(x) x = self.conv_pw(x) @@ -296,7 +295,8 @@ class LayerScale2d(nn.Module): class UniversalInvertedResidual(nn.Module): """ Universal Inverted Residual Block - For MobileNetV4 - https://arxiv.org/abs/ + For MobileNetV4 - https://arxiv.org/abs/, referenced from + https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L778 """ def __init__( @@ -338,8 +338,9 @@ class UniversalInvertedResidual(nn.Module): ) self.norm_dw_start = dw_norm_act_layer(in_chs, apply_act=False) else: - self.conv_dw_start = nn.Identity() - self.norm_dw_start = nn.Identity() + # start is None when not used for cleaner repr + self.conv_dw_start = None + self.norm_dw_start = None # Point-wise expansion mid_chs = make_divisible(in_chs * exp_ratio) @@ -359,6 +360,7 @@ class UniversalInvertedResidual(nn.Module): ) self.norm_dw_mid = dw_norm_act_layer(mid_chs, inplace=True) else: + # keeping mid as identity so it can be hooked more easily for features self.conv_dw_mid = nn.Identity() self.norm_dw_mid = nn.Identity() @@ -379,7 +381,7 @@ class UniversalInvertedResidual(nn.Module): ) self.norm_dw_end = dw_norm_act_layer(out_chs, apply_act=False) else: - # dw_end rarely used so keeping it out of repr by not using None instead of nn.Identitty() + # end is None when not in use for cleaner repr self.conv_dw_end = None self.norm_dw_end = None @@ -397,8 +399,9 @@ class UniversalInvertedResidual(nn.Module): def forward(self, x): shortcut = x - x = self.conv_dw_start(x) - x = self.norm_dw_start(x) + if self.conv_dw_start is not None: + x = self.conv_dw_start(x) + x = self.norm_dw_start(x) x = self.conv_pw(x) x = self.norm_pw(x) x = self.conv_dw_mid(x) @@ -418,7 +421,8 @@ class UniversalInvertedResidual(nn.Module): class MobileAttention(nn.Module): """ Mobile Attention Block - For MobileNetV4 - https://arxiv.org/abs/ + For MobileNetV4 - https://arxiv.org/abs/, referenced from + https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L1504 """ def __init__( self, @@ -476,34 +480,21 @@ class MobileAttention(nn.Module): num_heads = in_chs // key_dim if use_multi_query: - #if self.has_query_stride or self.kv_stride > 1: - self.attn = ( - MultiQueryAttention2d( - in_chs, - dim_out=out_chs, - num_heads=num_heads, - key_dim=key_dim, - value_dim=value_dim, - query_strides=query_strides, - kv_stride=kv_stride, - dilation=dilation, - padding=pad_type, - dw_kernel_size=dw_kernel_size, - attn_drop=attn_drop, - proj_drop=proj_drop, - #bias=use_bias, # why not here if used w/ mhsa? - ) + self.attn = MultiQueryAttention2d( + in_chs, + dim_out=out_chs, + num_heads=num_heads, + key_dim=key_dim, + value_dim=value_dim, + query_strides=query_strides, + kv_stride=kv_stride, + dilation=dilation, + padding=pad_type, + dw_kernel_size=dw_kernel_size, + attn_drop=attn_drop, + proj_drop=proj_drop, + #bias=use_bias, # why not here if used w/ mhsa? ) - # else: - # self.attn = MultiQueryAttentionV2( - # in_chs, - # dim_out=out_chs, - # num_heads=num_heads, - # key_dim=key_dim, - # value_dim=value_dim, - # attn_drop=attn_drop, - # proj_drop=proj_drop, - # ) else: self.attn = Attention2d( in_chs, diff --git a/timm/models/_efficientnet_builder.py b/timm/models/_efficientnet_builder.py index 4cbd6342..7d96216a 100644 --- a/timm/models/_efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -5,6 +5,7 @@ Handles stride, dilation calculations, and selects feature extraction points. Hacked together by / Copyright 2019, Ross Wightman """ +from typing import Callable, Optional import logging import math @@ -321,15 +322,16 @@ class EfficientNetBuilder: """ def __init__( self, - output_stride=32, - pad_type='', - round_chs_fn=round_channels, - se_from_exp=False, - act_layer=None, - norm_layer=None, - se_layer=None, - drop_path_rate=0., - feature_location='', + output_stride: int = 32, + pad_type: str = '', + round_chs_fn: Callable = round_channels, + se_from_exp: bool = False, + act_layer: Optional[Callable] = None, + norm_layer: Optional[Callable] = None, + se_layer: Optional[Callable] = None, + drop_path_rate: float = 0., + layer_scale_init_value: Optional[float] = None, + feature_location: str = '', ): self.output_stride = output_stride self.pad_type = pad_type @@ -344,6 +346,7 @@ class EfficientNetBuilder: except TypeError: self.se_has_ratio = False self.drop_path_rate = drop_path_rate + self.layer_scale_init_value = layer_scale_init_value if feature_location == 'depthwise': # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense _logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'") @@ -402,13 +405,13 @@ class EfficientNetBuilder: block = ConvBnAct(**ba) elif bt == 'uir': _log_info_if(' UniversalInvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) - block = UniversalInvertedResidual(**ba) + block = UniversalInvertedResidual(**ba, layer_scale_init_value=self.layer_scale_init_value) elif bt == 'mqa': _log_info_if(' MobileMultiQueryAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose) - block = MobileAttention(**ba, use_multi_query=True) + block = MobileAttention(**ba, use_multi_query=True, layer_scale_init_value=self.layer_scale_init_value) elif bt == 'mha': _log_info_if(' MobileMultiHeadAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose) - block = MobileAttention(**ba) + block = MobileAttention(**ba, layer_scale_init_value=self.layer_scale_init_value) else: assert False, 'Unknown block type (%s) while building model.' % bt diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index a9c63f28..e90a8df4 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -51,6 +51,7 @@ class MobileNetV3(nn.Module): fix_stem: bool = False, num_features: int = 1280, head_bias: bool = True, + head_norm: bool = False, pad_type: PadType = '', act_layer: Optional[LayerType] = None, norm_layer: Optional[LayerType] = None, @@ -59,6 +60,7 @@ class MobileNetV3(nn.Module): round_chs_fn: Callable = round_channels, drop_rate: float = 0., drop_path_rate: float = 0., + layer_scale_init_value: Optional[float] = None, global_pool: str = 'avg', ): """ @@ -78,6 +80,7 @@ class MobileNetV3(nn.Module): round_chs_fn: Callable to round number of filters based on depth multiplier. drop_rate: Dropout rate. drop_path_rate: Stochastic depth rate. + layer_scale_init_value: Enable layer scale on compatible blocks if not None global_pool: Type of pooling to use for global pooling features of the FC head. """ super(MobileNetV3, self).__init__() @@ -106,6 +109,7 @@ class MobileNetV3(nn.Module): norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features @@ -115,8 +119,16 @@ class MobileNetV3(nn.Module): # Head + Pooling self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) num_pooled_chs = head_chs * self.global_pool.feat_mult() - self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) - self.act2 = act_layer(inplace=True) + if head_norm: + # mobilenet-v4 post-pooling PW conv is followed by a norm+act layer + self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type) # never bias + self.norm_head = norm_act_layer(self.num_features) + self.act2 = nn.Identity() + else: + # mobilenet-v3 and others only have an activation after final PW conv + self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) + self.norm_head = nn.Identity() + self.act2 = act_layer(inplace=True) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() @@ -125,7 +137,7 @@ class MobileNetV3(nn.Module): def as_sequential(self): layers = [self.conv_stem, self.bn1] layers.extend(self.blocks) - layers.extend([self.global_pool, self.conv_head, self.act2]) + layers.extend([self.global_pool, self.conv_head, self.norm_head, self.act2]) layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) return nn.Sequential(*layers) @@ -224,8 +236,10 @@ class MobileNetV3(nn.Module): self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0 if max_index < len(self.blocks): self.conv_head = nn.Identity() + self.norm_head = nn.Identity() if prune_head: self.conv_head = nn.Identity() + self.norm_head = nn.Identity() self.reset_classifier(0, '') return take_indices @@ -241,6 +255,7 @@ class MobileNetV3(nn.Module): def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: x = self.global_pool(x) x = self.conv_head(x) + x = self.norm_head(x) x = self.act2(x) x = self.flatten(x) if pre_logits: @@ -632,6 +647,7 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: channel_multiplier: multiplier to number of channels per layer. """ if 'hybrid' in variant: + layer_scale_init_value = 1e-5 if 'medium' in variant: stem_size = 32 num_features = 1280 @@ -730,6 +746,7 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: else: assert False, f'Unknown variant {variant}.' else: + layer_scale_init_value = None if 'small' in variant: stem_size = 32 num_features = 1280 @@ -836,9 +853,12 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: else: assert False, f'Unknown variant {variant}.' + # NOTE SE not used in initial MobileNet-v4 definitions se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) model_kwargs = dict( block_args=decode_arch_def(arch_def), + head_bias=False, + head_norm=True, num_features=num_features, stem_size=stem_size, fix_stem=channel_multiplier < 0.75, @@ -846,6 +866,7 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=act_layer, se_layer=se_layer, + layer_scale_init_value=layer_scale_init_value, **kwargs, ) model = _create_mnv3(variant, pretrained, **model_kwargs)