mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add eps arg to LayerNorm2d, add 'tf' (tensorflow) variant of trunc_normal_ that applies scale/shift after sampling (instead of needing to move a/b)
This commit is contained in:
parent
82c311d082
commit
7a9c6811c9
@ -39,4 +39,4 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
from .trace_utils import _assert, _float_to_int
|
||||
from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_
|
||||
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
|
||||
|
@ -16,8 +16,8 @@ class GroupNorm(nn.GroupNorm):
|
||||
|
||||
class LayerNorm2d(nn.LayerNorm):
|
||||
""" LayerNorm for channels of '2D' spatial BCHW tensors """
|
||||
def __init__(self, num_channels):
|
||||
super().__init__(num_channels)
|
||||
def __init__(self, num_channels, eps=1e-6):
|
||||
super().__init__(num_channels, eps=eps)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.layer_norm(
|
||||
|
@ -49,6 +49,11 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
|
||||
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
|
||||
applied while sampling the normal with mean/std applied, therefore a, b args
|
||||
should be adjusted to match the range of mean, std args.
|
||||
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
@ -62,6 +67,35 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
|
||||
def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||
# type: (Tensor, float, float, float, float) -> Tensor
|
||||
r"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution. The values are effectively drawn from the
|
||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
|
||||
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
||||
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
||||
and the result is subsquently scaled and shifted by the mean and std args.
|
||||
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
std: the standard deviation of the normal distribution
|
||||
a: the minimum cutoff value
|
||||
b: the maximum cutoff value
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.trunc_normal_(w)
|
||||
"""
|
||||
_no_grad_trunc_normal_(tensor, 0, 1.0, a, b)
|
||||
with torch.no_grad():
|
||||
tensor.mul_(std).add_(mean)
|
||||
return tensor
|
||||
|
||||
|
||||
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
if mode == 'fan_in':
|
||||
@ -75,7 +109,7 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
||||
|
||||
if distribution == "truncated_normal":
|
||||
# constant is stddev of standard normal truncated to (-2, 2)
|
||||
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
|
||||
trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978)
|
||||
elif distribution == "normal":
|
||||
tensor.normal_(std=math.sqrt(variance))
|
||||
elif distribution == "uniform":
|
||||
|
Loading…
x
Reference in New Issue
Block a user