Add cs3darknet_x, cs3sedarknet_l, and darknetaa53 weights from TPU sessions. Move SE btwn conv1 & conv2 in DarkBlock. Improve SE/attn handling in Csp/DarkNet. Fix leaky_relu bug on older csp models.
parent
4283c0c478
commit
05313940e2
|
@ -23,7 +23,7 @@ import torch.nn.functional as F
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP
|
||||
from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, create_act_layer, make_divisible
|
||||
from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
|
@ -57,9 +57,10 @@ default_cfgs = {
|
|||
'sedarknet21': _cfg(url=''),
|
||||
'darknet53': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth',
|
||||
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0,
|
||||
),
|
||||
'darknetaa53': _cfg(url=''),
|
||||
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'darknetaa53': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknetaa53_c2ns-5c28ec8a.pth',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
|
||||
'cs3darknet_s': _cfg(
|
||||
url='', interpolation='bicubic'),
|
||||
|
@ -71,7 +72,8 @@ default_cfgs = {
|
|||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_l_c2ns-16220c5d.pth',
|
||||
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||
'cs3darknet_x': _cfg(
|
||||
url=''),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_x_c2ns-4e4490aa.pth',
|
||||
interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
|
||||
'cs3darknet_focus_s': _cfg(
|
||||
url='', interpolation='bicubic'),
|
||||
|
@ -84,6 +86,10 @@ default_cfgs = {
|
|||
'cs3darknet_focus_x': _cfg(
|
||||
url='', interpolation='bicubic'),
|
||||
|
||||
'cs3sedarknet_l': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_l_c2ns-e8d1dc13.pth',
|
||||
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||
|
||||
'cs3sedarknet_xdw': _cfg(
|
||||
url='', interpolation='bicubic'),
|
||||
}
|
||||
|
@ -119,6 +125,7 @@ class CspStagesCfg:
|
|||
bottle_ratio: Union[float, Tuple[float, ...]] = 1. # bottleneck-ratio of blocks in stage
|
||||
avg_down: Union[bool, Tuple[bool, ...]] = False
|
||||
attn_layer: Optional[Union[str, Tuple[str, ...]]] = None
|
||||
attn_kwargs: Optional[Union[Dict, Tuple[Dict]]] = None
|
||||
stage_type: Union[str, Tuple[str]] = 'csp' # stage type ('csp', 'cs2', 'dark')
|
||||
block_type: Union[str, Tuple[str]] = 'bottle' # blocks type for stages ('bottle', 'dark')
|
||||
|
||||
|
@ -136,6 +143,7 @@ class CspStagesCfg:
|
|||
self.bottle_ratio = _pad_arg(self.bottle_ratio, n)
|
||||
self.avg_down = _pad_arg(self.avg_down, n)
|
||||
self.attn_layer = _pad_arg(self.attn_layer, n)
|
||||
self.attn_kwargs = _pad_arg(self.attn_kwargs, n)
|
||||
self.stage_type = _pad_arg(self.stage_type, n)
|
||||
self.block_type = _pad_arg(self.block_type, n)
|
||||
|
||||
|
@ -149,12 +157,20 @@ class CspModelCfg:
|
|||
stem: CspStemCfg
|
||||
stages: CspStagesCfg
|
||||
zero_init_last: bool = True # zero init last weight (usually bn) in residual path
|
||||
act_layer: str = 'relu'
|
||||
act_layer: str = 'leaky_relu'
|
||||
norm_layer: str = 'batchnorm'
|
||||
aa_layer: Optional[str] = None # FIXME support string factory for this
|
||||
|
||||
|
||||
def _cs3darknet_cfg(width_multiplier=1.0, depth_multiplier=1.0, avg_down=False, act_layer='silu', focus=False):
|
||||
def _cs3darknet_cfg(
|
||||
width_multiplier=1.0,
|
||||
depth_multiplier=1.0,
|
||||
avg_down=False,
|
||||
act_layer='silu',
|
||||
focus=False,
|
||||
attn_layer=None,
|
||||
attn_kwargs=None,
|
||||
):
|
||||
if focus:
|
||||
stem_cfg = CspStemCfg(
|
||||
out_chs=make_divisible(64 * width_multiplier),
|
||||
|
@ -172,6 +188,8 @@ def _cs3darknet_cfg(width_multiplier=1.0, depth_multiplier=1.0, avg_down=False,
|
|||
bottle_ratio=1.,
|
||||
block_ratio=0.5,
|
||||
avg_down=avg_down,
|
||||
attn_layer=attn_layer,
|
||||
attn_kwargs=attn_kwargs,
|
||||
stage_type='cs3',
|
||||
block_type='dark',
|
||||
),
|
||||
|
@ -201,7 +219,7 @@ model_cfgs = dict(
|
|||
bottle_ratio=0.5,
|
||||
block_ratio=1.,
|
||||
cross_linear=True,
|
||||
)
|
||||
),
|
||||
),
|
||||
cspresnet50w=CspModelCfg(
|
||||
stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
|
||||
|
@ -213,7 +231,7 @@ model_cfgs = dict(
|
|||
bottle_ratio=0.25,
|
||||
block_ratio=0.5,
|
||||
cross_linear=True,
|
||||
)
|
||||
),
|
||||
),
|
||||
cspresnext50=CspModelCfg(
|
||||
stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
|
||||
|
@ -226,7 +244,7 @@ model_cfgs = dict(
|
|||
bottle_ratio=1.,
|
||||
block_ratio=0.5,
|
||||
cross_linear=True,
|
||||
)
|
||||
),
|
||||
),
|
||||
cspdarknet53=CspModelCfg(
|
||||
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
||||
|
@ -240,7 +258,6 @@ model_cfgs = dict(
|
|||
down_growth=True,
|
||||
block_type='dark',
|
||||
),
|
||||
act_layer='leaky_relu',
|
||||
),
|
||||
darknet17=CspModelCfg(
|
||||
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
||||
|
@ -253,7 +270,6 @@ model_cfgs = dict(
|
|||
stage_type='dark',
|
||||
block_type='dark',
|
||||
),
|
||||
act_layer='leaky_relu',
|
||||
),
|
||||
darknet21=CspModelCfg(
|
||||
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
||||
|
@ -267,7 +283,6 @@ model_cfgs = dict(
|
|||
block_type='dark',
|
||||
|
||||
),
|
||||
act_layer='leaky_relu',
|
||||
),
|
||||
sedarknet21=CspModelCfg(
|
||||
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
||||
|
@ -282,7 +297,6 @@ model_cfgs = dict(
|
|||
block_type='dark',
|
||||
|
||||
),
|
||||
act_layer='leaky_relu',
|
||||
),
|
||||
darknet53=CspModelCfg(
|
||||
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
||||
|
@ -295,7 +309,6 @@ model_cfgs = dict(
|
|||
stage_type='dark',
|
||||
block_type='dark',
|
||||
),
|
||||
act_layer='leaky_relu',
|
||||
),
|
||||
darknetaa53=CspModelCfg(
|
||||
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
||||
|
@ -309,7 +322,6 @@ model_cfgs = dict(
|
|||
stage_type='dark',
|
||||
block_type='dark',
|
||||
),
|
||||
act_layer='leaky_relu',
|
||||
),
|
||||
|
||||
cs3darknet_s=_cs3darknet_cfg(width_multiplier=0.5, depth_multiplier=0.5),
|
||||
|
@ -322,6 +334,8 @@ model_cfgs = dict(
|
|||
cs3darknet_focus_l=_cs3darknet_cfg(focus=True),
|
||||
cs3darknet_focus_x=_cs3darknet_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True),
|
||||
|
||||
cs3sedarknet_l=_cs3darknet_cfg(attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
|
||||
|
||||
cs3sedarknet_xdw=CspModelCfg(
|
||||
stem=CspStemCfg(out_chs=(32, 64), kernel_size=3, stride=2, pool=''),
|
||||
stages=CspStagesCfg(
|
||||
|
@ -333,6 +347,7 @@ model_cfgs = dict(
|
|||
block_ratio=0.5,
|
||||
attn_layer='se',
|
||||
),
|
||||
act_layer='silu',
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -359,14 +374,16 @@ class BottleneckBlock(nn.Module):
|
|||
super(BottleneckBlock, self).__init__()
|
||||
mid_chs = int(round(out_chs * bottle_ratio))
|
||||
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
|
||||
attn_last = attn_layer is not None and attn_last
|
||||
attn_first = attn_layer is not None and not attn_last
|
||||
|
||||
self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
|
||||
self.conv2 = ConvNormActAa(
|
||||
mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups,
|
||||
aa_layer=aa_layer, drop_layer=drop_block, **ckwargs)
|
||||
self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None
|
||||
self.attn2 = attn_layer(mid_chs, act_layer=act_layer) if attn_first else nn.Identity()
|
||||
self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs)
|
||||
self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None
|
||||
self.attn3 = attn_layer(out_chs, act_layer=act_layer) if attn_last else nn.Identity()
|
||||
self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
|
||||
self.act3 = create_act_layer(act_layer)
|
||||
|
||||
|
@ -377,11 +394,9 @@ class BottleneckBlock(nn.Module):
|
|||
shortcut = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
if self.attn2 is not None:
|
||||
x = self.attn2(x)
|
||||
x = self.attn2(x)
|
||||
x = self.conv3(x)
|
||||
if self.attn3 is not None:
|
||||
x = self.attn3(x)
|
||||
x = self.attn3(x)
|
||||
x = self.drop_path(x) + shortcut
|
||||
# FIXME partial shortcut needed if first block handled as per original, not used for my current impl
|
||||
#x[:, :shortcut.size(1)] += shortcut
|
||||
|
@ -410,11 +425,12 @@ class DarkBlock(nn.Module):
|
|||
super(DarkBlock, self).__init__()
|
||||
mid_chs = int(round(out_chs * bottle_ratio))
|
||||
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
|
||||
|
||||
self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
|
||||
self.attn = attn_layer(mid_chs, act_layer=act_layer) if attn_layer is not None else nn.Identity()
|
||||
self.conv2 = ConvNormActAa(
|
||||
mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups,
|
||||
aa_layer=aa_layer, drop_layer=drop_block, **ckwargs)
|
||||
self.attn = create_attn(attn_layer, channels=out_chs, act_layer=act_layer)
|
||||
self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
|
||||
|
||||
def zero_init_last(self):
|
||||
|
@ -423,9 +439,8 @@ class DarkBlock(nn.Module):
|
|||
def forward(self, x):
|
||||
shortcut = x
|
||||
x = self.conv1(x)
|
||||
x = self.attn(x)
|
||||
x = self.conv2(x)
|
||||
if self.attn is not None:
|
||||
x = self.attn(x)
|
||||
x = self.drop_path(x) + shortcut
|
||||
return x
|
||||
|
||||
|
@ -688,7 +703,8 @@ def create_csp_stem(
|
|||
return stem, feature_info
|
||||
|
||||
|
||||
def _get_stage_fn(stage_type: str, stage_args):
|
||||
def _get_stage_fn(stage_args):
|
||||
stage_type = stage_args.pop('stage_type')
|
||||
assert stage_type in ('dark', 'csp', 'cs3')
|
||||
if stage_type == 'dark':
|
||||
stage_args.pop('expand_ratio', None)
|
||||
|
@ -702,14 +718,25 @@ def _get_stage_fn(stage_type: str, stage_args):
|
|||
return stage_fn, stage_args
|
||||
|
||||
|
||||
def _get_block_fn(stage_type: str, stage_args):
|
||||
assert stage_type in ('dark', 'bottle')
|
||||
if stage_type == 'dark':
|
||||
def _get_block_fn(stage_args):
|
||||
block_type = stage_args.pop('block_type')
|
||||
assert block_type in ('dark', 'bottle')
|
||||
if block_type == 'dark':
|
||||
return DarkBlock, stage_args
|
||||
else:
|
||||
return BottleneckBlock, stage_args
|
||||
|
||||
|
||||
def _get_attn_fn(stage_args):
|
||||
attn_layer = stage_args.pop('attn_layer')
|
||||
attn_kwargs = stage_args.pop('attn_kwargs', None) or {}
|
||||
if attn_layer is not None:
|
||||
attn_layer = get_attn(attn_layer)
|
||||
if attn_kwargs:
|
||||
attn_layer = partial(attn_layer, **attn_kwargs)
|
||||
return attn_layer, stage_args
|
||||
|
||||
|
||||
def create_csp_stages(
|
||||
cfg: CspModelCfg,
|
||||
drop_path_rate: float,
|
||||
|
@ -734,8 +761,9 @@ def create_csp_stages(
|
|||
feature_info = []
|
||||
stages = []
|
||||
for stage_idx, stage_args in enumerate(stage_args):
|
||||
stage_fn, stage_args = _get_stage_fn(stage_args.pop('stage_type'), stage_args)
|
||||
block_fn, stage_args = _get_block_fn(stage_args.pop('block_type'), stage_args)
|
||||
stage_fn, stage_args = _get_stage_fn(stage_args)
|
||||
block_fn, stage_args = _get_block_fn(stage_args)
|
||||
attn_fn, stage_args = _get_attn_fn(stage_args)
|
||||
stride = stage_args.pop('stride')
|
||||
if stride != 1 and prev_feat:
|
||||
feature_info.append(prev_feat)
|
||||
|
@ -752,6 +780,7 @@ def create_csp_stages(
|
|||
first_dilation=first_dilation,
|
||||
dilation=dilation,
|
||||
block_fn=block_fn,
|
||||
attn_layer=attn_fn, # will be passed through stage as block_kwargs
|
||||
**block_kwargs,
|
||||
)]
|
||||
prev_chs = stage_args['out_chs']
|
||||
|
@ -968,6 +997,11 @@ def cs3darknet_focus_x(pretrained=False, **kwargs):
|
|||
return _create_cspnet('cs3darknet_focus_x', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def cs3sedarknet_l(pretrained=False, **kwargs):
|
||||
return _create_cspnet('cs3sedarknet_l', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def cs3sedarknet_xdw(pretrained=False, **kwargs):
|
||||
return _create_cspnet('cs3sedarknet_xdw', pretrained=pretrained, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue