mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Experimenting with a custom MixNet-XL and MixNet-XXL definition
This commit is contained in:
parent
9816ca3ab4
commit
51a2375b0c
@ -138,6 +138,8 @@ default_cfgs = {
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'),
|
||||
'mixnet_l': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth'),
|
||||
'mixnet_xl': _cfg(),
|
||||
'mixnet_xxl': _cfg(),
|
||||
'tf_mixnet_s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'),
|
||||
'tf_mixnet_m': _cfg(
|
||||
@ -312,21 +314,59 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
|
||||
else:
|
||||
assert False, 'Unknown block type (%s)' % block_type
|
||||
|
||||
# return a list of block args expanded by num_repeat and
|
||||
# scaled by depth_multiplier
|
||||
num_repeat = int(math.ceil(num_repeat * depth_multiplier))
|
||||
return [deepcopy(block_args) for _ in range(num_repeat)]
|
||||
return block_args, num_repeat
|
||||
|
||||
|
||||
def _decode_arch_def(arch_def, depth_multiplier=1.0):
|
||||
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
|
||||
""" Per-stage depth scaling
|
||||
Scales the block repeats in each stage. This depth scaling impl maintains
|
||||
compatibility with the EfficientNet scaling method, while allowing sensible
|
||||
scaling for other models that may have multiple block arg definitions in each stage.
|
||||
"""
|
||||
|
||||
# We scale the total repeat count for each stage, there may be multiple
|
||||
# block arg defs per stage so we need to sum.
|
||||
num_repeat = sum(repeats)
|
||||
if depth_trunc == 'round':
|
||||
# Truncating to int by rounding allows stages with few repeats to remain
|
||||
# proportionally smaller for longer. This is a good choice when stage definitions
|
||||
# include single repeat stages that we'd prefer to keep that way as long as possible
|
||||
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
|
||||
else:
|
||||
# The default for EfficientNet truncates repeats to int via 'ceil'.
|
||||
# Any multiplier > 1.0 will result in an increased depth for every stage.
|
||||
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
|
||||
|
||||
# Proportionally distribute repeat count scaling to each block definition in the stage.
|
||||
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
|
||||
# The first block makes less sense to repeat in most of the arch definitions.
|
||||
repeats_scaled = []
|
||||
for r in repeats[::-1]:
|
||||
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
|
||||
repeats_scaled.append(rs)
|
||||
num_repeat -= r
|
||||
num_repeat_scaled -= rs
|
||||
repeats_scaled = repeats_scaled[::-1]
|
||||
|
||||
# Apply the calculated scaling to each block arg in the stage
|
||||
sa_scaled = []
|
||||
for ba, rep in zip(stack_args, repeats_scaled):
|
||||
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
|
||||
return sa_scaled
|
||||
|
||||
|
||||
def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
|
||||
arch_args = []
|
||||
for stack_idx, block_strings in enumerate(arch_def):
|
||||
assert isinstance(block_strings, list)
|
||||
stack_args = []
|
||||
repeats = []
|
||||
for block_str in block_strings:
|
||||
assert isinstance(block_str, str)
|
||||
stack_args.extend(_decode_block_str(block_str, depth_multiplier))
|
||||
arch_args.append(stack_args)
|
||||
ba, rep = _decode_block_str(block_str)
|
||||
stack_args.append(ba)
|
||||
repeats.append(rep)
|
||||
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
|
||||
return arch_args
|
||||
|
||||
|
||||
@ -1261,7 +1301,7 @@ def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs):
|
||||
def _gen_mixnet_m(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
|
||||
"""Creates a MixNet Medium-Large model.
|
||||
|
||||
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
|
||||
@ -1283,7 +1323,7 @@ def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs):
|
||||
# 7x7
|
||||
]
|
||||
model = GenEfficientNet(
|
||||
_decode_arch_def(arch_def),
|
||||
_decode_arch_def(arch_def, depth_multiplier=depth_multiplier, depth_trunc='round'),
|
||||
num_classes=num_classes,
|
||||
stem_size=24,
|
||||
num_features=1536,
|
||||
@ -1876,6 +1916,33 @@ def mixnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixnet_xl(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Creates a MixNet Extra-Large model.
|
||||
"""
|
||||
default_cfg = default_cfgs['mixnet_xl']
|
||||
#kwargs['drop_connect_rate'] = 0.2
|
||||
model = _gen_mixnet_m(
|
||||
channel_multiplier=1.6, depth_multiplier=1.2, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixnet_xxl(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Creates a MixNet Double Extra Large model.
|
||||
"""
|
||||
default_cfg = default_cfgs['mixnet_xxl']
|
||||
model = _gen_mixnet_m(
|
||||
channel_multiplier=2.4, depth_multiplier=1.3, num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Creates a MixNet Small model. Tensorflow compatible variant
|
||||
|
Loading…
x
Reference in New Issue
Block a user