mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
EfficientNetV2 official impl w/ weights ported from TF. Cleanup/refactor of related EfficientNet classes and models.
This commit is contained in:
parent
c16d65a8a7
commit
c4f482a08b
@ -2,6 +2,9 @@
|
||||
|
||||
An implementation of EfficienNet that covers variety of related models with efficient architectures:
|
||||
|
||||
* EfficientNet-V2
|
||||
- `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
|
||||
|
||||
* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent weight ports)
|
||||
- EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946
|
||||
- CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971
|
||||
@ -22,23 +25,26 @@ An implementation of EfficienNet that covers variety of related models with effi
|
||||
|
||||
* And likely more...
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from typing import List
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
||||
from .efficientnet_blocks import SqueezeExcite
|
||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\
|
||||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from .features import FeatureInfo, FeatureHooks
|
||||
from .helpers import build_model_with_cfg, default_cfg_for_features
|
||||
from .layers import create_conv2d, create_classifier
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['EfficientNet']
|
||||
__all__ = ['EfficientNet', 'EfficientNetFeatures']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
@ -149,9 +155,20 @@ default_cfgs = {
|
||||
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb3_pruned_5abcc29f.pth',
|
||||
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||
|
||||
'efficientnet_v2s': _cfg(
|
||||
'efficientnetv2_rw_s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth',
|
||||
input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), # FIXME WIP
|
||||
input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0),
|
||||
|
||||
'efficientnetv2_s': _cfg(
|
||||
url='',
|
||||
input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0),
|
||||
'efficientnetv2_m': _cfg(
|
||||
url='',
|
||||
input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0),
|
||||
'efficientnetv2_l': _cfg(
|
||||
url='',
|
||||
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
|
||||
|
||||
|
||||
'tf_efficientnet_b0': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
|
||||
@ -298,6 +315,58 @@ default_cfgs = {
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'),
|
||||
|
||||
'tf_efficientnetv2_s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
|
||||
'tf_efficientnetv2_m': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
|
||||
'tf_efficientnetv2_l': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
|
||||
|
||||
'tf_efficientnetv2_s_21kft1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21kft1k-d7dafa41.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
|
||||
'tf_efficientnetv2_m_21kft1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21kft1k-bf41664a.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
|
||||
'tf_efficientnetv2_l_21kft1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21kft1k-60127a9d.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
|
||||
|
||||
'tf_efficientnetv2_s_21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
|
||||
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
|
||||
'tf_efficientnetv2_m_21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21k-361418a2.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
|
||||
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
|
||||
'tf_efficientnetv2_l_21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21k-91a19ec9.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
|
||||
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
|
||||
|
||||
'tf_efficientnetv2_b0': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b0-c7cc451f.pth',
|
||||
input_size=(3, 192, 192), test_input_size=(3, 224, 224), pool_size=(6, 6)),
|
||||
'tf_efficientnetv2_b1': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b1-be6e41b0.pth',
|
||||
input_size=(3, 192, 192), test_input_size=(3, 240, 240), pool_size=(6, 6), crop_pct=0.882),
|
||||
'tf_efficientnetv2_b2': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b2-847de54e.pth',
|
||||
input_size=(3, 208, 208), test_input_size=(3, 260, 260), pool_size=(7, 7), crop_pct=0.890),
|
||||
'tf_efficientnetv2_b3': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b3-57773f13.pth',
|
||||
input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.904),
|
||||
|
||||
'mixnet_s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'),
|
||||
'mixnet_m': _cfg(
|
||||
@ -316,13 +385,12 @@ default_cfgs = {
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'),
|
||||
}
|
||||
|
||||
_DEBUG = False
|
||||
|
||||
|
||||
class EfficientNet(nn.Module):
|
||||
""" (Generic) EfficientNet
|
||||
|
||||
A flexible and performant PyTorch implementation of efficient network architectures, including:
|
||||
* EfficientNet-V2 Small, Medium, Large & B0-B3
|
||||
* EfficientNet B0-B8, L2
|
||||
* EfficientNet-EdgeTPU
|
||||
* EfficientNet-CondConv
|
||||
@ -333,35 +401,35 @@ class EfficientNet(nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
|
||||
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
|
||||
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, fix_stem=False,
|
||||
output_stride=32, pad_type='', round_chs_fn=round_channels, act_layer=None, norm_layer=None,
|
||||
se_layer=None, drop_rate=0., drop_path_rate=0., global_pool='avg'):
|
||||
super(EfficientNet, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
|
||||
act_layer = act_layer or nn.ReLU
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
se_layer = se_layer or SqueezeExcite
|
||||
self.num_classes = num_classes
|
||||
self.num_features = num_features
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
# Stem
|
||||
if not fix_stem:
|
||||
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
||||
stem_size = round_chs_fn(stem_size)
|
||||
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
|
||||
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||
self.bn1 = norm_layer(stem_size)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
|
||||
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn,
|
||||
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)
|
||||
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||
self.feature_info = builder.features
|
||||
head_chs = builder.in_chs
|
||||
|
||||
# Head + Pooling
|
||||
self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type)
|
||||
self.bn2 = norm_layer(self.num_features, **norm_kwargs)
|
||||
self.bn2 = norm_layer(self.num_features)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.global_pool, self.classifier = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool)
|
||||
@ -408,25 +476,27 @@ class EfficientNetFeatures(nn.Module):
|
||||
and object detection models.
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
|
||||
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
|
||||
stem_size=32, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
|
||||
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
|
||||
super(EfficientNetFeatures, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
act_layer = act_layer or nn.ReLU
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
se_layer = se_layer or SqueezeExcite
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
# Stem
|
||||
if not fix_stem:
|
||||
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
||||
stem_size = round_chs_fn(stem_size)
|
||||
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
|
||||
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||
self.bn1 = norm_layer(stem_size)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn,
|
||||
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate,
|
||||
feature_location=feature_location)
|
||||
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||
self.feature_info = FeatureInfo(builder.features, out_indices)
|
||||
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
|
||||
@ -505,8 +575,8 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
**kwargs
|
||||
)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
@ -541,8 +611,8 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
**kwargs
|
||||
)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
@ -570,8 +640,8 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
stem_size=8,
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
**kwargs
|
||||
)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
@ -593,13 +663,14 @@ def _gen_mobilenet_v2(
|
||||
['ir_r3_k3_s2_e6_c160'],
|
||||
['ir_r1_k3_s1_e6_c320'],
|
||||
]
|
||||
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
|
||||
num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None),
|
||||
num_features=1280 if fix_stem_head else round_chs_fn(1280),
|
||||
stem_size=32,
|
||||
fix_stem=fix_stem_head,
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
round_chs_fn=round_chs_fn,
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=resolve_act_layer(kwargs, 'relu6'),
|
||||
**kwargs
|
||||
)
|
||||
@ -629,8 +700,8 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
block_args=decode_arch_def(arch_def),
|
||||
stem_size=16,
|
||||
num_features=1984, # paper suggests this, but is not 100% clear
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
**kwargs
|
||||
)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
@ -664,8 +735,8 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
**kwargs
|
||||
)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
@ -705,13 +776,14 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
|
||||
['ir_r4_k5_s2_e6_c192_se0.25'],
|
||||
['ir_r1_k3_s1_e6_c320_se0.25'],
|
||||
]
|
||||
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier),
|
||||
num_features=round_channels(1280, channel_multiplier, 8, None),
|
||||
num_features=round_chs_fn(1280),
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
round_chs_fn=round_chs_fn,
|
||||
act_layer=resolve_act_layer(kwargs, 'swish'),
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
@ -734,12 +806,13 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
|
||||
['ir_r4_k5_s1_e8_c144'],
|
||||
['ir_r2_k5_s2_e8_c192'],
|
||||
]
|
||||
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier),
|
||||
num_features=round_channels(1280, channel_multiplier, 8, None),
|
||||
num_features=round_chs_fn(1280),
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
round_chs_fn=round_chs_fn,
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=resolve_act_layer(kwargs, 'relu'),
|
||||
**kwargs,
|
||||
)
|
||||
@ -764,12 +837,13 @@ def _gen_efficientnet_condconv(
|
||||
]
|
||||
# NOTE unlike official impl, this one uses `cc<x>` option where x is the base number of experts for each stage and
|
||||
# the expert_multiplier increases that on a per-model basis as with depth/channel multipliers
|
||||
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier),
|
||||
num_features=round_channels(1280, channel_multiplier, 8, None),
|
||||
num_features=round_chs_fn(1280),
|
||||
stem_size=32,
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
round_chs_fn=round_chs_fn,
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=resolve_act_layer(kwargs, 'swish'),
|
||||
**kwargs,
|
||||
)
|
||||
@ -809,45 +883,137 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
|
||||
num_features=1280,
|
||||
stem_size=32,
|
||||
fix_stem=True,
|
||||
channel_multiplier=channel_multiplier,
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
act_layer=resolve_act_layer(kwargs, 'relu6'),
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def _gen_efficientnet_v2s(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
||||
""" Creates an EfficientNet-V2s model
|
||||
def _gen_efficientnetv2_base(
|
||||
variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
||||
""" Creates an EfficientNet-V2 base model
|
||||
|
||||
NOTE: this is a preliminary definition based on paper, awaiting official code release for details
|
||||
and weights
|
||||
|
||||
Ref impl:
|
||||
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
||||
Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
|
||||
"""
|
||||
|
||||
arch_def = [
|
||||
# FIXME it's not clear if the FusedMBConv layers have SE enabled for the Small variant,
|
||||
# Table 4 suggests no. 23.94M params w/o, 23.98 with which is closer to 24M.
|
||||
# ['er_r2_k3_s1_e1_c24_se0.25'],
|
||||
# ['er_r4_k3_s2_e4_c48_se0.25'],
|
||||
# ['er_r4_k3_s2_e4_c64_se0.25'],
|
||||
['er_r2_k3_s1_e1_c24'],
|
||||
['cn_r1_k3_s1_e1_c16_skip'],
|
||||
['er_r2_k3_s2_e4_c32'],
|
||||
['er_r2_k3_s2_e4_c48'],
|
||||
['ir_r3_k3_s2_e4_c96_se0.25'],
|
||||
['ir_r5_k3_s1_e6_c112_se0.25'],
|
||||
['ir_r8_k3_s2_e6_c192_se0.25'],
|
||||
]
|
||||
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier),
|
||||
num_features=round_chs_fn(1280),
|
||||
stem_size=32,
|
||||
round_chs_fn=round_chs_fn,
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=resolve_act_layer(kwargs, 'silu'),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def _gen_efficientnetv2_s(
|
||||
variant, channel_multiplier=1.0, depth_multiplier=1.0, rw=False, pretrained=False, **kwargs):
|
||||
""" Creates an EfficientNet-V2 Small model
|
||||
|
||||
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
||||
Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
|
||||
|
||||
NOTE: `rw` flag sets up 'small' variant to behave like my initial v2 small model,
|
||||
before ref the impl was released.
|
||||
"""
|
||||
arch_def = [
|
||||
['cn_r2_k3_s1_e1_c24_skip'],
|
||||
['er_r4_k3_s2_e4_c48'],
|
||||
['er_r4_k3_s2_e4_c64'],
|
||||
['ir_r6_k3_s2_e4_c128_se0.25'],
|
||||
['ir_r9_k3_s1_e6_c160_se0.25'],
|
||||
['ir_r15_k3_s2_e6_c272_se0.25'],
|
||||
['ir_r15_k3_s2_e6_c256_se0.25'],
|
||||
]
|
||||
num_features = 1280
|
||||
if rw:
|
||||
# my original variant, based on paper figure differs from the official release
|
||||
arch_def[0] = ['er_r2_k3_s1_e1_c24']
|
||||
arch_def[-1] = ['ir_r15_k3_s2_e6_c272_se0.25']
|
||||
num_features = 1792
|
||||
|
||||
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier),
|
||||
num_features=round_channels(1792, channel_multiplier, 8, None),
|
||||
num_features=round_chs_fn(num_features),
|
||||
stem_size=24,
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
act_layer=resolve_act_layer(kwargs, 'silu'), # FIXME this is an assumption, paper does not mention
|
||||
round_chs_fn=round_chs_fn,
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=resolve_act_layer(kwargs, 'silu'),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
||||
""" Creates an EfficientNet-V2 Medium model
|
||||
|
||||
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
||||
Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
|
||||
"""
|
||||
|
||||
arch_def = [
|
||||
['cn_r3_k3_s1_e1_c24_skip'],
|
||||
['er_r5_k3_s2_e4_c48'],
|
||||
['er_r5_k3_s2_e4_c80'],
|
||||
['ir_r7_k3_s2_e4_c160_se0.25'],
|
||||
['ir_r14_k3_s1_e6_c176_se0.25'],
|
||||
['ir_r18_k3_s2_e6_c304_se0.25'],
|
||||
['ir_r5_k3_s1_e6_c512_se0.25'],
|
||||
]
|
||||
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier),
|
||||
num_features=1280,
|
||||
stem_size=24,
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=resolve_act_layer(kwargs, 'silu'),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
||||
""" Creates an EfficientNet-V2 Large model
|
||||
|
||||
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
|
||||
Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
|
||||
"""
|
||||
|
||||
arch_def = [
|
||||
['cn_r4_k3_s1_e1_c32_skip'],
|
||||
['er_r7_k3_s2_e4_c64'],
|
||||
['er_r7_k3_s2_e4_c96'],
|
||||
['ir_r10_k3_s2_e4_c192_se0.25'],
|
||||
['ir_r19_k3_s1_e6_c224_se0.25'],
|
||||
['ir_r25_k3_s2_e6_c384_se0.25'],
|
||||
['ir_r7_k3_s1_e6_c640_se0.25'],
|
||||
]
|
||||
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier),
|
||||
num_features=1280,
|
||||
stem_size=32,
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=resolve_act_layer(kwargs, 'silu'),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
@ -879,8 +1045,8 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
block_args=decode_arch_def(arch_def),
|
||||
num_features=1536,
|
||||
stem_size=16,
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
**kwargs
|
||||
)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
@ -912,8 +1078,8 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
|
||||
num_features=1536,
|
||||
stem_size=24,
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
**kwargs
|
||||
)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
@ -1290,13 +1456,35 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_v2s(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2 Small. """
|
||||
model = _gen_efficientnet_v2s(
|
||||
'efficientnet_v2s', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
||||
def efficientnetv2_rw_s(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2 Small.
|
||||
NOTE: This is my initial (pre official code release) w/ some differences.
|
||||
See efficientnetv2_s and tf_efficientnetv2_s for versions that match the official w/ PyTorch vs TF padding
|
||||
"""
|
||||
model = _gen_efficientnetv2_s('efficientnetv2_rw_s', rw=True, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnetv2_s(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2 Small. """
|
||||
model = _gen_efficientnetv2_s('efficientnetv2_s', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnetv2_m(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2 Medium. """
|
||||
model = _gen_efficientnetv2_m('efficientnetv2_m', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnetv2_l(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2 Large. """
|
||||
model = _gen_efficientnetv2_l('efficientnetv2_l', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnet_b0(pretrained=False, **kwargs):
|
||||
@ -1708,6 +1896,127 @@ def tf_efficientnet_lite4(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnetv2_s(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2 Small. Tensorflow compatible variant """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnetv2_s('tf_efficientnetv2_s', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnetv2_m(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2 Medium. Tensorflow compatible variant """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnetv2_m('tf_efficientnetv2_m', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnetv2_l(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2 Large. Tensorflow compatible variant """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnetv2_l('tf_efficientnetv2_l', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnetv2_s_21kft1k(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2 Small. Tensorflow compatible variant """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnetv2_s('tf_efficientnetv2_s_21kft1k', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnetv2_m_21kft1k(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2 Medium. Tensorflow compatible variant """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnetv2_m('tf_efficientnetv2_m_21kft1k', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnetv2_l_21kft1k(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2 Large. Tensorflow compatible variant """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnetv2_l('tf_efficientnetv2_l_21kft1k', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnetv2_s_21k(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2 Small w/ ImageNet-21k pretrained weights. Tensorflow compatible variant """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnetv2_s('tf_efficientnetv2_s_21k', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnetv2_m_21k(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2 Medium w/ ImageNet-21k pretrained weights. Tensorflow compatible variant """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnetv2_m('tf_efficientnetv2_m_21k', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnetv2_l_21k(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2 Large w/ ImageNet-21k pretrained weights. Tensorflow compatible variant """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnetv2_l('tf_efficientnetv2_l_21k', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnetv2_b0(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2-B0. Tensorflow compatible variant """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnetv2_base('tf_efficientnetv2_b0', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnetv2_b1(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2-B1. Tensorflow compatible variant """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnetv2_base(
|
||||
'tf_efficientnetv2_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnetv2_b2(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2-B2. Tensorflow compatible variant """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnetv2_base(
|
||||
'tf_efficientnetv2_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_efficientnetv2_b3(pretrained=False, **kwargs):
|
||||
""" EfficientNet-V2-B3. Tensorflow compatible variant """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnetv2_base(
|
||||
'tf_efficientnetv2_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixnet_s(pretrained=False, **kwargs):
|
||||
"""Creates a MixNet Small model.
|
||||
|
@ -7,106 +7,34 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .layers import create_conv2d, drop_path, get_act_layer
|
||||
from .layers import create_conv2d, drop_path, make_divisible
|
||||
from .layers.activations import sigmoid
|
||||
|
||||
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
||||
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
|
||||
# NOTE: momentum varies btw .99 and .9997 depending on source
|
||||
# .99 in official TF TPU impl
|
||||
# .9997 (/w .999 in search space) for paper
|
||||
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
|
||||
BN_EPS_TF_DEFAULT = 1e-3
|
||||
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
|
||||
|
||||
|
||||
def get_bn_args_tf():
|
||||
return _BN_ARGS_TF.copy()
|
||||
|
||||
|
||||
def resolve_bn_args(kwargs):
|
||||
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
|
||||
bn_momentum = kwargs.pop('bn_momentum', None)
|
||||
if bn_momentum is not None:
|
||||
bn_args['momentum'] = bn_momentum
|
||||
bn_eps = kwargs.pop('bn_eps', None)
|
||||
if bn_eps is not None:
|
||||
bn_args['eps'] = bn_eps
|
||||
return bn_args
|
||||
|
||||
|
||||
_SE_ARGS_DEFAULT = dict(
|
||||
gate_fn=sigmoid,
|
||||
act_layer=None,
|
||||
reduce_mid=False,
|
||||
divisor=1)
|
||||
|
||||
|
||||
def resolve_se_args(kwargs, in_chs, act_layer=None):
|
||||
se_kwargs = kwargs.copy() if kwargs is not None else {}
|
||||
# fill in args that aren't specified with the defaults
|
||||
for k, v in _SE_ARGS_DEFAULT.items():
|
||||
se_kwargs.setdefault(k, v)
|
||||
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
|
||||
if not se_kwargs.pop('reduce_mid'):
|
||||
se_kwargs['reduced_base_chs'] = in_chs
|
||||
# act_layer override, if it remains None, the containing block's act_layer will be used
|
||||
if se_kwargs['act_layer'] is None:
|
||||
assert act_layer is not None
|
||||
se_kwargs['act_layer'] = act_layer
|
||||
return se_kwargs
|
||||
|
||||
|
||||
def resolve_act_layer(kwargs, default='relu'):
|
||||
act_layer = kwargs.pop('act_layer', default)
|
||||
if isinstance(act_layer, str):
|
||||
act_layer = get_act_layer(act_layer)
|
||||
return act_layer
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
min_value = min_value or divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
if not multiplier:
|
||||
return channels
|
||||
channels *= multiplier
|
||||
return make_divisible(channels, divisor, channel_min)
|
||||
|
||||
|
||||
class ChannelShuffle(nn.Module):
|
||||
# FIXME haven't used yet
|
||||
def __init__(self, groups):
|
||||
super(ChannelShuffle, self).__init__()
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
"""Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]"""
|
||||
N, C, H, W = x.size()
|
||||
g = self.groups
|
||||
assert C % g == 0, "Incompatible group size {} for input channel {}".format(
|
||||
g, C
|
||||
)
|
||||
return (
|
||||
x.view(N, g, int(C / g), H, W)
|
||||
.permute(0, 2, 1, 3, 4)
|
||||
.contiguous()
|
||||
.view(N, C, H, W)
|
||||
)
|
||||
__all__ = [
|
||||
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual']
|
||||
|
||||
|
||||
class SqueezeExcite(nn.Module):
|
||||
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
|
||||
act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_):
|
||||
""" Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family
|
||||
|
||||
Args:
|
||||
in_chs (int): input channels to layer
|
||||
se_ratio (float): ratio of squeeze reduction
|
||||
act_layer (nn.Module): activation layer of containing block
|
||||
gate_fn (Callable): attention gate function
|
||||
block_in_chs (int): input channels of containing block (for calculating reduction from)
|
||||
reduce_from_block (bool): calculate reduction from block input channels if True
|
||||
force_act_layer (nn.Module): override block's activation fn if this is set/bound
|
||||
divisor (int): make reduction channels divisible by this
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_chs, se_ratio=0.25, act_layer=nn.ReLU, gate_fn=sigmoid,
|
||||
block_in_chs=None, reduce_from_block=True, force_act_layer=None, divisor=1):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
|
||||
reduced_chs = (block_in_chs or in_chs) if reduce_from_block else in_chs
|
||||
reduced_chs = make_divisible(reduced_chs * se_ratio, divisor)
|
||||
act_layer = force_act_layer or act_layer
|
||||
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
||||
@ -121,13 +49,16 @@ class SqueezeExcite(nn.Module):
|
||||
|
||||
|
||||
class ConvBnAct(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, kernel_size,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
""" Conv + Norm Layer + Activation w/ optional skip connection
|
||||
"""
|
||||
def __init__(
|
||||
self, in_chs, out_chs, kernel_size, stride=1, dilation=1, pad_type='',
|
||||
skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.):
|
||||
super(ConvBnAct, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
self.has_residual = skip and stride == 1 and in_chs == out_chs
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
||||
self.bn1 = norm_layer(out_chs, **norm_kwargs)
|
||||
self.bn1 = norm_layer(out_chs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
def feature_info(self, location):
|
||||
@ -138,9 +69,14 @@ class ConvBnAct(nn.Module):
|
||||
return info
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
x = self.conv(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
if self.has_residual:
|
||||
if self.drop_path_rate > 0.:
|
||||
x = drop_path(x, self.drop_path_rate, self.training)
|
||||
x += shortcut
|
||||
return x
|
||||
|
||||
|
||||
@ -149,31 +85,26 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
|
||||
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
|
||||
"""
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
|
||||
noskip=False, pw_kernel_size=1, pw_act=False, se_ratio=0.,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
has_se = se_ratio is not None and se_ratio > 0.
|
||||
has_se = se_layer is not None and se_ratio > 0.
|
||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.has_pw_act = pw_act # activation after point-wise conv
|
||||
self.drop_path_rate = drop_path_rate
|
||||
|
||||
self.conv_dw = create_conv2d(
|
||||
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
|
||||
self.bn1 = norm_layer(in_chs, **norm_kwargs)
|
||||
self.bn1 = norm_layer(in_chs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if has_se:
|
||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
|
||||
else:
|
||||
self.se = None
|
||||
self.se = se_layer(in_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity()
|
||||
|
||||
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
|
||||
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
||||
self.bn2 = norm_layer(out_chs)
|
||||
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
|
||||
|
||||
def feature_info(self, location):
|
||||
@ -190,8 +121,7 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
if self.se is not None:
|
||||
x = self.se(x)
|
||||
x = self.se(x)
|
||||
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn2(x)
|
||||
@ -214,41 +144,36 @@ class InvertedResidual(nn.Module):
|
||||
* MobileNet-V3 - https://arxiv.org/abs/1905.02244
|
||||
"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
conv_kwargs=None, drop_path_rate=0.):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
|
||||
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
|
||||
super(InvertedResidual, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
has_se = se_ratio is not None and se_ratio > 0.
|
||||
has_se = se_layer is not None and se_ratio > 0.
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.drop_path_rate = drop_path_rate
|
||||
|
||||
# Point-wise expansion
|
||||
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
|
||||
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
||||
self.bn1 = norm_layer(mid_chs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Depth-wise convolution
|
||||
self.conv_dw = create_conv2d(
|
||||
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
|
||||
padding=pad_type, depthwise=True, **conv_kwargs)
|
||||
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
|
||||
self.bn2 = norm_layer(mid_chs)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if has_se:
|
||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
||||
else:
|
||||
self.se = None
|
||||
self.se = se_layer(
|
||||
mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity()
|
||||
|
||||
# Point-wise linear projection
|
||||
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
||||
self.bn3 = norm_layer(out_chs, **norm_kwargs)
|
||||
self.bn3 = norm_layer(out_chs)
|
||||
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion': # after SE, input to PWL
|
||||
@ -271,8 +196,7 @@ class InvertedResidual(nn.Module):
|
||||
x = self.act2(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if self.se is not None:
|
||||
x = self.se(x)
|
||||
x = self.se(x)
|
||||
|
||||
# Point-wise linear projection
|
||||
x = self.conv_pwl(x)
|
||||
@ -289,11 +213,10 @@ class InvertedResidual(nn.Module):
|
||||
class CondConvResidual(InvertedResidual):
|
||||
""" Inverted residual block w/ CondConv routing"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
num_experts=0, drop_path_rate=0.):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
|
||||
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
|
||||
|
||||
self.num_experts = num_experts
|
||||
conv_kwargs = dict(num_experts=self.num_experts)
|
||||
@ -301,9 +224,8 @@ class CondConvResidual(InvertedResidual):
|
||||
super(CondConvResidual, self).__init__(
|
||||
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
|
||||
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
|
||||
drop_path_rate=drop_path_rate)
|
||||
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_layer=se_layer,
|
||||
norm_layer=norm_layer, conv_kwargs=conv_kwargs, drop_path_rate=drop_path_rate)
|
||||
|
||||
self.routing_fn = nn.Linear(in_chs, self.num_experts)
|
||||
|
||||
@ -325,8 +247,7 @@ class CondConvResidual(InvertedResidual):
|
||||
x = self.act2(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if self.se is not None:
|
||||
x = self.se(x)
|
||||
x = self.se(x)
|
||||
|
||||
# Point-wise linear projection
|
||||
x = self.conv_pwl(x, routing_weights)
|
||||
@ -351,36 +272,32 @@ class EdgeResidual(nn.Module):
|
||||
* EfficientNet-V2 - https://arxiv.org/abs/2104.00298
|
||||
"""
|
||||
|
||||
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
drop_path_rate=0.):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='',
|
||||
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, se_ratio=0.,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
|
||||
super(EdgeResidual, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
if fake_in_chs > 0:
|
||||
mid_chs = make_divisible(fake_in_chs * exp_ratio)
|
||||
if force_in_chs > 0:
|
||||
mid_chs = make_divisible(force_in_chs * exp_ratio)
|
||||
else:
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
has_se = se_ratio is not None and se_ratio > 0.
|
||||
has_se = se_layer is not None and se_ratio > 0.
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.drop_path_rate = drop_path_rate
|
||||
|
||||
# Expansion convolution
|
||||
self.conv_exp = create_conv2d(
|
||||
in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
||||
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
||||
self.bn1 = norm_layer(mid_chs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if has_se:
|
||||
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
||||
else:
|
||||
self.se = None
|
||||
self.se = SqueezeExcite(
|
||||
mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity()
|
||||
|
||||
# Point-wise linear projection
|
||||
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)
|
||||
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
||||
self.bn2 = norm_layer(out_chs)
|
||||
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion': # after SE, before PWL
|
||||
@ -398,8 +315,7 @@ class EdgeResidual(nn.Module):
|
||||
x = self.act1(x)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
if self.se is not None:
|
||||
x = self.se(x)
|
||||
x = self.se(x)
|
||||
|
||||
# Point-wise linear projection
|
||||
x = self.conv_pwl(x)
|
||||
|
@ -14,13 +14,55 @@ from copy import deepcopy
|
||||
import torch.nn as nn
|
||||
|
||||
from .efficientnet_blocks import *
|
||||
from .layers import CondConv2d, get_condconv_initializer
|
||||
from .layers import CondConv2d, get_condconv_initializer, get_act_layer, make_divisible
|
||||
|
||||
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"]
|
||||
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
|
||||
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_DEBUG_BUILDER = False
|
||||
|
||||
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
||||
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
|
||||
# NOTE: momentum varies btw .99 and .9997 depending on source
|
||||
# .99 in official TF TPU impl
|
||||
# .9997 (/w .999 in search space) for paper
|
||||
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
|
||||
BN_EPS_TF_DEFAULT = 1e-3
|
||||
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
|
||||
|
||||
|
||||
def get_bn_args_tf():
|
||||
return _BN_ARGS_TF.copy()
|
||||
|
||||
|
||||
def resolve_bn_args(kwargs):
|
||||
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
|
||||
bn_momentum = kwargs.pop('bn_momentum', None)
|
||||
if bn_momentum is not None:
|
||||
bn_args['momentum'] = bn_momentum
|
||||
bn_eps = kwargs.pop('bn_eps', None)
|
||||
if bn_eps is not None:
|
||||
bn_args['eps'] = bn_eps
|
||||
return bn_args
|
||||
|
||||
|
||||
def resolve_act_layer(kwargs, default='relu'):
|
||||
act_layer = kwargs.pop('act_layer', default)
|
||||
if isinstance(act_layer, str):
|
||||
act_layer = get_act_layer(act_layer)
|
||||
return act_layer
|
||||
|
||||
|
||||
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
if not multiplier:
|
||||
return channels
|
||||
return make_divisible(channels * multiplier, divisor, channel_min, round_limit=round_limit)
|
||||
|
||||
|
||||
def _log_info_if(msg, condition):
|
||||
if condition:
|
||||
_logger.info(msg)
|
||||
@ -63,11 +105,13 @@ def _decode_block_str(block_str):
|
||||
block_type = ops[0] # take the block type off the front
|
||||
ops = ops[1:]
|
||||
options = {}
|
||||
noskip = False
|
||||
skip = None
|
||||
for op in ops:
|
||||
# string options being checked on individual basis, combine if they grow
|
||||
if op == 'noskip':
|
||||
noskip = True
|
||||
skip = False # force no skip connection
|
||||
elif op == 'skip':
|
||||
skip = True # force a skip connection
|
||||
elif op.startswith('n'):
|
||||
# activation fn
|
||||
key = op[0]
|
||||
@ -94,7 +138,7 @@ def _decode_block_str(block_str):
|
||||
act_layer = options['n'] if 'n' in options else None
|
||||
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
|
||||
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
|
||||
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
|
||||
force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
|
||||
|
||||
num_repeat = int(options['r'])
|
||||
# each type of block has different valid arguments, fill accordingly
|
||||
@ -106,10 +150,10 @@ def _decode_block_str(block_str):
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
exp_ratio=float(options['e']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
se_ratio=float(options['se']) if 'se' in options else 0.,
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
noskip=noskip,
|
||||
noskip=skip is False,
|
||||
)
|
||||
if 'cc' in options:
|
||||
block_args['num_experts'] = int(options['cc'])
|
||||
@ -119,11 +163,11 @@ def _decode_block_str(block_str):
|
||||
dw_kernel_size=_parse_ksize(options['k']),
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
se_ratio=float(options['se']) if 'se' in options else 0.,
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
pw_act=block_type == 'dsa',
|
||||
noskip=block_type == 'dsa' or noskip,
|
||||
noskip=block_type == 'dsa' or skip is False,
|
||||
)
|
||||
elif block_type == 'er':
|
||||
block_args = dict(
|
||||
@ -132,11 +176,11 @@ def _decode_block_str(block_str):
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
out_chs=int(options['c']),
|
||||
exp_ratio=float(options['e']),
|
||||
fake_in_chs=fake_in_chs,
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
force_in_chs=force_in_chs,
|
||||
se_ratio=float(options['se']) if 'se' in options else 0.,
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
noskip=noskip,
|
||||
noskip=skip is False,
|
||||
)
|
||||
elif block_type == 'cn':
|
||||
block_args = dict(
|
||||
@ -145,6 +189,7 @@ def _decode_block_str(block_str):
|
||||
out_chs=int(options['c']),
|
||||
stride=int(options['s']),
|
||||
act_layer=act_layer,
|
||||
skip=skip is True,
|
||||
)
|
||||
else:
|
||||
assert False, 'Unknown block type (%s)' % block_type
|
||||
@ -219,19 +264,14 @@ class EfficientNetBuilder:
|
||||
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
|
||||
|
||||
"""
|
||||
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
output_stride=32, pad_type='', act_layer=None, se_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., feature_location='',
|
||||
verbose=False):
|
||||
self.channel_multiplier = channel_multiplier
|
||||
self.channel_divisor = channel_divisor
|
||||
self.channel_min = channel_min
|
||||
def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels,
|
||||
act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''):
|
||||
self.output_stride = output_stride
|
||||
self.pad_type = pad_type
|
||||
self.round_chs_fn = round_chs_fn
|
||||
self.act_layer = act_layer
|
||||
self.se_kwargs = se_kwargs
|
||||
self.norm_layer = norm_layer
|
||||
self.norm_kwargs = norm_kwargs
|
||||
self.se_layer = se_layer
|
||||
self.drop_path_rate = drop_path_rate
|
||||
if feature_location == 'depthwise':
|
||||
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
|
||||
@ -239,45 +279,39 @@ class EfficientNetBuilder:
|
||||
feature_location = 'expansion'
|
||||
self.feature_location = feature_location
|
||||
assert feature_location in ('bottleneck', 'expansion', '')
|
||||
self.verbose = verbose
|
||||
self.verbose = _DEBUG_BUILDER
|
||||
|
||||
# state updated during build, consumed by model
|
||||
self.in_chs = None
|
||||
self.features = []
|
||||
|
||||
def _round_channels(self, chs):
|
||||
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
|
||||
|
||||
def _make_block(self, ba, block_idx, block_count):
|
||||
drop_path_rate = self.drop_path_rate * block_idx / block_count
|
||||
bt = ba.pop('block_type')
|
||||
ba['in_chs'] = self.in_chs
|
||||
ba['out_chs'] = self._round_channels(ba['out_chs'])
|
||||
if 'fake_in_chs' in ba and ba['fake_in_chs']:
|
||||
# FIXME this is a hack to work around mismatch in origin impl input filters
|
||||
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
|
||||
ba['norm_layer'] = self.norm_layer
|
||||
ba['norm_kwargs'] = self.norm_kwargs
|
||||
ba['out_chs'] = self.round_chs_fn(ba['out_chs'])
|
||||
if 'force_in_chs' in ba and ba['force_in_chs']:
|
||||
# NOTE this is a hack to work around mismatch in TF EdgeEffNet impl
|
||||
ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs'])
|
||||
ba['pad_type'] = self.pad_type
|
||||
# block act fn overrides the model default
|
||||
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
|
||||
assert ba['act_layer'] is not None
|
||||
if bt == 'ir':
|
||||
ba['norm_layer'] = self.norm_layer
|
||||
if bt != 'cn':
|
||||
ba['se_layer'] = self.se_layer
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
|
||||
if bt == 'ir':
|
||||
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
if ba.get('num_experts', 0) > 0:
|
||||
block = CondConvResidual(**ba)
|
||||
else:
|
||||
block = InvertedResidual(**ba)
|
||||
elif bt == 'ds' or bt == 'dsa':
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
_log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
block = DepthwiseSeparableConv(**ba)
|
||||
elif bt == 'er':
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
_log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
block = EdgeResidual(**ba)
|
||||
elif bt == 'cn':
|
||||
@ -285,8 +319,8 @@ class EfficientNetBuilder:
|
||||
block = ConvBnAct(**ba)
|
||||
else:
|
||||
assert False, 'Uknkown block type (%s) while building model.' % bt
|
||||
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
|
||||
|
||||
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
|
||||
return block
|
||||
|
||||
def __call__(self, in_chs, model_block_args):
|
||||
|
@ -13,8 +13,8 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid
|
||||
from .efficientnet_blocks import SqueezeExcite, ConvBnAct, make_divisible
|
||||
from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid, make_divisible
|
||||
from .efficientnet_blocks import SqueezeExcite, ConvBnAct
|
||||
from .helpers import build_model_with_cfg
|
||||
from .registry import register_model
|
||||
|
||||
@ -110,7 +110,6 @@ class GhostBottleneck(nn.Module):
|
||||
nn.BatchNorm2d(out_chs),
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
|
||||
|
@ -1,10 +1,14 @@
|
||||
from functools import partial
|
||||
|
||||
import torch.nn as nn
|
||||
from .efficientnet_builder import decode_arch_def, resolve_bn_args
|
||||
from .mobilenetv3 import MobileNetV3, MobileNetV3Features, build_model_with_cfg, default_cfg_for_features
|
||||
from .layers import hard_sigmoid
|
||||
from .efficientnet_blocks import resolve_act_layer
|
||||
from .registry import register_model
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .efficientnet_blocks import SqueezeExcite
|
||||
from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args
|
||||
from .helpers import build_model_with_cfg, default_cfg_for_features
|
||||
from .layers import get_act_fn
|
||||
from .mobilenetv3 import MobileNetV3, MobileNetV3Features
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
@ -35,15 +39,15 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
|
||||
|
||||
"""
|
||||
num_features = 1280
|
||||
|
||||
se_layer = partial(
|
||||
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
num_features=num_features,
|
||||
stem_size=32,
|
||||
channel_multiplier=1,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
|
||||
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
|
||||
se_layer=se_layer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -22,10 +22,10 @@ to_4tuple = _ntuple(4)
|
||||
to_ntuple = _ntuple
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
|
||||
min_value = min_value or divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
if new_v < round_limit * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
return new_v
|
@ -5,23 +5,25 @@ A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
|
||||
|
||||
Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from typing import List
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
||||
from .efficientnet_blocks import SqueezeExcite
|
||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\
|
||||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from .features import FeatureInfo, FeatureHooks
|
||||
from .helpers import build_model_with_cfg, default_cfg_for_features
|
||||
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['MobileNetV3']
|
||||
__all__ = ['MobileNetV3', 'MobileNetV3Features']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
@ -47,9 +49,11 @@ default_cfgs = {
|
||||
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mobilenetv3_large_100_in21k_miil.pth', num_classes=11221),
|
||||
'mobilenetv3_small_075': _cfg(url=''),
|
||||
'mobilenetv3_small_100': _cfg(url=''),
|
||||
|
||||
'mobilenetv3_rw': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
|
||||
interpolation='bicubic'),
|
||||
|
||||
'tf_mobilenetv3_large_075': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||
@ -70,8 +74,6 @@ default_cfgs = {
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||
}
|
||||
|
||||
_DEBUG = False
|
||||
|
||||
|
||||
class MobileNetV3(nn.Module):
|
||||
""" MobiletNet-V3
|
||||
@ -84,24 +86,26 @@ class MobileNetV3(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
|
||||
channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
|
||||
pad_type='', act_layer=None, norm_layer=None, se_layer=None,
|
||||
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
|
||||
super(MobileNetV3, self).__init__()
|
||||
|
||||
act_layer = act_layer or nn.ReLU
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
se_layer = se_layer or SqueezeExcite
|
||||
self.num_classes = num_classes
|
||||
self.num_features = num_features
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
# Stem
|
||||
stem_size = round_channels(stem_size, channel_multiplier)
|
||||
stem_size = round_chs_fn(stem_size)
|
||||
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
|
||||
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||
self.bn1 = norm_layer(stem_size)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
|
||||
output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn,
|
||||
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)
|
||||
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||
self.feature_info = builder.features
|
||||
head_chs = builder.in_chs
|
||||
@ -158,23 +162,25 @@ class MobileNetV3Features(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
|
||||
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
|
||||
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
in_chans=3, stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels,
|
||||
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
|
||||
super(MobileNetV3Features, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
act_layer = act_layer or nn.ReLU
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
se_layer = se_layer or SqueezeExcite
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
# Stem
|
||||
stem_size = round_channels(stem_size, channel_multiplier)
|
||||
stem_size = round_chs_fn(stem_size)
|
||||
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
|
||||
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||
self.bn1 = norm_layer(stem_size)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn,
|
||||
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer,
|
||||
drop_path_rate=drop_path_rate, feature_location=feature_location)
|
||||
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||
self.feature_info = FeatureInfo(builder.features, out_indices)
|
||||
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
|
||||
@ -253,10 +259,10 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
head_bias=False,
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
|
||||
se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1),
|
||||
se_layer=partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), reduce_from_block=False),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_mnv3(variant, pretrained, **model_kwargs)
|
||||
@ -344,15 +350,16 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # hard-swish
|
||||
]
|
||||
|
||||
se_layer = partial(
|
||||
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
num_features=num_features,
|
||||
stem_size=16,
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=act_layer,
|
||||
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
|
||||
se_layer=se_layer,
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_mnv3(variant, pretrained, **model_kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user