convnext zepto, rmsnorm experiments

This commit is contained in:
Ross Wightman 2024-09-18 12:26:48 -07:00
parent e3242a5258
commit 5d7bd2973e
4 changed files with 88 additions and 9 deletions

View File

@ -34,7 +34,7 @@ from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same from .padding import get_padding, get_same_padding, pad_same

View File

@ -10,7 +10,7 @@ from typing import Type
import torch.nn as nn import torch.nn as nn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
_NORM_MAP = dict( _NORM_MAP = dict(
@ -22,6 +22,7 @@ _NORM_MAP = dict(
layernorm=LayerNorm, layernorm=LayerNorm,
layernorm2d=LayerNorm2d, layernorm2d=LayerNorm2d,
rmsnorm=RmsNorm, rmsnorm=RmsNorm,
rmsnorm2d=RmsNorm2d,
frozenbatchnorm2d=FrozenBatchNorm2d, frozenbatchnorm2d=FrozenBatchNorm2d,
) )
_NORM_TYPES = {m for n, m in _NORM_MAP.items()} _NORM_TYPES = {m for n, m in _NORM_MAP.items()}

View File

@ -152,3 +152,41 @@ class RmsNorm(nn.Module):
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed. # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps) x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
return x return x
class RmsNorm2d(nn.Module):
""" RmsNorm w/ fast (apex) norm if available
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
normalized_shape = channels
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter('weight', None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
x = x.permute(0, 3, 1, 2)
return x

View File

@ -45,7 +45,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple LayerNorm2d, LayerNorm, RmsNorm2d, RmsNorm, create_conv2d, get_act_layer, get_norm_layer, make_divisible, to_ntuple
from timm.layers import NormMlpClassifierHead, ClassifierHead from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices from ._features import feature_take_indices
@ -289,24 +289,27 @@ class ConvNeXt(nn.Module):
super().__init__() super().__init__()
assert output_stride in (8, 16, 32) assert output_stride in (8, 16, 32)
kernel_sizes = to_ntuple(4)(kernel_sizes) kernel_sizes = to_ntuple(4)(kernel_sizes)
if norm_layer is None: use_rms = isinstance(norm_layer, str) and norm_layer.startswith('rmsnorm')
norm_layer = LayerNorm2d if norm_layer is None or use_rms:
norm_layer_cl = norm_layer if conv_mlp else LayerNorm norm_layer = RmsNorm2d if use_rms else LayerNorm2d
norm_layer_cl = norm_layer if conv_mlp else (RmsNorm if use_rms else LayerNorm)
if norm_eps is not None: if norm_eps is not None:
norm_layer = partial(norm_layer, eps=norm_eps) norm_layer = partial(norm_layer, eps=norm_eps)
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
else: else:
assert conv_mlp,\ assert conv_mlp,\
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
norm_layer = get_norm_layer(norm_layer)
norm_layer_cl = norm_layer norm_layer_cl = norm_layer
if norm_eps is not None: if norm_eps is not None:
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
act_layer = get_act_layer(act_layer)
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.feature_info = [] self.feature_info = []
assert stem_type in ('patch', 'overlap', 'overlap_tiered') assert stem_type in ('patch', 'overlap', 'overlap_tiered', 'overlap_act')
if stem_type == 'patch': if stem_type == 'patch':
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential( self.stem = nn.Sequential(
@ -316,11 +319,12 @@ class ConvNeXt(nn.Module):
stem_stride = patch_size stem_stride = patch_size
else: else:
mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0] mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
self.stem = nn.Sequential( self.stem = nn.Sequential(*filter(None, [
nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias), nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
act_layer() if 'act' in stem_type else None,
nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias), nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
norm_layer(dims[0]), norm_layer(dims[0]),
) ]))
stem_stride = 4 stem_stride = 4
self.stages = nn.Sequential() self.stages = nn.Sequential()
@ -592,6 +596,14 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/', hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_zepto_rms.untrained': _cfg(
#hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
test_input_size=(3, 256, 256), test_crop_pct=0.95),
'convnext_zepto_rms_ols.untrained': _cfg(
# hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
test_input_size=(3, 256, 256), test_crop_pct=0.95),
'convnext_atto.d2_in1k': _cfg( 'convnext_atto.d2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
@ -600,6 +612,9 @@ default_cfgs = generate_default_cfgs({
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95), test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_atto_rms.untrained': _cfg(
#hf_hub_id='timm/',
test_input_size=(3, 256, 256), test_crop_pct=0.95),
'convnext_femto.d1_in1k': _cfg( 'convnext_femto.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
@ -968,6 +983,23 @@ default_cfgs = generate_default_cfgs({
}) })
@register_model
def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt:
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
model_args = dict(depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d')
model = _create_convnext('convnext_zepto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_zepto_rms_ols(pretrained=False, **kwargs) -> ConvNeXt:
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
model_args = dict(
depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d', stem_type='overlap_act')
model = _create_convnext('convnext_zepto_rms_oas', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model @register_model
def convnext_atto(pretrained=False, **kwargs) -> ConvNeXt: def convnext_atto(pretrained=False, **kwargs) -> ConvNeXt:
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
@ -984,6 +1016,14 @@ def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt:
return model return model
@register_model
def convnext_atto_rms(pretrained=False, **kwargs) -> ConvNeXt:
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, norm_layer='rmsnorm2d')
model = _create_convnext('convnext_atto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model @register_model
def convnext_femto(pretrained=False, **kwargs) -> ConvNeXt: def convnext_femto(pretrained=False, **kwargs) -> ConvNeXt:
# timm femto variant # timm femto variant