mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add more experimental darknet and 'cs2' darknet variants (different cross stage setup, closer to newer YOLO backbones) for train trials.
This commit is contained in:
parent
a050fde5cd
commit
82c311d082
@ -16,6 +16,7 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP
|
||||
@ -46,11 +47,21 @@ default_cfgs = {
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth',
|
||||
input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.875 # FIXME I trained this at 224x224, not 256 like ref impl
|
||||
),
|
||||
'cspresnext50_iabn': _cfg(url=''),
|
||||
'cspdarknet53': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'),
|
||||
'cspdarknet53_iabn': _cfg(url=''),
|
||||
|
||||
'darknet17': _cfg(url=''),
|
||||
'darknet21': _cfg(url=''),
|
||||
'darknet53': _cfg(url=''),
|
||||
|
||||
'cs2darknet_m': _cfg(
|
||||
url=''),
|
||||
'cs2darknet_l': _cfg(
|
||||
url=''),
|
||||
'cs2darknet_f_m': _cfg(
|
||||
url=''),
|
||||
'cs2darknet_f_l': _cfg(
|
||||
url=''),
|
||||
}
|
||||
|
||||
|
||||
@ -116,6 +127,37 @@ model_cfgs = dict(
|
||||
down_growth=True,
|
||||
)
|
||||
),
|
||||
darknet17=dict(
|
||||
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
|
||||
stage=dict(
|
||||
out_chs=(64, 128, 256, 512, 1024),
|
||||
depth=(1,) * 5,
|
||||
stride=(2,) * 5,
|
||||
bottle_ratio=(0.5,) * 5,
|
||||
block_ratio=(1.,) * 5,
|
||||
)
|
||||
),
|
||||
darknet21=dict(
|
||||
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
|
||||
stage=dict(
|
||||
out_chs=(64, 128, 256, 512, 1024),
|
||||
depth=(1, 1, 1, 2, 2),
|
||||
stride=(2,) * 5,
|
||||
bottle_ratio=(0.5,) * 5,
|
||||
block_ratio=(1.,) * 5,
|
||||
)
|
||||
),
|
||||
sedarknet21=dict(
|
||||
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
|
||||
stage=dict(
|
||||
out_chs=(64, 128, 256, 512, 1024),
|
||||
depth=(1, 1, 1, 2, 2),
|
||||
stride=(2,) * 5,
|
||||
bottle_ratio=(0.5,) * 5,
|
||||
block_ratio=(1.,) * 5,
|
||||
attn_layer=('se',) * 5,
|
||||
)
|
||||
),
|
||||
darknet53=dict(
|
||||
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
|
||||
stage=dict(
|
||||
@ -125,13 +167,81 @@ model_cfgs = dict(
|
||||
bottle_ratio=(0.5,) * 5,
|
||||
block_ratio=(1.,) * 5,
|
||||
)
|
||||
),
|
||||
|
||||
darknetaa53=dict(
|
||||
stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
|
||||
stage=dict(
|
||||
out_chs=(64, 128, 256, 512, 1024),
|
||||
depth=(1, 2, 8, 8, 4),
|
||||
stride=(2,) * 5,
|
||||
bottle_ratio=(0.5,) * 5,
|
||||
block_ratio=(1.,) * 5,
|
||||
avg_down=True,
|
||||
),
|
||||
),
|
||||
|
||||
cs2darknet_m=dict(
|
||||
stem=dict(out_chs=(24, 48), kernel_size=3, stride=2, pool=''),
|
||||
stage=dict(
|
||||
out_chs=(96, 192, 384, 768),
|
||||
depth=(2, 4, 6, 2),
|
||||
stride=(2,) * 4,
|
||||
bottle_ratio=(1.,) * 4,
|
||||
block_ratio=(0.5,) * 4,
|
||||
avg_down=False,
|
||||
),
|
||||
),
|
||||
|
||||
cs2darknet_f_m=dict(
|
||||
stem=dict(out_chs=48, kernel_size=6, stride=2, padding=2, pool=''),
|
||||
stage=dict(
|
||||
out_chs=(96, 192, 384, 768),
|
||||
depth=(2, 4, 6, 2),
|
||||
stride=(2,) * 4,
|
||||
bottle_ratio=(1.,) * 4,
|
||||
block_ratio=(0.5,) * 4,
|
||||
avg_down=False,
|
||||
),
|
||||
),
|
||||
|
||||
cs2darknet_l=dict(
|
||||
stem=dict(out_chs=(32, 64), kernel_size=3, stride=2, pool=''),
|
||||
stage=dict(
|
||||
out_chs=(128, 256, 512, 1024),
|
||||
depth=(3, 6, 9, 3),
|
||||
stride=(2,) * 4,
|
||||
bottle_ratio=(1.,) * 4,
|
||||
block_ratio=(0.5,) * 4,
|
||||
avg_down=False,
|
||||
),
|
||||
),
|
||||
|
||||
cs2darknet_f_l=dict(
|
||||
stem=dict(out_chs=64, kernel_size=6, stride=2, padding=2, pool=''),
|
||||
stage=dict(
|
||||
out_chs=(128, 256, 512, 1024),
|
||||
depth=(3, 6, 9, 3),
|
||||
stride=(2,) * 4,
|
||||
bottle_ratio=(1.,) * 4,
|
||||
block_ratio=(0.5,) * 4,
|
||||
avg_down=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def create_stem(
|
||||
in_chans=3, out_chs=32, kernel_size=3, stride=2, pool='',
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
|
||||
in_chans=3,
|
||||
out_chs=32,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
pool='',
|
||||
padding='',
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
aa_layer=None
|
||||
):
|
||||
stem = nn.Sequential()
|
||||
if not isinstance(out_chs, (tuple, list)):
|
||||
out_chs = [out_chs]
|
||||
@ -140,8 +250,12 @@ def create_stem(
|
||||
for i, out_c in enumerate(out_chs):
|
||||
conv_name = f'conv{i + 1}'
|
||||
stem.add_module(conv_name, ConvNormAct(
|
||||
in_c, out_c, kernel_size, stride=stride if i == 0 else 1,
|
||||
act_layer=act_layer, norm_layer=norm_layer))
|
||||
in_c, out_c, kernel_size,
|
||||
stride=stride if i == 0 else 1,
|
||||
padding=padding if i == 0 else '',
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer
|
||||
))
|
||||
in_c = out_c
|
||||
last_conv = conv_name
|
||||
if pool:
|
||||
@ -158,9 +272,20 @@ class ResBottleneck(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False,
|
||||
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
dilation=1,
|
||||
bottle_ratio=0.25,
|
||||
groups=1,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
attn_last=False,
|
||||
attn_layer=None,
|
||||
aa_layer=None,
|
||||
drop_block=None,
|
||||
drop_path=None
|
||||
):
|
||||
super(ResBottleneck, self).__init__()
|
||||
mid_chs = int(round(out_chs * bottle_ratio))
|
||||
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
|
||||
@ -173,7 +298,7 @@ class ResBottleneck(nn.Module):
|
||||
self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs)
|
||||
self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None
|
||||
self.drop_path = drop_path
|
||||
self.act3 = act_layer(inplace=True)
|
||||
self.act3 = act_layer()
|
||||
|
||||
def zero_init_last(self):
|
||||
nn.init.zeros_(self.conv3.bn.weight)
|
||||
@ -201,9 +326,19 @@ class DarkBlock(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None,
|
||||
drop_block=None, drop_path=None):
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
dilation=1,
|
||||
bottle_ratio=0.5,
|
||||
groups=1,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
attn_layer=None,
|
||||
aa_layer=None,
|
||||
drop_block=None,
|
||||
drop_path=None
|
||||
):
|
||||
super(DarkBlock, self).__init__()
|
||||
mid_chs = int(round(out_chs * bottle_ratio))
|
||||
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
|
||||
@ -211,7 +346,7 @@ class DarkBlock(nn.Module):
|
||||
self.conv2 = ConvNormActAa(
|
||||
mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups,
|
||||
aa_layer=aa_layer, drop_layer=drop_block, **ckwargs)
|
||||
self.attn = create_attn(attn_layer, channels=out_chs)
|
||||
self.attn = create_attn(attn_layer, channels=out_chs, act_layer=act_layer)
|
||||
self.drop_path = drop_path
|
||||
|
||||
def zero_init_last(self):
|
||||
@ -232,23 +367,44 @@ class DarkBlock(nn.Module):
|
||||
class CrossStage(nn.Module):
|
||||
"""Cross Stage."""
|
||||
def __init__(
|
||||
self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1.,
|
||||
groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None,
|
||||
block_fn=ResBottleneck, **block_kwargs):
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride,
|
||||
dilation,
|
||||
depth,
|
||||
block_ratio=1.,
|
||||
bottle_ratio=1.,
|
||||
exp_ratio=1.,
|
||||
groups=1,
|
||||
first_dilation=None,
|
||||
avg_down=False,
|
||||
down_growth=False,
|
||||
cross_linear=False,
|
||||
block_dpr=None,
|
||||
block_fn=ResBottleneck,
|
||||
**block_kwargs
|
||||
):
|
||||
super(CrossStage, self).__init__()
|
||||
first_dilation = first_dilation or dilation
|
||||
down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
|
||||
exp_chs = int(round(out_chs * exp_ratio))
|
||||
self.exp_chs = exp_chs = int(round(out_chs * exp_ratio))
|
||||
block_out_chs = int(round(out_chs * block_ratio))
|
||||
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
|
||||
|
||||
if stride != 1 or first_dilation != dilation:
|
||||
self.conv_down = ConvNormActAa(
|
||||
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
||||
aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs)
|
||||
if avg_down:
|
||||
self.conv_down = nn.Sequential(
|
||||
nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
||||
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
||||
)
|
||||
else:
|
||||
self.conv_down = ConvNormActAa(
|
||||
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
||||
aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs)
|
||||
prev_chs = down_chs
|
||||
else:
|
||||
self.conv_down = None
|
||||
self.conv_down = nn.Identity()
|
||||
prev_chs = in_chs
|
||||
|
||||
# FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also,
|
||||
@ -269,30 +425,115 @@ class CrossStage(nn.Module):
|
||||
self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
if self.conv_down is not None:
|
||||
x = self.conv_down(x)
|
||||
x = self.conv_down(x)
|
||||
x = self.conv_exp(x)
|
||||
split = x.shape[1] // 2
|
||||
xs, xb = x[:, :split], x[:, split:]
|
||||
xs, xb = x.split(self.exp_chs // 2, dim=1)
|
||||
xb = self.blocks(xb)
|
||||
xb = self.conv_transition_b(xb).contiguous()
|
||||
out = self.conv_transition(torch.cat([xs, xb], dim=1))
|
||||
return out
|
||||
|
||||
|
||||
class CrossStage2(nn.Module):
|
||||
"""Cross Stage v2.
|
||||
Similar to CrossStage, but with one transition conv for the concat output.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride,
|
||||
dilation,
|
||||
depth,
|
||||
block_ratio=1.,
|
||||
bottle_ratio=1.,
|
||||
exp_ratio=1.,
|
||||
groups=1,
|
||||
first_dilation=None,
|
||||
avg_down=False,
|
||||
down_growth=False,
|
||||
cross_linear=False,
|
||||
block_dpr=None,
|
||||
block_fn=ResBottleneck,
|
||||
**block_kwargs
|
||||
):
|
||||
super(CrossStage2, self).__init__()
|
||||
first_dilation = first_dilation or dilation
|
||||
down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
|
||||
self.exp_chs = exp_chs = int(round(out_chs * exp_ratio))
|
||||
block_out_chs = int(round(out_chs * block_ratio))
|
||||
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
|
||||
|
||||
if stride != 1 or first_dilation != dilation:
|
||||
if avg_down:
|
||||
self.conv_down = nn.Sequential(
|
||||
nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
||||
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
||||
)
|
||||
else:
|
||||
self.conv_down = ConvNormActAa(
|
||||
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
||||
aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs)
|
||||
prev_chs = down_chs
|
||||
else:
|
||||
self.conv_down = None
|
||||
prev_chs = in_chs
|
||||
|
||||
# expansion conv
|
||||
self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs)
|
||||
prev_chs = exp_chs // 2 # expanded output is split in 2 for blocks and cross stage
|
||||
|
||||
self.blocks = nn.Sequential()
|
||||
for i in range(depth):
|
||||
drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None
|
||||
self.blocks.add_module(str(i), block_fn(
|
||||
prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs))
|
||||
prev_chs = block_out_chs
|
||||
|
||||
# transition convs
|
||||
self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_down(x)
|
||||
x = self.conv_exp(x)
|
||||
x1, x2 = x.split(self.exp_chs // 2, dim=1)
|
||||
x1 = self.blocks(x1)
|
||||
out = self.conv_transition(torch.cat([x1, x2], dim=1))
|
||||
return out
|
||||
|
||||
|
||||
class DarkStage(nn.Module):
|
||||
"""DarkNet stage."""
|
||||
|
||||
def __init__(
|
||||
self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1,
|
||||
first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs):
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride,
|
||||
dilation,
|
||||
depth,
|
||||
block_ratio=1.,
|
||||
bottle_ratio=1.,
|
||||
groups=1,
|
||||
first_dilation=None,
|
||||
avg_down=False,
|
||||
block_fn=ResBottleneck,
|
||||
block_dpr=None,
|
||||
**block_kwargs
|
||||
):
|
||||
super(DarkStage, self).__init__()
|
||||
first_dilation = first_dilation or dilation
|
||||
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
|
||||
|
||||
self.conv_down = ConvNormActAa(
|
||||
in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
||||
act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'),
|
||||
aa_layer=block_kwargs.get('aa_layer', None))
|
||||
if avg_down:
|
||||
self.conv_down = nn.Sequential(
|
||||
nn.AvgPool2d(3, 2, 1) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
||||
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
||||
)
|
||||
else:
|
||||
self.conv_down = ConvNormActAa(
|
||||
in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
||||
aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs)
|
||||
|
||||
prev_chs = out_chs
|
||||
block_out_chs = int(round(out_chs * block_ratio))
|
||||
@ -318,6 +559,8 @@ def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.):
|
||||
cfg['down_growth'] = (cfg['down_growth'],) * num_stages
|
||||
if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)):
|
||||
cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages
|
||||
if 'avg_down' in cfg and not isinstance(cfg['avg_down'], (list, tuple)):
|
||||
cfg['avg_down'] = (cfg['avg_down'],) * num_stages
|
||||
cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \
|
||||
[x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])]
|
||||
stage_strides = []
|
||||
@ -352,9 +595,20 @@ class CspNet(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0.,
|
||||
act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_path_rate=0.,
|
||||
zero_init_last=True, stage_fn=CrossStage, block_fn=ResBottleneck):
|
||||
self,
|
||||
cfg,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
output_stride=32,
|
||||
global_pool='avg',
|
||||
act_layer=nn.LeakyReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
aa_layer=None,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
zero_init_last=True,
|
||||
stage_fn=CrossStage,
|
||||
block_fn=ResBottleneck):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
@ -427,23 +681,22 @@ class CspNet(nn.Module):
|
||||
def _init_weights(module, name, zero_init_last=False):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, mean=0.0, std=0.01)
|
||||
nn.init.zeros_(module.bias)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif zero_init_last and hasattr(module, 'zero_init_last'):
|
||||
module.zero_init_last()
|
||||
|
||||
|
||||
def _create_cspnet(variant, pretrained=False, **kwargs):
|
||||
cfg_variant = variant.split('_')[0]
|
||||
# NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5]
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5) if 'darknet' in variant else (0, 1, 2, 3, 4))
|
||||
return build_model_with_cfg(
|
||||
CspNet, variant, pretrained,
|
||||
model_cfg=model_cfgs[cfg_variant],
|
||||
model_cfg=model_cfgs[variant],
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs)
|
||||
|
||||
@ -468,23 +721,56 @@ def cspresnext50(pretrained=False, **kwargs):
|
||||
return _create_cspnet('cspresnext50', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def cspresnext50_iabn(pretrained=False, **kwargs):
|
||||
norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu')
|
||||
return _create_cspnet('cspresnext50_iabn', pretrained=pretrained, norm_layer=norm_layer, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def cspdarknet53(pretrained=False, **kwargs):
|
||||
return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def cspdarknet53_iabn(pretrained=False, **kwargs):
|
||||
norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu')
|
||||
return _create_cspnet('cspdarknet53_iabn', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs)
|
||||
def darknet17(pretrained=False, **kwargs):
|
||||
return _create_cspnet('darknet17', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def darknet21(pretrained=False, **kwargs):
|
||||
return _create_cspnet('darknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def sedarknet21(pretrained=False, **kwargs):
|
||||
return _create_cspnet('sedarknet21', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def darknet53(pretrained=False, **kwargs):
|
||||
return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def darknetaa53(pretrained=False, **kwargs):
|
||||
return _create_cspnet(
|
||||
'darknetaa53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def cs2darknet_m(pretrained=False, **kwargs):
|
||||
return _create_cspnet(
|
||||
'cs2darknet_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def cs2darknet_l(pretrained=False, **kwargs):
|
||||
return _create_cspnet(
|
||||
'cs2darknet_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def cs2darknet_f_m(pretrained=False, **kwargs):
|
||||
return _create_cspnet(
|
||||
'cs2darknet_f_m', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def cs2darknet_f_l(pretrained=False, **kwargs):
|
||||
return _create_cspnet(
|
||||
'cs2darknet_f_l', pretrained=pretrained, block_fn=DarkBlock, stage_fn=CrossStage2, act_layer='silu', **kwargs)
|
@ -2,6 +2,7 @@
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import functools
|
||||
from torch import nn as nn
|
||||
|
||||
from .create_conv2d import create_conv2d
|
||||
@ -40,12 +41,26 @@ class ConvNormAct(nn.Module):
|
||||
ConvBnAct = ConvNormAct
|
||||
|
||||
|
||||
def create_aa(aa_layer, channels, stride=2, enable=True):
|
||||
if not aa_layer or not enable:
|
||||
return nn.Identity()
|
||||
if isinstance(aa_layer, functools.partial):
|
||||
if issubclass(aa_layer.func, nn.AvgPool2d):
|
||||
return aa_layer()
|
||||
else:
|
||||
return aa_layer(channels)
|
||||
elif issubclass(aa_layer, nn.AvgPool2d):
|
||||
return aa_layer(stride)
|
||||
else:
|
||||
return aa_layer(channels=channels, stride=stride)
|
||||
|
||||
|
||||
class ConvNormActAa(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
|
||||
bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None):
|
||||
super(ConvNormActAa, self).__init__()
|
||||
use_aa = aa_layer is not None
|
||||
use_aa = aa_layer is not None and stride == 2
|
||||
|
||||
self.conv = create_conv2d(
|
||||
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
|
||||
@ -56,7 +71,7 @@ class ConvNormActAa(nn.Module):
|
||||
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
||||
norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
|
||||
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
|
||||
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else nn.Identity()
|
||||
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user