mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
DenseNet converted to support ABN (norm + act) modules. Experimenting with EvoNorm, IABN
This commit is contained in:
parent
022ed001f3
commit
14edacdf9a
@ -13,7 +13,7 @@ from torch.jit.annotations import List
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import SelectAdaptivePool2d
|
from .layers import SelectAdaptivePool2d, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
|
|
||||||
__all__ = ['DenseNet']
|
__all__ = ['DenseNet']
|
||||||
@ -35,90 +35,88 @@ default_cfgs = {
|
|||||||
'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'),
|
'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'),
|
||||||
'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'),
|
'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'),
|
||||||
'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'),
|
'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'),
|
||||||
|
'densenet264': _cfg(url=''),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class _DenseLayer(nn.Module):
|
class DenseLayer(nn.Module):
|
||||||
def __init__(self, num_input_features, growth_rate, bn_size, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
def __init__(self, num_input_features, growth_rate, bn_size, norm_act_layer=BatchNormAct2d,
|
||||||
drop_rate=0., memory_efficient=False):
|
drop_rate=0., memory_efficient=False):
|
||||||
super(_DenseLayer, self).__init__()
|
super(DenseLayer, self).__init__()
|
||||||
self.add_module('norm1', norm_layer(num_input_features)),
|
self.add_module('norm1', norm_act_layer(num_input_features)),
|
||||||
self.add_module('relu1', act_layer(inplace=True)),
|
|
||||||
self.add_module('conv1', nn.Conv2d(
|
self.add_module('conv1', nn.Conv2d(
|
||||||
num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)),
|
num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)),
|
||||||
self.add_module('norm2', norm_layer(bn_size * growth_rate)),
|
self.add_module('norm2', norm_act_layer(bn_size * growth_rate)),
|
||||||
self.add_module('relu2', act_layer(inplace=True)),
|
|
||||||
self.add_module('conv2', nn.Conv2d(
|
self.add_module('conv2', nn.Conv2d(
|
||||||
bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
|
bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
|
||||||
self.drop_rate = float(drop_rate)
|
self.drop_rate = float(drop_rate)
|
||||||
self.memory_efficient = memory_efficient
|
self.memory_efficient = memory_efficient
|
||||||
|
|
||||||
def bn_function(self, inputs):
|
def bottleneck_fn(self, xs):
|
||||||
# type: (List[torch.Tensor]) -> torch.Tensor
|
# type: (List[torch.Tensor]) -> torch.Tensor
|
||||||
concated_features = torch.cat(inputs, 1)
|
concated_features = torch.cat(xs, 1)
|
||||||
bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484
|
bottleneck_output = self.conv1(self.norm1(concated_features)) # noqa: T484
|
||||||
return bottleneck_output
|
return bottleneck_output
|
||||||
|
|
||||||
# todo: rewrite when torchscript supports any
|
# todo: rewrite when torchscript supports any
|
||||||
def any_requires_grad(self, input):
|
def any_requires_grad(self, x):
|
||||||
# type: (List[torch.Tensor]) -> bool
|
# type: (List[torch.Tensor]) -> bool
|
||||||
for tensor in input:
|
for tensor in x:
|
||||||
if tensor.requires_grad:
|
if tensor.requires_grad:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@torch.jit.unused # noqa: T484
|
@torch.jit.unused # noqa: T484
|
||||||
def call_checkpoint_bottleneck(self, input):
|
def call_checkpoint_bottleneck(self, x):
|
||||||
# type: (List[torch.Tensor]) -> torch.Tensor
|
# type: (List[torch.Tensor]) -> torch.Tensor
|
||||||
def closure(*inputs):
|
def closure(*xs):
|
||||||
return self.bn_function(*inputs)
|
return self.bottleneck_fn(*xs)
|
||||||
|
|
||||||
return cp.checkpoint(closure, input)
|
return cp.checkpoint(closure, x)
|
||||||
|
|
||||||
@torch.jit._overload_method # noqa: F811
|
@torch.jit._overload_method # noqa: F811
|
||||||
def forward(self, input):
|
def forward(self, x):
|
||||||
# type: (List[torch.Tensor]) -> (torch.Tensor)
|
# type: (List[torch.Tensor]) -> (torch.Tensor)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@torch.jit._overload_method # noqa: F811
|
@torch.jit._overload_method # noqa: F811
|
||||||
def forward(self, input):
|
def forward(self, x):
|
||||||
# type: (torch.Tensor) -> (torch.Tensor)
|
# type: (torch.Tensor) -> (torch.Tensor)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# torchscript does not yet support *args, so we overload method
|
# torchscript does not yet support *args, so we overload method
|
||||||
# allowing it to take either a List[Tensor] or single Tensor
|
# allowing it to take either a List[Tensor] or single Tensor
|
||||||
def forward(self, input): # noqa: F811
|
def forward(self, x): # noqa: F811
|
||||||
if isinstance(input, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
prev_features = [input]
|
prev_features = [x]
|
||||||
else:
|
else:
|
||||||
prev_features = input
|
prev_features = x
|
||||||
|
|
||||||
if self.memory_efficient and self.any_requires_grad(prev_features):
|
if self.memory_efficient and self.any_requires_grad(prev_features):
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
raise Exception("Memory Efficient not supported in JIT")
|
raise Exception("Memory Efficient not supported in JIT")
|
||||||
bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
|
bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
|
||||||
else:
|
else:
|
||||||
bottleneck_output = self.bn_function(prev_features)
|
bottleneck_output = self.bottleneck_fn(prev_features)
|
||||||
|
|
||||||
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
|
new_features = self.conv2(self.norm2(bottleneck_output))
|
||||||
if self.drop_rate > 0:
|
if self.drop_rate > 0:
|
||||||
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
|
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
|
||||||
return new_features
|
return new_features
|
||||||
|
|
||||||
|
|
||||||
class _DenseBlock(nn.ModuleDict):
|
class DenseBlock(nn.ModuleDict):
|
||||||
_version = 2
|
_version = 2
|
||||||
|
|
||||||
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, act_layer=nn.ReLU,
|
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_act_layer=nn.ReLU,
|
||||||
norm_layer=nn.BatchNorm2d, drop_rate=0., memory_efficient=False):
|
drop_rate=0., memory_efficient=False):
|
||||||
super(_DenseBlock, self).__init__()
|
super(DenseBlock, self).__init__()
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
layer = _DenseLayer(
|
layer = DenseLayer(
|
||||||
num_input_features + i * growth_rate,
|
num_input_features + i * growth_rate,
|
||||||
growth_rate=growth_rate,
|
growth_rate=growth_rate,
|
||||||
bn_size=bn_size,
|
bn_size=bn_size,
|
||||||
act_layer=act_layer,
|
norm_act_layer=norm_act_layer,
|
||||||
norm_layer=norm_layer,
|
|
||||||
drop_rate=drop_rate,
|
drop_rate=drop_rate,
|
||||||
memory_efficient=memory_efficient,
|
memory_efficient=memory_efficient,
|
||||||
)
|
)
|
||||||
@ -132,11 +130,10 @@ class _DenseBlock(nn.ModuleDict):
|
|||||||
return torch.cat(features, 1)
|
return torch.cat(features, 1)
|
||||||
|
|
||||||
|
|
||||||
class _Transition(nn.Sequential):
|
class DenseTransition(nn.Sequential):
|
||||||
def __init__(self, num_input_features, num_output_features, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
def __init__(self, num_input_features, num_output_features, norm_act_layer=nn.BatchNorm2d):
|
||||||
super(_Transition, self).__init__()
|
super(DenseTransition, self).__init__()
|
||||||
self.add_module('norm', norm_layer(num_input_features))
|
self.add_module('norm', norm_act_layer(num_input_features))
|
||||||
self.add_module('relu', act_layer(inplace=True))
|
|
||||||
self.add_module('conv', nn.Conv2d(
|
self.add_module('conv', nn.Conv2d(
|
||||||
num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
|
num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
|
||||||
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
|
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
|
||||||
@ -149,7 +146,6 @@ class DenseNet(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
growth_rate (int) - how many filters to add each layer (`k` in paper)
|
growth_rate (int) - how many filters to add each layer (`k` in paper)
|
||||||
block_config (list of 4 ints) - how many layers in each pooling block
|
block_config (list of 4 ints) - how many layers in each pooling block
|
||||||
num_init_features (int) - the number of filters to learn in the first convolution layer
|
|
||||||
bn_size (int) - multiplicative factor for number of bottle neck layers
|
bn_size (int) - multiplicative factor for number of bottle neck layers
|
||||||
(i.e. bn_size * k features in the bottleneck layer)
|
(i.e. bn_size * k features in the bottleneck layer)
|
||||||
drop_rate (float) - dropout rate after each dense layer
|
drop_rate (float) - dropout rate after each dense layer
|
||||||
@ -158,67 +154,66 @@ class DenseNet(nn.Module):
|
|||||||
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
|
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
|
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem_type='',
|
||||||
bn_size=4, stem_type='', num_classes=1000, in_chans=3, global_pool='avg',
|
num_classes=1000, in_chans=3, global_pool='avg',
|
||||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0, memory_efficient=False):
|
norm_act_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
deep_stem = 'deep' in stem_type
|
|
||||||
super(DenseNet, self).__init__()
|
super(DenseNet, self).__init__()
|
||||||
|
|
||||||
# First convolution
|
# Stem
|
||||||
|
deep_stem = 'deep' in stem_type # 3x3 deep stem
|
||||||
|
num_init_features = growth_rate * 2
|
||||||
if aa_layer is None:
|
if aa_layer is None:
|
||||||
max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
stem_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
else:
|
else:
|
||||||
max_pool = nn.Sequential(*[
|
stem_pool = nn.Sequential(*[
|
||||||
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
||||||
aa_layer(channels=self.inplanes, stride=2)])
|
aa_layer(channels=num_init_features, stride=2)])
|
||||||
if deep_stem:
|
if deep_stem:
|
||||||
stem_chs_1 = stem_chs_2 = num_init_features // 2
|
stem_chs_1 = stem_chs_2 = growth_rate
|
||||||
if 'tiered' in stem_type:
|
if 'tiered' in stem_type:
|
||||||
stem_chs_1 = 3 * (num_init_features // 8)
|
stem_chs_1 = 3 * (growth_rate // 4)
|
||||||
stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (num_init_features // 8)
|
stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4)
|
||||||
self.features = nn.Sequential(OrderedDict([
|
self.features = nn.Sequential(OrderedDict([
|
||||||
('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)),
|
('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)),
|
||||||
('norm0', norm_layer(stem_chs_1)),
|
('norm0', norm_act_layer(stem_chs_1)),
|
||||||
('relu0', act_layer(inplace=True)),
|
|
||||||
('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)),
|
('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)),
|
||||||
('norm1', norm_layer(stem_chs_2)),
|
('norm1', norm_act_layer(stem_chs_2)),
|
||||||
('relu1', act_layer(inplace=True)),
|
|
||||||
('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)),
|
('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)),
|
||||||
('norm2', norm_layer(num_init_features)),
|
('norm2', norm_act_layer(num_init_features)),
|
||||||
('relu2', act_layer(inplace=True)),
|
('pool0', stem_pool),
|
||||||
('pool0', max_pool),
|
|
||||||
]))
|
]))
|
||||||
else:
|
else:
|
||||||
self.features = nn.Sequential(OrderedDict([
|
self.features = nn.Sequential(OrderedDict([
|
||||||
('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
|
('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
|
||||||
('norm0', norm_layer(num_init_features)),
|
('norm0', norm_act_layer(num_init_features)),
|
||||||
('relu0', act_layer(inplace=True)),
|
('pool0', stem_pool),
|
||||||
('pool0', max_pool),
|
|
||||||
]))
|
]))
|
||||||
|
|
||||||
# Each denseblock
|
# DenseBlocks
|
||||||
num_features = num_init_features
|
num_features = num_init_features
|
||||||
for i, num_layers in enumerate(block_config):
|
for i, num_layers in enumerate(block_config):
|
||||||
block = _DenseBlock(
|
block = DenseBlock(
|
||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
num_input_features=num_features,
|
num_input_features=num_features,
|
||||||
bn_size=bn_size,
|
bn_size=bn_size,
|
||||||
growth_rate=growth_rate,
|
growth_rate=growth_rate,
|
||||||
|
norm_act_layer=norm_act_layer,
|
||||||
drop_rate=drop_rate,
|
drop_rate=drop_rate,
|
||||||
memory_efficient=memory_efficient
|
memory_efficient=memory_efficient
|
||||||
)
|
)
|
||||||
self.features.add_module('denseblock%d' % (i + 1), block)
|
self.features.add_module('denseblock%d' % (i + 1), block)
|
||||||
num_features = num_features + num_layers * growth_rate
|
num_features = num_features + num_layers * growth_rate
|
||||||
if i != len(block_config) - 1:
|
if i != len(block_config) - 1:
|
||||||
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
|
trans = DenseTransition(
|
||||||
|
num_input_features=num_features, num_output_features=num_features // 2,
|
||||||
|
norm_act_layer=norm_act_layer)
|
||||||
self.features.add_module('transition%d' % (i + 1), trans)
|
self.features.add_module('transition%d' % (i + 1), trans)
|
||||||
num_features = num_features // 2
|
num_features = num_features // 2
|
||||||
|
|
||||||
# Final batch norm
|
# Final batch norm
|
||||||
self.features.add_module('norm5', norm_layer(num_features))
|
self.features.add_module('norm5', norm_act_layer(num_features))
|
||||||
self.act = act_layer(inplace=True)
|
|
||||||
|
|
||||||
# Linear layer
|
# Linear layer
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
@ -248,9 +243,7 @@ class DenseNet(nn.Module):
|
|||||||
self.classifier = nn.Identity()
|
self.classifier = nn.Identity()
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.features(x)
|
return self.features(x)
|
||||||
x = self.act(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.forward_features(x)
|
x = self.forward_features(x)
|
||||||
@ -275,7 +268,7 @@ def _filter_torchvision_pretrained(state_dict):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def _densenet(variant, growth_rate, block_config, num_init_features, pretrained, **kwargs):
|
def _densenet(variant, growth_rate, block_config, pretrained, **kwargs):
|
||||||
if kwargs.pop('features_only', False):
|
if kwargs.pop('features_only', False):
|
||||||
assert False, 'Not Implemented' # TODO
|
assert False, 'Not Implemented' # TODO
|
||||||
load_strict = False
|
load_strict = False
|
||||||
@ -285,8 +278,7 @@ def _densenet(variant, growth_rate, block_config, num_init_features, pretrained,
|
|||||||
load_strict = True
|
load_strict = True
|
||||||
model_class = DenseNet
|
model_class = DenseNet
|
||||||
default_cfg = default_cfgs[variant]
|
default_cfg = default_cfgs[variant]
|
||||||
model = model_class(
|
model = model_class(growth_rate=growth_rate, block_config=block_config, **kwargs)
|
||||||
growth_rate=growth_rate, block_config=block_config, num_init_features=num_init_features, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(
|
load_pretrained(
|
||||||
@ -304,8 +296,7 @@ def densenet121(pretrained=False, **kwargs):
|
|||||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||||
"""
|
"""
|
||||||
model = _densenet(
|
model = _densenet(
|
||||||
'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
|
'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs)
|
||||||
pretrained=pretrained, **kwargs)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -315,8 +306,8 @@ def densenet121d(pretrained=False, **kwargs):
|
|||||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||||
"""
|
"""
|
||||||
model = _densenet(
|
model = _densenet(
|
||||||
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
|
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
|
||||||
stem_type='deep', pretrained=pretrained, **kwargs)
|
pretrained=pretrained, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -326,8 +317,42 @@ def densenet121tn(pretrained=False, **kwargs):
|
|||||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||||
"""
|
"""
|
||||||
model = _densenet(
|
model = _densenet(
|
||||||
'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
|
'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep_tiered_narrow',
|
||||||
stem_type='deep_tiered_narrow', pretrained=pretrained, **kwargs)
|
pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def densenet121d_evob(pretrained=False, **kwargs):
|
||||||
|
r"""Densenet-121 model from
|
||||||
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||||
|
"""
|
||||||
|
model = _densenet(
|
||||||
|
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
|
||||||
|
norm_act_layer=EvoNormBatch2d, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def densenet121d_evos(pretrained=False, **kwargs):
|
||||||
|
r"""Densenet-121 model from
|
||||||
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||||
|
"""
|
||||||
|
model = _densenet(
|
||||||
|
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
|
||||||
|
norm_act_layer=EvoNormSample2d, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def densenet121d_iabn(pretrained=False, **kwargs):
|
||||||
|
r"""Densenet-121 model from
|
||||||
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||||
|
"""
|
||||||
|
from inplace_abn import InPlaceABN
|
||||||
|
model = _densenet(
|
||||||
|
'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
|
||||||
|
norm_act_layer=InPlaceABN, pretrained=pretrained, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -337,8 +362,7 @@ def densenet169(pretrained=False, **kwargs):
|
|||||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||||
"""
|
"""
|
||||||
model = _densenet(
|
model = _densenet(
|
||||||
'densenet169', growth_rate=32, block_config=(6, 12, 32, 32), num_init_features=64,
|
'densenet169', growth_rate=32, block_config=(6, 12, 32, 32), pretrained=pretrained, **kwargs)
|
||||||
pretrained=pretrained, **kwargs)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -348,17 +372,25 @@ def densenet201(pretrained=False, **kwargs):
|
|||||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||||
"""
|
"""
|
||||||
model = _densenet(
|
model = _densenet(
|
||||||
'densenet201', growth_rate=32, block_config=(6, 12, 48, 32), num_init_features=64,
|
'densenet201', growth_rate=32, block_config=(6, 12, 48, 32), pretrained=pretrained, **kwargs)
|
||||||
pretrained=pretrained, **kwargs)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def densenet161(pretrained=False, **kwargs):
|
def densenet161(pretrained=False, **kwargs):
|
||||||
r"""Densenet-201 model from
|
r"""Densenet-161 model from
|
||||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||||
"""
|
"""
|
||||||
model = _densenet(
|
model = _densenet(
|
||||||
'densenet161', growth_rate=48, block_config=(6, 12, 36, 24), num_init_features=96,
|
'densenet161', growth_rate=48, block_config=(6, 12, 36, 24), pretrained=pretrained, **kwargs)
|
||||||
pretrained=pretrained, **kwargs)
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def densenet264(pretrained=False, **kwargs):
|
||||||
|
r"""Densenet-264 model from
|
||||||
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||||
|
"""
|
||||||
|
model = _densenet(
|
||||||
|
'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
@ -19,3 +19,5 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
|||||||
from .anti_aliasing import AntiAliasDownsampleLayer
|
from .anti_aliasing import AntiAliasDownsampleLayer
|
||||||
from .space_to_depth import SpaceToDepthModule
|
from .space_to_depth import SpaceToDepthModule
|
||||||
from .blur_pool import BlurPool2d
|
from .blur_pool import BlurPool2d
|
||||||
|
from .norm_act import BatchNormAct2d
|
||||||
|
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
134
timm/models/layers/evo_norm.py
Normal file
134
timm/models/layers/evo_norm.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch
|
||||||
|
|
||||||
|
An attempt at getting decent performing EvoNorms running in PyTorch.
|
||||||
|
While currently faster than other impl, still quite a ways off the built-in BN
|
||||||
|
in terms of memory usage and throughput.
|
||||||
|
|
||||||
|
Still very much a WIP, fiddling with buffer usage, in-place optimizations, and layouts.
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def evo_batch_jit(
|
||||||
|
x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, running_var: torch.Tensor,
|
||||||
|
momentum: float, training: bool, nonlin: bool, eps: float):
|
||||||
|
x_type = x.dtype
|
||||||
|
running_var = running_var.detach() # FIXME why is this needed, it's a buffer?
|
||||||
|
if training:
|
||||||
|
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) # FIXME biased, unbiased?
|
||||||
|
running_var.copy_(momentum * var + (1 - momentum) * running_var)
|
||||||
|
else:
|
||||||
|
var = running_var.clone()
|
||||||
|
|
||||||
|
if nonlin:
|
||||||
|
# FIXME biased, unbiased?
|
||||||
|
d = (x * v.to(x_type)) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(eps).sqrt_().to(dtype=x_type)
|
||||||
|
d = d.max(var.add(eps).sqrt_().to(dtype=x_type))
|
||||||
|
x = x / d
|
||||||
|
return x.mul_(weight).add_(bias)
|
||||||
|
else:
|
||||||
|
return x.mul(weight).add_(bias)
|
||||||
|
|
||||||
|
|
||||||
|
class EvoNormBatch2d(nn.Module):
|
||||||
|
def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5, jit=True):
|
||||||
|
super(EvoNormBatch2d, self).__init__()
|
||||||
|
self.momentum = momentum
|
||||||
|
self.nonlin = nonlin
|
||||||
|
self.eps = eps
|
||||||
|
self.jit = jit
|
||||||
|
param_shape = (1, num_features, 1, 1)
|
||||||
|
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||||
|
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
|
||||||
|
if nonlin:
|
||||||
|
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||||
|
self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.ones_(self.weight)
|
||||||
|
nn.init.zeros_(self.bias)
|
||||||
|
if self.nonlin:
|
||||||
|
nn.init.ones_(self.v)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert x.dim() == 4, 'expected 4D input'
|
||||||
|
|
||||||
|
if self.jit:
|
||||||
|
return evo_batch_jit(
|
||||||
|
x, self.v, self.weight, self.bias, self.running_var, self.momentum,
|
||||||
|
self.training, self.nonlin, self.eps)
|
||||||
|
else:
|
||||||
|
x_type = x.dtype
|
||||||
|
if self.training:
|
||||||
|
var = x.var(dim=(0, 2, 3), keepdim=True)
|
||||||
|
self.running_var.copy_(self.momentum * var + (1 - self.momentum) * self.running_var)
|
||||||
|
else:
|
||||||
|
var = self.running_var.clone()
|
||||||
|
|
||||||
|
if self.nonlin:
|
||||||
|
v = self.v.to(dtype=x_type)
|
||||||
|
d = (x * v) + x.var(dim=(2, 3), keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type)
|
||||||
|
d = d.max(var.add(self.eps).sqrt_().to(dtype=x_type))
|
||||||
|
x = x / d
|
||||||
|
return x.mul_(self.weight).add_(self.bias)
|
||||||
|
else:
|
||||||
|
return x.mul(self.weight).add_(self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def evo_sample_jit(
|
||||||
|
x: torch.Tensor, v: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
|
||||||
|
groups: int, nonlin: bool, eps: float):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
assert C % groups == 0
|
||||||
|
if nonlin:
|
||||||
|
n = (x * v).sigmoid_().reshape(B, groups, -1)
|
||||||
|
x = x.reshape(B, groups, -1)
|
||||||
|
x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(eps).sqrt_()
|
||||||
|
x = x.reshape(B, C, H, W)
|
||||||
|
return x.mul_(weight).add_(bias)
|
||||||
|
|
||||||
|
|
||||||
|
class EvoNormSample2d(nn.Module):
|
||||||
|
def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5, jit=True):
|
||||||
|
super(EvoNormSample2d, self).__init__()
|
||||||
|
self.nonlin = nonlin
|
||||||
|
self.groups = groups
|
||||||
|
self.eps = eps
|
||||||
|
self.jit = jit
|
||||||
|
param_shape = (1, num_features, 1, 1)
|
||||||
|
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||||
|
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
|
||||||
|
if nonlin:
|
||||||
|
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.ones_(self.weight)
|
||||||
|
nn.init.zeros_(self.bias)
|
||||||
|
if self.nonlin:
|
||||||
|
nn.init.ones_(self.v)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert x.dim() == 4, 'expected 4D input'
|
||||||
|
|
||||||
|
if self.jit:
|
||||||
|
return evo_sample_jit(
|
||||||
|
x, self.v, self.weight, self.bias, self.groups, self.nonlin, self.eps)
|
||||||
|
else:
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
assert C % self.groups == 0
|
||||||
|
if self.nonlin:
|
||||||
|
n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
|
||||||
|
x = x.reshape(B, self.groups, -1)
|
||||||
|
x = n / (x.std(dim=-1, unbiased=False, keepdim=True) + self.eps)
|
||||||
|
x = x.reshape(B, C, H, W)
|
||||||
|
return x.mul_(self.weight).add_(self.bias)
|
||||||
|
else:
|
||||||
|
return x.mul(self.weight).add_(self.bias)
|
50
timm/models/layers/norm_act.py
Normal file
50
timm/models/layers/norm_act.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
""" Normalization + Activation Layers
|
||||||
|
"""
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class BatchNormAct2d(nn.BatchNorm2d):
|
||||||
|
"""BatchNorm + Activation
|
||||||
|
|
||||||
|
This module performs BatchNorm + Actibation in s manner that will remain bavkwards
|
||||||
|
compatible with weights trained with separate bn, act. This is why we inherit from BN
|
||||||
|
instead of composing it as a .bn member.
|
||||||
|
"""
|
||||||
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
||||||
|
track_running_stats=True, act_layer=nn.ReLU, inplace=True):
|
||||||
|
super(BatchNormAct2d, self).__init__(num_features, eps, momentum, affine, track_running_stats)
|
||||||
|
self.act = act_layer(inplace=inplace)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# FIXME cannot call parent forward() and maintain jit.script compatibility?
|
||||||
|
# x = super(BatchNormAct2d, self).forward(x)
|
||||||
|
|
||||||
|
# BEGIN nn.BatchNorm2d forward() cut & paste
|
||||||
|
# self._check_input_dim(x)
|
||||||
|
|
||||||
|
# exponential_average_factor is self.momentum set to
|
||||||
|
# (when it is available) only so that if gets updated
|
||||||
|
# in ONNX graph when this node is exported to ONNX.
|
||||||
|
if self.momentum is None:
|
||||||
|
exponential_average_factor = 0.0
|
||||||
|
else:
|
||||||
|
exponential_average_factor = self.momentum
|
||||||
|
|
||||||
|
if self.training and self.track_running_stats:
|
||||||
|
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
||||||
|
if self.num_batches_tracked is not None:
|
||||||
|
self.num_batches_tracked += 1
|
||||||
|
if self.momentum is None: # use cumulative moving average
|
||||||
|
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
||||||
|
else: # use exponential moving average
|
||||||
|
exponential_average_factor = self.momentum
|
||||||
|
|
||||||
|
x = F.batch_norm(
|
||||||
|
x, self.running_mean, self.running_var, self.weight, self.bias,
|
||||||
|
self.training or not self.track_running_stats,
|
||||||
|
exponential_average_factor, self.eps)
|
||||||
|
# END BatchNorm2d forward()
|
||||||
|
|
||||||
|
x = self.act(x)
|
||||||
|
return x
|
Loading…
x
Reference in New Issue
Block a user