mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add 'fast' layer norm that doesn't cast to float32, support APEX LN impl for slight speed gain, update norm and act factories, tweak SE for ability to disable bias (needed by GCVit)
This commit is contained in:
parent
c486aa71f8
commit
43aa84e861
@ -11,11 +11,13 @@ from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
|
||||
from .create_act import create_act_layer, get_act_layer, get_act_fn
|
||||
from .create_attn import get_attn, create_attn
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_norm import get_norm_layer, create_norm_layer
|
||||
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
|
||||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
|
||||
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
|
||||
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
|
||||
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
|
||||
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
|
||||
from .gather_excite import GatherExcite
|
||||
from .global_context import GlobalContext
|
||||
@ -25,7 +27,7 @@ from .linear import Linear
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
|
||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||
from .norm import GroupNorm, GroupNorm1, LayerNorm2d
|
||||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
|
||||
from .padding import get_padding, get_same_padding, pad_same
|
||||
from .patch_embed import PatchEmbed
|
||||
|
@ -145,4 +145,10 @@ def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
|
||||
act_layer = get_act_layer(name)
|
||||
if act_layer is None:
|
||||
return None
|
||||
return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs)
|
||||
if inplace is None:
|
||||
return act_layer(**kwargs)
|
||||
try:
|
||||
return act_layer(inplace=inplace, **kwargs)
|
||||
except TypeError:
|
||||
# recover if act layer doesn't have inplace arg
|
||||
return act_layer(**kwargs)
|
||||
|
56
timm/models/layers/create_norm.py
Normal file
56
timm/models/layers/create_norm.py
Normal file
@ -0,0 +1,56 @@
|
||||
""" Norm Layer Factory
|
||||
|
||||
Create norm modules by string (to mirror create_act and creat_norm-act fns)
|
||||
|
||||
Copyright 2022 Ross Wightman
|
||||
"""
|
||||
import types
|
||||
import functools
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
|
||||
|
||||
_NORM_MAP = dict(
|
||||
batchnorm=nn.BatchNorm2d,
|
||||
batchnorm2d=nn.BatchNorm2d,
|
||||
batchnorm1d=nn.BatchNorm1d,
|
||||
groupnorm=GroupNorm,
|
||||
groupnorm1=GroupNorm1,
|
||||
layernorm=LayerNorm,
|
||||
layernorm2d=LayerNorm2d,
|
||||
)
|
||||
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
|
||||
|
||||
|
||||
def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs):
|
||||
layer = get_norm_layer(layer_name, act_layer=act_layer)
|
||||
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
|
||||
return layer_instance
|
||||
|
||||
|
||||
def get_norm_layer(norm_layer):
|
||||
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
|
||||
norm_kwargs = {}
|
||||
|
||||
# unbind partial fn, so args can be rebound later
|
||||
if isinstance(norm_layer, functools.partial):
|
||||
norm_kwargs.update(norm_layer.keywords)
|
||||
norm_layer = norm_layer.func
|
||||
|
||||
if isinstance(norm_layer, str):
|
||||
layer_name = norm_layer.replace('_', '')
|
||||
norm_layer = _NORM_MAP.get(layer_name, None)
|
||||
elif norm_layer in _NORM_TYPES:
|
||||
norm_layer = norm_layer
|
||||
elif isinstance(norm_layer, types.FunctionType):
|
||||
# if function type, assume it is a lambda/fn that creates a norm layer
|
||||
norm_layer = norm_layer
|
||||
else:
|
||||
type_name = norm_layer.__name__.lower().replace('_', '')
|
||||
norm_layer = _NORM_MAP.get(type_name, None)
|
||||
assert norm_layer is not None, f"No equivalent norm layer for {type_name}"
|
||||
|
||||
if norm_kwargs:
|
||||
norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args
|
||||
return norm_layer
|
68
timm/models/layers/fast_norm.py
Normal file
68
timm/models/layers/fast_norm.py
Normal file
@ -0,0 +1,68 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import fused_layer_norm_affine
|
||||
has_apex = True
|
||||
except ImportError:
|
||||
has_apex = False
|
||||
|
||||
|
||||
# fast (ie lower precision LN) can be disabled with this flag if issues crop up
|
||||
_USE_FAST_NORM = False # defaulting to False for now
|
||||
|
||||
|
||||
def is_fast_norm():
|
||||
return _USE_FAST_NORM
|
||||
|
||||
|
||||
def set_fast_norm(enable=True):
|
||||
global _USE_FAST_NORM
|
||||
_USE_FAST_NORM = enable
|
||||
|
||||
|
||||
def fast_group_norm(
|
||||
x: torch.Tensor,
|
||||
num_groups: int,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-5
|
||||
) -> torch.Tensor:
|
||||
if torch.jit.is_scripting():
|
||||
# currently cannot use is_autocast_enabled within torchscript
|
||||
return F.group_norm(x, num_groups, weight, bias, eps)
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
# normally native AMP casts GN inputs to float32
|
||||
# here we use the low precision autocast dtype
|
||||
dt = torch.get_autocast_gpu_dtype()
|
||||
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
return F.group_norm(x, num_groups, weight, bias, eps)
|
||||
|
||||
|
||||
def fast_layer_norm(
|
||||
x: torch.Tensor,
|
||||
normalized_shape: List[int],
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-5
|
||||
) -> torch.Tensor:
|
||||
if torch.jit.is_scripting():
|
||||
# currently cannot use is_autocast_enabled within torchscript
|
||||
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
||||
|
||||
if has_apex:
|
||||
return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
# normally native AMP casts LN inputs to float32
|
||||
# apex LN does not, this is behaving like Apex
|
||||
dt = torch.get_autocast_gpu_dtype()
|
||||
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
@ -1,17 +1,24 @@
|
||||
""" Normalization layers and wrappers
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
|
||||
|
||||
|
||||
class GroupNorm(nn.GroupNorm):
|
||||
def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
|
||||
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
|
||||
super().__init__(num_groups, num_channels, eps=eps, affine=affine)
|
||||
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||||
|
||||
def forward(self, x):
|
||||
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
if self.fast_norm:
|
||||
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
class GroupNorm1(nn.GroupNorm):
|
||||
@ -21,22 +28,48 @@ class GroupNorm1(nn.GroupNorm):
|
||||
|
||||
def __init__(self, num_channels, **kwargs):
|
||||
super().__init__(1, num_channels, **kwargs)
|
||||
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.fast_norm:
|
||||
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
""" LayerNorm w/ fast norm option
|
||||
"""
|
||||
def __init__(self, num_channels, eps=1e-6, affine=True):
|
||||
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
|
||||
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self._fast_norm:
|
||||
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm2d(nn.LayerNorm):
|
||||
""" LayerNorm for channels of '2D' spatial NCHW tensors """
|
||||
def __init__(self, num_channels, eps=1e-6, affine=True):
|
||||
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
|
||||
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.layer_norm(
|
||||
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
if self._fast_norm:
|
||||
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
|
||||
def _is_contiguous(tensor: torch.Tensor) -> bool:
|
||||
# jit is oh so lovely :/
|
||||
# if torch.jit.is_tracing():
|
||||
# return True
|
||||
if torch.jit.is_scripting():
|
||||
return tensor.is_contiguous()
|
||||
else:
|
||||
@ -51,6 +84,14 @@ def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, ep
|
||||
return x
|
||||
|
||||
|
||||
def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
|
||||
u = x.mean(dim=1, keepdim=True)
|
||||
s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0)
|
||||
x = (x - u) * torch.rsqrt(s + eps)
|
||||
x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1)
|
||||
return x
|
||||
|
||||
|
||||
class LayerNormExp2d(nn.LayerNorm):
|
||||
""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
|
||||
|
||||
|
@ -6,8 +6,9 @@ import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .trace_utils import _assert
|
||||
from .create_act import get_act_layer
|
||||
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
|
||||
from .trace_utils import _assert
|
||||
|
||||
|
||||
class BatchNormAct2d(nn.BatchNorm2d):
|
||||
@ -177,9 +178,13 @@ class GroupNormAct(nn.GroupNorm):
|
||||
self.act = act_layer(**act_args)
|
||||
else:
|
||||
self.act = nn.Identity()
|
||||
self._fast_norm = is_fast_norm()
|
||||
|
||||
def forward(self, x):
|
||||
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
if self._fast_norm:
|
||||
x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
x = self.drop(x)
|
||||
x = self.act(x)
|
||||
return x
|
||||
@ -197,9 +202,13 @@ class LayerNormAct(nn.LayerNorm):
|
||||
self.act = act_layer(**act_args)
|
||||
else:
|
||||
self.act = nn.Identity()
|
||||
self._fast_norm = is_fast_norm()
|
||||
|
||||
def forward(self, x):
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
if self._fast_norm:
|
||||
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
x = self.drop(x)
|
||||
x = self.act(x)
|
||||
return x
|
||||
@ -219,8 +228,12 @@ class LayerNormAct2d(nn.LayerNorm):
|
||||
self.act = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = F.layer_norm(
|
||||
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
if self._fast_norm:
|
||||
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.drop(x)
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
@ -27,15 +27,15 @@ class SEModule(nn.Module):
|
||||
"""
|
||||
def __init__(
|
||||
self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
|
||||
act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
|
||||
bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
|
||||
super(SEModule, self).__init__()
|
||||
self.add_maxpool = add_maxpool
|
||||
if not rd_channels:
|
||||
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
|
||||
self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True)
|
||||
self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias)
|
||||
self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()
|
||||
self.act = create_act_layer(act_layer, inplace=True)
|
||||
self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True)
|
||||
self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias)
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user