mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add support for tflite mnasnet pretrained weights and included spnasnet pretrained weights of my own.
* tensorflow 'SAME' padding support added to GenMobileNet models for tflite pretrained weights * folded batch norm support (made batch norm optional and enable conv bias) for tflite pretrained weights * add url for spnasnet1_00 weights that I recently trained * fix SE reduction size for semnasnet models
This commit is contained in:
parent
afb357ff68
commit
4663fc2132
39
models/conv2d_same.py
Normal file
39
models/conv2d_same.py
Normal file
@ -0,0 +1,39 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
class Conv2dSame(nn.Conv2d):
|
||||
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True):
|
||||
super(Conv2dSame, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, 0, dilation,
|
||||
groups, bias)
|
||||
|
||||
def forward(self, x):
|
||||
ih, iw = x.size()[-2:]
|
||||
kh, kw = self.weight.size()[-2:]
|
||||
oh = math.ceil(ih / self.stride[0])
|
||||
ow = math.ceil(iw / self.stride[1])
|
||||
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
||||
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
|
||||
return F.conv2d(x, self.weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
# helper method
|
||||
def sconv2d(in_chs, out_chs, kernel_size, **kwargs):
|
||||
padding = kwargs.pop('padding', 0)
|
||||
if isinstance(padding, str):
|
||||
if padding.lower() == 'same':
|
||||
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
||||
else:
|
||||
# 'valid'
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs)
|
||||
else:
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
@ -23,6 +23,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from models.conv2d_same import sconv2d
|
||||
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
__all__ = ['GenMobileNet', 'mnasnet0_50', 'mnasnet0_75', 'mnasnet1_00', 'mnasnet1_40',
|
||||
@ -45,10 +46,12 @@ default_cfgs = {
|
||||
'mnasnet0_50': _cfg(url=''),
|
||||
'mnasnet0_75': _cfg(url=''),
|
||||
'mnasnet1_00': _cfg(url=''),
|
||||
'tflite_mnasnet1_00': _cfg(url='', interpolation='bicubic'),
|
||||
'mnasnet1_40': _cfg(url=''),
|
||||
'semnasnet0_50': _cfg(url=''),
|
||||
'semnasnet0_75': _cfg(url=''),
|
||||
'semnasnet1_00': _cfg(url=''),
|
||||
'tflite_semnasnet1_00': _cfg(url='', interpolation='bicubic'),
|
||||
'semnasnet1_40': _cfg(url=''),
|
||||
'mnasnet_small': _cfg(url=''),
|
||||
'mobilenetv1_1_00': _cfg(url=''),
|
||||
@ -56,7 +59,7 @@ default_cfgs = {
|
||||
'chamnetv1_1_00': _cfg(url=''),
|
||||
'chamnetv2_1_00': _cfg(url=''),
|
||||
'fbnetc_1_00': _cfg(url=''),
|
||||
'spnasnet1_00': _cfg(url=''),
|
||||
'spnasnet1_00': _cfg(url='https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet1_00-048bc3f4.pth?dl=1'),
|
||||
}
|
||||
|
||||
_DEBUG = True
|
||||
@ -184,11 +187,15 @@ def _decode_block_str(block_str):
|
||||
return [deepcopy(block_args) for _ in range(num_repeat)]
|
||||
|
||||
|
||||
def _get_padding(kernel_size, stride, dilation):
|
||||
def _get_padding(kernel_size, stride, dilation=1):
|
||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||
return padding
|
||||
|
||||
|
||||
def _padding_arg(default, padding_same=False):
|
||||
return 'SAME' if padding_same else default
|
||||
|
||||
|
||||
def _decode_arch_args(string_list):
|
||||
block_args = []
|
||||
for block_str in string_list:
|
||||
@ -219,12 +226,15 @@ class _BlockBuilder:
|
||||
"""
|
||||
|
||||
def __init__(self, depth_multiplier=1.0, depth_divisor=8, min_depth=None,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
folded_bn=False, padding_same=False):
|
||||
self.depth_multiplier = depth_multiplier
|
||||
self.depth_divisor = depth_divisor
|
||||
self.min_depth = min_depth
|
||||
self.bn_momentum = bn_momentum
|
||||
self.bn_eps = bn_eps
|
||||
self.folded_bn = folded_bn
|
||||
self.padding_same = padding_same
|
||||
self.in_chs = None
|
||||
|
||||
def _round_channels(self, chs):
|
||||
@ -236,6 +246,8 @@ class _BlockBuilder:
|
||||
ba['out_chs'] = _round_channels(ba['out_chs'])
|
||||
ba['bn_momentum'] = self.bn_momentum
|
||||
ba['bn_eps'] = self.bn_eps
|
||||
ba['folded_bn'] = self.folded_bn
|
||||
ba['padding_same'] = self.padding_same
|
||||
if _DEBUG:
|
||||
print('args:', ba)
|
||||
# could replace this with lambdas or functools binding if variety increases
|
||||
@ -320,29 +332,37 @@ def _initialize_weight_default(m):
|
||||
class DepthwiseSeparableConv(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, kernel_size,
|
||||
stride=1, act_fn=F.relu, noskip=False, pw_act=False,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
folded_bn=False, padding_same=False):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.has_pw_act = pw_act # activation after point-wise conv
|
||||
self.act_fn = act_fn
|
||||
dw_padding = _padding_arg(kernel_size // 2, padding_same)
|
||||
pw_padding = _padding_arg(0, padding_same)
|
||||
|
||||
self.conv_dw = nn.Conv2d(
|
||||
self.conv_dw = sconv2d(
|
||||
in_chs, in_chs, kernel_size,
|
||||
stride=stride, padding=kernel_size // 2, groups=in_chs, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv_pw = nn.Conv2d(in_chs, out_chs, 1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
stride=stride, padding=dw_padding, groups=in_chs, bias=folded_bn)
|
||||
self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv_pw = sconv2d(in_chs, out_chs, 1, padding=pw_padding, bias=folded_bn)
|
||||
self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn1(x)
|
||||
if self.bn1 is not None:
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn2(x)
|
||||
if self.bn2 is not None:
|
||||
x = self.bn2(x)
|
||||
if self.has_pw_act:
|
||||
x = self.act_fn(x)
|
||||
|
||||
if self.has_residual:
|
||||
x += residual
|
||||
return x
|
||||
@ -351,24 +371,28 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
class CascadeConv3x3(nn.Sequential):
|
||||
# FIXME lifted from maskrcnn_benchmark blocks, haven't used yet
|
||||
def __init__(self, in_chs, out_chs, stride, act_fn=F.relu, noskip=False,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
folded_bn=False, padding_same=False):
|
||||
super(CascadeConv3x3, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
self.has_residual = not noskip and (stride == 1 and in_chs == out_chs)
|
||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.act_fn = act_fn
|
||||
padding = _padding_arg(1, padding_same)
|
||||
|
||||
self.conv1 = nn.Conv2d(in_chs, in_chs, 3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv2 = nn.Conv2d(in_chs, out_chs, 3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv1 = sconv2d(in_chs, in_chs, 3, stride=stride, padding=padding, bias=folded_bn)
|
||||
self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv2 = sconv2d(in_chs, out_chs, 3, stride=1, padding=padding, bias=folded_bn)
|
||||
self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
if self.bn1 is not None:
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
if self.bn2 is not None:
|
||||
x = self.bn2(x)
|
||||
if self.has_residual:
|
||||
x += residual
|
||||
return x
|
||||
@ -396,10 +420,10 @@ class ChannelShuffle(nn.Module):
|
||||
|
||||
|
||||
class SqueezeExcite(nn.Module):
|
||||
def __init__(self, in_chs, se_ratio=0.25, act_fn=F.relu):
|
||||
def __init__(self, in_chs, reduce_chs=None, act_fn=F.relu):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
self.act_fn = act_fn
|
||||
reduced_chs = max(1, int(in_chs * se_ratio))
|
||||
reduced_chs = reduce_chs or in_chs
|
||||
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
||||
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
||||
|
||||
@ -419,41 +443,44 @@ class InvertedResidual(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, kernel_size,
|
||||
stride=1, act_fn=F.relu, exp_ratio=1.0, noskip=False,
|
||||
se_ratio=0., shuffle_type=None, pw_group=1,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
folded_bn=False, padding_same=False):
|
||||
super(InvertedResidual, self).__init__()
|
||||
mid_chs = int(in_chs * exp_ratio)
|
||||
self.has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.act_fn = act_fn
|
||||
dw_padding = _padding_arg(kernel_size // 2, padding_same)
|
||||
pw_padding = _padding_arg(0, padding_same)
|
||||
|
||||
# Point-wise expansion
|
||||
self.conv_pw = nn.Conv2d(in_chs, mid_chs, 1, groups=pw_group, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv_pw = sconv2d(in_chs, mid_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn)
|
||||
self.bn1 = None if folded_bn else nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
|
||||
self.shuffle_type = shuffle_type
|
||||
if shuffle_type is not None:
|
||||
self.shuffle = ChannelShuffle(pw_group)
|
||||
|
||||
# Depth-wise convolution
|
||||
self.conv_dw = nn.Conv2d(
|
||||
mid_chs, mid_chs, kernel_size, padding=kernel_size // 2,
|
||||
stride=stride, groups=mid_chs, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv_dw = sconv2d(
|
||||
mid_chs, mid_chs, kernel_size, padding=dw_padding, stride=stride, groups=mid_chs, bias=folded_bn)
|
||||
self.bn2 = None if folded_bn else nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if self.has_se:
|
||||
self.se = SqueezeExcite(mid_chs, se_ratio)
|
||||
self.se = SqueezeExcite(mid_chs, reduce_chs=max(1, int(in_chs * se_ratio)))
|
||||
|
||||
# Point-wise linear projection
|
||||
self.conv_pwl = nn.Conv2d(mid_chs, out_chs, 1, groups=pw_group, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv_pwl = sconv2d(mid_chs, out_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn)
|
||||
self.bn3 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
# Point-wise expansion
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn1(x)
|
||||
if self.bn1 is not None:
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
|
||||
# FIXME haven't tried this yet
|
||||
@ -463,7 +490,8 @@ class InvertedResidual(nn.Module):
|
||||
|
||||
# Depth-wise convolution
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn2(x)
|
||||
if self.bn2 is not None:
|
||||
x = self.bn2(x)
|
||||
x = self.act_fn(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
@ -472,7 +500,8 @@ class InvertedResidual(nn.Module):
|
||||
|
||||
# Point-wise linear projection
|
||||
x = self.conv_pwl(x)
|
||||
x = self.bn3(x)
|
||||
if self.bn3 is not None:
|
||||
x = self.bn3(x)
|
||||
|
||||
if self.has_residual:
|
||||
x += residual
|
||||
@ -498,7 +527,7 @@ class GenMobileNet(nn.Module):
|
||||
depth_multiplier=1.0, depth_divisor=8, min_depth=None,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
drop_rate=0., act_fn=F.relu, global_pool='avg', skip_head_conv=False,
|
||||
weight_init='goog'):
|
||||
weight_init='goog', folded_bn=False, padding_same=False):
|
||||
super(GenMobileNet, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.depth_multiplier = depth_multiplier
|
||||
@ -507,13 +536,15 @@ class GenMobileNet(nn.Module):
|
||||
self.num_features = num_features
|
||||
|
||||
stem_size = _round_channels(stem_size, depth_multiplier, depth_divisor, min_depth)
|
||||
self.conv_stem = nn.Conv2d(in_chans, stem_size, 3, padding=1, stride=2, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(stem_size, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv_stem = sconv2d(
|
||||
in_chans, stem_size, 3,
|
||||
padding=_padding_arg(1, padding_same), stride=2, bias=folded_bn)
|
||||
self.bn1 = None if folded_bn else nn.BatchNorm2d(stem_size, momentum=bn_momentum, eps=bn_eps)
|
||||
in_chs = stem_size
|
||||
|
||||
builder = _BlockBuilder(
|
||||
depth_multiplier, depth_divisor, min_depth,
|
||||
bn_momentum, bn_eps)
|
||||
bn_momentum, bn_eps, folded_bn, padding_same)
|
||||
self.blocks = nn.Sequential(*builder(in_chs, block_args))
|
||||
in_chs = builder.in_chs
|
||||
|
||||
@ -521,8 +552,10 @@ class GenMobileNet(nn.Module):
|
||||
self.conv_head = None
|
||||
assert in_chs == self.num_features
|
||||
else:
|
||||
self.conv_head = nn.Conv2d(in_chs, self.num_features, 1, padding=0, stride=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv_head = sconv2d(
|
||||
in_chs, self.num_features, 1,
|
||||
padding=_padding_arg(0, padding_same), bias=folded_bn)
|
||||
self.bn2 = None if folded_bn else nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps)
|
||||
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.classifier = nn.Linear(self.num_features, self.num_classes)
|
||||
@ -548,12 +581,14 @@ class GenMobileNet(nn.Module):
|
||||
|
||||
def forward_features(self, x, pool=True):
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
if self.bn1 is not None:
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.blocks(x)
|
||||
if self.conv_head is not None:
|
||||
x = self.conv_head(x)
|
||||
x = self.bn2(x)
|
||||
if self.bn2 is not None:
|
||||
x = self.bn2(x)
|
||||
x = self.act_fn(x)
|
||||
if pool:
|
||||
x = self.global_pool(x)
|
||||
@ -909,6 +944,19 @@ def mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def tflite_mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||
""" MNASNet B1, depth multiplier of 1.0. """
|
||||
default_cfg = default_cfgs['tflite_mnasnet1_00']
|
||||
# these two args are for compat with tflite pretrained weights
|
||||
kwargs['folded_bn'] = True
|
||||
kwargs['padding_same'] = True
|
||||
model = _gen_mnasnet_b1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def mnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||
""" MNASNet B1, depth multiplier of 1.4 """
|
||||
default_cfg = default_cfgs['mnasnet1_40']
|
||||
@ -949,6 +997,19 @@ def semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def tflite_semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||
""" MNASNet A1, depth multiplier of 1.0. """
|
||||
default_cfg = default_cfgs['tflite_semnasnet1_00']
|
||||
# these two args are for compat with tflite pretrained weights
|
||||
kwargs['folded_bn'] = True
|
||||
kwargs['padding_same'] = True
|
||||
model = _gen_mnasnet_a1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def semnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||
""" MNASNet A1 (w/ SE), depth multiplier of 1.4. """
|
||||
default_cfg = default_cfgs['semnasnet1_40']
|
||||
|
@ -9,8 +9,8 @@ from models.senet import seresnet18, seresnet34, seresnet50, seresnet101, seresn
|
||||
from models.xception import xception
|
||||
from models.pnasnet import pnasnet5large
|
||||
from models.genmobilenet import \
|
||||
mnasnet0_50, mnasnet0_75, mnasnet1_00, mnasnet1_40,\
|
||||
semnasnet0_50, semnasnet0_75, semnasnet1_00, semnasnet1_40, mnasnet_small,\
|
||||
mnasnet0_50, mnasnet0_75, mnasnet1_00, mnasnet1_40, tflite_mnasnet1_00,\
|
||||
semnasnet0_50, semnasnet0_75, semnasnet1_00, semnasnet1_40, tflite_semnasnet1_00, mnasnet_small,\
|
||||
mobilenetv1_1_00, mobilenetv2_1_00, fbnetc_1_00, chamnetv1_1_00, chamnetv2_1_00,\
|
||||
spnasnet1_00
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user