mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Unify drop connect vs drop path under 'drop path' name, switch all EfficientNet/MobilenetV3 refs to 'drop_path'. Update factory to handle new drop args.
This commit is contained in:
parent
f1d5f8a6c4
commit
43225d110c
@ -253,7 +253,7 @@ class EfficientNet(nn.Module):
|
||||
|
||||
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
|
||||
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
|
||||
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
|
||||
super(EfficientNet, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
@ -273,7 +273,7 @@ class EfficientNet(nn.Module):
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG)
|
||||
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = builder.features
|
||||
self._in_chs = builder.in_chs
|
||||
@ -333,7 +333,7 @@ class EfficientNetFeatures(nn.Module):
|
||||
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
|
||||
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
|
||||
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
super(EfficientNetFeatures, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
@ -355,7 +355,7 @@ class EfficientNetFeatures(nn.Module):
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_connect_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = builder.features # builder provides info about feature channels for each block
|
||||
self._in_chs = builder.in_chs
|
||||
@ -875,7 +875,7 @@ def spnasnet_100(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def efficientnet_b0(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B0 """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
@ -884,7 +884,7 @@ def efficientnet_b0(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def efficientnet_b1(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B1 """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
@ -893,7 +893,7 @@ def efficientnet_b1(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def efficientnet_b2(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B2 """
|
||||
# NOTE for train, drop_rate should be 0.3, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
@ -902,7 +902,7 @@ def efficientnet_b2(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def efficientnet_b2a(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B2 @ 288x288 w/ 1.0 test crop"""
|
||||
# NOTE for train, drop_rate should be 0.3, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_b2a', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
@ -911,7 +911,7 @@ def efficientnet_b2a(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def efficientnet_b3(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B3 """
|
||||
# NOTE for train, drop_rate should be 0.3, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
@ -920,7 +920,7 @@ def efficientnet_b3(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def efficientnet_b3a(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B3 @ 320x320 w/ 1.0 test crop-pct """
|
||||
# NOTE for train, drop_rate should be 0.3, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_b3a', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
@ -929,7 +929,7 @@ def efficientnet_b3a(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def efficientnet_b4(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B4 """
|
||||
# NOTE for train, drop_rate should be 0.4, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
@ -938,7 +938,7 @@ def efficientnet_b4(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def efficientnet_b5(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B5 """
|
||||
# NOTE for train, drop_rate should be 0.4, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
@ -947,7 +947,7 @@ def efficientnet_b5(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def efficientnet_b6(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B6 """
|
||||
# NOTE for train, drop_rate should be 0.5, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
@ -956,7 +956,7 @@ def efficientnet_b6(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def efficientnet_b7(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B7 """
|
||||
# NOTE for train, drop_rate should be 0.5, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
@ -965,7 +965,7 @@ def efficientnet_b7(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def efficientnet_b8(pretrained=False, **kwargs):
|
||||
""" EfficientNet-B8 """
|
||||
# NOTE for train, drop_rate should be 0.5, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
@ -974,7 +974,7 @@ def efficientnet_b8(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def efficientnet_l2(pretrained=False, **kwargs):
|
||||
""" EfficientNet-L2."""
|
||||
# NOTE for train, drop_rate should be 0.5, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
@ -1007,7 +1007,7 @@ def efficientnet_el(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def efficientnet_cc_b0_4e(pretrained=False, **kwargs):
|
||||
""" EfficientNet-CondConv-B0 w/ 8 Experts """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet_condconv(
|
||||
'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
@ -1016,7 +1016,7 @@ def efficientnet_cc_b0_4e(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
|
||||
""" EfficientNet-CondConv-B0 w/ 8 Experts """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet_condconv(
|
||||
'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
|
||||
pretrained=pretrained, **kwargs)
|
||||
@ -1025,7 +1025,7 @@ def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
|
||||
""" EfficientNet-CondConv-B1 w/ 8 Experts """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet_condconv(
|
||||
'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
|
||||
pretrained=pretrained, **kwargs)
|
||||
@ -1355,7 +1355,7 @@ def tf_efficientnet_el(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
|
||||
""" EfficientNet-CondConv-B0 w/ 4 Experts. Tensorflow compatible variant """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnet_condconv(
|
||||
@ -1366,7 +1366,7 @@ def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
|
||||
""" EfficientNet-CondConv-B0 w/ 8 Experts. Tensorflow compatible variant """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnet_condconv(
|
||||
@ -1377,7 +1377,7 @@ def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
|
||||
""" EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_connect_rate should be 0.2
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
model = _gen_efficientnet_condconv(
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from .layers.activations import sigmoid
|
||||
from .layers import create_conv2d
|
||||
from .layers import create_conv2d, drop_path
|
||||
|
||||
|
||||
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
||||
@ -69,19 +69,6 @@ def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
|
||||
return make_divisible(channels, divisor, channel_min)
|
||||
|
||||
|
||||
def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.):
|
||||
"""Apply drop connect."""
|
||||
if not training:
|
||||
return inputs
|
||||
|
||||
keep_prob = 1 - drop_connect_rate
|
||||
random_tensor = keep_prob + torch.rand(
|
||||
(inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = inputs.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class ChannelShuffle(nn.Module):
|
||||
# FIXME haven't used yet
|
||||
def __init__(self, groups):
|
||||
@ -154,13 +141,13 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.has_pw_act = pw_act # activation after point-wise conv
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
self.drop_path_rate = drop_path_rate
|
||||
|
||||
self.conv_dw = create_conv2d(
|
||||
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
|
||||
@ -200,8 +187,8 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
x = self.act2(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
if self.drop_path_rate > 0.:
|
||||
x = drop_path(x, self.drop_path_rate, self.training)
|
||||
x += residual
|
||||
return x
|
||||
|
||||
@ -213,14 +200,14 @@ class InvertedResidual(nn.Module):
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
conv_kwargs=None, drop_connect_rate=0.):
|
||||
conv_kwargs=None, drop_path_rate=0.):
|
||||
super(InvertedResidual, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
self.drop_path_rate = drop_path_rate
|
||||
|
||||
# Point-wise expansion
|
||||
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
|
||||
@ -278,8 +265,8 @@ class InvertedResidual(nn.Module):
|
||||
x = self.bn3(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
if self.drop_path_rate > 0.:
|
||||
x = drop_path(x, self.drop_path_rate, self.training)
|
||||
x += residual
|
||||
|
||||
return x
|
||||
@ -292,7 +279,7 @@ class CondConvResidual(InvertedResidual):
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
num_experts=0, drop_connect_rate=0.):
|
||||
num_experts=0, drop_path_rate=0.):
|
||||
|
||||
self.num_experts = num_experts
|
||||
conv_kwargs = dict(num_experts=self.num_experts)
|
||||
@ -302,7 +289,7 @@ class CondConvResidual(InvertedResidual):
|
||||
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
|
||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
|
||||
drop_connect_rate=drop_connect_rate)
|
||||
drop_path_rate=drop_path_rate)
|
||||
|
||||
self.routing_fn = nn.Linear(in_chs, self.num_experts)
|
||||
|
||||
@ -332,8 +319,8 @@ class CondConvResidual(InvertedResidual):
|
||||
x = self.bn3(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
if self.drop_path_rate > 0.:
|
||||
x = drop_path(x, self.drop_path_rate, self.training)
|
||||
x += residual
|
||||
return x
|
||||
|
||||
@ -344,7 +331,7 @@ class EdgeResidual(nn.Module):
|
||||
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
|
||||
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
|
||||
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
drop_connect_rate=0.):
|
||||
drop_path_rate=0.):
|
||||
super(EdgeResidual, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
if fake_in_chs > 0:
|
||||
@ -353,7 +340,7 @@ class EdgeResidual(nn.Module):
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
has_se = se_ratio is not None and se_ratio > 0.
|
||||
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
self.drop_path_rate = drop_path_rate
|
||||
|
||||
# Expansion convolution
|
||||
self.conv_exp = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
|
||||
@ -400,8 +387,8 @@ class EdgeResidual(nn.Module):
|
||||
x = self.bn2(x)
|
||||
|
||||
if self.has_residual:
|
||||
if self.drop_connect_rate > 0.:
|
||||
x = drop_connect(x, self.training, self.drop_connect_rate)
|
||||
if self.drop_path_rate > 0.:
|
||||
x = drop_path(x, self.drop_path_rate, self.training)
|
||||
x += residual
|
||||
|
||||
return x
|
||||
|
@ -202,7 +202,7 @@ class EfficientNetBuilder:
|
||||
"""
|
||||
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
output_stride=32, pad_type='', act_layer=None, se_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0., feature_location='',
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., feature_location='',
|
||||
verbose=False):
|
||||
self.channel_multiplier = channel_multiplier
|
||||
self.channel_divisor = channel_divisor
|
||||
@ -213,7 +213,7 @@ class EfficientNetBuilder:
|
||||
self.se_kwargs = se_kwargs
|
||||
self.norm_layer = norm_layer
|
||||
self.norm_kwargs = norm_kwargs
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.feature_location = feature_location
|
||||
assert feature_location in ('pre_pwl', 'post_exp', '')
|
||||
self.verbose = verbose
|
||||
@ -226,7 +226,7 @@ class EfficientNetBuilder:
|
||||
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
|
||||
|
||||
def _make_block(self, ba, block_idx, block_count):
|
||||
drop_connect_rate = self.drop_connect_rate * block_idx / block_count
|
||||
drop_path_rate = self.drop_path_rate * block_idx / block_count
|
||||
bt = ba.pop('block_type')
|
||||
ba['in_chs'] = self.in_chs
|
||||
ba['out_chs'] = self._round_channels(ba['out_chs'])
|
||||
@ -240,7 +240,7 @@ class EfficientNetBuilder:
|
||||
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
|
||||
assert ba['act_layer'] is not None
|
||||
if bt == 'ir':
|
||||
ba['drop_connect_rate'] = drop_connect_rate
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
if self.verbose:
|
||||
logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)))
|
||||
@ -249,13 +249,13 @@ class EfficientNetBuilder:
|
||||
else:
|
||||
block = InvertedResidual(**ba)
|
||||
elif bt == 'ds' or bt == 'dsa':
|
||||
ba['drop_connect_rate'] = drop_connect_rate
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
if self.verbose:
|
||||
logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)))
|
||||
block = DepthwiseSeparableConv(**ba)
|
||||
elif bt == 'er':
|
||||
ba['drop_connect_rate'] = drop_connect_rate
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
if self.verbose:
|
||||
logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)))
|
||||
|
@ -31,7 +31,21 @@ def create_model(
|
||||
kwargs.pop('bn_tf', None)
|
||||
kwargs.pop('bn_momentum', None)
|
||||
kwargs.pop('bn_eps', None)
|
||||
kwargs.pop('drop_connect_rate', None)
|
||||
|
||||
# Parameters that aren't supported by all models should default to None in command line args,
|
||||
# remove them if they are present and not set so that non-supporting models don't break.
|
||||
if kwargs.get('drop_block_rate', None) is None:
|
||||
kwargs.pop('drop_block_rate', None)
|
||||
|
||||
# handle backwards compat with drop_connect -> drop_path change
|
||||
drop_connect_rate = kwargs.pop('drop_connect_rate', None)
|
||||
if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:
|
||||
print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'."
|
||||
" Setting drop_path to %f." % drop_connect_rate)
|
||||
kwargs['drop_path_rate'] = drop_connect_rate
|
||||
|
||||
if kwargs.get('drop_path_rate', None) is None:
|
||||
kwargs.pop('drop_path_rate', None)
|
||||
|
||||
if is_model(model_name):
|
||||
create_fn = model_entrypoint(model_name)
|
||||
|
@ -71,7 +71,7 @@ class MobileNetV3(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
|
||||
channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
|
||||
channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
|
||||
super(MobileNetV3, self).__init__()
|
||||
|
||||
@ -90,7 +90,7 @@ class MobileNetV3(nn.Module):
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG)
|
||||
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = builder.features
|
||||
self._in_chs = builder.in_chs
|
||||
@ -148,7 +148,7 @@ class MobileNetV3Features(nn.Module):
|
||||
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
|
||||
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
|
||||
act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., se_kwargs=None,
|
||||
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
super(MobileNetV3Features, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
@ -170,7 +170,7 @@ class MobileNetV3Features(nn.Module):
|
||||
# Middle stages (IR/ER/DS Blocks)
|
||||
builder = EfficientNetBuilder(
|
||||
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_connect_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = builder.features # builder provides info about feature channels for each block
|
||||
self._in_chs = builder.in_chs
|
||||
|
14
train.py
14
train.py
@ -81,10 +81,14 @@ parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
|
||||
help='input batch size for training (default: 32)')
|
||||
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
|
||||
help='ratio of validation batch size to training batch size (default: 1)')
|
||||
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
||||
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
||||
help='Dropout rate (default: 0.)')
|
||||
parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP',
|
||||
help='Drop connect rate (default: 0.)')
|
||||
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
|
||||
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
|
||||
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
|
||||
help='Drop path rate (default: None)')
|
||||
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
|
||||
help='Drop block rate (default: None)')
|
||||
parser.add_argument('--jsd', action='store_true', default=False,
|
||||
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
|
||||
# Optimizer parameters
|
||||
@ -242,7 +246,9 @@ def main():
|
||||
pretrained=args.pretrained,
|
||||
num_classes=args.num_classes,
|
||||
drop_rate=args.drop,
|
||||
drop_connect_rate=args.drop_connect,
|
||||
drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
|
||||
drop_path_rate=args.drop_path,
|
||||
drop_block_rate=args.drop_block,
|
||||
global_pool=args.gp,
|
||||
bn_tf=args.bn_tf,
|
||||
bn_momentum=args.bn_momentum,
|
||||
|
Loading…
x
Reference in New Issue
Block a user