mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fixing RmsNorm to fix #2380 and noticed with aimv2 when comparing outputs. Still some work to do, need to look at AMP / fast mode behaviour, dispatch to torch when possible. Add SimpleNorm for 'LayerNorm w/o centering and bias'
This commit is contained in:
parent
e752b5d07c
commit
04a484a895
@ -34,7 +34,7 @@ from .linear import Linear
|
|||||||
from .mixed_conv2d import MixedConv2d
|
from .mixed_conv2d import MixedConv2d
|
||||||
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
|
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
|
||||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
|
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
|
||||||
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
|
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
|
||||||
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
|
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
|
||||||
from .padding import get_padding, get_same_padding, pad_same
|
from .padding import get_padding, get_same_padding, pad_same
|
||||||
|
@ -108,6 +108,7 @@ def fast_layer_norm(
|
|||||||
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def rms_norm(
|
def rms_norm(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
normalized_shape: List[int],
|
normalized_shape: List[int],
|
||||||
@ -115,15 +116,16 @@ def rms_norm(
|
|||||||
eps: float = 1e-5,
|
eps: float = 1e-5,
|
||||||
):
|
):
|
||||||
norm_ndim = len(normalized_shape)
|
norm_ndim = len(normalized_shape)
|
||||||
|
v = x.pow(2)
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
# ndim = len(x.shape)
|
# ndim = len(x.shape)
|
||||||
# dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
|
# dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
|
||||||
# NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
|
# NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
|
||||||
assert norm_ndim == 1
|
assert norm_ndim == 1
|
||||||
v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
|
v = torch.mean(v, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
|
||||||
else:
|
else:
|
||||||
dims = tuple(range(-1, -norm_ndim - 1, -1))
|
dims = tuple(range(-1, -norm_ndim - 1, -1))
|
||||||
v = torch.var(x, dim=dims, keepdim=True)
|
v = torch.mean(v, dim=dims, keepdim=True)
|
||||||
x = x * torch.rsqrt(v + eps)
|
x = x * torch.rsqrt(v + eps)
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
x = x * weight
|
x = x * weight
|
||||||
@ -148,3 +150,47 @@ def fast_rms_norm(
|
|||||||
|
|
||||||
# fallback
|
# fallback
|
||||||
return rms_norm(x, normalized_shape, weight, eps)
|
return rms_norm(x, normalized_shape, weight, eps)
|
||||||
|
|
||||||
|
|
||||||
|
def simple_norm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
normalized_shape: List[int],
|
||||||
|
weight: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
):
|
||||||
|
norm_ndim = len(normalized_shape)
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
# ndim = len(x.shape)
|
||||||
|
# dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
|
||||||
|
# NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
|
||||||
|
assert norm_ndim == 1
|
||||||
|
v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
|
||||||
|
else:
|
||||||
|
dims = tuple(range(-1, -norm_ndim - 1, -1))
|
||||||
|
v = torch.var(x, dim=dims, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(v + eps)
|
||||||
|
if weight is not None:
|
||||||
|
x = x * weight
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def fast_simple_norm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
normalized_shape: List[int],
|
||||||
|
weight: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
# this must be by itself, cannot merge with has_apex_rmsnorm
|
||||||
|
return simple_norm(x, normalized_shape, weight, eps)
|
||||||
|
|
||||||
|
if is_autocast_enabled(x.device.type):
|
||||||
|
# normally native AMP casts LN inputs to float32
|
||||||
|
# apex LN does not, this is behaving like Apex
|
||||||
|
dt = get_autocast_dtype(x.device.type)
|
||||||
|
x, weight = x.to(dt), weight.to(dt)
|
||||||
|
|
||||||
|
with torch.amp.autocast(device_type=x.device.type, enabled=False):
|
||||||
|
x = simple_norm(x, normalized_shape, weight, eps)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm
|
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm
|
||||||
|
|
||||||
|
|
||||||
class GroupNorm(nn.GroupNorm):
|
class GroupNorm(nn.GroupNorm):
|
||||||
@ -190,3 +190,73 @@ class RmsNorm2d(nn.Module):
|
|||||||
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||||
x = x.permute(0, 3, 1, 2)
|
x = x.permute(0, 3, 1, 2)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleNorm(nn.Module):
|
||||||
|
""" SimpleNorm (x / std(x))
|
||||||
|
"""
|
||||||
|
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
|
||||||
|
normalized_shape: Tuple[int, ...]
|
||||||
|
eps: float
|
||||||
|
elementwise_affine: bool
|
||||||
|
|
||||||
|
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
|
||||||
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
|
super().__init__()
|
||||||
|
normalized_shape = channels
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
# mypy error: incompatible types in assignment
|
||||||
|
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||||
|
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||||
|
self.eps = eps
|
||||||
|
self.elementwise_affine = affine
|
||||||
|
if self.elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.register_parameter('weight', None)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
if self.elementwise_affine:
|
||||||
|
nn.init.ones_(self.weight)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleNorm2d(nn.Module):
|
||||||
|
""" SimpleNorm for NCHW tensors
|
||||||
|
"""
|
||||||
|
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
|
||||||
|
normalized_shape: Tuple[int, ...]
|
||||||
|
eps: float
|
||||||
|
elementwise_affine: bool
|
||||||
|
|
||||||
|
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
|
||||||
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
|
super().__init__()
|
||||||
|
normalized_shape = channels
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
# mypy error: incompatible types in assignment
|
||||||
|
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||||
|
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||||
|
self.eps = eps
|
||||||
|
self.elementwise_affine = affine
|
||||||
|
if self.elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.register_parameter('weight', None)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
if self.elementwise_affine:
|
||||||
|
nn.init.ones_(self.weight)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = x.permute(0, 2, 3, 1)
|
||||||
|
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||||
|
x = x.permute(0, 3, 1, 2)
|
||||||
|
return x
|
||||||
|
@ -46,6 +46,7 @@ import torch.nn as nn
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
|
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
|
||||||
LayerNorm2d, LayerNorm, RmsNorm2d, RmsNorm, create_conv2d, get_act_layer, get_norm_layer, make_divisible, to_ntuple
|
LayerNorm2d, LayerNorm, RmsNorm2d, RmsNorm, create_conv2d, get_act_layer, get_norm_layer, make_divisible, to_ntuple
|
||||||
|
from timm.layers import SimpleNorm2d, SimpleNorm
|
||||||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features import feature_take_indices
|
from ._features import feature_take_indices
|
||||||
@ -233,6 +234,34 @@ class ConvNeXtStage(nn.Module):
|
|||||||
x = self.blocks(x)
|
x = self.blocks(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
# map of norm layers with NCHW (2D) and channels last variants
|
||||||
|
_NORM_MAP = {
|
||||||
|
'layernorm': (LayerNorm2d, LayerNorm),
|
||||||
|
'layernorm2d': (LayerNorm2d, LayerNorm),
|
||||||
|
'simplenorm': (SimpleNorm2d, SimpleNorm),
|
||||||
|
'simplenorm2d': (SimpleNorm2d, SimpleNorm),
|
||||||
|
'rmsnorm': (RmsNorm2d, RmsNorm),
|
||||||
|
'rmsnorm2d': (RmsNorm2d, RmsNorm),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_norm_layers(norm_layer: Union[Callable, str], conv_mlp: bool, norm_eps: float):
|
||||||
|
norm_layer = norm_layer or 'layernorm'
|
||||||
|
if norm_layer in _NORM_MAP:
|
||||||
|
norm_layer_cl = _NORM_MAP[norm_layer][0] if conv_mlp else _NORM_MAP[norm_layer][1]
|
||||||
|
norm_layer = _NORM_MAP[norm_layer][0]
|
||||||
|
if norm_eps is not None:
|
||||||
|
norm_layer = partial(norm_layer, eps=norm_eps)
|
||||||
|
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
||||||
|
else:
|
||||||
|
assert conv_mlp, \
|
||||||
|
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
|
||||||
|
norm_layer = get_norm_layer(norm_layer)
|
||||||
|
norm_layer_cl = norm_layer
|
||||||
|
if norm_eps is not None:
|
||||||
|
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
||||||
|
return norm_layer, norm_layer_cl
|
||||||
|
|
||||||
|
|
||||||
class ConvNeXt(nn.Module):
|
class ConvNeXt(nn.Module):
|
||||||
r""" ConvNeXt
|
r""" ConvNeXt
|
||||||
@ -289,20 +318,7 @@ class ConvNeXt(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
assert output_stride in (8, 16, 32)
|
assert output_stride in (8, 16, 32)
|
||||||
kernel_sizes = to_ntuple(4)(kernel_sizes)
|
kernel_sizes = to_ntuple(4)(kernel_sizes)
|
||||||
use_rms = isinstance(norm_layer, str) and norm_layer.startswith('rmsnorm')
|
norm_layer, norm_layer_cl = _get_norm_layers(norm_layer, conv_mlp, norm_eps)
|
||||||
if norm_layer is None or use_rms:
|
|
||||||
norm_layer = RmsNorm2d if use_rms else LayerNorm2d
|
|
||||||
norm_layer_cl = norm_layer if conv_mlp else (RmsNorm if use_rms else LayerNorm)
|
|
||||||
if norm_eps is not None:
|
|
||||||
norm_layer = partial(norm_layer, eps=norm_eps)
|
|
||||||
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
|
||||||
else:
|
|
||||||
assert conv_mlp,\
|
|
||||||
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
|
|
||||||
norm_layer = get_norm_layer(norm_layer)
|
|
||||||
norm_layer_cl = norm_layer
|
|
||||||
if norm_eps is not None:
|
|
||||||
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
|
||||||
act_layer = get_act_layer(act_layer)
|
act_layer = get_act_layer(act_layer)
|
||||||
|
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
@ -975,7 +991,7 @@ default_cfgs = generate_default_cfgs({
|
|||||||
@register_model
|
@register_model
|
||||||
def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt:
|
def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt:
|
||||||
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
||||||
model_args = dict(depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d')
|
model_args = dict(depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='simplenorm')
|
||||||
model = _create_convnext('convnext_zepto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
|
model = _create_convnext('convnext_zepto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -984,7 +1000,7 @@ def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt:
|
|||||||
def convnext_zepto_rms_ols(pretrained=False, **kwargs) -> ConvNeXt:
|
def convnext_zepto_rms_ols(pretrained=False, **kwargs) -> ConvNeXt:
|
||||||
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d', stem_type='overlap_act')
|
depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='simplenorm', stem_type='overlap_act')
|
||||||
model = _create_convnext('convnext_zepto_rms_ols', pretrained=pretrained, **dict(model_args, **kwargs))
|
model = _create_convnext('convnext_zepto_rms_ols', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user