Use torch F.rms_norm when possible, select fast vs normal paths appropriately and test with torchscript

This commit is contained in:
Ross Wightman 2024-12-29 14:05:07 -08:00
parent e0cacbfd15
commit 5809c2fe5e
2 changed files with 64 additions and 18 deletions

View File

@ -24,6 +24,8 @@ except ImportError:
has_apex_rmsnorm = False
has_torch_rms_norm = hasattr(F, 'rms_norm')
# fast (ie lower precision LN) can be disabled with this flag if issues crop up
_USE_FAST_NORM = False # defaulting to False for now
@ -75,7 +77,6 @@ def fast_group_norm(
if is_autocast_enabled(x.device.type):
# normally native AMP casts GN inputs to float32
# here we use the low precision autocast dtype
# FIXME what to do re CPU autocast?
dt = get_autocast_dtype(x.device.type)
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
@ -101,14 +102,12 @@ def fast_layer_norm(
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
# FIXME what to do re CPU autocast?
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
with torch.amp.autocast(device_type=x.device.type, enabled=False):
return F.layer_norm(x, normalized_shape, weight, bias, eps)
def rms_norm(
x: torch.Tensor,
normalized_shape: List[int],
@ -148,8 +147,19 @@ def fast_rms_norm(
else:
return fused_rms_norm_affine(x, weight, normalized_shape, eps)
# fallback
return rms_norm(x, normalized_shape, weight, eps)
if is_autocast_enabled(x.device.type):
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
x, weight = x.to(dt), weight.to(dt)
with torch.amp.autocast(device_type=x.device.type, enabled=False):
if has_torch_rms_norm:
x = F.rms_norm(x, normalized_shape, weight, eps)
else:
x = rms_norm(x, normalized_shape, weight, eps)
return x
def simple_norm(

View File

@ -11,17 +11,24 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm, simple_norm
try:
from torch.nn.functional import rms_norm
except ImportError:
from .fast_norm import rms_norm
class GroupNorm(nn.GroupNorm):
_fast_norm: torch.jit.Final[bool]
def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
super().__init__(num_groups, num_channels, eps=eps, affine=affine)
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
def forward(self, x):
if self.fast_norm:
if self._fast_norm:
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
@ -31,13 +38,14 @@ class GroupNorm1(nn.GroupNorm):
""" Group Normalization with 1 group.
Input: tensor in shape [B, C, *]
"""
_fast_norm: torch.jit.Final[bool]
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.fast_norm:
if self._fast_norm:
return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
@ -46,6 +54,8 @@ class GroupNorm1(nn.GroupNorm):
class LayerNorm(nn.LayerNorm):
""" LayerNorm w/ fast norm option
"""
_fast_norm: torch.jit.Final[bool]
def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
@ -60,6 +70,8 @@ class LayerNorm(nn.LayerNorm):
class LayerNorm2d(nn.LayerNorm):
""" LayerNorm for channels of '2D' spatial NCHW tensors """
_fast_norm: torch.jit.Final[bool]
def __init__(self, num_channels, eps=1e-6, affine=True):
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
@ -121,10 +133,11 @@ class LayerNormExp2d(nn.LayerNorm):
class RmsNorm(nn.Module):
""" RmsNorm w/ fast (apex) norm if available
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
@ -136,6 +149,8 @@ class RmsNorm(nn.Module):
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
@ -150,17 +165,21 @@ class RmsNorm(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
if self._fast_norm:
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
return x
class RmsNorm2d(nn.Module):
""" RmsNorm w/ fast (apex) norm if available
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
@ -172,6 +191,8 @@ class RmsNorm2d(nn.Module):
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
@ -187,7 +208,10 @@ class RmsNorm2d(nn.Module):
x = x.permute(0, 2, 3, 1)
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
if self._fast_norm:
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
x = x.permute(0, 3, 1, 2)
return x
@ -195,10 +219,11 @@ class RmsNorm2d(nn.Module):
class SimpleNorm(nn.Module):
""" SimpleNorm (x / std(x))
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
@ -210,6 +235,8 @@ class SimpleNorm(nn.Module):
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
@ -222,17 +249,21 @@ class SimpleNorm(nn.Module):
nn.init.ones_(self.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
if self._fast_norm:
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
return x
class SimpleNorm2d(nn.Module):
""" SimpleNorm for NCHW tensors
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
_fast_norm: bool
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
@ -244,6 +275,8 @@ class SimpleNorm2d(nn.Module):
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
@ -257,6 +290,9 @@ class SimpleNorm2d(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
if self._fast_norm:
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
else:
x = simple_norm(x, self.normalized_shape, self.weight, self.eps)
x = x.permute(0, 3, 1, 2)
return x