mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add Single-Path NAS pixel1 model
This commit is contained in:
parent
419555be62
commit
34cd76899f
@ -5,6 +5,7 @@ A generic MobileNet class with building blocks to support a variety of models:
|
||||
* MobileNetV2
|
||||
* FBNet-C (TODO A & B)
|
||||
* ChamNet (TODO still guessing at architecture definition)
|
||||
* Single-Path NAS Pixel1
|
||||
* ShuffleNetV2 (TODO add IR shuffle block)
|
||||
* And likely more...
|
||||
|
||||
@ -25,8 +26,9 @@ from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
__all__ = ['GenMobileNet', 'mnasnet0_50', 'mnasnet0_75', 'mnasnet1_00', 'mnasnet1_40',
|
||||
'semnasnet0_50', 'semnasnet0_75', 'semnasnet1_00', 'semnasnet1_40',
|
||||
'mnasnet_small']
|
||||
'semnasnet0_50', 'semnasnet0_75', 'semnasnet1_00', 'semnasnet1_40', 'mnasnet_small',
|
||||
'mobilenetv1_1_00', 'mobilenetv2_1_00', 'chamnetv1_1_00', 'chamnetv2_1_00',
|
||||
'fbnetc_1_00', 'spnasnet1_00']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
@ -54,6 +56,7 @@ default_cfgs = {
|
||||
'chamnetv1_1_00': _cfg(url=''),
|
||||
'chamnetv2_1_00': _cfg(url=''),
|
||||
'fbnetc_1_00': _cfg(url=''),
|
||||
'spnasnet1_00': _cfg(url=''),
|
||||
}
|
||||
|
||||
_DEBUG = True
|
||||
@ -476,6 +479,7 @@ class GenMobileNet(nn.Module):
|
||||
* MNASNet A1, B1, and small
|
||||
* FBNet A, B, and C
|
||||
* ChamNet (arch details are murky)
|
||||
* Single-Path NAS Pixel1
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
|
||||
@ -820,6 +824,45 @@ def _gen_fbnetc(depth_multiplier, num_classes=1000, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_spnasnet(depth_multiplier, num_classes=1000, **kwargs):
|
||||
"""Creates the Single-Path NAS model from search targeted for Pixel1 phone.
|
||||
|
||||
Paper: https://arxiv.org/abs/1904.02877
|
||||
|
||||
Args:
|
||||
depth_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s1_c16_noskip'],
|
||||
# stage 1, 112x112 in
|
||||
['ir_r3_k3_s2_e3_c24'],
|
||||
# stage 2, 56x56 in
|
||||
['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'],
|
||||
# stage 3, 28x28 in
|
||||
['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'],
|
||||
# stage 4, 14x14in
|
||||
['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'],
|
||||
# stage 5, 14x14in
|
||||
['ir_r4_k5_s2_e6_c192'],
|
||||
# stage 6, 7x7 in
|
||||
['ir_r1_k3_s1_e6_c320_noskip']
|
||||
]
|
||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||
model = GenMobileNet(
|
||||
arch_def,
|
||||
num_classes=num_classes,
|
||||
stem_size=32,
|
||||
depth_multiplier=depth_multiplier,
|
||||
depth_divisor=8,
|
||||
min_depth=None,
|
||||
bn_momentum=bn_momentum,
|
||||
bn_eps=bn_eps,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def mnasnet0_50(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
""" MNASNet B1, depth multiplier of 0.5. """
|
||||
default_cfg = default_cfgs['mnasnet0_50']
|
||||
@ -958,3 +1001,13 @@ def chamnetv2_1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def spnasnet1_00(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||
""" Single-Path NAS Pixel1"""
|
||||
default_cfg = default_cfgs['spnasnet1_00']
|
||||
model = _gen_spnasnet(1.0, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
@ -11,7 +11,8 @@ from models.pnasnet import pnasnet5large
|
||||
from models.genmobilenet import \
|
||||
mnasnet0_50, mnasnet0_75, mnasnet1_00, mnasnet1_40,\
|
||||
semnasnet0_50, semnasnet0_75, semnasnet1_00, semnasnet1_40, mnasnet_small,\
|
||||
mobilenetv1_1_00, mobilenetv2_1_00, fbnetc_1_00, chamnetv1_1_00, chamnetv2_1_00
|
||||
mobilenetv1_1_00, mobilenetv2_1_00, fbnetc_1_00, chamnetv1_1_00, chamnetv2_1_00,\
|
||||
spnasnet1_00
|
||||
|
||||
from models.helpers import load_checkpoint
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user