mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
One codepath for stdconv, switch layernorm to batchnorm so gain included. Tweak epsilon values for nfnet, resnetv2, vit hybrid.
This commit is contained in:
parent
2f5ed2dec1
commit
ba2ca4b464
@ -18,27 +18,20 @@ class StdConv2d(nn.Conv2d):
|
|||||||
https://arxiv.org/abs/1903.10520v2
|
https://arxiv.org/abs/1903.10520v2
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1,
|
self, in_channel, out_channels, kernel_size, stride=1, padding=None,
|
||||||
groups=1, bias=False, eps=1e-5, use_layernorm=True):
|
dilation=1, groups=1, bias=False, eps=1e-6):
|
||||||
if padding is None:
|
if padding is None:
|
||||||
padding = get_padding(kernel_size, stride, dilation)
|
padding = get_padding(kernel_size, stride, dilation)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
in_channel, out_channels, kernel_size, stride=stride,
|
in_channel, out_channels, kernel_size, stride=stride,
|
||||||
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.use_layernorm = use_layernorm
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
if self.use_layernorm:
|
|
||||||
# NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
|
|
||||||
weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
|
|
||||||
else:
|
|
||||||
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
|
||||||
weight = (self.weight - mean) / (std + self.eps)
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
weight = F.batch_norm(
|
||||||
|
self.weight.view(1, self.out_channels, -1), None, None,
|
||||||
|
eps=self.eps, training=True, momentum=0.).reshape_as(self.weight)
|
||||||
|
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -49,29 +42,22 @@ class StdConv2dSame(nn.Conv2d):
|
|||||||
https://arxiv.org/abs/1903.10520v2
|
https://arxiv.org/abs/1903.10520v2
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1,
|
self, in_channel, out_channels, kernel_size, stride=1, padding='SAME',
|
||||||
groups=1, bias=False, eps=1e-5, use_layernorm=True):
|
dilation=1, groups=1, bias=False, eps=1e-6):
|
||||||
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
|
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
|
in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
|
||||||
groups=groups, bias=bias)
|
groups=groups, bias=bias)
|
||||||
self.same_pad = is_dynamic
|
self.same_pad = is_dynamic
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.use_layernorm = use_layernorm
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
if self.use_layernorm:
|
|
||||||
# NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
|
|
||||||
weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
|
|
||||||
else:
|
|
||||||
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
|
||||||
weight = (self.weight - mean) / (std + self.eps)
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.same_pad:
|
if self.same_pad:
|
||||||
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
|
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
|
||||||
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
weight = F.batch_norm(
|
||||||
|
self.weight.view(1, self.out_channels, -1), None, None,
|
||||||
|
eps=self.eps, training=True, momentum=0.).reshape_as(self.weight)
|
||||||
|
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -85,8 +71,8 @@ class ScaledStdConv2d(nn.Conv2d):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
|
self, in_channels, out_channels, kernel_size, stride=1, padding=None,
|
||||||
bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=True):
|
dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0):
|
||||||
if padding is None:
|
if padding is None:
|
||||||
padding = get_padding(kernel_size, stride, dilation)
|
padding = get_padding(kernel_size, stride, dilation)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -95,19 +81,13 @@ class ScaledStdConv2d(nn.Conv2d):
|
|||||||
self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
|
self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
|
||||||
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
|
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
if self.use_layernorm:
|
|
||||||
# NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
|
|
||||||
weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
|
|
||||||
else:
|
|
||||||
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
|
||||||
weight = (self.weight - mean) / (std + self.eps)
|
|
||||||
return weight.mul_(self.gain * self.scale)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
weight = F.batch_norm(
|
||||||
|
self.weight.view(1, self.out_channels, -1), None, None,
|
||||||
|
weight=(self.gain * self.scale).view(-1),
|
||||||
|
eps=self.eps, training=True, momentum=0.).reshape_as(self.weight)
|
||||||
|
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
|
|
||||||
class ScaledStdConv2dSame(nn.Conv2d):
|
class ScaledStdConv2dSame(nn.Conv2d):
|
||||||
@ -120,8 +100,8 @@ class ScaledStdConv2dSame(nn.Conv2d):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1,
|
self, in_channels, out_channels, kernel_size, stride=1, padding='SAME',
|
||||||
bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=True):
|
dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0):
|
||||||
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
|
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
|
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
|
||||||
@ -130,18 +110,12 @@ class ScaledStdConv2dSame(nn.Conv2d):
|
|||||||
self.scale = gamma * self.weight[0].numel() ** -0.5
|
self.scale = gamma * self.weight[0].numel() ** -0.5
|
||||||
self.same_pad = is_dynamic
|
self.same_pad = is_dynamic
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
if self.use_layernorm:
|
|
||||||
# NOTE F.layer_norm is being used to compute (self.weight - mean) / (sqrt(var) + self.eps) in one op
|
|
||||||
weight = F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
|
|
||||||
else:
|
|
||||||
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
|
||||||
weight = (self.weight - mean) / (std + self.eps)
|
|
||||||
return weight.mul_(self.gain * self.scale)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.same_pad:
|
if self.same_pad:
|
||||||
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
|
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
|
||||||
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
weight = F.batch_norm(
|
||||||
|
self.weight.view(1, self.out_channels, -1), None, None,
|
||||||
|
weight=(self.gain * self.scale).view(-1),
|
||||||
|
eps=self.eps, training=True, momentum=0.).reshape_as(self.weight)
|
||||||
|
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
@ -167,7 +167,6 @@ class NfCfg:
|
|||||||
gamma_in_act: bool = False
|
gamma_in_act: bool = False
|
||||||
same_padding: bool = False
|
same_padding: bool = False
|
||||||
std_conv_eps: float = 1e-5
|
std_conv_eps: float = 1e-5
|
||||||
std_conv_ln: bool = True # use layer-norm impl to normalize in std-conv, works in PyTorch XLA, slightly faster
|
|
||||||
skipinit: bool = False # disabled by default, non-trivial performance impact
|
skipinit: bool = False # disabled by default, non-trivial performance impact
|
||||||
zero_init_fc: bool = False
|
zero_init_fc: bool = False
|
||||||
act_layer: str = 'silu'
|
act_layer: str = 'silu'
|
||||||
@ -484,11 +483,10 @@ class NormFreeNet(nn.Module):
|
|||||||
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
|
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
|
||||||
if cfg.gamma_in_act:
|
if cfg.gamma_in_act:
|
||||||
act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer])
|
act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer])
|
||||||
conv_layer = partial(conv_layer, eps=cfg.std_conv_eps, use_layernorm=cfg.std_conv_ln)
|
conv_layer = partial(conv_layer, eps=cfg.std_conv_eps)
|
||||||
else:
|
else:
|
||||||
act_layer = get_act_layer(cfg.act_layer)
|
act_layer = get_act_layer(cfg.act_layer)
|
||||||
conv_layer = partial(
|
conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps)
|
||||||
conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps, use_layernorm=cfg.std_conv_ln)
|
|
||||||
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
||||||
|
|
||||||
stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)
|
stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)
|
||||||
|
@ -276,7 +276,7 @@ class ResNetStage(nn.Module):
|
|||||||
|
|
||||||
def create_resnetv2_stem(
|
def create_resnetv2_stem(
|
||||||
in_chs, out_chs=64, stem_type='', preact=True,
|
in_chs, out_chs=64, stem_type='', preact=True,
|
||||||
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
|
conv_layer=partial(StdConv2d, eps=1e-8), norm_layer=partial(GroupNormAct, num_groups=32)):
|
||||||
stem = OrderedDict()
|
stem = OrderedDict()
|
||||||
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
|
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
|
||||||
|
|
||||||
@ -315,8 +315,8 @@ class ResNetV2(nn.Module):
|
|||||||
def __init__(self, layers, channels=(256, 512, 1024, 2048),
|
def __init__(self, layers, channels=(256, 512, 1024, 2048),
|
||||||
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
||||||
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
|
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
|
||||||
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
|
act_layer=nn.ReLU, conv_layer=partial(StdConv2d, eps=1e-8),
|
||||||
drop_rate=0., drop_path_rate=0.):
|
norm_layer=partial(GroupNormAct, num_groups=32), drop_rate=0., drop_path_rate=0.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
|
@ -116,12 +116,8 @@ def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwa
|
|||||||
def _resnetv2(layers=(3, 4, 9), **kwargs):
|
def _resnetv2(layers=(3, 4, 9), **kwargs):
|
||||||
""" ResNet-V2 backbone helper"""
|
""" ResNet-V2 backbone helper"""
|
||||||
padding_same = kwargs.get('padding_same', True)
|
padding_same = kwargs.get('padding_same', True)
|
||||||
if padding_same:
|
stem_type = 'same' if padding_same else ''
|
||||||
stem_type = 'same'
|
conv_layer = partial(StdConv2dSame, eps=1e-8) if padding_same else partial(StdConv2d, eps=1e-8)
|
||||||
conv_layer = partial(StdConv2dSame, eps=1e-5)
|
|
||||||
else:
|
|
||||||
stem_type = ''
|
|
||||||
conv_layer = StdConv2d
|
|
||||||
if len(layers):
|
if len(layers):
|
||||||
backbone = ResNetV2(
|
backbone = ResNetV2(
|
||||||
layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user