mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add norm_act factory method, move JIT of norm layers to factory
This commit is contained in:
parent
14edacdf9a
commit
780860d140
@ -4,6 +4,7 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool.
|
|||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -13,7 +14,7 @@ from torch.jit.annotations import List
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import SelectAdaptivePool2d, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d
|
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
|
|
||||||
__all__ = ['DenseNet']
|
__all__ = ['DenseNet']
|
||||||
@ -327,9 +328,11 @@ def densenet121d_evob(pretrained=False, **kwargs):
|
|||||||
r"""Densenet-121 model from
|
r"""Densenet-121 model from
|
||||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||||
"""
|
"""
|
||||||
|
def norm_act_fn(num_features, **kwargs):
|
||||||
|
return create_norm_act('EvoNormBatch', num_features, jit=True, **kwargs)
|
||||||
model = _densenet(
|
model = _densenet(
|
||||||
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
|
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
|
||||||
norm_act_layer=EvoNormBatch2d, pretrained=pretrained, **kwargs)
|
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -338,9 +341,11 @@ def densenet121d_evos(pretrained=False, **kwargs):
|
|||||||
r"""Densenet-121 model from
|
r"""Densenet-121 model from
|
||||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||||
"""
|
"""
|
||||||
|
def norm_act_fn(num_features, **kwargs):
|
||||||
|
return create_norm_act('EvoNormSample', num_features, jit=True, **kwargs)
|
||||||
model = _densenet(
|
model = _densenet(
|
||||||
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
|
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
|
||||||
norm_act_layer=EvoNormSample2d, pretrained=pretrained, **kwargs)
|
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -349,10 +354,11 @@ def densenet121d_iabn(pretrained=False, **kwargs):
|
|||||||
r"""Densenet-121 model from
|
r"""Densenet-121 model from
|
||||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||||
"""
|
"""
|
||||||
from inplace_abn import InPlaceABN
|
def norm_act_fn(num_features, **kwargs):
|
||||||
|
return create_norm_act('iabn', num_features, **kwargs)
|
||||||
model = _densenet(
|
model = _densenet(
|
||||||
'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
|
'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
|
||||||
norm_act_layer=InPlaceABN, pretrained=pretrained, **kwargs)
|
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,4 +20,5 @@ from .anti_aliasing import AntiAliasDownsampleLayer
|
|||||||
from .space_to_depth import SpaceToDepthModule
|
from .space_to_depth import SpaceToDepthModule
|
||||||
from .blur_pool import BlurPool2d
|
from .blur_pool import BlurPool2d
|
||||||
from .norm_act import BatchNormAct2d
|
from .norm_act import BatchNormAct2d
|
||||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||||
|
from .create_norm_act import create_norm_act
|
||||||
|
37
timm/models/layers/create_norm_act.py
Normal file
37
timm/models/layers/create_norm_act.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||||
|
from .norm_act import BatchNormAct2d
|
||||||
|
try:
|
||||||
|
from inplace_abn import InPlaceABN
|
||||||
|
has_iabn = True
|
||||||
|
except ImportError:
|
||||||
|
has_iabn = False
|
||||||
|
|
||||||
|
|
||||||
|
def create_norm_act(layer_type, num_features, jit=False, **kwargs):
|
||||||
|
layer_parts = layer_type.split('_')
|
||||||
|
assert len(layer_parts) in (1, 2)
|
||||||
|
layer_class = layer_parts[0].lower()
|
||||||
|
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection
|
||||||
|
|
||||||
|
if layer_class == "batchnormact":
|
||||||
|
layer = BatchNormAct2d(num_features, **kwargs) # defaults to RELU of no kwargs override
|
||||||
|
elif layer_class == "batchnormrelu":
|
||||||
|
assert 'act_layer' not in kwargs
|
||||||
|
layer = BatchNormAct2d(num_features, act_layer=nn.ReLU, **kwargs)
|
||||||
|
elif layer_class == "evonormbatch":
|
||||||
|
layer = EvoNormBatch2d(num_features, **kwargs)
|
||||||
|
elif layer_class == "evonormsample":
|
||||||
|
layer = EvoNormSample2d(num_features, **kwargs)
|
||||||
|
elif layer_class == "iabn" or layer_class == "inplaceabn":
|
||||||
|
if not has_iabn:
|
||||||
|
raise ImportError(
|
||||||
|
"Pplease install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
|
||||||
|
layer = InPlaceABN(num_features, **kwargs)
|
||||||
|
else:
|
||||||
|
assert False, "Invalid norm_act layer (%s)" % layer_class
|
||||||
|
if jit:
|
||||||
|
layer = torch.jit.script(layer)
|
||||||
|
return layer
|
@ -13,35 +13,12 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def evo_batch_jit(
|
|
||||||
x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, running_var: torch.Tensor,
|
|
||||||
momentum: float, training: bool, nonlin: bool, eps: float):
|
|
||||||
x_type = x.dtype
|
|
||||||
running_var = running_var.detach() # FIXME why is this needed, it's a buffer?
|
|
||||||
if training:
|
|
||||||
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) # FIXME biased, unbiased?
|
|
||||||
running_var.copy_(momentum * var + (1 - momentum) * running_var)
|
|
||||||
else:
|
|
||||||
var = running_var.clone()
|
|
||||||
|
|
||||||
if nonlin:
|
|
||||||
# FIXME biased, unbiased?
|
|
||||||
d = (x * v.to(x_type)) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(eps).sqrt_().to(dtype=x_type)
|
|
||||||
d = d.max(var.add(eps).sqrt_().to(dtype=x_type))
|
|
||||||
x = x / d
|
|
||||||
return x.mul_(weight).add_(bias)
|
|
||||||
else:
|
|
||||||
return x.mul(weight).add_(bias)
|
|
||||||
|
|
||||||
|
|
||||||
class EvoNormBatch2d(nn.Module):
|
class EvoNormBatch2d(nn.Module):
|
||||||
def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5, jit=True):
|
def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5):
|
||||||
super(EvoNormBatch2d, self).__init__()
|
super(EvoNormBatch2d, self).__init__()
|
||||||
self.momentum = momentum
|
self.momentum = momentum
|
||||||
self.nonlin = nonlin
|
self.nonlin = nonlin
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.jit = jit
|
|
||||||
param_shape = (1, num_features, 1, 1)
|
param_shape = (1, num_features, 1, 1)
|
||||||
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||||
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
|
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
|
||||||
@ -58,50 +35,29 @@ class EvoNormBatch2d(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
assert x.dim() == 4, 'expected 4D input'
|
assert x.dim() == 4, 'expected 4D input'
|
||||||
|
x_type = x.dtype
|
||||||
if self.jit:
|
if self.training:
|
||||||
return evo_batch_jit(
|
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
|
||||||
x, self.v, self.weight, self.bias, self.running_var, self.momentum,
|
self.running_var.copy_(self.momentum * var.detach() + (1 - self.momentum) * self.running_var)
|
||||||
self.training, self.nonlin, self.eps)
|
|
||||||
else:
|
else:
|
||||||
x_type = x.dtype
|
var = self.running_var.clone()
|
||||||
if self.training:
|
|
||||||
var = x.var(dim=(0, 2, 3), keepdim=True)
|
|
||||||
self.running_var.copy_(self.momentum * var + (1 - self.momentum) * self.running_var)
|
|
||||||
else:
|
|
||||||
var = self.running_var.clone()
|
|
||||||
|
|
||||||
if self.nonlin:
|
if self.nonlin:
|
||||||
v = self.v.to(dtype=x_type)
|
v = self.v.to(dtype=x_type)
|
||||||
d = (x * v) + x.var(dim=(2, 3), keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type)
|
d = (x * v) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type)
|
||||||
d = d.max(var.add(self.eps).sqrt_().to(dtype=x_type))
|
d = d.max(var.add_(self.eps).sqrt_().to(dtype=x_type))
|
||||||
x = x / d
|
x = x / d
|
||||||
return x.mul_(self.weight).add_(self.bias)
|
return x.mul_(self.weight).add_(self.bias)
|
||||||
else:
|
else:
|
||||||
return x.mul(self.weight).add_(self.bias)
|
return x.mul(self.weight).add_(self.bias)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def evo_sample_jit(
|
|
||||||
x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
|
|
||||||
groups: int, nonlin: bool, eps: float):
|
|
||||||
B, C, H, W = x.shape
|
|
||||||
assert C % groups == 0
|
|
||||||
if nonlin:
|
|
||||||
n = (x * v).sigmoid_().reshape(B, groups, -1)
|
|
||||||
x = x.reshape(B, groups, -1)
|
|
||||||
x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(eps).sqrt_()
|
|
||||||
x = x.reshape(B, C, H, W)
|
|
||||||
return x.mul_(weight).add_(bias)
|
|
||||||
|
|
||||||
|
|
||||||
class EvoNormSample2d(nn.Module):
|
class EvoNormSample2d(nn.Module):
|
||||||
def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5, jit=True):
|
def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5):
|
||||||
super(EvoNormSample2d, self).__init__()
|
super(EvoNormSample2d, self).__init__()
|
||||||
self.nonlin = nonlin
|
self.nonlin = nonlin
|
||||||
self.groups = groups
|
self.groups = groups
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.jit = jit
|
|
||||||
param_shape = (1, num_features, 1, 1)
|
param_shape = (1, num_features, 1, 1)
|
||||||
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||||
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
|
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
|
||||||
@ -117,18 +73,13 @@ class EvoNormSample2d(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
assert x.dim() == 4, 'expected 4D input'
|
assert x.dim() == 4, 'expected 4D input'
|
||||||
|
B, C, H, W = x.shape
|
||||||
if self.jit:
|
assert C % self.groups == 0
|
||||||
return evo_sample_jit(
|
if self.nonlin:
|
||||||
x, self.v, self.weight, self.bias, self.groups, self.nonlin, self.eps)
|
n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
|
||||||
|
x = x.reshape(B, self.groups, -1)
|
||||||
|
x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(self.eps).sqrt_()
|
||||||
|
x = x.reshape(B, C, H, W)
|
||||||
|
return x.mul_(self.weight).add_(self.bias)
|
||||||
else:
|
else:
|
||||||
B, C, H, W = x.shape
|
return x.mul(self.weight).add_(self.bias)
|
||||||
assert C % self.groups == 0
|
|
||||||
if self.nonlin:
|
|
||||||
n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
|
|
||||||
x = x.reshape(B, self.groups, -1)
|
|
||||||
x = n / (x.std(dim=-1, unbiased=False, keepdim=True) + self.eps)
|
|
||||||
x = x.reshape(B, C, H, W)
|
|
||||||
return x.mul_(self.weight).add_(self.bias)
|
|
||||||
else:
|
|
||||||
return x.mul(self.weight).add_(self.bias)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user