mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add support for tflite mnasnet pretrained weights and included spnasnet pretrained weights of my own.
* tensorflow 'SAME' padding support added to GenMobileNet models for tflite pretrained weights * folded batch norm support (made batch norm optional and enable conv bias) for tflite pretrained weights * add url for spnasnet1_00 weights that I recently trained * fix SE reduction size for semnasnet models
This commit is contained in:
parent
afb357ff68
commit
4663fc2132
39
models/conv2d_same.py
Normal file
39
models/conv2d_same.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dSame(nn.Conv2d):
|
||||||
|
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||||
|
padding=0, dilation=1, groups=1, bias=True):
|
||||||
|
super(Conv2dSame, self).__init__(
|
||||||
|
in_channels, out_channels, kernel_size, stride, 0, dilation,
|
||||||
|
groups, bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
ih, iw = x.size()[-2:]
|
||||||
|
kh, kw = self.weight.size()[-2:]
|
||||||
|
oh = math.ceil(ih / self.stride[0])
|
||||||
|
ow = math.ceil(iw / self.stride[1])
|
||||||
|
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
||||||
|
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
||||||
|
if pad_h > 0 or pad_w > 0:
|
||||||
|
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
|
||||||
|
return F.conv2d(x, self.weight, self.bias, self.stride,
|
||||||
|
self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
|
|
||||||
|
# helper method
|
||||||
|
def sconv2d(in_chs, out_chs, kernel_size, **kwargs):
|
||||||
|
padding = kwargs.pop('padding', 0)
|
||||||
|
if isinstance(padding, str):
|
||||||
|
if padding.lower() == 'same':
|
||||||
|
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
||||||
|
else:
|
||||||
|
# 'valid'
|
||||||
|
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs)
|
||||||
|
else:
|
||||||
|
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
@ -23,6 +23,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from models.helpers import load_pretrained
|
from models.helpers import load_pretrained
|
||||||
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||||
|
from models.conv2d_same import sconv2d
|
||||||
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
__all__ = ['GenMobileNet', 'mnasnet0_50', 'mnasnet0_75', 'mnasnet1_00', 'mnasnet1_40',
|
__all__ = ['GenMobileNet', 'mnasnet0_50', 'mnasnet0_75', 'mnasnet1_00', 'mnasnet1_40',
|
||||||
@ -45,10 +46,12 @@ default_cfgs = {
|
|||||||
'mnasnet0_50': _cfg(url=''),
|
'mnasnet0_50': _cfg(url=''),
|
||||||
'mnasnet0_75': _cfg(url=''),
|
'mnasnet0_75': _cfg(url=''),
|
||||||
'mnasnet1_00': _cfg(url=''),
|
'mnasnet1_00': _cfg(url=''),
|
||||||
|
'tflite_mnasnet1_00': _cfg(url='', interpolation='bicubic'),
|
||||||
'mnasnet1_40': _cfg(url=''),
|
'mnasnet1_40': _cfg(url=''),
|
||||||
'semnasnet0_50': _cfg(url=''),
|
'semnasnet0_50': _cfg(url=''),
|
||||||
'semnasnet0_75': _cfg(url=''),
|
'semnasnet0_75': _cfg(url=''),
|
||||||
'semnasnet1_00': _cfg(url=''),
|
'semnasnet1_00': _cfg(url=''),
|
||||||
|
'tflite_semnasnet1_00': _cfg(url='', interpolation='bicubic'),
|
||||||
'semnasnet1_40': _cfg(url=''),
|
'semnasnet1_40': _cfg(url=''),
|
||||||
'mnasnet_small': _cfg(url=''),
|
'mnasnet_small': _cfg(url=''),
|
||||||
'mobilenetv1_1_00': _cfg(url=''),
|
'mobilenetv1_1_00': _cfg(url=''),
|
||||||
@ -56,7 +59,7 @@ default_cfgs = {
|
|||||||
'chamnetv1_1_00': _cfg(url=''),
|
'chamnetv1_1_00': _cfg(url=''),
|
||||||
'chamnetv2_1_00': _cfg(url=''),
|
'chamnetv2_1_00': _cfg(url=''),
|
||||||
'fbnetc_1_00': _cfg(url=''),
|
'fbnetc_1_00': _cfg(url=''),
|
||||||
'spnasnet1_00': _cfg(url=''),
|
'spnasnet1_00': _cfg(url='https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet1_00-048bc3f4.pth?dl=1'),
|
||||||
}
|
}
|
||||||
|
|
||||||
_DEBUG = True
|
_DEBUG = True
|
||||||
@ -184,11 +187,15 @@ def _decode_block_str(block_str):
|
|||||||
return [deepcopy(block_args) for _ in range(num_repeat)]
|
return [deepcopy(block_args) for _ in range(num_repeat)]
|
||||||
|
|
||||||
|
|
||||||
def _get_padding(kernel_size, stride, dilation):
|
def _get_padding(kernel_size, stride, dilation=1):
|
||||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||||
return padding
|
return padding
|
||||||
|
|
||||||
|
|
||||||
|
def _padding_arg(default, padding_same=False):
|
||||||
|
return 'SAME' if padding_same else default
|
||||||
|
|
||||||
|
|
||||||
def _decode_arch_args(string_list):
|
def _decode_arch_args(string_list):
|
||||||
block_args = []
|
block_args = []
|
||||||
for block_str in string_list:
|
for block_str in string_list:
|
||||||
@ -219,12 +226,15 @@ class _BlockBuilder:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, depth_multiplier=1.0, depth_divisor=8, min_depth=None,
|
def __init__(self, depth_multiplier=1.0, depth_divisor=8, min_depth=None,
|
||||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||||
|
folded_bn=False, padding_same=False):
|
||||||
self.depth_multiplier = depth_multiplier
|
self.depth_multiplier = depth_multiplier
|
||||||
self.depth_divisor = depth_divisor
|
self.depth_divisor = depth_divisor
|
||||||
self.min_depth = min_depth
|
self.min_depth = min_depth
|
||||||
self.bn_momentum = bn_momentum
|
self.bn_momentum = bn_momentum
|
||||||
self.bn_eps = bn_eps
|
self.bn_eps = bn_eps
|
||||||
|
self.folded_bn = folded_bn
|
||||||
|
self.padding_same = padding_same
|
||||||
self.in_chs = None
|
self.in_chs = None
|
||||||
|
|
||||||
def _round_channels(self, chs):
|
def _round_channels(self, chs):
|
||||||
@ -236,6 +246,8 @@ class _BlockBuilder:
|
|||||||
ba['out_chs'] = _round_channels(ba['out_chs'])
|
ba['out_chs'] = _round_channels(ba['out_chs'])
|
||||||
ba['bn_momentum'] = self.bn_momentum
|
ba['bn_momentum'] = self.bn_momentum
|
||||||
ba['bn_eps'] = self.bn_eps
|
ba['bn_eps'] = self.bn_eps
|
||||||
|
ba['folded_bn'] = self.folded_bn
|
||||||
|
ba['padding_same'] = self.padding_same
|
||||||
if _DEBUG:
|
if _DEBUG:
|
||||||
print('args:', ba)
|
print('args:', ba)
|
||||||
# could replace this with lambdas or functools binding if variety increases
|
# could replace this with lambdas or functools binding if variety increases
|
||||||
@ -320,29 +332,37 @@ def _initialize_weight_default(m):
|
|||||||
class DepthwiseSeparableConv(nn.Module):
|
class DepthwiseSeparableConv(nn.Module):
|
||||||
def __init__(self, in_chs, out_chs, kernel_size,
|
def __init__(self, in_chs, out_chs, kernel_size,
|
||||||
stride=1, act_fn=F.relu, noskip=False, pw_act=False,
|
stride=1, act_fn=F.relu, noskip=False, pw_act=False,
|
||||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||||
|
folded_bn=False, padding_same=False):
|
||||||
super(DepthwiseSeparableConv, self).__init__()
|
super(DepthwiseSeparableConv, self).__init__()
|
||||||
assert stride in [1, 2]
|
assert stride in [1, 2]
|
||||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||||
self.has_pw_act = pw_act # activation after point-wise conv
|
self.has_pw_act = pw_act # activation after point-wise conv
|
||||||
self.act_fn = act_fn
|
self.act_fn = act_fn
|
||||||
|
dw_padding = _padding_arg(kernel_size // 2, padding_same)
|
||||||
|
pw_padding = _padding_arg(0, padding_same)
|
||||||
|
|
||||||
self.conv_dw = nn.Conv2d(
|
self.conv_dw = sconv2d(
|
||||||
in_chs, in_chs, kernel_size,
|
in_chs, in_chs, kernel_size,
|
||||||
stride=stride, padding=kernel_size // 2, groups=in_chs, bias=False)
|
stride=stride, padding=dw_padding, groups=in_chs, bias=folded_bn)
|
||||||
self.bn1 = nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
|
self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
|
||||||
self.conv_pw = nn.Conv2d(in_chs, out_chs, 1, bias=False)
|
self.conv_pw = sconv2d(in_chs, out_chs, 1, padding=pw_padding, bias=folded_bn)
|
||||||
self.bn2 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
|
|
||||||
x = self.conv_dw(x)
|
x = self.conv_dw(x)
|
||||||
x = self.bn1(x)
|
if self.bn1 is not None:
|
||||||
|
x = self.bn1(x)
|
||||||
x = self.act_fn(x)
|
x = self.act_fn(x)
|
||||||
|
|
||||||
x = self.conv_pw(x)
|
x = self.conv_pw(x)
|
||||||
x = self.bn2(x)
|
if self.bn2 is not None:
|
||||||
|
x = self.bn2(x)
|
||||||
if self.has_pw_act:
|
if self.has_pw_act:
|
||||||
x = self.act_fn(x)
|
x = self.act_fn(x)
|
||||||
|
|
||||||
if self.has_residual:
|
if self.has_residual:
|
||||||
x += residual
|
x += residual
|
||||||
return x
|
return x
|
||||||
@ -351,24 +371,28 @@ class DepthwiseSeparableConv(nn.Module):
|
|||||||
class CascadeConv3x3(nn.Sequential):
|
class CascadeConv3x3(nn.Sequential):
|
||||||
# FIXME lifted from maskrcnn_benchmark blocks, haven't used yet
|
# FIXME lifted from maskrcnn_benchmark blocks, haven't used yet
|
||||||
def __init__(self, in_chs, out_chs, stride, act_fn=F.relu, noskip=False,
|
def __init__(self, in_chs, out_chs, stride, act_fn=F.relu, noskip=False,
|
||||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||||
|
folded_bn=False, padding_same=False):
|
||||||
super(CascadeConv3x3, self).__init__()
|
super(CascadeConv3x3, self).__init__()
|
||||||
assert stride in [1, 2]
|
assert stride in [1, 2]
|
||||||
self.has_residual = not noskip and (stride == 1 and in_chs == out_chs)
|
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||||
self.act_fn = act_fn
|
self.act_fn = act_fn
|
||||||
|
padding = _padding_arg(1, padding_same)
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(in_chs, in_chs, 3, stride=stride, padding=1, bias=False)
|
self.conv1 = sconv2d(in_chs, in_chs, 3, stride=stride, padding=padding, bias=folded_bn)
|
||||||
self.bn1 = nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
|
self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
|
||||||
self.conv2 = nn.Conv2d(in_chs, out_chs, 3, stride=1, padding=1, bias=False)
|
self.conv2 = sconv2d(in_chs, out_chs, 3, stride=1, padding=padding, bias=folded_bn)
|
||||||
self.bn2 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.bn1(x)
|
if self.bn1 is not None:
|
||||||
|
x = self.bn1(x)
|
||||||
x = self.act_fn(x)
|
x = self.act_fn(x)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.bn2(x)
|
if self.bn2 is not None:
|
||||||
|
x = self.bn2(x)
|
||||||
if self.has_residual:
|
if self.has_residual:
|
||||||
x += residual
|
x += residual
|
||||||
return x
|
return x
|
||||||
@ -396,10 +420,10 @@ class ChannelShuffle(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SqueezeExcite(nn.Module):
|
class SqueezeExcite(nn.Module):
|
||||||
def __init__(self, in_chs, se_ratio=0.25, act_fn=F.relu):
|
def __init__(self, in_chs, reduce_chs=None, act_fn=F.relu):
|
||||||
super(SqueezeExcite, self).__init__()
|
super(SqueezeExcite, self).__init__()
|
||||||
self.act_fn = act_fn
|
self.act_fn = act_fn
|
||||||
reduced_chs = max(1, int(in_chs * se_ratio))
|
reduced_chs = reduce_chs or in_chs
|
||||||
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
||||||
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
||||||
|
|
||||||
@ -419,41 +443,44 @@ class InvertedResidual(nn.Module):
|
|||||||
def __init__(self, in_chs, out_chs, kernel_size,
|
def __init__(self, in_chs, out_chs, kernel_size,
|
||||||
stride=1, act_fn=F.relu, exp_ratio=1.0, noskip=False,
|
stride=1, act_fn=F.relu, exp_ratio=1.0, noskip=False,
|
||||||
se_ratio=0., shuffle_type=None, pw_group=1,
|
se_ratio=0., shuffle_type=None, pw_group=1,
|
||||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT):
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||||
|
folded_bn=False, padding_same=False):
|
||||||
super(InvertedResidual, self).__init__()
|
super(InvertedResidual, self).__init__()
|
||||||
mid_chs = int(in_chs * exp_ratio)
|
mid_chs = int(in_chs * exp_ratio)
|
||||||
self.has_se = se_ratio is not None and se_ratio > 0.
|
self.has_se = se_ratio is not None and se_ratio > 0.
|
||||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||||
self.act_fn = act_fn
|
self.act_fn = act_fn
|
||||||
|
dw_padding = _padding_arg(kernel_size // 2, padding_same)
|
||||||
|
pw_padding = _padding_arg(0, padding_same)
|
||||||
|
|
||||||
# Point-wise expansion
|
# Point-wise expansion
|
||||||
self.conv_pw = nn.Conv2d(in_chs, mid_chs, 1, groups=pw_group, bias=False)
|
self.conv_pw = sconv2d(in_chs, mid_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn)
|
||||||
self.bn1 = nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
|
self.bn1 = None if folded_bn else nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
|
||||||
|
|
||||||
self.shuffle_type = shuffle_type
|
self.shuffle_type = shuffle_type
|
||||||
if shuffle_type is not None:
|
if shuffle_type is not None:
|
||||||
self.shuffle = ChannelShuffle(pw_group)
|
self.shuffle = ChannelShuffle(pw_group)
|
||||||
|
|
||||||
# Depth-wise convolution
|
# Depth-wise convolution
|
||||||
self.conv_dw = nn.Conv2d(
|
self.conv_dw = sconv2d(
|
||||||
mid_chs, mid_chs, kernel_size, padding=kernel_size // 2,
|
mid_chs, mid_chs, kernel_size, padding=dw_padding, stride=stride, groups=mid_chs, bias=folded_bn)
|
||||||
stride=stride, groups=mid_chs, bias=False)
|
self.bn2 = None if folded_bn else nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
|
||||||
self.bn2 = nn.BatchNorm2d(mid_chs, momentum=bn_momentum, eps=bn_eps)
|
|
||||||
|
|
||||||
# Squeeze-and-excitation
|
# Squeeze-and-excitation
|
||||||
if self.has_se:
|
if self.has_se:
|
||||||
self.se = SqueezeExcite(mid_chs, se_ratio)
|
self.se = SqueezeExcite(mid_chs, reduce_chs=max(1, int(in_chs * se_ratio)))
|
||||||
|
|
||||||
# Point-wise linear projection
|
# Point-wise linear projection
|
||||||
self.conv_pwl = nn.Conv2d(mid_chs, out_chs, 1, groups=pw_group, bias=False)
|
self.conv_pwl = sconv2d(mid_chs, out_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn)
|
||||||
self.bn3 = nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
self.bn3 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
|
|
||||||
# Point-wise expansion
|
# Point-wise expansion
|
||||||
x = self.conv_pw(x)
|
x = self.conv_pw(x)
|
||||||
x = self.bn1(x)
|
if self.bn1 is not None:
|
||||||
|
x = self.bn1(x)
|
||||||
x = self.act_fn(x)
|
x = self.act_fn(x)
|
||||||
|
|
||||||
# FIXME haven't tried this yet
|
# FIXME haven't tried this yet
|
||||||
@ -463,7 +490,8 @@ class InvertedResidual(nn.Module):
|
|||||||
|
|
||||||
# Depth-wise convolution
|
# Depth-wise convolution
|
||||||
x = self.conv_dw(x)
|
x = self.conv_dw(x)
|
||||||
x = self.bn2(x)
|
if self.bn2 is not None:
|
||||||
|
x = self.bn2(x)
|
||||||
x = self.act_fn(x)
|
x = self.act_fn(x)
|
||||||
|
|
||||||
# Squeeze-and-excitation
|
# Squeeze-and-excitation
|
||||||
@ -472,7 +500,8 @@ class InvertedResidual(nn.Module):
|
|||||||
|
|
||||||
# Point-wise linear projection
|
# Point-wise linear projection
|
||||||
x = self.conv_pwl(x)
|
x = self.conv_pwl(x)
|
||||||
x = self.bn3(x)
|
if self.bn3 is not None:
|
||||||
|
x = self.bn3(x)
|
||||||
|
|
||||||
if self.has_residual:
|
if self.has_residual:
|
||||||
x += residual
|
x += residual
|
||||||
@ -498,7 +527,7 @@ class GenMobileNet(nn.Module):
|
|||||||
depth_multiplier=1.0, depth_divisor=8, min_depth=None,
|
depth_multiplier=1.0, depth_divisor=8, min_depth=None,
|
||||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||||
drop_rate=0., act_fn=F.relu, global_pool='avg', skip_head_conv=False,
|
drop_rate=0., act_fn=F.relu, global_pool='avg', skip_head_conv=False,
|
||||||
weight_init='goog'):
|
weight_init='goog', folded_bn=False, padding_same=False):
|
||||||
super(GenMobileNet, self).__init__()
|
super(GenMobileNet, self).__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.depth_multiplier = depth_multiplier
|
self.depth_multiplier = depth_multiplier
|
||||||
@ -507,13 +536,15 @@ class GenMobileNet(nn.Module):
|
|||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
|
|
||||||
stem_size = _round_channels(stem_size, depth_multiplier, depth_divisor, min_depth)
|
stem_size = _round_channels(stem_size, depth_multiplier, depth_divisor, min_depth)
|
||||||
self.conv_stem = nn.Conv2d(in_chans, stem_size, 3, padding=1, stride=2, bias=False)
|
self.conv_stem = sconv2d(
|
||||||
self.bn1 = nn.BatchNorm2d(stem_size, momentum=bn_momentum, eps=bn_eps)
|
in_chans, stem_size, 3,
|
||||||
|
padding=_padding_arg(1, padding_same), stride=2, bias=folded_bn)
|
||||||
|
self.bn1 = None if folded_bn else nn.BatchNorm2d(stem_size, momentum=bn_momentum, eps=bn_eps)
|
||||||
in_chs = stem_size
|
in_chs = stem_size
|
||||||
|
|
||||||
builder = _BlockBuilder(
|
builder = _BlockBuilder(
|
||||||
depth_multiplier, depth_divisor, min_depth,
|
depth_multiplier, depth_divisor, min_depth,
|
||||||
bn_momentum, bn_eps)
|
bn_momentum, bn_eps, folded_bn, padding_same)
|
||||||
self.blocks = nn.Sequential(*builder(in_chs, block_args))
|
self.blocks = nn.Sequential(*builder(in_chs, block_args))
|
||||||
in_chs = builder.in_chs
|
in_chs = builder.in_chs
|
||||||
|
|
||||||
@ -521,8 +552,10 @@ class GenMobileNet(nn.Module):
|
|||||||
self.conv_head = None
|
self.conv_head = None
|
||||||
assert in_chs == self.num_features
|
assert in_chs == self.num_features
|
||||||
else:
|
else:
|
||||||
self.conv_head = nn.Conv2d(in_chs, self.num_features, 1, padding=0, stride=1, bias=False)
|
self.conv_head = sconv2d(
|
||||||
self.bn2 = nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps)
|
in_chs, self.num_features, 1,
|
||||||
|
padding=_padding_arg(0, padding_same), bias=folded_bn)
|
||||||
|
self.bn2 = None if folded_bn else nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps)
|
||||||
|
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
self.classifier = nn.Linear(self.num_features, self.num_classes)
|
self.classifier = nn.Linear(self.num_features, self.num_classes)
|
||||||
@ -548,12 +581,14 @@ class GenMobileNet(nn.Module):
|
|||||||
|
|
||||||
def forward_features(self, x, pool=True):
|
def forward_features(self, x, pool=True):
|
||||||
x = self.conv_stem(x)
|
x = self.conv_stem(x)
|
||||||
x = self.bn1(x)
|
if self.bn1 is not None:
|
||||||
|
x = self.bn1(x)
|
||||||
x = self.act_fn(x)
|
x = self.act_fn(x)
|
||||||
x = self.blocks(x)
|
x = self.blocks(x)
|
||||||
if self.conv_head is not None:
|
if self.conv_head is not None:
|
||||||
x = self.conv_head(x)
|
x = self.conv_head(x)
|
||||||
x = self.bn2(x)
|
if self.bn2 is not None:
|
||||||
|
x = self.bn2(x)
|
||||||
x = self.act_fn(x)
|
x = self.act_fn(x)
|
||||||
if pool:
|
if pool:
|
||||||
x = self.global_pool(x)
|
x = self.global_pool(x)
|
||||||
@ -909,6 +944,19 @@ def mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def tflite_mnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" MNASNet B1, depth multiplier of 1.0. """
|
||||||
|
default_cfg = default_cfgs['tflite_mnasnet1_00']
|
||||||
|
# these two args are for compat with tflite pretrained weights
|
||||||
|
kwargs['folded_bn'] = True
|
||||||
|
kwargs['padding_same'] = True
|
||||||
|
model = _gen_mnasnet_b1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def mnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs):
|
def mnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
""" MNASNet B1, depth multiplier of 1.4 """
|
""" MNASNet B1, depth multiplier of 1.4 """
|
||||||
default_cfg = default_cfgs['mnasnet1_40']
|
default_cfg = default_cfgs['mnasnet1_40']
|
||||||
@ -949,6 +997,19 @@ def semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def tflite_semnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" MNASNet A1, depth multiplier of 1.0. """
|
||||||
|
default_cfg = default_cfgs['tflite_semnasnet1_00']
|
||||||
|
# these two args are for compat with tflite pretrained weights
|
||||||
|
kwargs['folded_bn'] = True
|
||||||
|
kwargs['padding_same'] = True
|
||||||
|
model = _gen_mnasnet_a1(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def semnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs):
|
def semnasnet1_40(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
""" MNASNet A1 (w/ SE), depth multiplier of 1.4. """
|
""" MNASNet A1 (w/ SE), depth multiplier of 1.4. """
|
||||||
default_cfg = default_cfgs['semnasnet1_40']
|
default_cfg = default_cfgs['semnasnet1_40']
|
||||||
|
@ -9,8 +9,8 @@ from models.senet import seresnet18, seresnet34, seresnet50, seresnet101, seresn
|
|||||||
from models.xception import xception
|
from models.xception import xception
|
||||||
from models.pnasnet import pnasnet5large
|
from models.pnasnet import pnasnet5large
|
||||||
from models.genmobilenet import \
|
from models.genmobilenet import \
|
||||||
mnasnet0_50, mnasnet0_75, mnasnet1_00, mnasnet1_40,\
|
mnasnet0_50, mnasnet0_75, mnasnet1_00, mnasnet1_40, tflite_mnasnet1_00,\
|
||||||
semnasnet0_50, semnasnet0_75, semnasnet1_00, semnasnet1_40, mnasnet_small,\
|
semnasnet0_50, semnasnet0_75, semnasnet1_00, semnasnet1_40, tflite_semnasnet1_00, mnasnet_small,\
|
||||||
mobilenetv1_1_00, mobilenetv2_1_00, fbnetc_1_00, chamnetv1_1_00, chamnetv2_1_00,\
|
mobilenetv1_1_00, mobilenetv2_1_00, fbnetc_1_00, chamnetv1_1_00, chamnetv2_1_00,\
|
||||||
spnasnet1_00
|
spnasnet1_00
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user