mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
convnext zepto, rmsnorm experiments
This commit is contained in:
parent
e3242a5258
commit
5d7bd2973e
@ -34,7 +34,7 @@ from .linear import Linear
|
|||||||
from .mixed_conv2d import MixedConv2d
|
from .mixed_conv2d import MixedConv2d
|
||||||
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
|
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
|
||||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
|
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
|
||||||
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
|
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
|
||||||
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
|
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
|
||||||
from .padding import get_padding, get_same_padding, pad_same
|
from .padding import get_padding, get_same_padding, pad_same
|
||||||
|
@ -10,7 +10,7 @@ from typing import Type
|
|||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
|
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
|
||||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
|
|
||||||
_NORM_MAP = dict(
|
_NORM_MAP = dict(
|
||||||
@ -22,6 +22,7 @@ _NORM_MAP = dict(
|
|||||||
layernorm=LayerNorm,
|
layernorm=LayerNorm,
|
||||||
layernorm2d=LayerNorm2d,
|
layernorm2d=LayerNorm2d,
|
||||||
rmsnorm=RmsNorm,
|
rmsnorm=RmsNorm,
|
||||||
|
rmsnorm2d=RmsNorm2d,
|
||||||
frozenbatchnorm2d=FrozenBatchNorm2d,
|
frozenbatchnorm2d=FrozenBatchNorm2d,
|
||||||
)
|
)
|
||||||
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
|
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
|
||||||
|
@ -152,3 +152,41 @@ class RmsNorm(nn.Module):
|
|||||||
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
|
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
|
||||||
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RmsNorm2d(nn.Module):
|
||||||
|
""" RmsNorm w/ fast (apex) norm if available
|
||||||
|
"""
|
||||||
|
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
|
||||||
|
normalized_shape: Tuple[int, ...]
|
||||||
|
eps: float
|
||||||
|
elementwise_affine: bool
|
||||||
|
|
||||||
|
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
|
||||||
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
|
super().__init__()
|
||||||
|
normalized_shape = channels
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
# mypy error: incompatible types in assignment
|
||||||
|
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||||
|
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||||
|
self.eps = eps
|
||||||
|
self.elementwise_affine = affine
|
||||||
|
if self.elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.register_parameter('weight', None)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
if self.elementwise_affine:
|
||||||
|
nn.init.ones_(self.weight)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = x.permute(0, 2, 3, 1)
|
||||||
|
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
|
||||||
|
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
|
||||||
|
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||||
|
x = x.permute(0, 3, 1, 2)
|
||||||
|
return x
|
||||||
|
@ -45,7 +45,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
|
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
|
||||||
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
|
LayerNorm2d, LayerNorm, RmsNorm2d, RmsNorm, create_conv2d, get_act_layer, get_norm_layer, make_divisible, to_ntuple
|
||||||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features import feature_take_indices
|
from ._features import feature_take_indices
|
||||||
@ -289,24 +289,27 @@ class ConvNeXt(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
assert output_stride in (8, 16, 32)
|
assert output_stride in (8, 16, 32)
|
||||||
kernel_sizes = to_ntuple(4)(kernel_sizes)
|
kernel_sizes = to_ntuple(4)(kernel_sizes)
|
||||||
if norm_layer is None:
|
use_rms = isinstance(norm_layer, str) and norm_layer.startswith('rmsnorm')
|
||||||
norm_layer = LayerNorm2d
|
if norm_layer is None or use_rms:
|
||||||
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
|
norm_layer = RmsNorm2d if use_rms else LayerNorm2d
|
||||||
|
norm_layer_cl = norm_layer if conv_mlp else (RmsNorm if use_rms else LayerNorm)
|
||||||
if norm_eps is not None:
|
if norm_eps is not None:
|
||||||
norm_layer = partial(norm_layer, eps=norm_eps)
|
norm_layer = partial(norm_layer, eps=norm_eps)
|
||||||
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
||||||
else:
|
else:
|
||||||
assert conv_mlp,\
|
assert conv_mlp,\
|
||||||
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
|
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
|
||||||
|
norm_layer = get_norm_layer(norm_layer)
|
||||||
norm_layer_cl = norm_layer
|
norm_layer_cl = norm_layer
|
||||||
if norm_eps is not None:
|
if norm_eps is not None:
|
||||||
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
||||||
|
act_layer = get_act_layer(act_layer)
|
||||||
|
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
self.feature_info = []
|
self.feature_info = []
|
||||||
|
|
||||||
assert stem_type in ('patch', 'overlap', 'overlap_tiered')
|
assert stem_type in ('patch', 'overlap', 'overlap_tiered', 'overlap_act')
|
||||||
if stem_type == 'patch':
|
if stem_type == 'patch':
|
||||||
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
|
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
|
||||||
self.stem = nn.Sequential(
|
self.stem = nn.Sequential(
|
||||||
@ -316,11 +319,12 @@ class ConvNeXt(nn.Module):
|
|||||||
stem_stride = patch_size
|
stem_stride = patch_size
|
||||||
else:
|
else:
|
||||||
mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
|
mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
|
||||||
self.stem = nn.Sequential(
|
self.stem = nn.Sequential(*filter(None, [
|
||||||
nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
|
nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
|
||||||
|
act_layer() if 'act' in stem_type else None,
|
||||||
nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
|
nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
|
||||||
norm_layer(dims[0]),
|
norm_layer(dims[0]),
|
||||||
)
|
]))
|
||||||
stem_stride = 4
|
stem_stride = 4
|
||||||
|
|
||||||
self.stages = nn.Sequential()
|
self.stages = nn.Sequential()
|
||||||
@ -592,6 +596,14 @@ default_cfgs = generate_default_cfgs({
|
|||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||||
|
|
||||||
|
'convnext_zepto_rms.untrained': _cfg(
|
||||||
|
#hf_hub_id='timm/',
|
||||||
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||||
|
test_input_size=(3, 256, 256), test_crop_pct=0.95),
|
||||||
|
'convnext_zepto_rms_ols.untrained': _cfg(
|
||||||
|
# hf_hub_id='timm/',
|
||||||
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||||
|
test_input_size=(3, 256, 256), test_crop_pct=0.95),
|
||||||
'convnext_atto.d2_in1k': _cfg(
|
'convnext_atto.d2_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
@ -600,6 +612,9 @@ default_cfgs = generate_default_cfgs({
|
|||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||||
|
'convnext_atto_rms.untrained': _cfg(
|
||||||
|
#hf_hub_id='timm/',
|
||||||
|
test_input_size=(3, 256, 256), test_crop_pct=0.95),
|
||||||
'convnext_femto.d1_in1k': _cfg(
|
'convnext_femto.d1_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
@ -968,6 +983,23 @@ default_cfgs = generate_default_cfgs({
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt:
|
||||||
|
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
||||||
|
model_args = dict(depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d')
|
||||||
|
model = _create_convnext('convnext_zepto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def convnext_zepto_rms_ols(pretrained=False, **kwargs) -> ConvNeXt:
|
||||||
|
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
||||||
|
model_args = dict(
|
||||||
|
depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d', stem_type='overlap_act')
|
||||||
|
model = _create_convnext('convnext_zepto_rms_oas', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_atto(pretrained=False, **kwargs) -> ConvNeXt:
|
def convnext_atto(pretrained=False, **kwargs) -> ConvNeXt:
|
||||||
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
||||||
@ -984,6 +1016,14 @@ def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def convnext_atto_rms(pretrained=False, **kwargs) -> ConvNeXt:
|
||||||
|
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
||||||
|
model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, norm_layer='rmsnorm2d')
|
||||||
|
model = _create_convnext('convnext_atto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_femto(pretrained=False, **kwargs) -> ConvNeXt:
|
def convnext_femto(pretrained=False, **kwargs) -> ConvNeXt:
|
||||||
# timm femto variant
|
# timm femto variant
|
||||||
|
Loading…
x
Reference in New Issue
Block a user