mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add norm_act factory method, move JIT of norm layers to factory
This commit is contained in:
parent
14edacdf9a
commit
780860d140
@ -4,6 +4,7 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool.
|
||||
"""
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -13,7 +14,7 @@ from torch.jit.annotations import List
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectAdaptivePool2d, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d
|
||||
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['DenseNet']
|
||||
@ -327,9 +328,11 @@ def densenet121d_evob(pretrained=False, **kwargs):
|
||||
r"""Densenet-121 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
def norm_act_fn(num_features, **kwargs):
|
||||
return create_norm_act('EvoNormBatch', num_features, jit=True, **kwargs)
|
||||
model = _densenet(
|
||||
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
|
||||
norm_act_layer=EvoNormBatch2d, pretrained=pretrained, **kwargs)
|
||||
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@ -338,9 +341,11 @@ def densenet121d_evos(pretrained=False, **kwargs):
|
||||
r"""Densenet-121 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
def norm_act_fn(num_features, **kwargs):
|
||||
return create_norm_act('EvoNormSample', num_features, jit=True, **kwargs)
|
||||
model = _densenet(
|
||||
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
|
||||
norm_act_layer=EvoNormSample2d, pretrained=pretrained, **kwargs)
|
||||
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@ -349,10 +354,11 @@ def densenet121d_iabn(pretrained=False, **kwargs):
|
||||
r"""Densenet-121 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
from inplace_abn import InPlaceABN
|
||||
def norm_act_fn(num_features, **kwargs):
|
||||
return create_norm_act('iabn', num_features, **kwargs)
|
||||
model = _densenet(
|
||||
'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
|
||||
norm_act_layer=InPlaceABN, pretrained=pretrained, **kwargs)
|
||||
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -20,4 +20,5 @@ from .anti_aliasing import AntiAliasDownsampleLayer
|
||||
from .space_to_depth import SpaceToDepthModule
|
||||
from .blur_pool import BlurPool2d
|
||||
from .norm_act import BatchNormAct2d
|
||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||
from .create_norm_act import create_norm_act
|
||||
|
37
timm/models/layers/create_norm_act.py
Normal file
37
timm/models/layers/create_norm_act.py
Normal file
@ -0,0 +1,37 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||
from .norm_act import BatchNormAct2d
|
||||
try:
|
||||
from inplace_abn import InPlaceABN
|
||||
has_iabn = True
|
||||
except ImportError:
|
||||
has_iabn = False
|
||||
|
||||
|
||||
def create_norm_act(layer_type, num_features, jit=False, **kwargs):
|
||||
layer_parts = layer_type.split('_')
|
||||
assert len(layer_parts) in (1, 2)
|
||||
layer_class = layer_parts[0].lower()
|
||||
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection
|
||||
|
||||
if layer_class == "batchnormact":
|
||||
layer = BatchNormAct2d(num_features, **kwargs) # defaults to RELU of no kwargs override
|
||||
elif layer_class == "batchnormrelu":
|
||||
assert 'act_layer' not in kwargs
|
||||
layer = BatchNormAct2d(num_features, act_layer=nn.ReLU, **kwargs)
|
||||
elif layer_class == "evonormbatch":
|
||||
layer = EvoNormBatch2d(num_features, **kwargs)
|
||||
elif layer_class == "evonormsample":
|
||||
layer = EvoNormSample2d(num_features, **kwargs)
|
||||
elif layer_class == "iabn" or layer_class == "inplaceabn":
|
||||
if not has_iabn:
|
||||
raise ImportError(
|
||||
"Pplease install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
|
||||
layer = InPlaceABN(num_features, **kwargs)
|
||||
else:
|
||||
assert False, "Invalid norm_act layer (%s)" % layer_class
|
||||
if jit:
|
||||
layer = torch.jit.script(layer)
|
||||
return layer
|
@ -13,35 +13,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def evo_batch_jit(
|
||||
x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, running_var: torch.Tensor,
|
||||
momentum: float, training: bool, nonlin: bool, eps: float):
|
||||
x_type = x.dtype
|
||||
running_var = running_var.detach() # FIXME why is this needed, it's a buffer?
|
||||
if training:
|
||||
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) # FIXME biased, unbiased?
|
||||
running_var.copy_(momentum * var + (1 - momentum) * running_var)
|
||||
else:
|
||||
var = running_var.clone()
|
||||
|
||||
if nonlin:
|
||||
# FIXME biased, unbiased?
|
||||
d = (x * v.to(x_type)) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(eps).sqrt_().to(dtype=x_type)
|
||||
d = d.max(var.add(eps).sqrt_().to(dtype=x_type))
|
||||
x = x / d
|
||||
return x.mul_(weight).add_(bias)
|
||||
else:
|
||||
return x.mul(weight).add_(bias)
|
||||
|
||||
|
||||
class EvoNormBatch2d(nn.Module):
|
||||
def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5, jit=True):
|
||||
def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5):
|
||||
super(EvoNormBatch2d, self).__init__()
|
||||
self.momentum = momentum
|
||||
self.nonlin = nonlin
|
||||
self.eps = eps
|
||||
self.jit = jit
|
||||
param_shape = (1, num_features, 1, 1)
|
||||
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
|
||||
@ -58,50 +35,29 @@ class EvoNormBatch2d(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
assert x.dim() == 4, 'expected 4D input'
|
||||
|
||||
if self.jit:
|
||||
return evo_batch_jit(
|
||||
x, self.v, self.weight, self.bias, self.running_var, self.momentum,
|
||||
self.training, self.nonlin, self.eps)
|
||||
x_type = x.dtype
|
||||
if self.training:
|
||||
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
|
||||
self.running_var.copy_(self.momentum * var.detach() + (1 - self.momentum) * self.running_var)
|
||||
else:
|
||||
x_type = x.dtype
|
||||
if self.training:
|
||||
var = x.var(dim=(0, 2, 3), keepdim=True)
|
||||
self.running_var.copy_(self.momentum * var + (1 - self.momentum) * self.running_var)
|
||||
else:
|
||||
var = self.running_var.clone()
|
||||
var = self.running_var.clone()
|
||||
|
||||
if self.nonlin:
|
||||
v = self.v.to(dtype=x_type)
|
||||
d = (x * v) + x.var(dim=(2, 3), keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type)
|
||||
d = d.max(var.add(self.eps).sqrt_().to(dtype=x_type))
|
||||
x = x / d
|
||||
return x.mul_(self.weight).add_(self.bias)
|
||||
else:
|
||||
return x.mul(self.weight).add_(self.bias)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def evo_sample_jit(
|
||||
x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
|
||||
groups: int, nonlin: bool, eps: float):
|
||||
B, C, H, W = x.shape
|
||||
assert C % groups == 0
|
||||
if nonlin:
|
||||
n = (x * v).sigmoid_().reshape(B, groups, -1)
|
||||
x = x.reshape(B, groups, -1)
|
||||
x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(eps).sqrt_()
|
||||
x = x.reshape(B, C, H, W)
|
||||
return x.mul_(weight).add_(bias)
|
||||
if self.nonlin:
|
||||
v = self.v.to(dtype=x_type)
|
||||
d = (x * v) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type)
|
||||
d = d.max(var.add_(self.eps).sqrt_().to(dtype=x_type))
|
||||
x = x / d
|
||||
return x.mul_(self.weight).add_(self.bias)
|
||||
else:
|
||||
return x.mul(self.weight).add_(self.bias)
|
||||
|
||||
|
||||
class EvoNormSample2d(nn.Module):
|
||||
def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5, jit=True):
|
||||
def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5):
|
||||
super(EvoNormSample2d, self).__init__()
|
||||
self.nonlin = nonlin
|
||||
self.groups = groups
|
||||
self.eps = eps
|
||||
self.jit = jit
|
||||
param_shape = (1, num_features, 1, 1)
|
||||
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
|
||||
@ -117,18 +73,13 @@ class EvoNormSample2d(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
assert x.dim() == 4, 'expected 4D input'
|
||||
|
||||
if self.jit:
|
||||
return evo_sample_jit(
|
||||
x, self.v, self.weight, self.bias, self.groups, self.nonlin, self.eps)
|
||||
B, C, H, W = x.shape
|
||||
assert C % self.groups == 0
|
||||
if self.nonlin:
|
||||
n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
|
||||
x = x.reshape(B, self.groups, -1)
|
||||
x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(self.eps).sqrt_()
|
||||
x = x.reshape(B, C, H, W)
|
||||
return x.mul_(self.weight).add_(self.bias)
|
||||
else:
|
||||
B, C, H, W = x.shape
|
||||
assert C % self.groups == 0
|
||||
if self.nonlin:
|
||||
n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
|
||||
x = x.reshape(B, self.groups, -1)
|
||||
x = n / (x.std(dim=-1, unbiased=False, keepdim=True) + self.eps)
|
||||
x = x.reshape(B, C, H, W)
|
||||
return x.mul_(self.weight).add_(self.bias)
|
||||
else:
|
||||
return x.mul(self.weight).add_(self.bias)
|
||||
return x.mul(self.weight).add_(self.bias)
|
||||
|
Loading…
x
Reference in New Issue
Block a user