From 5d7bd2973e63889f038c7da4b13be886bb06e4ab Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 18 Sep 2024 12:26:48 -0700 Subject: [PATCH] convnext zepto, rmsnorm experiments --- timm/layers/__init__.py | 2 +- timm/layers/create_norm.py | 3 ++- timm/layers/norm.py | 38 +++++++++++++++++++++++++++ timm/models/convnext.py | 54 +++++++++++++++++++++++++++++++++----- 4 files changed, 88 insertions(+), 9 deletions(-) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 49ffa0ce..f631e868 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -34,7 +34,7 @@ from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn -from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d from .padding import get_padding, get_same_padding, pad_same diff --git a/timm/layers/create_norm.py b/timm/layers/create_norm.py index fbf58985..74e893d8 100644 --- a/timm/layers/create_norm.py +++ b/timm/layers/create_norm.py @@ -10,7 +10,7 @@ from typing import Type import torch.nn as nn -from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm +from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d from torchvision.ops.misc import FrozenBatchNorm2d _NORM_MAP = dict( @@ -22,6 +22,7 @@ _NORM_MAP = dict( layernorm=LayerNorm, layernorm2d=LayerNorm2d, rmsnorm=RmsNorm, + rmsnorm2d=RmsNorm2d, frozenbatchnorm2d=FrozenBatchNorm2d, ) _NORM_TYPES = {m for n, m in _NORM_MAP.items()} diff --git a/timm/layers/norm.py b/timm/layers/norm.py index 4b81dcef..e9f9c27d 100644 --- a/timm/layers/norm.py +++ b/timm/layers/norm.py @@ -152,3 +152,41 @@ class RmsNorm(nn.Module): # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed. x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) return x + + +class RmsNorm2d(nn.Module): + """ RmsNorm w/ fast (apex) norm if available + """ + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + normalized_shape = channels + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + else: + self.register_parameter('weight', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 3, 1) + # NOTE fast norm fallback needs our rms norm impl, so both paths through here. + # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed. + x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) + x = x.permute(0, 3, 1, 2) + return x diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 004d8e83..cddefb17 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -45,7 +45,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \ - LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple + LayerNorm2d, LayerNorm, RmsNorm2d, RmsNorm, create_conv2d, get_act_layer, get_norm_layer, make_divisible, to_ntuple from timm.layers import NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -289,24 +289,27 @@ class ConvNeXt(nn.Module): super().__init__() assert output_stride in (8, 16, 32) kernel_sizes = to_ntuple(4)(kernel_sizes) - if norm_layer is None: - norm_layer = LayerNorm2d - norm_layer_cl = norm_layer if conv_mlp else LayerNorm + use_rms = isinstance(norm_layer, str) and norm_layer.startswith('rmsnorm') + if norm_layer is None or use_rms: + norm_layer = RmsNorm2d if use_rms else LayerNorm2d + norm_layer_cl = norm_layer if conv_mlp else (RmsNorm if use_rms else LayerNorm) if norm_eps is not None: norm_layer = partial(norm_layer, eps=norm_eps) norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) else: assert conv_mlp,\ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' + norm_layer = get_norm_layer(norm_layer) norm_layer_cl = norm_layer if norm_eps is not None: norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) + act_layer = get_act_layer(act_layer) self.num_classes = num_classes self.drop_rate = drop_rate self.feature_info = [] - assert stem_type in ('patch', 'overlap', 'overlap_tiered') + assert stem_type in ('patch', 'overlap', 'overlap_tiered', 'overlap_act') if stem_type == 'patch': # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 self.stem = nn.Sequential( @@ -316,11 +319,12 @@ class ConvNeXt(nn.Module): stem_stride = patch_size else: mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0] - self.stem = nn.Sequential( + self.stem = nn.Sequential(*filter(None, [ nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias), + act_layer() if 'act' in stem_type else None, nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias), norm_layer(dims[0]), - ) + ])) stem_stride = 4 self.stages = nn.Sequential() @@ -592,6 +596,14 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_zepto_rms.untrained': _cfg( + #hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + test_input_size=(3, 256, 256), test_crop_pct=0.95), + 'convnext_zepto_rms_ols.untrained': _cfg( + # hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + test_input_size=(3, 256, 256), test_crop_pct=0.95), 'convnext_atto.d2_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth', hf_hub_id='timm/', @@ -600,6 +612,9 @@ default_cfgs = generate_default_cfgs({ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth', hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_atto_rms.untrained': _cfg( + #hf_hub_id='timm/', + test_input_size=(3, 256, 256), test_crop_pct=0.95), 'convnext_femto.d1_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth', hf_hub_id='timm/', @@ -968,6 +983,23 @@ default_cfgs = generate_default_cfgs({ }) +@register_model +def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt: + # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M + model_args = dict(depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d') + model = _create_convnext('convnext_zepto_rms', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def convnext_zepto_rms_ols(pretrained=False, **kwargs) -> ConvNeXt: + # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M + model_args = dict( + depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d', stem_type='overlap_act') + model = _create_convnext('convnext_zepto_rms_oas', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def convnext_atto(pretrained=False, **kwargs) -> ConvNeXt: # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M @@ -984,6 +1016,14 @@ def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt: return model +@register_model +def convnext_atto_rms(pretrained=False, **kwargs) -> ConvNeXt: + # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M + model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, norm_layer='rmsnorm2d') + model = _create_convnext('convnext_atto_rms', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def convnext_femto(pretrained=False, **kwargs) -> ConvNeXt: # timm femto variant