Lots of changes to model creation helpers, close to finalizing feature extraction / interfaces

This commit is contained in:
Ross Wightman 2020-07-17 17:54:26 -07:00
parent e2cc481310
commit 3b9004bef9
27 changed files with 454 additions and 632 deletions

View File

@ -17,9 +17,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .features import FeatureNet from .helpers import build_model_with_cfg
from .helpers import load_pretrained from .layers import ClassifierHead, ConvBnAct, DropPath, create_attn, get_norm_act_layer
from .layers import SelectAdaptivePool2d, ConvBnAct, DropPath, create_attn, get_norm_act_layer
from .registry import register_model from .registry import register_model
@ -294,26 +293,6 @@ class DarkStage(nn.Module):
return x return x
class ClassifierHead(nn.Module):
"""Head."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
if num_classes > 0:
self.fc = nn.Linear(in_chs, num_classes, bias=True)
else:
self.fc = nn.Identity()
def forward(self, x):
x = self.global_pool(x).flatten(1)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x
def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.): def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.):
# get per stage args for stage and containing blocks, calculate strides to meet target output_stride # get per stage args for stage and containing blocks, calculate strides to meet target output_stride
num_stages = len(cfg['depth']) num_stages = len(cfg['depth'])
@ -420,62 +399,50 @@ class CspNet(nn.Module):
return x return x
def _cspnet(variant, pretrained=False, **kwargs): def _create_cspnet(variant, pretrained=False, **kwargs):
features = False
out_indices = None
if kwargs.pop('features_only', False):
features = True
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
cfg_variant = variant.split('_')[0] cfg_variant = variant.split('_')[0]
cfg = model_cfgs[cfg_variant] return build_model_with_cfg(
model = CspNet(cfg, **kwargs) CspNet, variant, pretrained, default_cfg=default_cfgs[variant],
model.default_cfg = default_cfgs[variant] feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant], **kwargs)
if pretrained:
load_pretrained(
model,
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features)
if features:
model = FeatureNet(model, out_indices, flatten_sequential=True)
return model
@register_model @register_model
def cspresnet50(pretrained=False, **kwargs): def cspresnet50(pretrained=False, **kwargs):
return _cspnet('cspresnet50', pretrained=pretrained, **kwargs) return _create_cspnet('cspresnet50', pretrained=pretrained, **kwargs)
@register_model @register_model
def cspresnet50d(pretrained=False, **kwargs): def cspresnet50d(pretrained=False, **kwargs):
return _cspnet('cspresnet50d', pretrained=pretrained, **kwargs) return _create_cspnet('cspresnet50d', pretrained=pretrained, **kwargs)
@register_model @register_model
def cspresnet50w(pretrained=False, **kwargs): def cspresnet50w(pretrained=False, **kwargs):
return _cspnet('cspresnet50w', pretrained=pretrained, **kwargs) return _create_cspnet('cspresnet50w', pretrained=pretrained, **kwargs)
@register_model @register_model
def cspresnext50(pretrained=False, **kwargs): def cspresnext50(pretrained=False, **kwargs):
return _cspnet('cspresnext50', pretrained=pretrained, **kwargs) return _create_cspnet('cspresnext50', pretrained=pretrained, **kwargs)
@register_model @register_model
def cspresnext50_iabn(pretrained=False, **kwargs): def cspresnext50_iabn(pretrained=False, **kwargs):
norm_layer = get_norm_act_layer('iabn') norm_layer = get_norm_act_layer('iabn')
return _cspnet('cspresnext50', pretrained=pretrained, norm_layer=norm_layer, **kwargs) return _create_cspnet('cspresnext50', pretrained=pretrained, norm_layer=norm_layer, **kwargs)
@register_model @register_model
def cspdarknet53(pretrained=False, **kwargs): def cspdarknet53(pretrained=False, **kwargs):
return _cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs) return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs)
@register_model @register_model
def cspdarknet53_iabn(pretrained=False, **kwargs): def cspdarknet53_iabn(pretrained=False, **kwargs):
norm_layer = get_norm_act_layer('iabn') norm_layer = get_norm_act_layer('iabn')
return _cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs) return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs)
@register_model @register_model
def darknet53(pretrained=False, **kwargs): def darknet53(pretrained=False, **kwargs):
return _cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)

View File

@ -13,8 +13,7 @@ import torch.utils.checkpoint as cp
from torch.jit.annotations import List from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .features import FeatureNet from .helpers import build_model_with_cfg
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, BlurPool2d from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, BlurPool2d
from .registry import register_model from .registry import register_model
@ -288,26 +287,12 @@ def _filter_torchvision_pretrained(state_dict):
return state_dict return state_dict
def _densenet(variant, growth_rate, block_config, pretrained, **kwargs): def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs):
features = False kwargs['growth_rate'] = growth_rate
out_indices = None kwargs['block_config'] = block_config
if kwargs.pop('features_only', False): return build_model_with_cfg(
features = True DenseNet, variant, pretrained, default_cfg=default_cfgs[variant],
kwargs.pop('num_classes', 0) feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained, **kwargs)
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
default_cfg = default_cfgs[variant]
model = DenseNet(growth_rate=growth_rate, block_config=block_config, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(
model, default_cfg,
num_classes=kwargs.get('num_classes', 0),
in_chans=kwargs.get('in_chans', 3),
filter_fn=_filter_torchvision_pretrained,
strict=not features)
if features:
model = FeatureNet(model, out_indices, flatten_sequential=True)
return model
@register_model @register_model
@ -315,7 +300,7 @@ def densenet121(pretrained=False, **kwargs):
r"""Densenet-121 model from r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
""" """
model = _densenet( model = _create_densenet(
'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs) 'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs)
return model return model
@ -325,7 +310,7 @@ def densenetblur121d(pretrained=False, **kwargs):
r"""Densenet-121 model from r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
""" """
model = _densenet( model = _create_densenet(
'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep', 'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep',
aa_layer=BlurPool2d, **kwargs) aa_layer=BlurPool2d, **kwargs)
return model return model
@ -336,7 +321,7 @@ def densenet121d(pretrained=False, **kwargs):
r"""Densenet-121 model from r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
""" """
model = _densenet( model = _create_densenet(
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
pretrained=pretrained, **kwargs) pretrained=pretrained, **kwargs)
return model return model
@ -347,7 +332,7 @@ def densenet169(pretrained=False, **kwargs):
r"""Densenet-169 model from r"""Densenet-169 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
""" """
model = _densenet( model = _create_densenet(
'densenet169', growth_rate=32, block_config=(6, 12, 32, 32), pretrained=pretrained, **kwargs) 'densenet169', growth_rate=32, block_config=(6, 12, 32, 32), pretrained=pretrained, **kwargs)
return model return model
@ -357,7 +342,7 @@ def densenet201(pretrained=False, **kwargs):
r"""Densenet-201 model from r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
""" """
model = _densenet( model = _create_densenet(
'densenet201', growth_rate=32, block_config=(6, 12, 48, 32), pretrained=pretrained, **kwargs) 'densenet201', growth_rate=32, block_config=(6, 12, 48, 32), pretrained=pretrained, **kwargs)
return model return model
@ -367,7 +352,7 @@ def densenet161(pretrained=False, **kwargs):
r"""Densenet-161 model from r"""Densenet-161 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
""" """
model = _densenet( model = _create_densenet(
'densenet161', growth_rate=48, block_config=(6, 12, 36, 24), pretrained=pretrained, **kwargs) 'densenet161', growth_rate=48, block_config=(6, 12, 36, 24), pretrained=pretrained, **kwargs)
return model return model
@ -377,7 +362,7 @@ def densenet264(pretrained=False, **kwargs):
r"""Densenet-264 model from r"""Densenet-264 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
""" """
model = _densenet( model = _create_densenet(
'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs) 'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs)
return model return model
@ -388,7 +373,7 @@ def densenet264d_iabn(pretrained=False, **kwargs):
""" """
def norm_act_fn(num_features, **kwargs): def norm_act_fn(num_features, **kwargs):
return create_norm_act('iabn', num_features, **kwargs) return create_norm_act('iabn', num_features, **kwargs)
model = _densenet( model = _create_densenet(
'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep', 'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep',
norm_layer=norm_act_fn, pretrained=pretrained, **kwargs) norm_layer=norm_act_fn, pretrained=pretrained, **kwargs)
return model return model
@ -399,6 +384,6 @@ def tv_densenet121(pretrained=False, **kwargs):
r"""Densenet-121 model with original Torchvision weights, from r"""Densenet-121 model with original Torchvision weights, from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
""" """
model = _densenet( model = _create_densenet(
'tv_densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs) 'tv_densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs)
return model return model

View File

@ -17,8 +17,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, create_conv2d from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_conv2d, ConvBnAct
from .registry import register_model from .registry import register_model
__all__ = ['DPN'] __all__ = ['DPN']
@ -82,20 +82,6 @@ class BnActConv2d(nn.Module):
return self.conv(self.bn(x)) return self.conv(self.bn(x))
class InputBlock(nn.Module):
def __init__(self, num_init_features, kernel_size=7, in_chans=3, norm_layer=BatchNormAct2d):
super(InputBlock, self).__init__()
self.conv = create_conv2d(in_chans, num_init_features, kernel_size=kernel_size, stride=2)
self.bn = norm_layer(num_init_features, eps=0.001)
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.pool(x)
return x
class DualPathBlock(nn.Module): class DualPathBlock(nn.Module):
def __init__( def __init__(
self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False): self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False):
@ -183,21 +169,21 @@ class DualPathBlock(nn.Module):
class DPN(nn.Module): class DPN(nn.Module):
def __init__(self, small=False, num_init_features=64, k_r=96, groups=32, def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32,
num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', fc_act=nn.ELU): num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', fc_act=nn.ELU):
super(DPN, self).__init__() super(DPN, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.b = b self.b = b
assert output_stride == 32 # FIXME look into dilation support
bw_factor = 1 if small else 4 bw_factor = 1 if small else 4
blocks = OrderedDict() blocks = OrderedDict()
# conv1 # conv1
if small: blocks['conv1_1'] = ConvBnAct(
blocks['conv1_1'] = InputBlock(num_init_features, in_chans=in_chans, kernel_size=3) in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_kwargs=dict(eps=.001))
else: blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
blocks['conv1_1'] = InputBlock(num_init_features, in_chans=in_chans, kernel_size=7) self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')]
# conv2 # conv2
bw = 64 * bw_factor bw = 64 * bw_factor
@ -208,6 +194,7 @@ class DPN(nn.Module):
for i in range(2, k_sec[0] + 1): for i in range(2, k_sec[0] + 1):
blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
in_chs += inc in_chs += inc
self.feature_info += [dict(num_chs=in_chs, reduction=4, module=f'features.conv2_{k_sec[0]}')]
# conv3 # conv3
bw = 128 * bw_factor bw = 128 * bw_factor
@ -218,6 +205,7 @@ class DPN(nn.Module):
for i in range(2, k_sec[1] + 1): for i in range(2, k_sec[1] + 1):
blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
in_chs += inc in_chs += inc
self.feature_info += [dict(num_chs=in_chs, reduction=8, module=f'features.conv3_{k_sec[1]}')]
# conv4 # conv4
bw = 256 * bw_factor bw = 256 * bw_factor
@ -228,6 +216,7 @@ class DPN(nn.Module):
for i in range(2, k_sec[2] + 1): for i in range(2, k_sec[2] + 1):
blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
in_chs += inc in_chs += inc
self.feature_info += [dict(num_chs=in_chs, reduction=16, module=f'features.conv4_{k_sec[2]}')]
# conv5 # conv5
bw = 512 * bw_factor bw = 512 * bw_factor
@ -238,6 +227,7 @@ class DPN(nn.Module):
for i in range(2, k_sec[3] + 1): for i in range(2, k_sec[3] + 1):
blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
in_chs += inc in_chs += inc
self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')]
def _fc_norm(f, eps): return BatchNormAct2d(f, eps=eps, act_layer=fc_act, inplace=False) def _fc_norm(f, eps): return BatchNormAct2d(f, eps=eps, act_layer=fc_act, inplace=False)
blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=_fc_norm) blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=_fc_norm)
@ -274,79 +264,55 @@ class DPN(nn.Module):
return out.flatten(1) return out.flatten(1)
def _create_dpn(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
DPN, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_concat=True, flatten_sequential=True), **kwargs)
@register_model @register_model
def dpn68(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def dpn68(pretrained=False, **kwargs):
default_cfg = default_cfgs['dpn68'] model_kwargs = dict(
model = DPN(
small=True, num_init_features=10, k_r=128, groups=32, small=True, num_init_features=10, k_r=128, groups=32,
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_dpn('dpn68', pretrained=pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dpn68b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def dpn68b(pretrained=False, **kwargs):
default_cfg = default_cfgs['dpn68b'] model_kwargs = dict(
model = DPN(
small=True, num_init_features=10, k_r=128, groups=32, small=True, num_init_features=10, k_r=128, groups=32,
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_dpn('dpn68b', pretrained=pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dpn92(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def dpn92(pretrained=False, **kwargs):
default_cfg = default_cfgs['dpn92'] model_kwargs = dict(
model = DPN(
num_init_features=64, k_r=96, groups=32, num_init_features=64, k_r=96, groups=32,
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_dpn('dpn92', pretrained=pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dpn98(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def dpn98(pretrained=False, **kwargs):
default_cfg = default_cfgs['dpn98'] model_kwargs = dict(
model = DPN(
num_init_features=96, k_r=160, groups=40, num_init_features=96, k_r=160, groups=40,
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_dpn('dpn98', pretrained=pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dpn131(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def dpn131(pretrained=False, **kwargs):
default_cfg = default_cfgs['dpn131'] model_kwargs = dict(
model = DPN(
num_init_features=128, k_r=160, groups=40, num_init_features=128, k_r=160, groups=40,
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_dpn('dpn131', pretrained=pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def dpn107(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def dpn107(pretrained=False, **kwargs):
default_cfg = default_cfgs['dpn107'] model_kwargs = dict(
model = DPN(
num_init_features=128, k_r=200, groups=50, num_init_features=128, k_r=200, groups=50,
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), **kwargs)
num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_dpn('dpn107', pretrained=pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

View File

@ -471,7 +471,7 @@ class EfficientNetFeatures(nn.Module):
return self.feature_hooks.get_output(x.device) return self.feature_hooks.get_output(x.device)
def _create_model(model_kwargs, default_cfg, pretrained=False): def _create_effnet(model_kwargs, default_cfg, pretrained=False):
if model_kwargs.pop('features_only', False): if model_kwargs.pop('features_only', False):
load_strict = False load_strict = False
model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_classes', 0)
@ -528,7 +528,7 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs **kwargs
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
return model return model
@ -564,7 +564,7 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs **kwargs
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
return model return model
@ -593,7 +593,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs **kwargs
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
return model return model
@ -622,7 +622,7 @@ def _gen_mobilenet_v2(
act_layer=resolve_act_layer(kwargs, 'relu6'), act_layer=resolve_act_layer(kwargs, 'relu6'),
**kwargs **kwargs
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
return model return model
@ -652,7 +652,7 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs **kwargs
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
return model return model
@ -687,7 +687,7 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs **kwargs
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
return model return model
@ -734,7 +734,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
variant=variant, variant=variant,
**kwargs, **kwargs,
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
return model return model
@ -763,7 +763,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
act_layer=resolve_act_layer(kwargs, 'relu'), act_layer=resolve_act_layer(kwargs, 'relu'),
**kwargs, **kwargs,
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
return model return model
@ -793,7 +793,7 @@ def _gen_efficientnet_condconv(
act_layer=resolve_act_layer(kwargs, 'swish'), act_layer=resolve_act_layer(kwargs, 'swish'),
**kwargs, **kwargs,
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
return model return model
@ -834,7 +834,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs, **kwargs,
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
return model return model
@ -867,7 +867,7 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs **kwargs
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
return model return model
@ -900,7 +900,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs **kwargs
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
return model return model

View File

@ -7,36 +7,38 @@ Hacked together by Ross Wightman
import torch import torch
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from functools import partial from functools import partial, partialmethod
from typing import List from typing import List
class FeatureHooks: class FeatureHooks:
def __init__(self, hooks, named_modules, output_as_dict=False): def __init__(self, hooks, named_modules, out_as_dict=False, out_map=None, default_hook_type='forward'):
# setup feature hooks # setup feature hooks
modules = {k: v for k, v in named_modules} modules = {k: v for k, v in named_modules}
for h in hooks: for i, h in enumerate(hooks):
hook_name = h['module'] hook_name = h['module']
m = modules[hook_name] m = modules[hook_name]
hook_fn = partial(self._collect_output_hook, hook_name) hook_id = out_map[i] if out_map else hook_name
if h['hook_type'] == 'forward_pre': hook_fn = partial(self._collect_output_hook, hook_id)
hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type
if hook_type == 'forward_pre':
m.register_forward_pre_hook(hook_fn) m.register_forward_pre_hook(hook_fn)
elif h['hook_type'] == 'forward': elif hook_type == 'forward':
m.register_forward_hook(hook_fn) m.register_forward_hook(hook_fn)
else: else:
assert False, "Unsupported hook type" assert False, "Unsupported hook type"
self._feature_outputs = defaultdict(OrderedDict) self._feature_outputs = defaultdict(OrderedDict)
self.output_as_dict = output_as_dict self.out_as_dict = out_as_dict
def _collect_output_hook(self, name, *args): def _collect_output_hook(self, hook_id, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
if isinstance(x, tuple): if isinstance(x, tuple):
x = x[0] # unwrap input tuple x = x[0] # unwrap input tuple
self._feature_outputs[x.device][name] = x self._feature_outputs[x.device][hook_id] = x
def get_output(self, device) -> List[torch.tensor]: def get_output(self, device) -> List[torch.tensor]:
if self.output_as_dict: if self.out_as_dict:
output = self._feature_outputs[device] output = self._feature_outputs[device]
else: else:
output = list(self._feature_outputs[device].values()) output = list(self._feature_outputs[device].values())

View File

@ -11,7 +11,8 @@ from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from .feature_hooks import FeatureHooks
class FeatureInfo: class FeatureInfo:
@ -24,11 +25,11 @@ class FeatureInfo:
assert 'reduction' in fi and fi['reduction'] >= prev_reduction assert 'reduction' in fi and fi['reduction'] >= prev_reduction
prev_reduction = fi['reduction'] prev_reduction = fi['reduction']
assert 'module' in fi assert 'module' in fi
self._out_indices = out_indices self.out_indices = out_indices
self._info = feature_info self.info = feature_info
def from_other(self, out_indices: Tuple[int]): def from_other(self, out_indices: Tuple[int]):
return FeatureInfo(deepcopy(self._info), out_indices) return FeatureInfo(deepcopy(self.info), out_indices)
def channels(self, idx=None): def channels(self, idx=None):
""" feature channels accessor """ feature channels accessor
@ -36,8 +37,8 @@ class FeatureInfo:
if idx is an integer, return feature channel count for that feature module index if idx is an integer, return feature channel count for that feature module index
""" """
if isinstance(idx, int): if isinstance(idx, int):
return self._info[idx]['num_chs'] return self.info[idx]['num_chs']
return [self._info[i]['num_chs'] for i in self._out_indices] return [self.info[i]['num_chs'] for i in self.out_indices]
def reduction(self, idx=None): def reduction(self, idx=None):
""" feature reduction (output stride) accessor """ feature reduction (output stride) accessor
@ -45,8 +46,8 @@ class FeatureInfo:
if idx is an integer, return feature channel count at that feature module index if idx is an integer, return feature channel count at that feature module index
""" """
if isinstance(idx, int): if isinstance(idx, int):
return self._info[idx]['reduction'] return self.info[idx]['reduction']
return [self._info[i]['reduction'] for i in self._out_indices] return [self.info[i]['reduction'] for i in self.out_indices]
def module_name(self, idx=None): def module_name(self, idx=None):
""" feature module name accessor """ feature module name accessor
@ -54,24 +55,24 @@ class FeatureInfo:
if idx is an integer, return feature module name at that feature module index if idx is an integer, return feature module name at that feature module index
""" """
if isinstance(idx, int): if isinstance(idx, int):
return self._info[idx]['module'] return self.info[idx]['module']
return [self._info[i]['module'] for i in self._out_indices] return [self.info[i]['module'] for i in self.out_indices]
def get_by_key(self, idx=None, keys=None): def get_by_key(self, idx=None, keys=None):
""" return info dicts for specified keys (or all if None) at specified idx (or out_indices if None) """ return info dicts for specified keys (or all if None) at specified idx (or out_indices if None)
""" """
if isinstance(idx, int): if isinstance(idx, int):
return self._info[idx] if keys is None else {k: self._info[idx][k] for k in keys} return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
if keys is None: if keys is None:
return [self._info[i] for i in self._out_indices] return [self.info[i] for i in self.out_indices]
else: else:
return [{k: self._info[i][k] for k in keys} for i in self._out_indices] return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
def __getitem__(self, item): def __getitem__(self, item):
return self._info[item] return self.info[item]
def __len__(self): def __len__(self):
return len(self._info) return len(self.info)
def _module_list(module, flatten_sequential=False): def _module_list(module, flatten_sequential=False):
@ -81,30 +82,47 @@ def _module_list(module, flatten_sequential=False):
if flatten_sequential and isinstance(module, nn.Sequential): if flatten_sequential and isinstance(module, nn.Sequential):
# first level of Sequential containers is flattened into containing model # first level of Sequential containers is flattened into containing model
for child_name, child_module in module.named_children(): for child_name, child_module in module.named_children():
ml.append(('_'.join([name, child_name]), child_module)) combined = [name, child_name]
ml.append(('_'.join(combined), '.'.join(combined), child_module))
else: else:
ml.append((name, module)) ml.append((name, name, module))
return ml return ml
def _check_return_layers(input_return_layers, modules): class LayerGetterHooks(nn.ModuleDict):
return_layers = {} """ LayerGetterHooks
for k, v in input_return_layers.items(): TODO
ks = k.split('.') """
assert 0 < len(ks) <= 2
return_layers['_'.join(ks)] = v def __init__(self, model, feature_info, flatten_sequential=False, out_as_dict=False, out_map=None,
return_set = set(return_layers.keys()) default_hook_type='forward'):
sdiff = return_set - {name for name, _ in modules} modules = _module_list(model, flatten_sequential=flatten_sequential)
if sdiff: remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type for f in feature_info}
raise ValueError(f'return_layers {sdiff} are not present in model') layers = OrderedDict()
return return_layers, return_set hooks = []
for new_name, old_name, module in modules:
layers[new_name] = module
for fn, fm in module.named_modules(prefix=old_name):
if fn in remaining:
hooks.append(dict(module=fn, hook_type=remaining[fn]))
del remaining[fn]
if not remaining:
break
assert not remaining, f'Return layers ({remaining}) are not present in model'
super(LayerGetterHooks, self).__init__(layers)
self.hooks = FeatureHooks(hooks, model.named_modules(), out_as_dict=out_as_dict, out_map=out_map)
def forward(self, x) -> Dict[Any, torch.Tensor]:
for name, module in self.items():
x = module(x)
return self.hooks.get_output(x.device)
class LayerGetterDict(nn.ModuleDict): class LayerGetterDict(nn.ModuleDict):
""" """
Module wrapper that returns intermediate layers from a model as a dictionary Module wrapper that returns intermediate layers from a model as a dictionary
Originally based on IntermediateLayerGetter at Originally based on concepts from IntermediateLayerGetter at
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
It has a strong assumption that the modules have been registered into the model in the same It has a strong assumption that the modules have been registered into the model in the same
@ -131,16 +149,20 @@ class LayerGetterDict(nn.ModuleDict):
""" """
def __init__(self, model, return_layers, concat=False, flatten_sequential=False): def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
modules = _module_list(model, flatten_sequential=flatten_sequential) self.return_layers = {}
self.return_layers, remaining = _check_return_layers(return_layers, modules)
layers = OrderedDict()
self.concat = concat self.concat = concat
for name, module in modules: modules = _module_list(model, flatten_sequential=flatten_sequential)
layers[name] = module remaining = set(return_layers.keys())
if name in remaining: layers = OrderedDict()
remaining.remove(name) for new_name, old_name, module in modules:
layers[new_name] = module
if old_name in remaining:
self.return_layers[new_name] = return_layers[old_name]
remaining.remove(old_name)
if not remaining: if not remaining:
break break
assert not remaining and len(self.return_layers) == len(return_layers), \
f'Return layers ({remaining}) are not present in model'
super(LayerGetterDict, self).__init__(layers) super(LayerGetterDict, self).__init__(layers)
def forward(self, x) -> Dict[Any, torch.Tensor]: def forward(self, x) -> Dict[Any, torch.Tensor]:
@ -162,7 +184,7 @@ class LayerGetterList(nn.Sequential):
""" """
Module wrapper that returns intermediate layers from a model as a list Module wrapper that returns intermediate layers from a model as a list
Originally based on IntermediateLayerGetter at Originally based on concepts from IntermediateLayerGetter at
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
It has a strong assumption that the modules have been registered into the model in the same It has a strong assumption that the modules have been registered into the model in the same
@ -190,15 +212,19 @@ class LayerGetterList(nn.Sequential):
def __init__(self, model, return_layers, concat=False, flatten_sequential=False): def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
super(LayerGetterList, self).__init__() super(LayerGetterList, self).__init__()
modules = _module_list(model, flatten_sequential=flatten_sequential) self.return_layers = {}
self.return_layers, remaining = _check_return_layers(return_layers, modules)
self.concat = concat self.concat = concat
for name, module in modules: modules = _module_list(model, flatten_sequential=flatten_sequential)
self.add_module(name, module) remaining = set(return_layers.keys())
if name in remaining: for new_name, orig_name, module in modules:
remaining.remove(name) self.add_module(new_name, module)
if orig_name in remaining:
self.return_layers[new_name] = return_layers[orig_name]
remaining.remove(orig_name)
if not remaining: if not remaining:
break break
assert not remaining and len(self.return_layers) == len(return_layers), \
f'Return layers ({remaining}) are not present in model'
def forward(self, x) -> List[torch.Tensor]: def forward(self, x) -> List[torch.Tensor]:
out = [] out = []
@ -225,6 +251,14 @@ def _resolve_feature_info(net, out_indices, feature_info=None):
assert False, "Provided feature_info is not valid" assert False, "Provided feature_info is not valid"
def _get_return_layers(feature_info, out_map):
module_names = feature_info.module_name()
return_layers = {}
for i, name in enumerate(module_names):
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
return return_layers
class FeatureNet(nn.Module): class FeatureNet(nn.Module):
""" FeatureNet """ FeatureNet
@ -235,17 +269,41 @@ class FeatureNet(nn.Module):
""" """
def __init__( def __init__(
self, net, self, net,
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, use_hooks=False,
feature_info=None, feature_concat=False, flatten_sequential=False): feature_info=None, feature_concat=False, flatten_sequential=False):
super(FeatureNet, self).__init__() super(FeatureNet, self).__init__()
self.feature_info = _resolve_feature_info(net, out_indices, feature_info) self.feature_info = _resolve_feature_info(net, out_indices, feature_info)
module_names = self.feature_info.module_name() if use_hooks:
return_layers = {} self.body = LayerGetterHooks(net, self.feature_info, out_as_dict=out_as_dict, out_map=out_map)
for i in range(len(out_indices)): else:
return_layers[module_names[i]] = out_map[i] if out_map is not None else out_indices[i] return_layers = _get_return_layers(self.feature_info, out_map)
lg_args = dict(return_layers=return_layers, concat=feature_concat, flatten_sequential=flatten_sequential) lg_args = dict(return_layers=return_layers, concat=feature_concat, flatten_sequential=flatten_sequential)
self.body = LayerGetterDict(net, **lg_args) if out_as_dict else LayerGetterList(net, **lg_args) self.body = LayerGetterDict(net, **lg_args) if out_as_dict else LayerGetterList(net, **lg_args)
def forward(self, x): def forward(self, x):
output = self.body(x) output = self.body(x)
return output return output
class FeatureHookNet(nn.Module):
""" FeatureHookNet
Wrap a model and extract features specified by the out indices.
Features are extracted via hooks without modifying the underlying network in any way. If only
part of the model is used it is up to the caller to remove unneeded layers as this wrapper
does not rewrite and remove unused top-level modules like FeatureNet with LayerGetter.
"""
def __init__(
self, net,
out_indices=(0, 1, 2, 3, 4), out_as_dict=False, out_map=None,
feature_info=None, feature_concat=False):
super(FeatureHookNet, self).__init__()
self.feature_info = _resolve_feature_info(net, out_indices, feature_info)
self.body = net
self.hooks = FeatureHooks(
self.feature_info, self.body.named_modules(), out_as_dict=out_as_dict, out_map=out_map)
def forward(self, x):
self.body(x)
return self.hooks.get_output(x.device)

View File

@ -5,9 +5,10 @@ by Ross Wightman
""" """
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import SEModule from .layers import SEModule
from .registry import register_model from .registry import register_model
from .resnet import _create_resnet_with_cfg, Bottleneck, BasicBlock from .resnet import ResNet, Bottleneck, BasicBlock
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -47,8 +48,7 @@ default_cfgs = {
def _create_resnet(variant, pretrained=False, **kwargs): def _create_resnet(variant, pretrained=False, **kwargs):
default_cfg = default_cfgs[variant] return build_model_with_cfg(ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs)
@register_model @register_model

View File

@ -1,11 +1,15 @@
import logging
import os
from collections import OrderedDict
from copy import deepcopy
from typing import Callable
import torch import torch
import torch.nn as nn import torch.nn as nn
from copy import deepcopy
import torch.utils.model_zoo as model_zoo import torch.utils.model_zoo as model_zoo
import os
import logging from .features import FeatureNet
from collections import OrderedDict from .layers import Conv2dSame
from timm.models.layers.conv2d_same import Conv2dSame
def load_state_dict(checkpoint_path, use_ema=False): def load_state_dict(checkpoint_path, use_ema=False):
@ -194,3 +198,42 @@ def adapt_model_from_file(parent_module, model_variant):
adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt') adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
with open(adapt_file, 'r') as f: with open(adapt_file, 'r') as f:
return adapt_model_from_string(parent_module, f.read().strip()) return adapt_model_from_string(parent_module, f.read().strip())
def build_model_with_cfg(
model_cls: Callable,
variant: str,
pretrained: bool,
default_cfg: dict,
model_cfg: dict = None,
feature_cfg: dict = None,
pretrained_filter_fn: Callable = None,
**kwargs):
pruned = kwargs.pop('pruned', False)
features = False
feature_cfg = feature_cfg or {}
if kwargs.pop('features_only', False):
features = True
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
if 'out_indices' in kwargs:
feature_cfg['out_indices'] = kwargs.pop('out_indices')
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
model.default_cfg = deepcopy(default_cfg)
if pruned:
model = adapt_model_from_file(model, variant)
if pretrained:
load_pretrained(
model,
num_classes=kwargs.get('num_classes', 0),
in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn)
if features:
feature_cls = feature_cfg.pop('feature_cls', FeatureNet)
model = feature_cls(model, **feature_cfg)
return model

View File

@ -19,7 +19,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from .registry import register_model from .registry import register_model
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
@ -734,67 +734,52 @@ class HighResolutionNet(nn.Module):
return x return x
def _create_model(variant, pretrained, model_kwargs): def _create_hrnet(variant, pretrained, **model_kwargs):
if model_kwargs.pop('features_only', False): return build_model_with_cfg(
assert False, 'Not Implemented' # TODO HighResolutionNet, variant, pretrained, default_cfg=default_cfgs[variant],
load_strict = False model_cfg=cfg_cls[variant], **model_kwargs)
model_kwargs.pop('num_classes', 0)
model_class = HighResolutionNet
else:
load_strict = True
model_class = HighResolutionNet
model = model_class(cfg_cls[variant], **model_kwargs)
model.default_cfg = default_cfgs[variant]
if pretrained:
load_pretrained(
model,
num_classes=model_kwargs.get('num_classes', 0),
in_chans=model_kwargs.get('in_chans', 3),
strict=load_strict)
return model
@register_model @register_model
def hrnet_w18_small(pretrained=True, **kwargs): def hrnet_w18_small(pretrained=True, **kwargs):
return _create_model('hrnet_w18_small', pretrained, kwargs) return _create_hrnet('hrnet_w18_small', pretrained, **kwargs)
@register_model @register_model
def hrnet_w18_small_v2(pretrained=True, **kwargs): def hrnet_w18_small_v2(pretrained=True, **kwargs):
return _create_model('hrnet_w18_small_v2', pretrained, kwargs) return _create_hrnet('hrnet_w18_small_v2', pretrained, **kwargs)
@register_model @register_model
def hrnet_w18(pretrained=True, **kwargs): def hrnet_w18(pretrained=True, **kwargs):
return _create_model('hrnet_w18', pretrained, kwargs) return _create_hrnet('hrnet_w18', pretrained, **kwargs)
@register_model @register_model
def hrnet_w30(pretrained=True, **kwargs): def hrnet_w30(pretrained=True, **kwargs):
return _create_model('hrnet_w30', pretrained, kwargs) return _create_hrnet('hrnet_w30', pretrained, **kwargs)
@register_model @register_model
def hrnet_w32(pretrained=True, **kwargs): def hrnet_w32(pretrained=True, **kwargs):
return _create_model('hrnet_w32', pretrained, kwargs) return _create_hrnet('hrnet_w32', pretrained, **kwargs)
@register_model @register_model
def hrnet_w40(pretrained=True, **kwargs): def hrnet_w40(pretrained=True, **kwargs):
return _create_model('hrnet_w40', pretrained, kwargs) return _create_hrnet('hrnet_w40', pretrained, **kwargs)
@register_model @register_model
def hrnet_w44(pretrained=True, **kwargs): def hrnet_w44(pretrained=True, **kwargs):
return _create_model('hrnet_w44', pretrained, kwargs) return _create_hrnet('hrnet_w44', pretrained, **kwargs)
@register_model @register_model
def hrnet_w48(pretrained=True, **kwargs): def hrnet_w48(pretrained=True, **kwargs):
return _create_model('hrnet_w48', pretrained, kwargs) return _create_hrnet('hrnet_w48', pretrained, **kwargs)
@register_model @register_model
def hrnet_w64(pretrained=True, **kwargs): def hrnet_w64(pretrained=True, **kwargs):
return _create_model('hrnet_w64', pretrained, kwargs) return _create_hrnet('hrnet_w64', pretrained, **kwargs)

View File

@ -7,8 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .features import FeatureNet from .helpers import build_model_with_cfg
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from .registry import register_model from .registry import register_model
@ -340,20 +339,9 @@ class InceptionResnetV2(nn.Module):
return x return x
def _inception_resnet_v2(variant, pretrained=False, **kwargs): def _create_inception_resnet_v2(variant, pretrained=False, **kwargs):
features, out_indices = False, None return build_model_with_cfg(
if kwargs.pop('features_only', False): InceptionResnetV2, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs)
features = True
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
model = InceptionResnetV2(**kwargs)
model.default_cfg = default_cfgs[variant]
if pretrained:
load_pretrained(
model,
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features)
if features:
model = FeatureNet(model, out_indices)
return model
@register_model @register_model
@ -361,7 +349,7 @@ def inception_resnet_v2(pretrained=False, **kwargs):
r"""InceptionResnetV2 model architecture from the r"""InceptionResnetV2 model architecture from the
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper. `"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
""" """
return _inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs) return _create_inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs)
@register_model @register_model
@ -370,4 +358,4 @@ def ens_adv_inception_resnet_v2(pretrained=False, **kwargs):
As per https://arxiv.org/abs/1705.07204 and As per https://arxiv.org/abs/1705.07204 and
https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models. https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
""" """
return _inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs) return _create_inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs)

View File

@ -504,14 +504,9 @@ class BasicConv2d(nn.Module):
return F.relu(x, inplace=True) return F.relu(x, inplace=True)
def _inception_v3(variant, pretrained=False, **kwargs): def _create_inception_v3(variant, pretrained=False, **kwargs):
assert not kwargs.pop('features_only', False)
default_cfg = default_cfgs[variant] default_cfg = default_cfgs[variant]
if kwargs.pop('features_only', False):
assert False, 'Not Implemented' # TODO
load_strict = False
model_kwargs.pop('num_classes', 0)
model_class = InceptionV3
else:
aux_logits = kwargs.pop('aux_logits', False) aux_logits = kwargs.pop('aux_logits', False)
if aux_logits: if aux_logits:
model_class = InceptionV3Aux model_class = InceptionV3Aux
@ -534,14 +529,14 @@ def _inception_v3(variant, pretrained=False, **kwargs):
@register_model @register_model
def inception_v3(pretrained=False, **kwargs): def inception_v3(pretrained=False, **kwargs):
# original PyTorch weights, ported from Tensorflow but modified # original PyTorch weights, ported from Tensorflow but modified
model = _inception_v3('inception_v3', pretrained=pretrained, **kwargs) model = _create_inception_v3('inception_v3', pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def tf_inception_v3(pretrained=False, **kwargs): def tf_inception_v3(pretrained=False, **kwargs):
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
model = _inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs) model = _create_inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs)
return model return model
@ -549,7 +544,7 @@ def tf_inception_v3(pretrained=False, **kwargs):
def adv_inception_v3(pretrained=False, **kwargs): def adv_inception_v3(pretrained=False, **kwargs):
# my port of Tensorflow adversarially trained Inception V3 from # my port of Tensorflow adversarially trained Inception V3 from
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
model = _inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs) model = _create_inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs)
return model return model
@ -557,5 +552,5 @@ def adv_inception_v3(pretrained=False, **kwargs):
def gluon_inception_v3(pretrained=False, **kwargs): def gluon_inception_v3(pretrained=False, **kwargs):
# from gluon pretrained models, best performing in terms of accuracy/loss metrics # from gluon pretrained models, best performing in terms of accuracy/loss metrics
# https://gluon-cv.mxnet.io/model_zoo/classification.html # https://gluon-cv.mxnet.io/model_zoo/classification.html
model = _inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs) model = _create_inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs)
return model return model

View File

@ -3,6 +3,7 @@ from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .anti_aliasing import AntiAliasDownsampleLayer from .anti_aliasing import AntiAliasDownsampleLayer
from .blur_pool import BlurPool2d from .blur_pool import BlurPool2d
from .classifier import ClassifierHead
from .cond_conv2d import CondConv2d, get_condconv_initializer from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
set_layer_config set_layer_config
@ -24,6 +25,7 @@ from .se import SEModule
from .selective_kernel import SelectiveKernelConv from .selective_kernel import SelectiveKernelConv
from .separable_conv import SeparableConv2d, SeparableConvBnAct from .separable_conv import SeparableConv2d, SeparableConvBnAct
from .space_to_depth import SpaceToDepthModule from .space_to_depth import SpaceToDepthModule
from .split_attn import SplitAttnConv2d
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_

View File

@ -82,7 +82,7 @@ class HardSwish(nn.Module):
self.inplace = inplace self.inplace = inplace
def forward(self, x): def forward(self, x):
return hard_swish(x, self.inplace) return F.hardswish(x) #hard_swish(x, self.inplace)
def hard_sigmoid(x, inplace: bool = False): def hard_sigmoid(x, inplace: bool = False):

View File

@ -0,0 +1,24 @@
from torch import nn as nn
from torch.nn import functional as F
from .adaptive_avgmax_pool import SelectAdaptivePool2d
class ClassifierHead(nn.Module):
"""Classifier Head w/ configurable global pooling and dropout."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
if num_classes > 0:
self.fc = nn.Linear(in_chs * self.global_pool.feat_mult(), num_classes, bias=True)
else:
self.fc = nn.Identity()
def forward(self, x):
x = self.global_pool(x).flatten(1)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x

View File

@ -39,7 +39,7 @@ _ACT_FN_ME = dict(
) )
_ACT_LAYER_DEFAULT = dict( _ACT_LAYER_DEFAULT = dict(
swish=Swish, swish=Swish, #nn.SiLU, #
mish=Mish, mish=Mish,
relu=nn.ReLU, relu=nn.ReLU,
relu6=nn.ReLU6, relu6=nn.ReLU6,
@ -56,7 +56,7 @@ _ACT_LAYER_DEFAULT = dict(
) )
_ACT_LAYER_JIT = dict( _ACT_LAYER_JIT = dict(
swish=SwishJit, #swish=SwishJit,
mish=MishJit, mish=MishJit,
hard_sigmoid=HardSigmoidJit, hard_sigmoid=HardSigmoidJit,
hard_swish=HardSwishJit, hard_swish=HardSwishJit,

View File

@ -19,7 +19,7 @@ tup_single = _ntuple(1)
tup_pair = _ntuple(2) tup_pair = _ntuple(2)
tup_triple = _ntuple(3) tup_triple = _ntuple(3)
tup_quadruple = _ntuple(4) tup_quadruple = _ntuple(4)
ntup = _ntuple

View File

@ -215,7 +215,7 @@ class MobileNetV3Features(nn.Module):
return self.feature_hooks.get_output(x.device) return self.feature_hooks.get_output(x.device)
def _create_model(model_kwargs, default_cfg, pretrained=False): def _create_mnv3(model_kwargs, default_cfg, pretrained=False):
if model_kwargs.pop('features_only', False): if model_kwargs.pop('features_only', False):
load_strict = False load_strict = False
model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_classes', 0)
@ -272,7 +272,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1), se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1),
**kwargs, **kwargs,
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_mnv3(model_kwargs, default_cfgs[variant], pretrained)
return model return model
@ -368,7 +368,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8), se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
**kwargs, **kwargs,
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_mnv3(model_kwargs, default_cfgs[variant], pretrained)
return model return model

View File

@ -14,12 +14,10 @@ Weights from original impl have been modified
""" """
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .features import FeatureNet from .helpers import build_model_with_cfg
from .helpers import load_pretrained from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule
from .layers import SelectAdaptivePool2d, AvgPool2dSame, ConvBnAct, SEModule
from .registry import register_model from .registry import register_model
@ -222,26 +220,6 @@ class RegStage(nn.Module):
return x return x
class ClassifierHead(nn.Module):
"""Head."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
if num_classes > 0:
self.fc = nn.Linear(in_chs, num_classes, bias=True)
else:
self.fc = nn.Identity()
def forward(self, x):
x = self.global_pool(x).flatten(1)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x
class RegNet(nn.Module): class RegNet(nn.Module):
"""RegNet model. """RegNet model.
@ -343,163 +321,150 @@ class RegNet(nn.Module):
return x return x
def _regnet(variant, pretrained, **kwargs): def _create_regnet(variant, pretrained, **kwargs):
features = False return build_model_with_cfg(
out_indices = None RegNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant], **kwargs)
if kwargs.pop('features_only', False):
features = True
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
model_cfg = model_cfgs[variant]
model = RegNet(model_cfg, **kwargs)
model.default_cfg = default_cfgs[variant]
if pretrained:
load_pretrained(
model,
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features)
if features:
model = FeatureNet(model, out_indices=out_indices)
return model
@register_model @register_model
def regnetx_002(pretrained=False, **kwargs): def regnetx_002(pretrained=False, **kwargs):
"""RegNetX-200MF""" """RegNetX-200MF"""
return _regnet('regnetx_002', pretrained, **kwargs) return _create_regnet('regnetx_002', pretrained, **kwargs)
@register_model @register_model
def regnetx_004(pretrained=False, **kwargs): def regnetx_004(pretrained=False, **kwargs):
"""RegNetX-400MF""" """RegNetX-400MF"""
return _regnet('regnetx_004', pretrained, **kwargs) return _create_regnet('regnetx_004', pretrained, **kwargs)
@register_model @register_model
def regnetx_006(pretrained=False, **kwargs): def regnetx_006(pretrained=False, **kwargs):
"""RegNetX-600MF""" """RegNetX-600MF"""
return _regnet('regnetx_006', pretrained, **kwargs) return _create_regnet('regnetx_006', pretrained, **kwargs)
@register_model @register_model
def regnetx_008(pretrained=False, **kwargs): def regnetx_008(pretrained=False, **kwargs):
"""RegNetX-800MF""" """RegNetX-800MF"""
return _regnet('regnetx_008', pretrained, **kwargs) return _create_regnet('regnetx_008', pretrained, **kwargs)
@register_model @register_model
def regnetx_016(pretrained=False, **kwargs): def regnetx_016(pretrained=False, **kwargs):
"""RegNetX-1.6GF""" """RegNetX-1.6GF"""
return _regnet('regnetx_016', pretrained, **kwargs) return _create_regnet('regnetx_016', pretrained, **kwargs)
@register_model @register_model
def regnetx_032(pretrained=False, **kwargs): def regnetx_032(pretrained=False, **kwargs):
"""RegNetX-3.2GF""" """RegNetX-3.2GF"""
return _regnet('regnetx_032', pretrained, **kwargs) return _create_regnet('regnetx_032', pretrained, **kwargs)
@register_model @register_model
def regnetx_040(pretrained=False, **kwargs): def regnetx_040(pretrained=False, **kwargs):
"""RegNetX-4.0GF""" """RegNetX-4.0GF"""
return _regnet('regnetx_040', pretrained, **kwargs) return _create_regnet('regnetx_040', pretrained, **kwargs)
@register_model @register_model
def regnetx_064(pretrained=False, **kwargs): def regnetx_064(pretrained=False, **kwargs):
"""RegNetX-6.4GF""" """RegNetX-6.4GF"""
return _regnet('regnetx_064', pretrained, **kwargs) return _create_regnet('regnetx_064', pretrained, **kwargs)
@register_model @register_model
def regnetx_080(pretrained=False, **kwargs): def regnetx_080(pretrained=False, **kwargs):
"""RegNetX-8.0GF""" """RegNetX-8.0GF"""
return _regnet('regnetx_080', pretrained, **kwargs) return _create_regnet('regnetx_080', pretrained, **kwargs)
@register_model @register_model
def regnetx_120(pretrained=False, **kwargs): def regnetx_120(pretrained=False, **kwargs):
"""RegNetX-12GF""" """RegNetX-12GF"""
return _regnet('regnetx_120', pretrained, **kwargs) return _create_regnet('regnetx_120', pretrained, **kwargs)
@register_model @register_model
def regnetx_160(pretrained=False, **kwargs): def regnetx_160(pretrained=False, **kwargs):
"""RegNetX-16GF""" """RegNetX-16GF"""
return _regnet('regnetx_160', pretrained, **kwargs) return _create_regnet('regnetx_160', pretrained, **kwargs)
@register_model @register_model
def regnetx_320(pretrained=False, **kwargs): def regnetx_320(pretrained=False, **kwargs):
"""RegNetX-32GF""" """RegNetX-32GF"""
return _regnet('regnetx_320', pretrained, **kwargs) return _create_regnet('regnetx_320', pretrained, **kwargs)
@register_model @register_model
def regnety_002(pretrained=False, **kwargs): def regnety_002(pretrained=False, **kwargs):
"""RegNetY-200MF""" """RegNetY-200MF"""
return _regnet('regnety_002', pretrained, **kwargs) return _create_regnet('regnety_002', pretrained, **kwargs)
@register_model @register_model
def regnety_004(pretrained=False, **kwargs): def regnety_004(pretrained=False, **kwargs):
"""RegNetY-400MF""" """RegNetY-400MF"""
return _regnet('regnety_004', pretrained, **kwargs) return _create_regnet('regnety_004', pretrained, **kwargs)
@register_model @register_model
def regnety_006(pretrained=False, **kwargs): def regnety_006(pretrained=False, **kwargs):
"""RegNetY-600MF""" """RegNetY-600MF"""
return _regnet('regnety_006', pretrained, **kwargs) return _create_regnet('regnety_006', pretrained, **kwargs)
@register_model @register_model
def regnety_008(pretrained=False, **kwargs): def regnety_008(pretrained=False, **kwargs):
"""RegNetY-800MF""" """RegNetY-800MF"""
return _regnet('regnety_008', pretrained, **kwargs) return _create_regnet('regnety_008', pretrained, **kwargs)
@register_model @register_model
def regnety_016(pretrained=False, **kwargs): def regnety_016(pretrained=False, **kwargs):
"""RegNetY-1.6GF""" """RegNetY-1.6GF"""
return _regnet('regnety_016', pretrained, **kwargs) return _create_regnet('regnety_016', pretrained, **kwargs)
@register_model @register_model
def regnety_032(pretrained=False, **kwargs): def regnety_032(pretrained=False, **kwargs):
"""RegNetY-3.2GF""" """RegNetY-3.2GF"""
return _regnet('regnety_032', pretrained, **kwargs) return _create_regnet('regnety_032', pretrained, **kwargs)
@register_model @register_model
def regnety_040(pretrained=False, **kwargs): def regnety_040(pretrained=False, **kwargs):
"""RegNetY-4.0GF""" """RegNetY-4.0GF"""
return _regnet('regnety_040', pretrained, **kwargs) return _create_regnet('regnety_040', pretrained, **kwargs)
@register_model @register_model
def regnety_064(pretrained=False, **kwargs): def regnety_064(pretrained=False, **kwargs):
"""RegNetY-6.4GF""" """RegNetY-6.4GF"""
return _regnet('regnety_064', pretrained, **kwargs) return _create_regnet('regnety_064', pretrained, **kwargs)
@register_model @register_model
def regnety_080(pretrained=False, **kwargs): def regnety_080(pretrained=False, **kwargs):
"""RegNetY-8.0GF""" """RegNetY-8.0GF"""
return _regnet('regnety_080', pretrained, **kwargs) return _create_regnet('regnety_080', pretrained, **kwargs)
@register_model @register_model
def regnety_120(pretrained=False, **kwargs): def regnety_120(pretrained=False, **kwargs):
"""RegNetY-12GF""" """RegNetY-12GF"""
return _regnet('regnety_120', pretrained, **kwargs) return _create_regnet('regnety_120', pretrained, **kwargs)
@register_model @register_model
def regnety_160(pretrained=False, **kwargs): def regnety_160(pretrained=False, **kwargs):
"""RegNetY-16GF""" """RegNetY-16GF"""
return _regnet('regnety_160', pretrained, **kwargs) return _create_regnet('regnety_160', pretrained, **kwargs)
@register_model @register_model
def regnety_320(pretrained=False, **kwargs): def regnety_320(pretrained=False, **kwargs):
"""RegNetY-32GF""" """RegNetY-32GF"""
return _regnet('regnety_320', pretrained, **kwargs) return _create_regnet('regnety_320', pretrained, **kwargs)

View File

@ -8,9 +8,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained from .helpers import build_model_with_cfg
from .registry import register_model from .registry import register_model
from .resnet import _create_resnet_with_cfg from .resnet import ResNet
__all__ = [] __all__ = []
@ -133,8 +133,8 @@ class Bottle2neck(nn.Module):
def _create_res2net(variant, pretrained=False, **kwargs): def _create_res2net(variant, pretrained=False, **kwargs):
default_cfg = default_cfgs[variant] return build_model_with_cfg(
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs) ResNet, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs)
@register_model @register_model

View File

@ -10,10 +10,10 @@ import torch
from torch import nn from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropBlock2d from .helpers import build_model_with_cfg
from .layers.split_attn import SplitAttnConv2d from .layers import SplitAttnConv2d
from .registry import register_model from .registry import register_model
from .resnet import _create_resnet_with_cfg from .resnet import ResNet
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -140,8 +140,8 @@ class ResNestBottleneck(nn.Module):
def _create_resnest(variant, pretrained=False, **kwargs): def _create_resnest(variant, pretrained=False, **kwargs):
default_cfg = default_cfgs[variant] return build_model_with_cfg(
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs) ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
@register_model @register_model

View File

@ -13,8 +13,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .features import FeatureNet from .helpers import build_model_with_cfg
from .helpers import load_pretrained, adapt_model_from_file
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
from .registry import register_model from .registry import register_model
@ -590,32 +589,9 @@ class ResNet(nn.Module):
return x return x
def _create_resnet_with_cfg(variant, default_cfg, pretrained=False, **kwargs):
assert isinstance(default_cfg, dict)
features = False
out_indices = None
if kwargs.pop('features_only', False):
features = True
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
pruned = kwargs.pop('pruned', False)
model = ResNet(**kwargs)
model.default_cfg = copy.deepcopy(default_cfg)
if pruned:
model = adapt_model_from_file(model, variant)
if pretrained:
load_pretrained(
model,
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features)
if features:
model = FeatureNet(model, out_indices=out_indices)
return model
def _create_resnet(variant, pretrained=False, **kwargs): def _create_resnet(variant, pretrained=False, **kwargs):
default_cfg = default_cfgs[variant] return build_model_with_cfg(
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs) ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
@register_model @register_model

View File

@ -16,8 +16,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .features import FeatureNet from .helpers import build_model_with_cfg
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from .registry import register_model from .registry import register_model
@ -178,7 +177,7 @@ class SelecSLS(nn.Module):
return x return x
def _create_model(variant, pretrained, model_kwargs): def _create_selecsls(variant, pretrained, model_kwargs):
cfg = {} cfg = {}
feature_info = [dict(num_chs=32, reduction=2, module='stem.2')] feature_info = [dict(num_chs=32, reduction=2, module='stem.2')]
if variant.startswith('selecsls42'): if variant.startswith('selecsls42'):
@ -299,61 +298,42 @@ def _create_model(variant, pretrained, model_kwargs):
else: else:
raise ValueError('Invalid net configuration ' + variant + ' !!!') raise ValueError('Invalid net configuration ' + variant + ' !!!')
load_strict = True
features = False
out_indices = None
if model_kwargs.pop('features_only', False):
load_strict = False
features = True
# this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises? # this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises?
out_indices = model_kwargs.pop('out_indices', (0, 1, 2, 3, 4)) return build_model_with_cfg(
model_kwargs.pop('num_classes', 0) SelecSLS, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=cfg,
feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), **model_kwargs)
model = SelecSLS(cfg, **model_kwargs)
model.default_cfg = default_cfgs[variant]
model.feature_info = feature_info
if pretrained:
load_pretrained(
model,
num_classes=model_kwargs.get('num_classes', 0),
in_chans=model_kwargs.get('in_chans', 3),
strict=load_strict)
if features:
model = FeatureNet(model, out_indices, flatten_sequential=True)
return model
@register_model @register_model
def selecsls42(pretrained=False, **kwargs): def selecsls42(pretrained=False, **kwargs):
"""Constructs a SelecSLS42 model. """Constructs a SelecSLS42 model.
""" """
return _create_model('selecsls42', pretrained, kwargs) return _create_selecsls('selecsls42', pretrained, kwargs)
@register_model @register_model
def selecsls42b(pretrained=False, **kwargs): def selecsls42b(pretrained=False, **kwargs):
"""Constructs a SelecSLS42_B model. """Constructs a SelecSLS42_B model.
""" """
return _create_model('selecsls42b', pretrained, kwargs) return _create_selecsls('selecsls42b', pretrained, kwargs)
@register_model @register_model
def selecsls60(pretrained=False, **kwargs): def selecsls60(pretrained=False, **kwargs):
"""Constructs a SelecSLS60 model. """Constructs a SelecSLS60 model.
""" """
return _create_model('selecsls60', pretrained, kwargs) return _create_selecsls('selecsls60', pretrained, kwargs)
@register_model @register_model
def selecsls60b(pretrained=False, **kwargs): def selecsls60b(pretrained=False, **kwargs):
"""Constructs a SelecSLS60_B model. """Constructs a SelecSLS60_B model.
""" """
return _create_model('selecsls60b', pretrained, kwargs) return _create_selecsls('selecsls60b', pretrained, kwargs)
@register_model @register_model
def selecsls84(pretrained=False, **kwargs): def selecsls84(pretrained=False, **kwargs):
"""Constructs a SelecSLS84 model. """Constructs a SelecSLS84 model.
""" """
return _create_model('selecsls84', pretrained, kwargs) return _create_selecsls('selecsls84', pretrained, kwargs)

View File

@ -13,10 +13,10 @@ import math
from torch import nn as nn from torch import nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained from .helpers import build_model_with_cfg
from .layers import SelectiveKernelConv, ConvBnAct, create_attn from .layers import SelectiveKernelConv, ConvBnAct, create_attn
from .registry import register_model from .registry import register_model
from .resnet import _create_resnet_with_cfg from .resnet import ResNet
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -139,8 +139,8 @@ class SelectiveKernelBottleneck(nn.Module):
def _create_skresnet(variant, pretrained=False, **kwargs): def _create_skresnet(variant, pretrained=False, **kwargs):
default_cfg = default_cfgs[variant] return build_model_with_cfg(
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs) ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
@register_model @register_model

View File

@ -5,6 +5,7 @@ https://arxiv.org/pdf/2003.13630.pdf
Original model: https://github.com/mrT23/TResNet Original model: https://github.com/mrT23/TResNet
""" """
import copy
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
@ -12,8 +13,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .helpers import load_pretrained from .helpers import build_model_with_cfg
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, SelectAdaptivePool2d, InplaceAbn from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead
from .registry import register_model from .registry import register_model
__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] __all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
@ -220,11 +221,17 @@ class TResNet(nn.Module):
('layer3', layer3), ('layer3', layer3),
('layer4', layer4)])) ('layer4', layer4)]))
self.feature_info = [
dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D?
dict(num_chs=self.planes, reduction=4, module='body.layer1'),
dict(num_chs=self.planes * 2, reduction=8, module='body.layer2'),
dict(num_chs=self.planes * 4, reduction=16, module='body.layer3'),
dict(num_chs=self.planes * 8, reduction=32, module='body.layer4'),
]
# head # head
self.num_features = (self.planes * 8) * Bottleneck.expansion self.num_features = (self.planes * 8) * Bottleneck.expansion
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
self.head = nn.Sequential(OrderedDict([
('fc', nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes))]))
# model initilization # model initilization
for m in self.modules(): for m in self.modules():
@ -240,7 +247,8 @@ class TResNet(nn.Module):
m.conv2[1].weight = nn.Parameter(torch.zeros_like(m.conv2[1].weight)) # BN to zero m.conv2[1].weight = nn.Parameter(torch.zeros_like(m.conv2[1].weight)) # BN to zero
if isinstance(m, Bottleneck): if isinstance(m, Bottleneck):
m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero
if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01) if isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None): def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None):
downsample = None downsample = None
@ -266,86 +274,55 @@ class TResNet(nn.Module):
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) self.head = ClassifierHead(
self.num_classes = num_classes self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
self.head = None
if num_classes:
num_features = self.num_features * self.global_pool.feat_mult()
self.head = nn.Sequential(OrderedDict([('fc', nn.Linear(num_features, num_classes))]))
else:
self.head = nn.Sequential(OrderedDict([('fc', nn.Identity())]))
def forward_features(self, x): def forward_features(self, x):
return self.body(x) return self.body(x)
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.head(x) x = self.head(x)
return x return x
@register_model def _create_tresnet(variant, pretrained=False, **kwargs):
def tresnet_m(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return build_model_with_cfg(
default_cfg = default_cfgs['tresnet_m'] TResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained,
model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def tresnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def tresnet_m(pretrained=False, **kwargs):
default_cfg = default_cfgs['tresnet_l'] model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs)
model = TResNet( return _create_tresnet('tresnet_m', pretrained=pretrained, **model_kwargs)
layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def tresnet_xl(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def tresnet_l(pretrained=False, **kwargs):
default_cfg = default_cfgs['tresnet_xl'] model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs)
model = TResNet( return _create_tresnet('tresnet_l', pretrained=pretrained, **model_kwargs)
layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def tresnet_m_448(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def tresnet_xl(pretrained=False, **kwargs):
default_cfg = default_cfgs['tresnet_m_448'] model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs)
model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) return _create_tresnet('tresnet_xl', pretrained=pretrained, **model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def tresnet_l_448(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def tresnet_m_448(pretrained=False, **kwargs):
default_cfg = default_cfgs['tresnet_l_448'] model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs)
model = TResNet( return _create_tresnet('tresnet_m_448', pretrained=pretrained, **model_kwargs)
layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def tresnet_xl_448(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def tresnet_l_448(pretrained=False, **kwargs):
default_cfg = default_cfgs['tresnet_xl_448'] model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs)
model = TResNet( return _create_tresnet('tresnet_l_448', pretrained=pretrained, **model_kwargs)
layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3, **kwargs)
model.default_cfg = default_cfg
if pretrained: @register_model
load_pretrained(model, default_cfg, num_classes, in_chans) def tresnet_xl_448(pretrained=False, **kwargs):
return model model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs)
return _create_tresnet('tresnet_xl_448', pretrained=pretrained, **model_kwargs)

View File

@ -19,9 +19,8 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model from .registry import register_model
from .helpers import load_pretrained from .helpers import build_model_with_cfg
from .features import FeatureNet from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, ClassifierHead, \
from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, SelectAdaptivePool2d, \
create_attn, create_norm_act, get_norm_act_layer create_attn, create_norm_act, get_norm_act_layer
@ -253,26 +252,6 @@ class OsaStage(nn.Module):
return x return x
class ClassifierHead(nn.Module):
"""Head."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
if num_classes > 0:
self.fc = nn.Linear(in_chs, num_classes, bias=True)
else:
self.fc = nn.Identity()
def forward(self, x):
x = self.global_pool(x).flatten(1)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x
class VovNet(nn.Module): class VovNet(nn.Module):
def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4, def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4,
@ -346,67 +325,55 @@ class VovNet(nn.Module):
return self.head(x) return self.head(x)
def _vovnet(variant, pretrained=False, **kwargs): def _create_vovnet(variant, pretrained=False, **kwargs):
features = False return build_model_with_cfg(
out_indices = None VovNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant],
if kwargs.pop('features_only', False): feature_cfg=dict(flatten_sequential=True), **kwargs)
features = True
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
model_cfg = model_cfgs[variant]
model = VovNet(model_cfg, **kwargs)
model.default_cfg = default_cfgs[variant]
if pretrained:
load_pretrained(
model,
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features)
if features:
model = FeatureNet(model, out_indices, flatten_sequential=True)
return model
@register_model @register_model
def vovnet39a(pretrained=False, **kwargs): def vovnet39a(pretrained=False, **kwargs):
return _vovnet('vovnet39a', pretrained=pretrained, **kwargs) return _create_vovnet('vovnet39a', pretrained=pretrained, **kwargs)
@register_model @register_model
def vovnet57a(pretrained=False, **kwargs): def vovnet57a(pretrained=False, **kwargs):
return _vovnet('vovnet57a', pretrained=pretrained, **kwargs) return _create_vovnet('vovnet57a', pretrained=pretrained, **kwargs)
@register_model @register_model
def ese_vovnet19b_slim_dw(pretrained=False, **kwargs): def ese_vovnet19b_slim_dw(pretrained=False, **kwargs):
return _vovnet('ese_vovnet19b_slim_dw', pretrained=pretrained, **kwargs) return _create_vovnet('ese_vovnet19b_slim_dw', pretrained=pretrained, **kwargs)
@register_model @register_model
def ese_vovnet19b_dw(pretrained=False, **kwargs): def ese_vovnet19b_dw(pretrained=False, **kwargs):
return _vovnet('ese_vovnet19b_dw', pretrained=pretrained, **kwargs) return _create_vovnet('ese_vovnet19b_dw', pretrained=pretrained, **kwargs)
@register_model @register_model
def ese_vovnet19b_slim(pretrained=False, **kwargs): def ese_vovnet19b_slim(pretrained=False, **kwargs):
return _vovnet('ese_vovnet19b_slim', pretrained=pretrained, **kwargs) return _create_vovnet('ese_vovnet19b_slim', pretrained=pretrained, **kwargs)
@register_model @register_model
def ese_vovnet39b(pretrained=False, **kwargs): def ese_vovnet39b(pretrained=False, **kwargs):
return _vovnet('ese_vovnet39b', pretrained=pretrained, **kwargs) return _create_vovnet('ese_vovnet39b', pretrained=pretrained, **kwargs)
@register_model @register_model
def ese_vovnet57b(pretrained=False, **kwargs): def ese_vovnet57b(pretrained=False, **kwargs):
return _vovnet('ese_vovnet57b', pretrained=pretrained, **kwargs) return _create_vovnet('ese_vovnet57b', pretrained=pretrained, **kwargs)
@register_model @register_model
def ese_vovnet99b(pretrained=False, **kwargs): def ese_vovnet99b(pretrained=False, **kwargs):
return _vovnet('ese_vovnet99b', pretrained=pretrained, **kwargs) return _create_vovnet('ese_vovnet99b', pretrained=pretrained, **kwargs)
@register_model @register_model
def eca_vovnet39b(pretrained=False, **kwargs): def eca_vovnet39b(pretrained=False, **kwargs):
return _vovnet('eca_vovnet39b', pretrained=pretrained, **kwargs) return _create_vovnet('eca_vovnet39b', pretrained=pretrained, **kwargs)
# Experimental Models # Experimental Models
@ -415,11 +382,11 @@ def eca_vovnet39b(pretrained=False, **kwargs):
def ese_vovnet39b_evos(pretrained=False, **kwargs): def ese_vovnet39b_evos(pretrained=False, **kwargs):
def norm_act_fn(num_features, **nkwargs): def norm_act_fn(num_features, **nkwargs):
return create_norm_act('EvoNormSample', num_features, jit=False, **nkwargs) return create_norm_act('EvoNormSample', num_features, jit=False, **nkwargs)
return _vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs) return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)
@register_model @register_model
def ese_vovnet99b_iabn(pretrained=False, **kwargs): def ese_vovnet99b_iabn(pretrained=False, **kwargs):
norm_layer = get_norm_act_layer('iabn') norm_layer = get_norm_act_layer('iabn')
return _vovnet( return _create_vovnet(
'ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, act_layer=nn.LeakyReLU, **kwargs) 'ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, act_layer=nn.LeakyReLU, **kwargs)

View File

@ -25,8 +25,7 @@ The resize parameter of the validation transform should be 333, and make sure to
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .helpers import load_pretrained from .helpers import build_model_with_cfg
from .features import FeatureNet
from .layers import SelectAdaptivePool2d from .layers import SelectAdaptivePool2d
from .registry import register_model from .registry import register_model
@ -220,25 +219,9 @@ class Xception(nn.Module):
def _xception(variant, pretrained=False, **kwargs): def _xception(variant, pretrained=False, **kwargs):
load_strict = True return build_model_with_cfg(
features = False Xception, variant, pretrained, default_cfg=default_cfgs[variant],
out_indices = None feature_cfg=dict(), **kwargs)
if kwargs.pop('features_only', False):
load_strict = False
features = True
kwargs.pop('num_classes', 0)
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
model = Xception(**kwargs)
model.default_cfg = default_cfgs[variant]
if pretrained:
load_pretrained(
model,
num_classes=kwargs.get('num_classes', 0),
in_chans=kwargs.get('in_chans', 3),
strict=load_strict)
if features:
model = FeatureNet(model, out_indices)
return model
@register_model @register_model

View File

@ -10,9 +10,9 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .features import FeatureNet from .helpers import build_model_with_cfg
from .helpers import load_pretrained from .layers import ClassifierHead, ConvBnAct, create_conv2d
from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d from .layers.helpers import tup_triple
from .registry import register_model from .registry import register_model
__all__ = ['XceptionAligned'] __all__ = ['XceptionAligned']
@ -81,10 +81,7 @@ class XceptionModule(nn.Module):
start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None, norm_kwargs=None): start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None, norm_kwargs=None):
super(XceptionModule, self).__init__() super(XceptionModule, self).__init__()
norm_kwargs = norm_kwargs if norm_kwargs is not None else {} norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
if isinstance(out_chs, (list, tuple)): out_chs = tup_triple(out_chs)
assert len(out_chs) == 3
else:
out_chs = (out_chs,) * 3
self.in_channels = in_chs self.in_channels = in_chs
self.out_channels = out_chs[-1] self.out_channels = out_chs[-1]
self.no_skip = no_skip self.no_skip = no_skip
@ -115,26 +112,6 @@ class XceptionModule(nn.Module):
return x return x
class ClassifierHead(nn.Module):
"""Head."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
if num_classes > 0:
self.fc = nn.Linear(in_chs, num_classes, bias=True)
else:
self.fc = nn.Identity()
def forward(self, x):
x = self.global_pool(x).flatten(1)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x
class XceptionAligned(nn.Module): class XceptionAligned(nn.Module):
"""Modified Aligned Xception """Modified Aligned Xception
""" """
@ -147,32 +124,29 @@ class XceptionAligned(nn.Module):
assert output_stride in (8, 16, 32) assert output_stride in (8, 16, 32)
norm_kwargs = norm_kwargs if norm_kwargs is not None else {} norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
xtra_args = dict(act_layer=act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs) layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.stem = nn.Sequential(*[ self.stem = nn.Sequential(*[
ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **xtra_args), ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **layer_args),
ConvBnAct(32, 64, kernel_size=3, stride=1, **xtra_args) ConvBnAct(32, 64, kernel_size=3, stride=1, **layer_args)
]) ])
curr_dilation = 1 curr_dilation = 1
curr_stride = 2 curr_stride = 2
self.feature_info = [dict(num_chs=64, reduction=curr_stride, module='stem.1')] self.feature_info = []
self.blocks = nn.Sequential() self.blocks = nn.Sequential()
for i, b in enumerate(block_cfg): for i, b in enumerate(block_cfg):
feature_extract = False
b['dilation'] = curr_dilation b['dilation'] = curr_dilation
if b['stride'] > 1: if b['stride'] > 1:
feature_extract = True self.feature_info += [dict(
num_chs=tup_triple(b['out_chs'])[-2], reduction=curr_stride, module=f'blocks.{i}.stack.act3')]
next_stride = curr_stride * b['stride'] next_stride = curr_stride * b['stride']
if next_stride > output_stride: if next_stride > output_stride:
curr_dilation *= b['stride'] curr_dilation *= b['stride']
b['stride'] = 1 b['stride'] = 1
else: else:
curr_stride = next_stride curr_stride = next_stride
self.blocks.add_module(str(i), XceptionModule(**b, **xtra_args)) self.blocks.add_module(str(i), XceptionModule(**b, **layer_args))
self.num_features = self.blocks[-1].out_channels self.num_features = self.blocks[-1].out_channels
if feature_extract:
self.feature_info += [dict(
num_chs=self.num_features, reduction=curr_stride, module=f'blocks.{i}.stack.act2')]
self.feature_info += [dict( self.feature_info += [dict(
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
@ -198,24 +172,9 @@ class XceptionAligned(nn.Module):
def _xception(variant, pretrained=False, **kwargs): def _xception(variant, pretrained=False, **kwargs):
features = False return build_model_with_cfg(
out_indices = None XceptionAligned, variant, pretrained, default_cfg=default_cfgs[variant],
if kwargs.pop('features_only', False): feature_cfg=dict(flatten_sequential=True, use_hooks=True), **kwargs)
features = True
kwargs.pop('num_classes', 0)
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
model = XceptionAligned(**kwargs)
model.default_cfg = default_cfgs[variant]
if pretrained:
load_pretrained(
model,
num_classes=kwargs.get('num_classes', 0),
in_chans=kwargs.get('in_chans', 3),
strict=not features)
if features:
model = FeatureNet(model, out_indices)
return model
@register_model @register_model