diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index f35b8a86..1bac6d31 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -12,9 +12,255 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, trunc_normal_ +from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn from ._registry import register_model -from .byobnet import MobileOneBlock + + +def num_groups(group_size, channels): + if not group_size: # 0 or None + return 1 # normal conv with 1 group + else: + # NOTE group_size == 1 -> depthwise conv + assert channels % group_size == 0 + return channels // group_size + + +class MobileOneBlock(nn.Module): + """MobileOne building block. + + This block has a multi-branched architecture at train-time + and plain-CNN style architecture at inference time + For more details, please refer to our paper: + `An Improved One millisecond Mobile Backbone` - + https://arxiv.org/pdf/2206.04040.pdf + """ + + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + group_size: int = 0, + inference_mode: bool = False, + use_se: bool = False, + use_act: bool = True, + use_scale_branch: bool = True, + num_conv_branches: int = 1, + act_layer: nn.Module = nn.GELU, + ) -> None: + """Construct a MobileOneBlock module. + + Args: + in_chs: Number of channels in the input. + out_chs: Number of channels produced by the block. + kernel_size: Size of the convolution kernel. + stride: Stride size. + dilation: Kernel dilation factor. + groups: Group number. + inference_mode: If True, instantiates model in inference mode. + use_se: Whether to use SE-ReLU activations. + use_act: Whether to use activation. Default: ``True`` + use_scale_branch: Whether to use scale branch. Default: ``True`` + num_conv_branches: Number of linear conv branches. + """ + super(MobileOneBlock, self).__init__() + self.inference_mode = inference_mode + self.groups = num_groups(group_size, in_chs) + self.stride = stride + self.dilation = dilation + self.kernel_size = kernel_size + self.in_chs = in_chs + self.out_chs = out_chs + self.num_conv_branches = num_conv_branches + + # Check if SE-ReLU is requested + self.se = SqueezeExcite(out_chs) if use_se else nn.Identity() + + if inference_mode: + self.reparam_conv = create_conv2d( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=self.groups, + bias=True, + ) + else: + # Re-parameterizable skip connection + self.reparam_conv = None + + self.rbr_skip = ( + nn.BatchNorm2d(num_features=in_chs) + if out_chs == in_chs and stride == 1 + else None + ) + + # Re-parameterizable conv branches + if num_conv_branches > 0: + rbr_conv = list() + for _ in range(self.num_conv_branches): + rbr_conv.append(ConvNormAct( + self.in_chs, + self.out_chs, + kernel_size=kernel_size, + stride=self.stride, + groups=self.groups, + apply_act=False, + )) + self.rbr_conv = nn.ModuleList(rbr_conv) + else: + self.rbr_conv = None + + # Re-parameterizable scale branch + self.rbr_scale = None + if kernel_size > 1 and use_scale_branch: + self.rbr_scale = ConvNormAct( + self.in_chs, + self.out_chs, + kernel_size=1, + stride=self.stride, + groups=self.groups, + apply_act=False + ) + + self.act = act_layer() if use_act else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply forward pass.""" + # Inference mode forward pass. + if self.reparam_conv is not None: + return self.act(self.se(self.reparam_conv(x))) + + # Multi-branched train-time forward pass. + # Skip branch output + identity_out = 0 + if self.rbr_skip is not None: + identity_out = self.rbr_skip(x) + + # Scale branch output + scale_out = 0 + if self.rbr_scale is not None: + scale_out = self.rbr_scale(x) + + # Other branches + out = scale_out + identity_out + if self.rbr_conv is not None: + for ix in range(self.num_conv_branches): + out += self.rbr_conv[ix](x) + + return self.act(self.se(out)) + + def reparameterize(self): + """Following works like `RepVGG: Making VGG-style ConvNets Great Again` - + https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched + architecture used at training time to obtain a plain CNN-like structure + for inference. + """ + if self.inference_mode: + return + kernel, bias = self._get_kernel_bias() + self.reparam_conv = create_conv2d( + in_channels=self.in_chs, + out_channels=self.out_chs, + kernel_size=self.kernel_size, + stride=self.stride, + dilation=self.dilation, + groups=self.groups, + bias=True, + ) + self.reparam_conv.weight.data = kernel + self.reparam_conv.bias.data = bias + + # Delete un-used branches + for para in self.parameters(): + para.detach_() + self.__delattr__("rbr_conv") + self.__delattr__("rbr_scale") + if hasattr(self, "rbr_skip"): + self.__delattr__("rbr_skip") + + self.inference_mode = True + + def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 + + Returns: + Tuple of (kernel, bias) after fusing branches. + """ + # get weights and bias of scale branch + kernel_scale = 0 + bias_scale = 0 + if self.rbr_scale is not None: + kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.kernel_size // 2 + kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) + + # get weights and bias of skip branch + kernel_identity = 0 + bias_identity = 0 + if self.rbr_skip is not None: + kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip) + + # get weights and bias of conv branches + kernel_conv = 0 + bias_conv = 0 + if self.rbr_conv is not None: + for ix in range(self.num_conv_branches): + _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix]) + kernel_conv += _kernel + bias_conv += _bias + + kernel_final = kernel_conv + kernel_scale + kernel_identity + bias_final = bias_conv + bias_scale + bias_identity + return kernel_final, bias_final + + def _fuse_bn_tensor( + self, branch: Union[nn.Sequential, nn.BatchNorm2d] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Method to fuse batchnorm layer with preceeding conv layer. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 + + Args: + branch: Sequence of ops to be fused. + + Returns: + Tuple of (kernel, bias) after fusing batchnorm. + """ + if isinstance(branch, nn.Sequential): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, "id_tensor"): + input_dim = self.in_chs // self.groups + kernel_value = torch.zeros( + (self.in_chs, input_dim, self.kernel_size, self.kernel_size), + dtype=branch.weight.dtype, + device=branch.weight.device, + ) + for i in range(self.in_chs): + kernel_value[ + i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2 + ] = 1 + self.id_tensor = kernel_value + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std class ReparamLargeKernelConv(nn.Module): @@ -32,10 +278,10 @@ class ReparamLargeKernelConv(nn.Module): out_chs: int, kernel_size: int, stride: int, - groups: int, - small_kernel: int, + group_size: int, + small_kernel: Optional[int] = None, inference_mode: bool = False, - act_layer: nn.Module = nn.GELU(), + act_layer: Optional[nn.Module] = None, ) -> None: """Construct a ReparamLargeKernelConv module. @@ -44,55 +290,63 @@ class ReparamLargeKernelConv(nn.Module): out_chs: Number of output channels. kernel_size: Kernel size of the large kernel conv branch. stride: Stride size. Default: 1 - groups: Group number. Default: 1 + group_size: Group size. Default: 1 small_kernel: Kernel size of small kernel conv branch. inference_mode: If True, instantiates model in inference mode. Default: ``False`` act_layer: Activation module. Default: ``nn.GELU`` """ super(ReparamLargeKernelConv, self).__init__() - self.stride = stride - self.groups = groups + self.groups = num_groups(group_size, in_chs) self.in_chs = in_chs self.out_chs = out_chs - self.act_layer = act_layer self.kernel_size = kernel_size self.small_kernel = small_kernel - self.padding = kernel_size // 2 if inference_mode: - self.lkb_reparam = nn.Conv2d( - in_chs=in_chs, - out_chs=out_chs, + self.lkb_reparam = create_conv2d( + in_chs, + out_chs, kernel_size=kernel_size, stride=stride, - padding=self.padding, dilation=1, - groups=groups, + groups=self.groups, bias=True, ) else: - self.lkb_origin = self._conv_bn( - kernel_size=kernel_size, padding=self.padding + self.lkb_reparam = None + self.lkb_origin = ConvNormAct( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=self.stride, + groups=self.groups, + apply_act=False, ) if small_kernel is not None: assert ( small_kernel <= kernel_size ), "The kernel size for re-param cannot be larger than the large kernel!" - self.small_conv = self._conv_bn( - kernel_size=small_kernel, padding=small_kernel // 2 + self.small_conv = ConvNormAct( + in_chs, + out_chs, + kernel_size=small_kernel, + stride=self.stride, + groups=self.groups, + apply_act=False, ) + # FIXME output of this act was not used in original impl, likely due to bug + self.act = act_layer() if act_layer is not None else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply forward pass.""" - if hasattr(self, "lkb_reparam"): + if self.lkb_reparam is not None: out = self.lkb_reparam(x) else: out = self.lkb_origin(x) - if hasattr(self, "small_conv"): - out += self.small_conv(x) - - self.act_layer(out) + if self.small_conv is not None: + out = out + self.small_conv(x) + out = self.act(out) return out def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: @@ -119,12 +373,11 @@ class ReparamLargeKernelConv(nn.Module): for inference. """ eq_k, eq_b = self.get_kernel_bias() - self.lkb_reparam = nn.Conv2d( - in_chs=self.in_chs, - out_chs=self.out_chs, + self.lkb_reparam = create_conv2d( + self.in_chs, + self.out_chs, kernel_size=self.kernel_size, stride=self.stride, - padding=self.padding, dilation=self.lkb_origin.conv.dilation, groups=self.groups, bias=True, @@ -159,35 +412,11 @@ class ReparamLargeKernelConv(nn.Module): t = (gamma / std).reshape(-1, 1, 1, 1) return kernel * t, beta - running_mean * gamma / std - def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential: - """Helper method to construct conv-batchnorm layers. - - Args: - kernel_size: Size of the convolution kernel. - padding: Zero-padding size. - - Returns: - A nn.Sequential Conv-BN module. - """ - mod_list = nn.Sequential() - mod_list.add_module( - "conv", - nn.Conv2d( - in_chs=self.in_chs, - out_chs=self.out_chs, - kernel_size=kernel_size, - stride=self.stride, - padding=padding, - groups=self.groups, - bias=False, - ), - ) - mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_chs)) - return mod_list - def convolutional_stem( - in_chs: int, out_chs: int, inference_mode: bool = False + in_chs: int, + out_chs: int, + inference_mode: bool = False ) -> nn.Sequential: """Build convolutional stem with MobileOne blocks. @@ -206,8 +435,6 @@ def convolutional_stem( kernel_size=3, stride=2, inference_mode=inference_mode, - use_se=False, - num_conv_branches=1, ), MobileOneBlock( in_chs=out_chs, @@ -216,8 +443,6 @@ def convolutional_stem( stride=2, group_size=1, inference_mode=inference_mode, - use_se=False, - num_conv_branches=1, ), MobileOneBlock( in_chs=out_chs, @@ -225,8 +450,6 @@ def convolutional_stem( kernel_size=1, stride=1, inference_mode=inference_mode, - use_se=False, - num_conv_branches=1, ), ) @@ -237,6 +460,7 @@ class Attention(nn.Module): Source modified from: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ + fused_attn: torch.jit.Final[bool] def __init__( self, @@ -259,7 +483,8 @@ class Attention(nn.Module): assert dim % head_dim == 0, "dim should be divisible by head_dim" self.head_dim = head_dim self.num_heads = dim // head_dim - self.scale = head_dim**-0.5 + self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -267,11 +492,9 @@ class Attention(nn.Module): self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor) -> torch.Tensor: - shape = x.shape - B, C, H, W = shape + B, C, H, W = x.shape N = H * W - if len(shape) == 4: - x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C) + x = x.flatten(2).transpose(-2, -1) # (B, N, C) qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, self.head_dim) @@ -279,16 +502,22 @@ class Attention(nn.Module): ) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - # trick here to make q@k.t more stable - attn = (q * self.scale) @ k.transpose(-2, -1) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + if self.fused_attn: + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) - if len(shape) == 4: - x = x.transpose(-2, -1).reshape(B, C, H, W) + x = x.transpose(-2, -1).reshape(B, C, H, W) return x @@ -314,32 +543,25 @@ class PatchEmbed(nn.Module): inference_mode: Flag to instantiate model in inference mode. Default: ``False`` """ super().__init__() - block = list() - block.append( + self.proj = nn.Sequential( ReparamLargeKernelConv( in_chs=in_chs, out_chs=embed_dim, kernel_size=patch_size, stride=stride, - groups=in_chs, + group_size=1, small_kernel=3, inference_mode=inference_mode, - ) - ) - block.append( + act_layer=None, # activation was not used in original impl + ), MobileOneBlock( in_chs=embed_dim, out_chs=embed_dim, kernel_size=1, stride=1, - padding=0, - groups=1, inference_mode=inference_mode, - use_se=False, - num_conv_branches=1, ) ) - self.proj = nn.Sequential(*block) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) @@ -377,8 +599,8 @@ class RepMixer(nn.Module): if inference_mode: self.reparam_conv = nn.Conv2d( - in_chs=self.dim, - out_chs=self.dim, + self.dim, + self.dim, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size // 2, @@ -386,6 +608,7 @@ class RepMixer(nn.Module): bias=True, ) else: + self.reparam_conv = None self.norm = MobileOneBlock( dim, dim, @@ -404,12 +627,10 @@ class RepMixer(nn.Module): ) self.use_layer_scale = use_layer_scale if use_layer_scale: - self.layer_scale = nn.Parameter( - layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True - ) + self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1))) def forward(self, x: torch.Tensor) -> torch.Tensor: - if hasattr(self, "reparam_conv"): + if self.reparam_conv is not None: x = self.reparam_conv(x) return x else: @@ -444,12 +665,11 @@ class RepMixer(nn.Module): ) b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias - self.reparam_conv = nn.Conv2d( - in_chs=self.dim, - out_chs=self.dim, + self.reparam_conv = create_conv2d( + self.dim, + self.dim, kernel_size=self.kernel_size, stride=1, - padding=self.kernel_size // 2, groups=self.dim, bias=True, ) @@ -487,19 +707,26 @@ class ConvMlp(nn.Module): super().__init__() out_chs = out_chs or in_chs hidden_channels = hidden_channels or in_chs - self.conv = nn.Sequential() - self.conv.add_module( - "conv", - nn.Conv2d( - in_chs=in_chs, - out_chs=out_chs, - kernel_size=7, - padding=3, - groups=in_chs, - bias=False, - ), + # self.conv = nn.Sequential() + # self.conv.add_module( + # "conv", + # nn.Conv2d( + # in_chs, + # out_chs, + # kernel_size=7, + # padding=3, + # groups=in_chs, + # bias=False, + # ), + # ) + # self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_chs)) + self.conv = ConvNormAct( + in_chs, + out_chs, + kernel_size=7, + groups=in_chs, + apply_act=False, ) - self.conv.add_module("bn", nn.BatchNorm2d(num_features=out_chs)) self.fc1 = nn.Conv2d(in_chs, hidden_channels, kernel_size=1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_channels, out_chs, kernel_size=1) @@ -565,27 +792,28 @@ class RepCPE(nn.Module): if inference_mode: self.reparam_conv = nn.Conv2d( - in_chs=self.in_chs, - out_chs=self.embed_dim, + self.in_chs, + self.embed_dim, kernel_size=self.spatial_shape, stride=1, - padding=int(self.spatial_shape[0] // 2), + padding=spatial_shape[0] // 2, groups=self.embed_dim, bias=True, ) else: + self.reparam_conv = None self.pe = nn.Conv2d( in_chs, embed_dim, spatial_shape, 1, int(spatial_shape[0] // 2), - bias=True, groups=embed_dim, + bias=True, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - if hasattr(self, "reparam_conv"): + if self.reparam_conv is not None: x = self.reparam_conv(x) return x else: @@ -620,8 +848,8 @@ class RepCPE(nn.Module): # Introduce reparam conv self.reparam_conv = nn.Conv2d( - in_chs=self.in_chs, - out_chs=self.embed_dim, + self.in_chs, + self.embed_dim, kernel_size=self.spatial_shape, stride=1, padding=int(self.spatial_shape[0] // 2), @@ -682,10 +910,9 @@ class RepMixerBlock(nn.Module): assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( mlp_ratio ) - mlp_hidden_dim = int(dim * mlp_ratio) self.convffn = ConvMlp( in_chs=dim, - hidden_channels=mlp_hidden_dim, + hidden_channels=int(dim * mlp_ratio), act_layer=act_layer, drop=drop, ) @@ -696,9 +923,7 @@ class RepMixerBlock(nn.Module): # Layer Scale self.use_layer_scale = use_layer_scale if use_layer_scale: - self.layer_scale = nn.Parameter( - layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True - ) + self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1))) def forward(self, x): if self.use_layer_scale: @@ -763,12 +988,8 @@ class AttentionBlock(nn.Module): # Layer Scale self.use_layer_scale = use_layer_scale if use_layer_scale: - self.layer_scale_1 = nn.Parameter( - layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True - ) - self.layer_scale_2 = nn.Parameter( - layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True - ) + self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1))) + self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim, 1, 1))) def forward(self, x): if self.use_layer_scale: @@ -823,32 +1044,28 @@ def basic_blocks( / (sum(num_blocks) - 1) ) if token_mixer_type == "repmixer": - blocks.append( - RepMixerBlock( - dim, - kernel_size=kernel_size, - mlp_ratio=mlp_ratio, - act_layer=act_layer, - drop=drop_rate, - drop_path=block_dpr, - use_layer_scale=use_layer_scale, - layer_scale_init_value=layer_scale_init_value, - inference_mode=inference_mode, - ) - ) + blocks.append(RepMixerBlock( + dim, + kernel_size=kernel_size, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + drop=drop_rate, + drop_path=block_dpr, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + )) elif token_mixer_type == "attention": - blocks.append( - AttentionBlock( - dim, - mlp_ratio=mlp_ratio, - act_layer=act_layer, - norm_layer=norm_layer, - drop=drop_rate, - drop_path=block_dpr, - use_layer_scale=use_layer_scale, - layer_scale_init_value=layer_scale_init_value, - ) - ) + blocks.append(AttentionBlock( + dim, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + drop=drop_rate, + drop_path=block_dpr, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + )) else: raise ValueError( "Token mixer type: {} not supported".format(token_mixer_type) @@ -932,15 +1149,13 @@ class FastVit(nn.Module): # Patch merging/downsampling between stages. if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: - network.append( - PatchEmbed( - patch_size=down_patch_size, - stride=down_stride, - in_chs=embed_dims[i], - embed_dim=embed_dims[i + 1], - inference_mode=inference_mode, - ) - ) + network += [PatchEmbed( + patch_size=down_patch_size, + stride=down_stride, + in_chs=embed_dims[i], + embed_dim=embed_dims[i + 1], + inference_mode=inference_mode, + )] self.network = nn.ModuleList(network) @@ -1054,6 +1269,8 @@ default_cfgs = { "fastvit_t": _cfg(crop_pct=0.9), "fastvit_s": _cfg(crop_pct=0.9), "fastvit_m": _cfg(crop_pct=0.95), + 'fastvit_t8': _cfg( + url='https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar') }