mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add EfficientNet impl, change existing depth_multipler -> channel_multiplier as definitions have been confused
This commit is contained in:
parent
6bff9c75dc
commit
7a6d61566e
@ -30,6 +30,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec
|
||||
* DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene)
|
||||
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
|
||||
* Generic MobileNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks
|
||||
* EfficientNet (B0-B4) (https://arxiv.org/abs/1905.11946) -- work in progress, validating
|
||||
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
|
||||
* MobileNet-V1 (https://arxiv.org/abs/1704.04861)
|
||||
* MobileNet-V2 (https://arxiv.org/abs/1801.04381)
|
||||
|
@ -1,6 +1,7 @@
|
||||
""" Generic MobileNet
|
||||
|
||||
A generic MobileNet class with building blocks to support a variety of models:
|
||||
* EfficientNet (B0-B4 in code right now, work in progress, still verifying)
|
||||
* MNasNet B1, A1 (SE), Small
|
||||
* MobileNet V1, V2, and V3 (work in progress)
|
||||
* FBNet-C (TODO A & B)
|
||||
@ -30,7 +31,8 @@ _models = [
|
||||
'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075',
|
||||
'semnasnet_100', 'semnasnet_140', 'mnasnet_small', 'mobilenetv1_100', 'mobilenetv2_100',
|
||||
'mobilenetv3_050', 'mobilenetv3_075', 'mobilenetv3_100', 'chamnetv1_100', 'chamnetv2_100',
|
||||
'fbnetc_100', 'spnasnet_100', 'tflite_mnasnet_100', 'tflite_semnasnet_100']
|
||||
'fbnetc_100', 'spnasnet_100', 'tflite_mnasnet_100', 'tflite_semnasnet_100', 'efficientnet_b0',
|
||||
'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4']
|
||||
__all__ = ['GenMobileNet', 'genmobilenet_model_names'] + _models
|
||||
|
||||
|
||||
@ -67,6 +69,11 @@ default_cfgs = {
|
||||
'chamnetv2_100': _cfg(url=''),
|
||||
'fbnetc_100': _cfg(url='https://www.dropbox.com/s/0ku2tztuibrynld/fbnetc_100-f49a0c5f.pth?dl=1'),
|
||||
'spnasnet_100': _cfg(url='https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet_100-048bc3f4.pth?dl=1'),
|
||||
'efficientnet_b0': _cfg(url=''),
|
||||
'efficientnet_b1': _cfg(url='', input_size=(3, 240, 240)),
|
||||
'efficientnet_b2': _cfg(url='', input_size=(3, 260, 260)),
|
||||
'efficientnet_b3': _cfg(url='', input_size=(3, 300, 300)),
|
||||
'efficientnet_b4': _cfg(url='', input_size=(3, 380, 380)),
|
||||
}
|
||||
|
||||
_DEBUG = False
|
||||
@ -101,23 +108,23 @@ def _resolve_bn_params(kwargs):
|
||||
return bn_momentum, bn_eps
|
||||
|
||||
|
||||
def _round_channels(channels, depth_multiplier=1.0, depth_divisor=8, min_depth=None):
|
||||
def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
if not depth_multiplier:
|
||||
if not multiplier:
|
||||
return channels
|
||||
|
||||
channels *= depth_multiplier
|
||||
min_depth = min_depth or depth_divisor
|
||||
channels *= multiplier
|
||||
channel_min = channel_min or divisor
|
||||
new_channels = max(
|
||||
int(channels + depth_divisor / 2) // depth_divisor * depth_divisor,
|
||||
min_depth)
|
||||
int(channels + divisor / 2) // divisor * divisor,
|
||||
channel_min)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_channels < 0.9 * channels:
|
||||
new_channels += depth_divisor
|
||||
new_channels += divisor
|
||||
return new_channels
|
||||
|
||||
|
||||
def _decode_block_str(block_str):
|
||||
def _decode_block_str(block_str, depth_multiplier=1.0):
|
||||
""" Decode block definition string
|
||||
|
||||
Gets a list of block arg (dicts) through a string notation of arguments.
|
||||
@ -207,6 +214,7 @@ def _decode_block_str(block_str):
|
||||
block_type=block_type,
|
||||
kernel_size=int(options['k']),
|
||||
out_chs=int(options['c']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=int(options['s']),
|
||||
act_fn=act_fn,
|
||||
noskip=block_type == 'dsa' or noskip,
|
||||
@ -223,7 +231,9 @@ def _decode_block_str(block_str):
|
||||
else:
|
||||
assert False, 'Unknown block type (%s)' % block_type
|
||||
|
||||
# return a list of block args expanded by num_repeat
|
||||
# return a list of block args expanded by num_repeat and
|
||||
# scaled by depth_multiplier
|
||||
num_repeat = int(math.ceil(num_repeat * depth_multiplier))
|
||||
return [deepcopy(block_args) for _ in range(num_repeat)]
|
||||
|
||||
|
||||
@ -243,14 +253,14 @@ def _decode_arch_args(string_list):
|
||||
return block_args
|
||||
|
||||
|
||||
def _decode_arch_def(arch_def):
|
||||
def _decode_arch_def(arch_def, depth_multiplier=1.0):
|
||||
arch_args = []
|
||||
for stack_idx, block_strings in enumerate(arch_def):
|
||||
assert isinstance(block_strings, list)
|
||||
stack_args = []
|
||||
for block_str in block_strings:
|
||||
assert isinstance(block_str, str)
|
||||
stack_args.extend(_decode_block_str(block_str))
|
||||
stack_args.extend(_decode_block_str(block_str, depth_multiplier))
|
||||
arch_args.append(stack_args)
|
||||
return arch_args
|
||||
|
||||
@ -265,13 +275,13 @@ class _BlockBuilder:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, depth_multiplier=1.0, depth_divisor=8, min_depth=None,
|
||||
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
act_fn=None, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
folded_bn=False, padding_same=False, verbose=False):
|
||||
self.depth_multiplier = depth_multiplier
|
||||
self.depth_divisor = depth_divisor
|
||||
self.min_depth = min_depth
|
||||
self.channel_multiplier = channel_multiplier
|
||||
self.channel_divisor = channel_divisor
|
||||
self.channel_min = channel_min
|
||||
self.act_fn = act_fn
|
||||
self.se_gate_fn = se_gate_fn
|
||||
self.se_reduce_mid = se_reduce_mid
|
||||
@ -283,7 +293,7 @@ class _BlockBuilder:
|
||||
self.in_chs = None
|
||||
|
||||
def _round_channels(self, chs):
|
||||
return _round_channels(chs, self.depth_multiplier, self.depth_divisor, self.min_depth)
|
||||
return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
|
||||
|
||||
def _make_block(self, ba):
|
||||
bt = ba.pop('block_type')
|
||||
@ -327,7 +337,7 @@ class _BlockBuilder:
|
||||
blocks.append(block)
|
||||
return nn.Sequential(*blocks)
|
||||
|
||||
def __call__(self, in_chs, arch_def):
|
||||
def __call__(self, in_chs, block_args):
|
||||
""" Build the blocks
|
||||
Args:
|
||||
in_chs: Number of input-channels passed to first block
|
||||
@ -336,13 +346,12 @@ class _BlockBuilder:
|
||||
Return:
|
||||
List of block stacks (each stack wrapped in nn.Sequential)
|
||||
"""
|
||||
arch_args = _decode_arch_def(arch_def) # convert and expand string defs to arg dicts
|
||||
if self.verbose:
|
||||
print('Building model trunk with %d stacks (stages)...' % len(arch_args))
|
||||
print('Building model trunk with %d stacks (stages)...' % len(block_args))
|
||||
self.in_chs = in_chs
|
||||
blocks = []
|
||||
# outer list of arch_args defines the stacks ('stages' by some conventions)
|
||||
for stack_idx, stack in enumerate(arch_args):
|
||||
# outer list of block_args defines the stacks ('stages' by some conventions)
|
||||
for stack_idx, stack in enumerate(block_args):
|
||||
if self.verbose:
|
||||
print('stack', stack_idx)
|
||||
assert isinstance(stack, list)
|
||||
@ -381,6 +390,10 @@ def _initialize_weight_default(m):
|
||||
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def hard_swish(x):
|
||||
return x * F.relu6(x + 3.) / 6.
|
||||
|
||||
@ -389,98 +402,6 @@ def hard_sigmoid(x):
|
||||
return F.relu6(x + 3.) / 6.
|
||||
|
||||
|
||||
class ConvBnAct(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, kernel_size,
|
||||
stride=1, act_fn=F.relu,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
folded_bn=False, padding_same=False):
|
||||
super(ConvBnAct, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
self.act_fn = act_fn
|
||||
padding = _padding_arg(_get_padding(kernel_size, stride), padding_same)
|
||||
|
||||
self.conv = sconv2d(
|
||||
in_chs, out_chs, kernel_size,
|
||||
stride=stride, padding=padding, bias=folded_bn)
|
||||
self.bn1 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn1 is not None:
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
return x
|
||||
|
||||
|
||||
class DepthwiseSeparableConv(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, kernel_size,
|
||||
stride=1, act_fn=F.relu, noskip=False, pw_act=False,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
folded_bn=False, padding_same=False):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.has_pw_act = pw_act # activation after point-wise conv
|
||||
self.act_fn = act_fn
|
||||
dw_padding = _padding_arg(kernel_size // 2, padding_same)
|
||||
pw_padding = _padding_arg(0, padding_same)
|
||||
|
||||
self.conv_dw = sconv2d(
|
||||
in_chs, in_chs, kernel_size,
|
||||
stride=stride, padding=dw_padding, groups=in_chs, bias=folded_bn)
|
||||
self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv_pw = sconv2d(in_chs, out_chs, 1, padding=pw_padding, bias=folded_bn)
|
||||
self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
x = self.conv_dw(x)
|
||||
if self.bn1 is not None:
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
|
||||
x = self.conv_pw(x)
|
||||
if self.bn2 is not None:
|
||||
x = self.bn2(x)
|
||||
if self.has_pw_act:
|
||||
x = self.act_fn(x)
|
||||
|
||||
if self.has_residual:
|
||||
x += residual
|
||||
return x
|
||||
|
||||
|
||||
class CascadeConv(nn.Sequential):
|
||||
# FIXME haven't used yet
|
||||
def __init__(self, in_chs, out_chs, kernel_size=3, stride=2, act_fn=F.relu, noskip=False,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
folded_bn=False, padding_same=False):
|
||||
super(CascadeConv, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.act_fn = act_fn
|
||||
padding = _padding_arg(1, padding_same)
|
||||
|
||||
self.conv1 = sconv2d(in_chs, in_chs, kernel_size, stride=stride, padding=padding, bias=folded_bn)
|
||||
self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv2 = sconv2d(in_chs, out_chs, kernel_size, stride=1, padding=padding, bias=folded_bn)
|
||||
self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
if self.bn1 is not None:
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.conv2(x)
|
||||
if self.bn2 is not None:
|
||||
x = self.bn2(x)
|
||||
if self.has_residual:
|
||||
x += residual
|
||||
return x
|
||||
|
||||
|
||||
class ChannelShuffle(nn.Module):
|
||||
# FIXME haven't used yet
|
||||
def __init__(self, groups):
|
||||
@ -521,6 +442,113 @@ class SqueezeExcite(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class ConvBnAct(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, kernel_size,
|
||||
stride=1, act_fn=F.relu,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
folded_bn=False, padding_same=False):
|
||||
super(ConvBnAct, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
self.act_fn = act_fn
|
||||
padding = _padding_arg(_get_padding(kernel_size, stride), padding_same)
|
||||
|
||||
self.conv = sconv2d(
|
||||
in_chs, out_chs, kernel_size,
|
||||
stride=stride, padding=padding, bias=folded_bn)
|
||||
self.bn1 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn1 is not None:
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
return x
|
||||
|
||||
|
||||
class DepthwiseSeparableConv(nn.Module):
|
||||
""" DepthwiseSeparable block
|
||||
Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
|
||||
factor of 1.0. This is an alternative to having a IR with optional first pw conv.
|
||||
"""
|
||||
def __init__(self, in_chs, out_chs, kernel_size,
|
||||
stride=1, act_fn=F.relu, noskip=False, pw_act=False,
|
||||
se_ratio=0., se_gate_fn=torch.sigmoid,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
folded_bn=False, padding_same=False):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
self.has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.has_pw_act = pw_act # activation after point-wise conv
|
||||
self.act_fn = act_fn
|
||||
dw_padding = _padding_arg(kernel_size // 2, padding_same)
|
||||
pw_padding = _padding_arg(0, padding_same)
|
||||
|
||||
self.conv_dw = sconv2d(
|
||||
in_chs, in_chs, kernel_size,
|
||||
stride=stride, padding=dw_padding, groups=in_chs, bias=folded_bn)
|
||||
self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if self.has_se:
|
||||
self.se = SqueezeExcite(
|
||||
in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn)
|
||||
|
||||
self.conv_pw = sconv2d(in_chs, out_chs, 1, padding=pw_padding, bias=folded_bn)
|
||||
self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
x = self.conv_dw(x)
|
||||
if self.bn1 is not None:
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
|
||||
if self.has_se:
|
||||
x = self.se(x)
|
||||
|
||||
x = self.conv_pw(x)
|
||||
if self.bn2 is not None:
|
||||
x = self.bn2(x)
|
||||
if self.has_pw_act:
|
||||
x = self.act_fn(x)
|
||||
|
||||
if self.has_residual:
|
||||
x += residual # FIXME add drop-connect
|
||||
return x
|
||||
|
||||
|
||||
class CascadeConv(nn.Sequential):
|
||||
# FIXME haven't used yet
|
||||
def __init__(self, in_chs, out_chs, kernel_size=3, stride=2, act_fn=F.relu, noskip=False,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
folded_bn=False, padding_same=False):
|
||||
super(CascadeConv, self).__init__()
|
||||
assert stride in [1, 2]
|
||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.act_fn = act_fn
|
||||
padding = _padding_arg(1, padding_same)
|
||||
|
||||
self.conv1 = sconv2d(in_chs, in_chs, kernel_size, stride=stride, padding=padding, bias=folded_bn)
|
||||
self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
self.conv2 = sconv2d(in_chs, out_chs, kernel_size, stride=1, padding=padding, bias=folded_bn)
|
||||
self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
if self.bn1 is not None:
|
||||
x = self.bn1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.conv2(x)
|
||||
if self.bn2 is not None:
|
||||
x = self.bn2(x)
|
||||
if self.has_residual:
|
||||
x += residual
|
||||
return x
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
""" Inverted residual block w/ optional SE"""
|
||||
|
||||
@ -554,8 +582,8 @@ class InvertedResidual(nn.Module):
|
||||
# Squeeze-and-excitation
|
||||
if self.has_se:
|
||||
se_base_chs = mid_chs if se_reduce_mid else in_chs
|
||||
self.se = SqueezeExcite(mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)),
|
||||
act_fn=act_fn, gate_fn=se_gate_fn)
|
||||
self.se = SqueezeExcite(
|
||||
mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn)
|
||||
|
||||
# Point-wise linear projection
|
||||
self.conv_pwl = sconv2d(mid_chs, out_chs, 1, padding=pw_padding, groups=pw_group, bias=folded_bn)
|
||||
@ -591,7 +619,7 @@ class InvertedResidual(nn.Module):
|
||||
x = self.bn3(x)
|
||||
|
||||
if self.has_residual:
|
||||
x += residual
|
||||
x += residual # FIXME add drop-connect
|
||||
|
||||
# NOTE maskrcnn_benchmark building blocks have an SE module defined here for some variants
|
||||
|
||||
@ -609,22 +637,23 @@ class GenMobileNet(nn.Module):
|
||||
* FBNet A, B, and C
|
||||
* ChamNet (arch details are murky)
|
||||
* Single-Path NAS Pixel1
|
||||
* EfficientNet
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
|
||||
depth_multiplier=1.0, depth_divisor=8, min_depth=None,
|
||||
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
||||
drop_rate=0., act_fn=F.relu, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
|
||||
global_pool='avg', head_conv='default', weight_init='goog',
|
||||
folded_bn=False, padding_same=False):
|
||||
super(GenMobileNet, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.depth_multiplier = depth_multiplier
|
||||
self.depth_multiplier = channel_multiplier
|
||||
self.drop_rate = drop_rate
|
||||
self.act_fn = act_fn
|
||||
self.num_features = num_features
|
||||
|
||||
stem_size = _round_channels(stem_size, depth_multiplier, depth_divisor, min_depth)
|
||||
stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
||||
self.conv_stem = sconv2d(
|
||||
in_chans, stem_size, 3,
|
||||
padding=_padding_arg(1, padding_same), stride=2, bias=folded_bn)
|
||||
@ -632,7 +661,7 @@ class GenMobileNet(nn.Module):
|
||||
in_chs = stem_size
|
||||
|
||||
builder = _BlockBuilder(
|
||||
depth_multiplier, depth_divisor, min_depth,
|
||||
channel_multiplier, channel_divisor, channel_min,
|
||||
act_fn, se_gate_fn, se_reduce_mid,
|
||||
bn_momentum, bn_eps, folded_bn, padding_same, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(in_chs, block_args))
|
||||
@ -705,14 +734,14 @@ class GenMobileNet(nn.Module):
|
||||
return self.classifier(x)
|
||||
|
||||
|
||||
def _gen_mnasnet_a1(depth_multiplier, num_classes=1000, **kwargs):
|
||||
def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
"""Creates a mnasnet-a1 model.
|
||||
|
||||
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
|
||||
Paper: https://arxiv.org/pdf/1807.11626.pdf.
|
||||
|
||||
Args:
|
||||
depth_multiplier: multiplier to number of channels per layer.
|
||||
channel_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
@ -732,12 +761,12 @@ def _gen_mnasnet_a1(depth_multiplier, num_classes=1000, **kwargs):
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenMobileNet(
|
||||
arch_def,
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
depth_multiplier=depth_multiplier,
|
||||
depth_divisor=8,
|
||||
min_depth=None,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
**kwargs
|
||||
@ -745,14 +774,14 @@ def _gen_mnasnet_a1(depth_multiplier, num_classes=1000, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mnasnet_b1(depth_multiplier, num_classes=1000, **kwargs):
|
||||
def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
"""Creates a mnasnet-b1 model.
|
||||
|
||||
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
|
||||
Paper: https://arxiv.org/pdf/1807.11626.pdf.
|
||||
|
||||
Args:
|
||||
depth_multiplier: multiplier to number of channels per layer.
|
||||
channel_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
@ -772,12 +801,12 @@ def _gen_mnasnet_b1(depth_multiplier, num_classes=1000, **kwargs):
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenMobileNet(
|
||||
arch_def,
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
depth_multiplier=depth_multiplier,
|
||||
depth_divisor=8,
|
||||
min_depth=None,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
**kwargs
|
||||
@ -785,14 +814,14 @@ def _gen_mnasnet_b1(depth_multiplier, num_classes=1000, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mnasnet_small(depth_multiplier, num_classes=1000, **kwargs):
|
||||
def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs):
|
||||
"""Creates a mnasnet-b1 model.
|
||||
|
||||
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
|
||||
Paper: https://arxiv.org/pdf/1807.11626.pdf.
|
||||
|
||||
Args:
|
||||
depth_multiplier: multiplier to number of channels per layer.
|
||||
channel_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
arch_def = [
|
||||
['ds_r1_k3_s1_c8'],
|
||||
@ -805,12 +834,12 @@ def _gen_mnasnet_small(depth_multiplier, num_classes=1000, **kwargs):
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenMobileNet(
|
||||
arch_def,
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
stem_size=8,
|
||||
depth_multiplier=depth_multiplier,
|
||||
depth_divisor=8,
|
||||
min_depth=None,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
**kwargs
|
||||
@ -818,7 +847,7 @@ def _gen_mnasnet_small(depth_multiplier, num_classes=1000, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mobilenet_v1(depth_multiplier, num_classes=1000, **kwargs):
|
||||
def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
""" Generate MobileNet-V1 network
|
||||
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
|
||||
Paper: https://arxiv.org/abs/1801.04381
|
||||
@ -832,13 +861,13 @@ def _gen_mobilenet_v1(depth_multiplier, num_classes=1000, **kwargs):
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenMobileNet(
|
||||
arch_def,
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
num_features=1024,
|
||||
depth_multiplier=depth_multiplier,
|
||||
depth_divisor=8,
|
||||
min_depth=None,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
act_fn=F.relu6,
|
||||
@ -848,7 +877,7 @@ def _gen_mobilenet_v1(depth_multiplier, num_classes=1000, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mobilenet_v2(depth_multiplier, num_classes=1000, **kwargs):
|
||||
def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs):
|
||||
""" Generate MobileNet-V2 network
|
||||
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
|
||||
Paper: https://arxiv.org/abs/1801.04381
|
||||
@ -864,12 +893,12 @@ def _gen_mobilenet_v2(depth_multiplier, num_classes=1000, **kwargs):
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenMobileNet(
|
||||
arch_def,
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
depth_multiplier=depth_multiplier,
|
||||
depth_divisor=8,
|
||||
min_depth=None,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
act_fn=F.relu6,
|
||||
@ -878,14 +907,14 @@ def _gen_mobilenet_v2(depth_multiplier, num_classes=1000, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs):
|
||||
def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
|
||||
"""Creates a MobileNet-V3 model.
|
||||
|
||||
Ref impl: ?
|
||||
Paper: https://arxiv.org/abs/1905.02244
|
||||
|
||||
Args:
|
||||
depth_multiplier: multiplier to number of channels per layer.
|
||||
channel_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
@ -905,12 +934,12 @@ def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs):
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenMobileNet(
|
||||
arch_def,
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
stem_size=16,
|
||||
depth_multiplier=depth_multiplier,
|
||||
depth_divisor=8,
|
||||
min_depth=None,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
act_fn=hard_swish,
|
||||
@ -922,7 +951,7 @@ def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_chamnet_v1(depth_multiplier, num_classes=1000, **kwargs):
|
||||
def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs):
|
||||
""" Generate Chameleon Network (ChamNet)
|
||||
|
||||
Paper: https://arxiv.org/abs/1812.08934
|
||||
@ -941,13 +970,13 @@ def _gen_chamnet_v1(depth_multiplier, num_classes=1000, **kwargs):
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenMobileNet(
|
||||
arch_def,
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
num_features=1280, # no idea what this is? try mobile/mnasnet default?
|
||||
depth_multiplier=depth_multiplier,
|
||||
depth_divisor=8,
|
||||
min_depth=None,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
**kwargs
|
||||
@ -955,7 +984,7 @@ def _gen_chamnet_v1(depth_multiplier, num_classes=1000, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_chamnet_v2(depth_multiplier, num_classes=1000, **kwargs):
|
||||
def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs):
|
||||
""" Generate Chameleon Network (ChamNet)
|
||||
|
||||
Paper: https://arxiv.org/abs/1812.08934
|
||||
@ -974,13 +1003,13 @@ def _gen_chamnet_v2(depth_multiplier, num_classes=1000, **kwargs):
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenMobileNet(
|
||||
arch_def,
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
num_features=1280, # no idea what this is? try mobile/mnasnet default?
|
||||
depth_multiplier=depth_multiplier,
|
||||
depth_divisor=8,
|
||||
min_depth=None,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
**kwargs
|
||||
@ -988,7 +1017,7 @@ def _gen_chamnet_v2(depth_multiplier, num_classes=1000, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_fbnetc(depth_multiplier, num_classes=1000, **kwargs):
|
||||
def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs):
|
||||
""" FBNet-C
|
||||
|
||||
Paper: https://arxiv.org/abs/1812.03443
|
||||
@ -1008,13 +1037,13 @@ def _gen_fbnetc(depth_multiplier, num_classes=1000, **kwargs):
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenMobileNet(
|
||||
arch_def,
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
stem_size=16,
|
||||
num_features=1984, # paper suggests this, but is not 100% clear
|
||||
depth_multiplier=depth_multiplier,
|
||||
depth_divisor=8,
|
||||
min_depth=None,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
**kwargs
|
||||
@ -1022,13 +1051,13 @@ def _gen_fbnetc(depth_multiplier, num_classes=1000, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_spnasnet(depth_multiplier, num_classes=1000, **kwargs):
|
||||
def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
|
||||
"""Creates the Single-Path NAS model from search targeted for Pixel1 phone.
|
||||
|
||||
Paper: https://arxiv.org/abs/1904.02877
|
||||
|
||||
Args:
|
||||
depth_multiplier: multiplier to number of channels per layer.
|
||||
channel_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
@ -1048,12 +1077,12 @@ def _gen_spnasnet(depth_multiplier, num_classes=1000, **kwargs):
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenMobileNet(
|
||||
arch_def,
|
||||
_decode_arch_def(arch_def),
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
depth_multiplier=depth_multiplier,
|
||||
depth_divisor=8,
|
||||
min_depth=None,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
**kwargs
|
||||
@ -1061,6 +1090,41 @@ def _gen_spnasnet(depth_multiplier, num_classes=1000, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
|
||||
"""Creates a MobileNet-V3 model.
|
||||
|
||||
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
||||
Paper: https://arxiv.org/abs/1905.11946
|
||||
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer
|
||||
depth_multiplier: multiplier to number of repeats per stage
|
||||
"""
|
||||
arch_def = [
|
||||
['ds_r1_k3_s1_e1_c16_se0.25'],
|
||||
['ir_r2_k3_s2_e6_c24_se0.25'],
|
||||
['ir_r2_k5_s2_e6_c40_se0.25'],
|
||||
['ir_r3_k3_s2_e6_c80_se0.25'],
|
||||
['ir_r3_k5_s1_e6_c112_se0.25'],
|
||||
['ir_r4_k5_s2_e6_c192_se0.25'],
|
||||
['ir_r1_k3_s1_e6_c320_se0.25'],
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenMobileNet(
|
||||
_decode_arch_def(arch_def, depth_multiplier),
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
channel_divisor=8,
|
||||
channel_min=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
act_fn=swish,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def mnasnet_050(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
""" MNASNet B1, depth multiplier of 0.5. """
|
||||
default_cfg = default_cfgs['mnasnet_050']
|
||||
@ -1270,5 +1334,81 @@ def spnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
# EfficientNet params
|
||||
# (width_coefficient, depth_coefficient, resolution, dropout_rate)
|
||||
# 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
||||
# 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
||||
# 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
||||
# 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
||||
# 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
||||
# 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
||||
# 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
||||
# 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
||||
|
||||
def efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||
""" EfficientNet """
|
||||
default_cfg = default_cfgs['efficientnet_b0']
|
||||
# NOTE dropout should be 0.2 for train
|
||||
model = _gen_efficientnet(
|
||||
channel_multiplier=1.0, depth_multiplier=1.0,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def efficientnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||
""" EfficientNet """
|
||||
default_cfg = default_cfgs['efficientnet_b1']
|
||||
# NOTE dropout should be 0.2 for train
|
||||
model = _gen_efficientnet(
|
||||
channel_multiplier=1.0, depth_multiplier=1.1,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def efficientnet_b2(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||
""" EfficientNet """
|
||||
default_cfg = default_cfgs['efficientnet_b2']
|
||||
# NOTE dropout should be 0.3 for train
|
||||
model = _gen_efficientnet(
|
||||
channel_multiplier=1.1, depth_multiplier=1.2,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||
""" EfficientNet """
|
||||
default_cfg = default_cfgs['efficientnet_b3']
|
||||
# NOTE dropout should be 0.3 for train
|
||||
model = _gen_efficientnet(
|
||||
channel_multiplier=1.2, depth_multiplier=1.4,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def efficientnet_b4(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||
""" EfficientNet """
|
||||
default_cfg = default_cfgs['efficientnet_b4']
|
||||
# NOTE dropout should be 0.4 for train
|
||||
model = _gen_efficientnet(
|
||||
channel_multiplier=1.4, depth_multiplier=1.8,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def genmobilenet_model_names():
|
||||
return set(_models)
|
||||
|
Loading…
x
Reference in New Issue
Block a user