Update CSP ResNets for cross expansion without activation. Fix VovNet IABN compatibility with fixed activation arg.
parent
3b6cce4c95
commit
e2cc481310
|
@ -32,7 +32,7 @@ def _cfg(url='', **kwargs):
|
|||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv1', 'classifier': 'fc',
|
||||
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
@ -59,6 +59,7 @@ model_cfgs = dict(
|
|||
exp_ratio=(2.,) * 4,
|
||||
bottle_ratio=(0.5,) * 4,
|
||||
block_ratio=(1.,) * 4,
|
||||
cross_linear=True,
|
||||
)
|
||||
),
|
||||
cspresnet50d=dict(
|
||||
|
@ -70,6 +71,7 @@ model_cfgs = dict(
|
|||
exp_ratio=(2.,) * 4,
|
||||
bottle_ratio=(0.5,) * 4,
|
||||
block_ratio=(1.,) * 4,
|
||||
cross_linear=True,
|
||||
)
|
||||
),
|
||||
cspresnet50w=dict(
|
||||
|
@ -81,6 +83,7 @@ model_cfgs = dict(
|
|||
exp_ratio=(1.,) * 4,
|
||||
bottle_ratio=(0.25,) * 4,
|
||||
block_ratio=(0.5,) * 4,
|
||||
cross_linear=True,
|
||||
)
|
||||
),
|
||||
cspresnext50=dict(
|
||||
|
@ -93,6 +96,7 @@ model_cfgs = dict(
|
|||
exp_ratio=(1.,) * 4,
|
||||
bottle_ratio=(1.,) * 4,
|
||||
block_ratio=(0.5,) * 4,
|
||||
cross_linear=True,
|
||||
)
|
||||
),
|
||||
cspdarknet53=dict(
|
||||
|
@ -217,7 +221,7 @@ class DarkBlock(nn.Module):
|
|||
class CrossStage(nn.Module):
|
||||
"""Cross Stage."""
|
||||
def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1.,
|
||||
groups=1, first_dilation=None, down_growth=False, block_dpr=None,
|
||||
groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None,
|
||||
block_fn=ResBottleneck, **block_kwargs):
|
||||
super(CrossStage, self).__init__()
|
||||
first_dilation = first_dilation or dilation
|
||||
|
@ -238,7 +242,7 @@ class CrossStage(nn.Module):
|
|||
# FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also,
|
||||
# there is also special case for the first stage for some of the model that results in uneven split
|
||||
# across the two paths. I did it this way for simplicity for now.
|
||||
self.conv_exp = ConvBnAct(prev_chs, exp_chs, kernel_size=1, **conv_kwargs)
|
||||
self.conv_exp = ConvBnAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs)
|
||||
prev_chs = exp_chs // 2 # output of conv_exp is always split in two
|
||||
|
||||
self.blocks = nn.Sequential()
|
||||
|
@ -317,6 +321,8 @@ def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.):
|
|||
cfg['groups'] = (1,) * num_stages
|
||||
if 'down_growth' in cfg and not isinstance(cfg['down_growth'], (list, tuple)):
|
||||
cfg['down_growth'] = (cfg['down_growth'],) * num_stages
|
||||
if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)):
|
||||
cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages
|
||||
cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \
|
||||
[x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])]
|
||||
stage_strides = []
|
||||
|
|
|
@ -180,32 +180,33 @@ class SequentialAppendList(nn.Sequential):
|
|||
class OsaBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_chs, mid_chs, out_chs, layer_per_block, residual=False,
|
||||
depthwise=False, attn='', norm_layer=BatchNormAct2d):
|
||||
depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU):
|
||||
super(OsaBlock, self).__init__()
|
||||
|
||||
self.residual = residual
|
||||
self.depthwise = depthwise
|
||||
conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer)
|
||||
|
||||
next_in_chs = in_chs
|
||||
if self.depthwise and next_in_chs != mid_chs:
|
||||
assert not residual
|
||||
self.conv_reduction = ConvBnAct(next_in_chs, mid_chs, 1, norm_layer=norm_layer)
|
||||
self.conv_reduction = ConvBnAct(next_in_chs, mid_chs, 1, **conv_kwargs)
|
||||
else:
|
||||
self.conv_reduction = None
|
||||
|
||||
mid_convs = []
|
||||
for i in range(layer_per_block):
|
||||
if self.depthwise:
|
||||
conv = SeparableConvBnAct(mid_chs, mid_chs, norm_layer=norm_layer)
|
||||
conv = SeparableConvBnAct(mid_chs, mid_chs, **conv_kwargs)
|
||||
else:
|
||||
conv = ConvBnAct(next_in_chs, mid_chs, 3, norm_layer=norm_layer)
|
||||
conv = ConvBnAct(next_in_chs, mid_chs, 3, **conv_kwargs)
|
||||
next_in_chs = mid_chs
|
||||
mid_convs.append(conv)
|
||||
self.conv_mid = SequentialAppendList(*mid_convs)
|
||||
|
||||
# feature aggregation
|
||||
next_in_chs = in_chs + layer_per_block * mid_chs
|
||||
self.conv_concat = ConvBnAct(next_in_chs, out_chs, norm_layer=norm_layer)
|
||||
self.conv_concat = ConvBnAct(next_in_chs, out_chs, **conv_kwargs)
|
||||
|
||||
if attn:
|
||||
self.attn = create_attn(attn, out_chs)
|
||||
|
@ -227,8 +228,8 @@ class OsaBlock(nn.Module):
|
|||
|
||||
class OsaStage(nn.Module):
|
||||
|
||||
def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block,
|
||||
downsample=True, residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d):
|
||||
def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True,
|
||||
residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU):
|
||||
super(OsaStage, self).__init__()
|
||||
|
||||
if downsample:
|
||||
|
@ -241,7 +242,7 @@ class OsaStage(nn.Module):
|
|||
last_block = i == block_per_stage - 1
|
||||
blocks += [OsaBlock(
|
||||
in_chs if i == 0 else out_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0,
|
||||
depthwise=depthwise, attn=attn if last_block else '', norm_layer=norm_layer)
|
||||
depthwise=depthwise, attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer)
|
||||
]
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
|
@ -275,7 +276,7 @@ class ClassifierHead(nn.Module):
|
|||
class VovNet(nn.Module):
|
||||
|
||||
def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4,
|
||||
output_stride=32, norm_layer=BatchNormAct2d):
|
||||
output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU):
|
||||
""" VovNet (v2)
|
||||
"""
|
||||
super(VovNet, self).__init__()
|
||||
|
@ -289,14 +290,15 @@ class VovNet(nn.Module):
|
|||
stage_out_chs = cfg["stage_out_chs"]
|
||||
block_per_stage = cfg["block_per_stage"]
|
||||
layer_per_block = cfg["layer_per_block"]
|
||||
conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer)
|
||||
|
||||
# Stem module
|
||||
last_stem_stride = stem_stride // 2
|
||||
conv_type = SeparableConvBnAct if cfg["depthwise"] else ConvBnAct
|
||||
self.stem = nn.Sequential(*[
|
||||
ConvBnAct(in_chans, stem_chs[0], 3, stride=2, norm_layer=norm_layer),
|
||||
conv_type(stem_chs[0], stem_chs[1], 3, stride=1, norm_layer=norm_layer),
|
||||
conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, norm_layer=norm_layer),
|
||||
ConvBnAct(in_chans, stem_chs[0], 3, stride=2, **conv_kwargs),
|
||||
conv_type(stem_chs[0], stem_chs[1], 3, stride=1, **conv_kwargs),
|
||||
conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, **conv_kwargs),
|
||||
])
|
||||
self.feature_info = [dict(
|
||||
num_chs=stem_chs[1], reduction=2, module=f'stem.{1 if stem_stride == 4 else 2}')]
|
||||
|
@ -304,8 +306,7 @@ class VovNet(nn.Module):
|
|||
|
||||
# OSA stages
|
||||
in_ch_list = stem_chs[-1:] + stage_out_chs[:-1]
|
||||
stage_args = dict(
|
||||
residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], norm_layer=norm_layer)
|
||||
stage_args = dict(residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], **conv_kwargs)
|
||||
stages = []
|
||||
for i in range(4): # num_stages
|
||||
downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
|
||||
|
@ -420,4 +421,5 @@ def ese_vovnet39b_evos(pretrained=False, **kwargs):
|
|||
@register_model
|
||||
def ese_vovnet99b_iabn(pretrained=False, **kwargs):
|
||||
norm_layer = get_norm_act_layer('iabn')
|
||||
return _vovnet('ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, **kwargs)
|
||||
return _vovnet(
|
||||
'ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, act_layer=nn.LeakyReLU, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue