mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Use torch F.rms_norm when possible, select fast vs normal paths appropriately and test with torchscript
This commit is contained in:
parent
e0cacbfd15
commit
5809c2fe5e
@ -24,6 +24,8 @@ except ImportError:
|
||||
has_apex_rmsnorm = False
|
||||
|
||||
|
||||
has_torch_rms_norm = hasattr(F, 'rms_norm')
|
||||
|
||||
# fast (ie lower precision LN) can be disabled with this flag if issues crop up
|
||||
_USE_FAST_NORM = False # defaulting to False for now
|
||||
|
||||
@ -75,7 +77,6 @@ def fast_group_norm(
|
||||
if is_autocast_enabled(x.device.type):
|
||||
# normally native AMP casts GN inputs to float32
|
||||
# here we use the low precision autocast dtype
|
||||
# FIXME what to do re CPU autocast?
|
||||
dt = get_autocast_dtype(x.device.type)
|
||||
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
|
||||
|
||||
@ -101,14 +102,12 @@ def fast_layer_norm(
|
||||
# normally native AMP casts LN inputs to float32
|
||||
# apex LN does not, this is behaving like Apex
|
||||
dt = get_autocast_dtype(x.device.type)
|
||||
# FIXME what to do re CPU autocast?
|
||||
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
|
||||
|
||||
with torch.amp.autocast(device_type=x.device.type, enabled=False):
|
||||
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
||||
|
||||
|
||||
|
||||
def rms_norm(
|
||||
x: torch.Tensor,
|
||||
normalized_shape: List[int],
|
||||
@ -148,8 +147,19 @@ def fast_rms_norm(
|
||||
else:
|
||||
return fused_rms_norm_affine(x, weight, normalized_shape, eps)
|
||||
|
||||
# fallback
|
||||
return rms_norm(x, normalized_shape, weight, eps)
|
||||
if is_autocast_enabled(x.device.type):
|
||||
# normally native AMP casts LN inputs to float32
|
||||
# apex LN does not, this is behaving like Apex
|
||||
dt = get_autocast_dtype(x.device.type)
|
||||
x, weight = x.to(dt), weight.to(dt)
|
||||
|
||||
with torch.amp.autocast(device_type=x.device.type, enabled=False):
|
||||
if has_torch_rms_norm:
|
||||
x = F.rms_norm(x, normalized_shape, weight, eps)
|
||||
else:
|
||||
x = rms_norm(x, normalized_shape, weight, eps)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def simple_norm(
|
||||
|
@ -11,17 +11,24 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm
|
||||
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm, simple_norm
|
||||
|
||||
try:
|
||||
from torch.nn.functional import rms_norm
|
||||
except ImportError:
|
||||
from .fast_norm import rms_norm
|
||||
|
||||
|
||||
class GroupNorm(nn.GroupNorm):
|
||||
_fast_norm: torch.jit.Final[bool]
|
||||
|
||||
def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
|
||||
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
|
||||
super().__init__(num_groups, num_channels, eps=eps, affine=affine)
|
||||
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||||
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||||
|
||||
def forward(self, x):
|
||||
if self.fast_norm:
|
||||
if self._fast_norm:
|
||||
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
@ -31,13 +38,14 @@ class GroupNorm1(nn.GroupNorm):
|
||||
""" Group Normalization with 1 group.
|
||||
Input: tensor in shape [B, C, *]
|
||||
"""
|
||||
_fast_norm: torch.jit.Final[bool]
|
||||
|
||||
def __init__(self, num_channels, **kwargs):
|
||||
super().__init__(1, num_channels, **kwargs)
|
||||
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||||
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.fast_norm:
|
||||
if self._fast_norm:
|
||||
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
@ -46,6 +54,8 @@ class GroupNorm1(nn.GroupNorm):
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
""" LayerNorm w/ fast norm option
|
||||
"""
|
||||
_fast_norm: torch.jit.Final[bool]
|
||||
|
||||
def __init__(self, num_channels, eps=1e-6, affine=True):
|
||||
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
|
||||
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||||
@ -60,6 +70,8 @@ class LayerNorm(nn.LayerNorm):
|
||||
|
||||
class LayerNorm2d(nn.LayerNorm):
|
||||
""" LayerNorm for channels of '2D' spatial NCHW tensors """
|
||||
_fast_norm: torch.jit.Final[bool]
|
||||
|
||||
def __init__(self, num_channels, eps=1e-6, affine=True):
|
||||
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
|
||||
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||||
@ -121,10 +133,11 @@ class LayerNormExp2d(nn.LayerNorm):
|
||||
class RmsNorm(nn.Module):
|
||||
""" RmsNorm w/ fast (apex) norm if available
|
||||
"""
|
||||
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
|
||||
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
|
||||
normalized_shape: Tuple[int, ...]
|
||||
eps: float
|
||||
elementwise_affine: bool
|
||||
_fast_norm: bool
|
||||
|
||||
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
@ -136,6 +149,8 @@ class RmsNorm(nn.Module):
|
||||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||
self.eps = eps
|
||||
self.elementwise_affine = affine
|
||||
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||||
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
else:
|
||||
@ -150,17 +165,21 @@ class RmsNorm(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
|
||||
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
|
||||
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
if self._fast_norm:
|
||||
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
else:
|
||||
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
return x
|
||||
|
||||
|
||||
class RmsNorm2d(nn.Module):
|
||||
""" RmsNorm w/ fast (apex) norm if available
|
||||
"""
|
||||
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
|
||||
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
|
||||
normalized_shape: Tuple[int, ...]
|
||||
eps: float
|
||||
elementwise_affine: bool
|
||||
_fast_norm: bool
|
||||
|
||||
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
@ -172,6 +191,8 @@ class RmsNorm2d(nn.Module):
|
||||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||
self.eps = eps
|
||||
self.elementwise_affine = affine
|
||||
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||||
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
else:
|
||||
@ -187,7 +208,10 @@ class RmsNorm2d(nn.Module):
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
|
||||
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
|
||||
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
if self._fast_norm:
|
||||
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
else:
|
||||
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
@ -195,10 +219,11 @@ class RmsNorm2d(nn.Module):
|
||||
class SimpleNorm(nn.Module):
|
||||
""" SimpleNorm (x / std(x))
|
||||
"""
|
||||
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
|
||||
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
|
||||
normalized_shape: Tuple[int, ...]
|
||||
eps: float
|
||||
elementwise_affine: bool
|
||||
_fast_norm: bool
|
||||
|
||||
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
@ -210,6 +235,8 @@ class SimpleNorm(nn.Module):
|
||||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||
self.eps = eps
|
||||
self.elementwise_affine = affine
|
||||
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||||
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
else:
|
||||
@ -222,17 +249,21 @@ class SimpleNorm(nn.Module):
|
||||
nn.init.ones_(self.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
if self._fast_norm:
|
||||
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
else:
|
||||
x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
return x
|
||||
|
||||
|
||||
class SimpleNorm2d(nn.Module):
|
||||
""" SimpleNorm for NCHW tensors
|
||||
"""
|
||||
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
|
||||
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
|
||||
normalized_shape: Tuple[int, ...]
|
||||
eps: float
|
||||
elementwise_affine: bool
|
||||
_fast_norm: bool
|
||||
|
||||
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
@ -244,6 +275,8 @@ class SimpleNorm2d(nn.Module):
|
||||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||
self.eps = eps
|
||||
self.elementwise_affine = affine
|
||||
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
|
||||
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
else:
|
||||
@ -257,6 +290,9 @@ class SimpleNorm2d(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
if self._fast_norm:
|
||||
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
else:
|
||||
x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user