More MobileNet-v4 fixes

* missed final norm after post pooling 1x1 PW head conv
* improve repr of model by flipping a few modules to None when not used, nn.Sequential for MultiQueryAttention query/key/value/output
* allow layer scaling to be enabled/disabled at model variant level, conv variants don't use it
This commit is contained in:
Ross Wightman 2024-05-24 15:09:29 -07:00
parent 28d76a97db
commit 7fe96e7a92
4 changed files with 102 additions and 102 deletions

View File

@ -107,6 +107,7 @@ class MultiQueryAttention2d(nn.Module):
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = nn.BatchNorm2d,
use_bias: bool = False,
):
"""Initializer.
@ -130,26 +131,25 @@ class MultiQueryAttention2d(nn.Module):
self.fused_attn = use_fused_attn()
self.drop = attn_drop
self.query = nn.Sequential()
if self.has_query_strides:
# FIXME dilation
self.query_down_pool = create_pool2d(
self.query.add_module('down_pool', create_pool2d(
'avg',
kernel_size=self.query_strides,
padding=padding,
)
self.query_down_norm = norm_layer(dim)
else:
self.query_down_pool = nn.Identity()
self.query_down_norm = nn.Identity()
self.query_proj = create_conv2d(
))
self.query.add_module('norm', norm_layer(dim))
self.query.add_module('proj', create_conv2d(
dim,
self.num_heads * self.key_dim,
kernel_size=1,
)
bias=use_bias,
))
self.key = nn.Sequential()
if kv_stride > 1:
self.key_down_conv = create_conv2d(
self.key.add_module('down_conv', create_conv2d(
dim,
dim,
kernel_size=dw_kernel_size,
@ -157,21 +157,19 @@ class MultiQueryAttention2d(nn.Module):
dilation=dilation,
padding=padding,
depthwise=True,
)
self.key_down_norm = norm_layer(dim)
else:
self.key_down_conv = nn.Identity()
self.key_down_norm = nn.Identity()
self.key_proj = create_conv2d(
))
self.key.add_module('norm', norm_layer(dim))
self.key.add_module('proj', create_conv2d(
dim,
self.key_dim,
kernel_size=1,
padding=padding,
)
bias=use_bias,
))
self.value = nn.Sequential()
if kv_stride > 1:
self.value_down_conv = create_conv2d(
self.value.add_module('down_conv', create_conv2d(
dim,
dim,
kernel_size=dw_kernel_size,
@ -179,32 +177,28 @@ class MultiQueryAttention2d(nn.Module):
dilation=dilation,
padding=padding,
depthwise=True,
)
self.value_down_norm = norm_layer(dim)
else:
self.value_down_conv = nn.Identity()
self.value_down_norm = nn.Identity()
self.value_proj = create_conv2d(
))
self.value.add_module('norm', norm_layer(dim))
self.value.add_module('proj', create_conv2d(
dim,
self.value_dim,
kernel_size=1,
)
bias=use_bias,
))
self.attn_drop = nn.Dropout(attn_drop)
self.output = nn.Sequential()
if self.has_query_strides:
self.upsampling = nn.Upsample(self.query_strides, mode='bilinear', align_corners=False)
else:
self.upsampling = nn.Identity()
self.out_proj = create_conv2d(
self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False))
self.output.add_module('proj', create_conv2d(
self.value_dim * self.num_heads,
dim_out,
kernel_size=1,
)
bias=use_bias,
))
self.output.add_module('drop', nn.Dropout(proj_drop))
self.proj_drop = nn.Dropout(proj_drop)
self.einsum = False
def _reshape_input(self, t: torch.Tensor):
@ -237,21 +231,15 @@ class MultiQueryAttention2d(nn.Module):
"""Run layer computation."""
B, C, H, W = s = x.shape
q = self.query_down_pool(x)
q = self.query_down_norm(q)
q = self.query_proj(q)
q = self.query(x)
# desired q shape: [b, h, k, n x n] - [b, l, h, k]
q = self._reshape_projected_query(q, self.num_heads, self.key_dim)
k = self.key_down_conv(x)
k = self.key_down_norm(k)
k = self.key_proj(k)
k = self.key(x)
# output shape of k: [b, k, p], p = m x m
k = self._reshape_input(k)
v = self.value_down_conv(x)
v = self.value_down_norm(v)
v = self.value_proj(v)
v = self.value(x)
# output shape of v: [ b, p, k], p = m x m
v = self._reshape_input(v)
@ -285,10 +273,7 @@ class MultiQueryAttention2d(nn.Module):
# reshape o into [b, hk, n, n,]
o = self._reshape_output(o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1])
o = self.upsampling(o)
x = self.out_proj(o)
x = self.proj_drop(x)
x = self.output(o)
return x

View File

@ -174,13 +174,12 @@ class DepthwiseSeparableConv(nn.Module):
def forward(self, x):
shortcut = x
#print('ii', x.shape)
#print('ii', x.shape) # FIXME debug s2d
if self.conv_s2d is not None:
x = self.conv_s2d(x)
x = self.bn_s2d(x)
#print('id', x.shape)
#print('id', x.shape) # FIXME debug s2d
x = self.conv_dw(x)
#print('od', x.shape)
x = self.bn1(x)
x = self.se(x)
x = self.conv_pw(x)
@ -296,7 +295,8 @@ class LayerScale2d(nn.Module):
class UniversalInvertedResidual(nn.Module):
""" Universal Inverted Residual Block
For MobileNetV4 - https://arxiv.org/abs/
For MobileNetV4 - https://arxiv.org/abs/, referenced from
https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L778
"""
def __init__(
@ -338,8 +338,9 @@ class UniversalInvertedResidual(nn.Module):
)
self.norm_dw_start = dw_norm_act_layer(in_chs, apply_act=False)
else:
self.conv_dw_start = nn.Identity()
self.norm_dw_start = nn.Identity()
# start is None when not used for cleaner repr
self.conv_dw_start = None
self.norm_dw_start = None
# Point-wise expansion
mid_chs = make_divisible(in_chs * exp_ratio)
@ -359,6 +360,7 @@ class UniversalInvertedResidual(nn.Module):
)
self.norm_dw_mid = dw_norm_act_layer(mid_chs, inplace=True)
else:
# keeping mid as identity so it can be hooked more easily for features
self.conv_dw_mid = nn.Identity()
self.norm_dw_mid = nn.Identity()
@ -379,7 +381,7 @@ class UniversalInvertedResidual(nn.Module):
)
self.norm_dw_end = dw_norm_act_layer(out_chs, apply_act=False)
else:
# dw_end rarely used so keeping it out of repr by not using None instead of nn.Identitty()
# end is None when not in use for cleaner repr
self.conv_dw_end = None
self.norm_dw_end = None
@ -397,6 +399,7 @@ class UniversalInvertedResidual(nn.Module):
def forward(self, x):
shortcut = x
if self.conv_dw_start is not None:
x = self.conv_dw_start(x)
x = self.norm_dw_start(x)
x = self.conv_pw(x)
@ -418,7 +421,8 @@ class UniversalInvertedResidual(nn.Module):
class MobileAttention(nn.Module):
""" Mobile Attention Block
For MobileNetV4 - https://arxiv.org/abs/
For MobileNetV4 - https://arxiv.org/abs/, referenced from
https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L1504
"""
def __init__(
self,
@ -476,9 +480,7 @@ class MobileAttention(nn.Module):
num_heads = in_chs // key_dim
if use_multi_query:
#if self.has_query_stride or self.kv_stride > 1:
self.attn = (
MultiQueryAttention2d(
self.attn = MultiQueryAttention2d(
in_chs,
dim_out=out_chs,
num_heads=num_heads,
@ -493,17 +495,6 @@ class MobileAttention(nn.Module):
proj_drop=proj_drop,
#bias=use_bias, # why not here if used w/ mhsa?
)
)
# else:
# self.attn = MultiQueryAttentionV2(
# in_chs,
# dim_out=out_chs,
# num_heads=num_heads,
# key_dim=key_dim,
# value_dim=value_dim,
# attn_drop=attn_drop,
# proj_drop=proj_drop,
# )
else:
self.attn = Attention2d(
in_chs,

View File

@ -5,6 +5,7 @@ Handles stride, dilation calculations, and selects feature extraction points.
Hacked together by / Copyright 2019, Ross Wightman
"""
from typing import Callable, Optional
import logging
import math
@ -321,15 +322,16 @@ class EfficientNetBuilder:
"""
def __init__(
self,
output_stride=32,
pad_type='',
round_chs_fn=round_channels,
se_from_exp=False,
act_layer=None,
norm_layer=None,
se_layer=None,
drop_path_rate=0.,
feature_location='',
output_stride: int = 32,
pad_type: str = '',
round_chs_fn: Callable = round_channels,
se_from_exp: bool = False,
act_layer: Optional[Callable] = None,
norm_layer: Optional[Callable] = None,
se_layer: Optional[Callable] = None,
drop_path_rate: float = 0.,
layer_scale_init_value: Optional[float] = None,
feature_location: str = '',
):
self.output_stride = output_stride
self.pad_type = pad_type
@ -344,6 +346,7 @@ class EfficientNetBuilder:
except TypeError:
self.se_has_ratio = False
self.drop_path_rate = drop_path_rate
self.layer_scale_init_value = layer_scale_init_value
if feature_location == 'depthwise':
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
_logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
@ -402,13 +405,13 @@ class EfficientNetBuilder:
block = ConvBnAct(**ba)
elif bt == 'uir':
_log_info_if(' UniversalInvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = UniversalInvertedResidual(**ba)
block = UniversalInvertedResidual(**ba, layer_scale_init_value=self.layer_scale_init_value)
elif bt == 'mqa':
_log_info_if(' MobileMultiQueryAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = MobileAttention(**ba, use_multi_query=True)
block = MobileAttention(**ba, use_multi_query=True, layer_scale_init_value=self.layer_scale_init_value)
elif bt == 'mha':
_log_info_if(' MobileMultiHeadAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = MobileAttention(**ba)
block = MobileAttention(**ba, layer_scale_init_value=self.layer_scale_init_value)
else:
assert False, 'Unknown block type (%s) while building model.' % bt

View File

@ -51,6 +51,7 @@ class MobileNetV3(nn.Module):
fix_stem: bool = False,
num_features: int = 1280,
head_bias: bool = True,
head_norm: bool = False,
pad_type: PadType = '',
act_layer: Optional[LayerType] = None,
norm_layer: Optional[LayerType] = None,
@ -59,6 +60,7 @@ class MobileNetV3(nn.Module):
round_chs_fn: Callable = round_channels,
drop_rate: float = 0.,
drop_path_rate: float = 0.,
layer_scale_init_value: Optional[float] = None,
global_pool: str = 'avg',
):
"""
@ -78,6 +80,7 @@ class MobileNetV3(nn.Module):
round_chs_fn: Callable to round number of filters based on depth multiplier.
drop_rate: Dropout rate.
drop_path_rate: Stochastic depth rate.
layer_scale_init_value: Enable layer scale on compatible blocks if not None
global_pool: Type of pooling to use for global pooling features of the FC head.
"""
super(MobileNetV3, self).__init__()
@ -106,6 +109,7 @@ class MobileNetV3(nn.Module):
norm_layer=norm_layer,
se_layer=se_layer,
drop_path_rate=drop_path_rate,
layer_scale_init_value=layer_scale_init_value,
)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features
@ -115,7 +119,15 @@ class MobileNetV3(nn.Module):
# Head + Pooling
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
num_pooled_chs = head_chs * self.global_pool.feat_mult()
if head_norm:
# mobilenet-v4 post-pooling PW conv is followed by a norm+act layer
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type) # never bias
self.norm_head = norm_act_layer(self.num_features)
self.act2 = nn.Identity()
else:
# mobilenet-v3 and others only have an activation after final PW conv
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.norm_head = nn.Identity()
self.act2 = act_layer(inplace=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
@ -125,7 +137,7 @@ class MobileNetV3(nn.Module):
def as_sequential(self):
layers = [self.conv_stem, self.bn1]
layers.extend(self.blocks)
layers.extend([self.global_pool, self.conv_head, self.act2])
layers.extend([self.global_pool, self.conv_head, self.norm_head, self.act2])
layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
return nn.Sequential(*layers)
@ -224,8 +236,10 @@ class MobileNetV3(nn.Module):
self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0
if max_index < len(self.blocks):
self.conv_head = nn.Identity()
self.norm_head = nn.Identity()
if prune_head:
self.conv_head = nn.Identity()
self.norm_head = nn.Identity()
self.reset_classifier(0, '')
return take_indices
@ -241,6 +255,7 @@ class MobileNetV3(nn.Module):
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
x = self.global_pool(x)
x = self.conv_head(x)
x = self.norm_head(x)
x = self.act2(x)
x = self.flatten(x)
if pre_logits:
@ -632,6 +647,7 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
channel_multiplier: multiplier to number of channels per layer.
"""
if 'hybrid' in variant:
layer_scale_init_value = 1e-5
if 'medium' in variant:
stem_size = 32
num_features = 1280
@ -730,6 +746,7 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
else:
assert False, f'Unknown variant {variant}.'
else:
layer_scale_init_value = None
if 'small' in variant:
stem_size = 32
num_features = 1280
@ -836,9 +853,12 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
else:
assert False, f'Unknown variant {variant}.'
# NOTE SE not used in initial MobileNet-v4 definitions
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
head_bias=False,
head_norm=True,
num_features=num_features,
stem_size=stem_size,
fix_stem=channel_multiplier < 0.75,
@ -846,6 +866,7 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=act_layer,
se_layer=se_layer,
layer_scale_init_value=layer_scale_init_value,
**kwargs,
)
model = _create_mnv3(variant, pretrained, **model_kwargs)