From 3b9004bef935ebcdfe6ab5416e914053b4c8d94b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 17 Jul 2020 17:54:26 -0700 Subject: [PATCH] Lots of changes to model creation helpers, close to finalizing feature extraction / interfaces --- timm/models/csp.py | 61 +++-------- timm/models/densenet.py | 47 +++------ timm/models/dpn.py | 118 ++++++++------------- timm/models/efficientnet.py | 26 ++--- timm/models/feature_hooks.py | 22 ++-- timm/models/features.py | 158 ++++++++++++++++++++--------- timm/models/gluon_resnet.py | 6 +- timm/models/helpers.py | 53 +++++++++- timm/models/hrnet.py | 43 +++----- timm/models/inception_resnet_v2.py | 24 ++--- timm/models/inception_v3.py | 29 +++--- timm/models/layers/__init__.py | 2 + timm/models/layers/activations.py | 2 +- timm/models/layers/classifier.py | 24 +++++ timm/models/layers/create_act.py | 4 +- timm/models/layers/helpers.py | 2 +- timm/models/mobilenetv3.py | 6 +- timm/models/regnet.py | 93 ++++++----------- timm/models/res2net.py | 8 +- timm/models/resnest.py | 10 +- timm/models/resnet.py | 30 +----- timm/models/selecsls.py | 42 ++------ timm/models/sknet.py | 8 +- timm/models/tresnet.py | 105 ++++++++----------- timm/models/vovnet.py | 67 ++++-------- timm/models/xception.py | 25 +---- timm/models/xception_aligned.py | 71 +++---------- 27 files changed, 454 insertions(+), 632 deletions(-) create mode 100644 timm/models/layers/classifier.py diff --git a/timm/models/csp.py b/timm/models/csp.py index ca3d17d3..98f0dceb 100644 --- a/timm/models/csp.py +++ b/timm/models/csp.py @@ -17,9 +17,8 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .features import FeatureNet -from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d, ConvBnAct, DropPath, create_attn, get_norm_act_layer +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, ConvBnAct, DropPath, create_attn, get_norm_act_layer from .registry import register_model @@ -294,26 +293,6 @@ class DarkStage(nn.Module): return x -class ClassifierHead(nn.Module): - """Head.""" - - def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.): - super(ClassifierHead, self).__init__() - self.drop_rate = drop_rate - self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) - if num_classes > 0: - self.fc = nn.Linear(in_chs, num_classes, bias=True) - else: - self.fc = nn.Identity() - - def forward(self, x): - x = self.global_pool(x).flatten(1) - if self.drop_rate: - x = F.dropout(x, p=float(self.drop_rate), training=self.training) - x = self.fc(x) - return x - - def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.): # get per stage args for stage and containing blocks, calculate strides to meet target output_stride num_stages = len(cfg['depth']) @@ -420,62 +399,50 @@ class CspNet(nn.Module): return x -def _cspnet(variant, pretrained=False, **kwargs): - features = False - out_indices = None - if kwargs.pop('features_only', False): - features = True - out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) +def _create_cspnet(variant, pretrained=False, **kwargs): cfg_variant = variant.split('_')[0] - cfg = model_cfgs[cfg_variant] - model = CspNet(cfg, **kwargs) - model.default_cfg = default_cfgs[variant] - if pretrained: - load_pretrained( - model, - num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features) - if features: - model = FeatureNet(model, out_indices, flatten_sequential=True) - return model + return build_model_with_cfg( + CspNet, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant], **kwargs) @register_model def cspresnet50(pretrained=False, **kwargs): - return _cspnet('cspresnet50', pretrained=pretrained, **kwargs) + return _create_cspnet('cspresnet50', pretrained=pretrained, **kwargs) @register_model def cspresnet50d(pretrained=False, **kwargs): - return _cspnet('cspresnet50d', pretrained=pretrained, **kwargs) + return _create_cspnet('cspresnet50d', pretrained=pretrained, **kwargs) @register_model def cspresnet50w(pretrained=False, **kwargs): - return _cspnet('cspresnet50w', pretrained=pretrained, **kwargs) + return _create_cspnet('cspresnet50w', pretrained=pretrained, **kwargs) @register_model def cspresnext50(pretrained=False, **kwargs): - return _cspnet('cspresnext50', pretrained=pretrained, **kwargs) + return _create_cspnet('cspresnext50', pretrained=pretrained, **kwargs) @register_model def cspresnext50_iabn(pretrained=False, **kwargs): norm_layer = get_norm_act_layer('iabn') - return _cspnet('cspresnext50', pretrained=pretrained, norm_layer=norm_layer, **kwargs) + return _create_cspnet('cspresnext50', pretrained=pretrained, norm_layer=norm_layer, **kwargs) @register_model def cspdarknet53(pretrained=False, **kwargs): - return _cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs) + return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs) @register_model def cspdarknet53_iabn(pretrained=False, **kwargs): norm_layer = get_norm_act_layer('iabn') - return _cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs) + return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs) @register_model def darknet53(pretrained=False, **kwargs): - return _cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) + return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 1eeaacee..5c8d6af8 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -13,8 +13,7 @@ import torch.utils.checkpoint as cp from torch.jit.annotations import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .features import FeatureNet -from .helpers import load_pretrained +from .helpers import build_model_with_cfg from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, BlurPool2d from .registry import register_model @@ -288,26 +287,12 @@ def _filter_torchvision_pretrained(state_dict): return state_dict -def _densenet(variant, growth_rate, block_config, pretrained, **kwargs): - features = False - out_indices = None - if kwargs.pop('features_only', False): - features = True - kwargs.pop('num_classes', 0) - out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) - default_cfg = default_cfgs[variant] - model = DenseNet(growth_rate=growth_rate, block_config=block_config, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained( - model, default_cfg, - num_classes=kwargs.get('num_classes', 0), - in_chans=kwargs.get('in_chans', 3), - filter_fn=_filter_torchvision_pretrained, - strict=not features) - if features: - model = FeatureNet(model, out_indices, flatten_sequential=True) - return model +def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs): + kwargs['growth_rate'] = growth_rate + kwargs['block_config'] = block_config + return build_model_with_cfg( + DenseNet, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained, **kwargs) @register_model @@ -315,7 +300,7 @@ def densenet121(pretrained=False, **kwargs): r"""Densenet-121 model from `"Densely Connected Convolutional Networks" ` """ - model = _densenet( + model = _create_densenet( 'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs) return model @@ -325,7 +310,7 @@ def densenetblur121d(pretrained=False, **kwargs): r"""Densenet-121 model from `"Densely Connected Convolutional Networks" ` """ - model = _densenet( + model = _create_densenet( 'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep', aa_layer=BlurPool2d, **kwargs) return model @@ -336,7 +321,7 @@ def densenet121d(pretrained=False, **kwargs): r"""Densenet-121 model from `"Densely Connected Convolutional Networks" ` """ - model = _densenet( + model = _create_densenet( 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', pretrained=pretrained, **kwargs) return model @@ -347,7 +332,7 @@ def densenet169(pretrained=False, **kwargs): r"""Densenet-169 model from `"Densely Connected Convolutional Networks" ` """ - model = _densenet( + model = _create_densenet( 'densenet169', growth_rate=32, block_config=(6, 12, 32, 32), pretrained=pretrained, **kwargs) return model @@ -357,7 +342,7 @@ def densenet201(pretrained=False, **kwargs): r"""Densenet-201 model from `"Densely Connected Convolutional Networks" ` """ - model = _densenet( + model = _create_densenet( 'densenet201', growth_rate=32, block_config=(6, 12, 48, 32), pretrained=pretrained, **kwargs) return model @@ -367,7 +352,7 @@ def densenet161(pretrained=False, **kwargs): r"""Densenet-161 model from `"Densely Connected Convolutional Networks" ` """ - model = _densenet( + model = _create_densenet( 'densenet161', growth_rate=48, block_config=(6, 12, 36, 24), pretrained=pretrained, **kwargs) return model @@ -377,7 +362,7 @@ def densenet264(pretrained=False, **kwargs): r"""Densenet-264 model from `"Densely Connected Convolutional Networks" ` """ - model = _densenet( + model = _create_densenet( 'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs) return model @@ -388,7 +373,7 @@ def densenet264d_iabn(pretrained=False, **kwargs): """ def norm_act_fn(num_features, **kwargs): return create_norm_act('iabn', num_features, **kwargs) - model = _densenet( + model = _create_densenet( 'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep', norm_layer=norm_act_fn, pretrained=pretrained, **kwargs) return model @@ -399,6 +384,6 @@ def tv_densenet121(pretrained=False, **kwargs): r"""Densenet-121 model with original Torchvision weights, from `"Densely Connected Convolutional Networks" ` """ - model = _densenet( + model = _create_densenet( 'tv_densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs) return model diff --git a/timm/models/dpn.py b/timm/models/dpn.py index d17d1a73..149ffad4 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -17,8 +17,8 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, create_conv2d +from .helpers import build_model_with_cfg +from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_conv2d, ConvBnAct from .registry import register_model __all__ = ['DPN'] @@ -82,20 +82,6 @@ class BnActConv2d(nn.Module): return self.conv(self.bn(x)) -class InputBlock(nn.Module): - def __init__(self, num_init_features, kernel_size=7, in_chans=3, norm_layer=BatchNormAct2d): - super(InputBlock, self).__init__() - self.conv = create_conv2d(in_chans, num_init_features, kernel_size=kernel_size, stride=2) - self.bn = norm_layer(num_init_features, eps=0.001) - self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - x = self.pool(x) - return x - - class DualPathBlock(nn.Module): def __init__( self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False): @@ -183,21 +169,21 @@ class DualPathBlock(nn.Module): class DPN(nn.Module): def __init__(self, small=False, num_init_features=64, k_r=96, groups=32, - b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), + b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', fc_act=nn.ELU): super(DPN, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate self.b = b + assert output_stride == 32 # FIXME look into dilation support bw_factor = 1 if small else 4 - blocks = OrderedDict() # conv1 - if small: - blocks['conv1_1'] = InputBlock(num_init_features, in_chans=in_chans, kernel_size=3) - else: - blocks['conv1_1'] = InputBlock(num_init_features, in_chans=in_chans, kernel_size=7) + blocks['conv1_1'] = ConvBnAct( + in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_kwargs=dict(eps=.001)) + blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')] # conv2 bw = 64 * bw_factor @@ -208,6 +194,7 @@ class DPN(nn.Module): for i in range(2, k_sec[0] + 1): blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=4, module=f'features.conv2_{k_sec[0]}')] # conv3 bw = 128 * bw_factor @@ -218,6 +205,7 @@ class DPN(nn.Module): for i in range(2, k_sec[1] + 1): blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=8, module=f'features.conv3_{k_sec[1]}')] # conv4 bw = 256 * bw_factor @@ -228,6 +216,7 @@ class DPN(nn.Module): for i in range(2, k_sec[2] + 1): blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=16, module=f'features.conv4_{k_sec[2]}')] # conv5 bw = 512 * bw_factor @@ -238,6 +227,7 @@ class DPN(nn.Module): for i in range(2, k_sec[3] + 1): blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')] def _fc_norm(f, eps): return BatchNormAct2d(f, eps=eps, act_layer=fc_act, inplace=False) blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=_fc_norm) @@ -274,79 +264,55 @@ class DPN(nn.Module): return out.flatten(1) +def _create_dpn(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + DPN, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_concat=True, flatten_sequential=True), **kwargs) + + @register_model -def dpn68(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['dpn68'] - model = DPN( +def dpn68(pretrained=False, **kwargs): + model_kwargs = dict( small=True, num_init_features=10, k_r=128, groups=32, - k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs) + return _create_dpn('dpn68', pretrained=pretrained, **model_kwargs) @register_model -def dpn68b(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['dpn68b'] - model = DPN( +def dpn68b(pretrained=False, **kwargs): + model_kwargs = dict( small=True, num_init_features=10, k_r=128, groups=32, - b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs) + return _create_dpn('dpn68b', pretrained=pretrained, **model_kwargs) @register_model -def dpn92(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['dpn92'] - model = DPN( +def dpn92(pretrained=False, **kwargs): + model_kwargs = dict( num_init_features=64, k_r=96, groups=32, - k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), **kwargs) + return _create_dpn('dpn92', pretrained=pretrained, **model_kwargs) @register_model -def dpn98(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['dpn98'] - model = DPN( +def dpn98(pretrained=False, **kwargs): + model_kwargs = dict( num_init_features=96, k_r=160, groups=40, - k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), **kwargs) + return _create_dpn('dpn98', pretrained=pretrained, **model_kwargs) @register_model -def dpn131(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['dpn131'] - model = DPN( +def dpn131(pretrained=False, **kwargs): + model_kwargs = dict( num_init_features=128, k_r=160, groups=40, - k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), **kwargs) + return _create_dpn('dpn131', pretrained=pretrained, **model_kwargs) @register_model -def dpn107(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['dpn107'] - model = DPN( +def dpn107(pretrained=False, **kwargs): + model_kwargs = dict( num_init_features=128, k_r=200, groups=50, - k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model + k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), **kwargs) + return _create_dpn('dpn107', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 08b14cb0..4c1e4d3f 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -471,7 +471,7 @@ class EfficientNetFeatures(nn.Module): return self.feature_hooks.get_output(x.device) -def _create_model(model_kwargs, default_cfg, pretrained=False): +def _create_effnet(model_kwargs, default_cfg, pretrained=False): if model_kwargs.pop('features_only', False): load_strict = False model_kwargs.pop('num_classes', 0) @@ -528,7 +528,7 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs) norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) return model @@ -564,7 +564,7 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs) norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) return model @@ -593,7 +593,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) return model @@ -622,7 +622,7 @@ def _gen_mobilenet_v2( act_layer=resolve_act_layer(kwargs, 'relu6'), **kwargs ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) return model @@ -652,7 +652,7 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) return model @@ -687,7 +687,7 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) return model @@ -734,7 +734,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre variant=variant, **kwargs, ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) return model @@ -763,7 +763,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0 act_layer=resolve_act_layer(kwargs, 'relu'), **kwargs, ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) return model @@ -793,7 +793,7 @@ def _gen_efficientnet_condconv( act_layer=resolve_act_layer(kwargs, 'swish'), **kwargs, ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) return model @@ -834,7 +834,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0 norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) return model @@ -867,7 +867,7 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) return model @@ -900,7 +900,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) return model diff --git a/timm/models/feature_hooks.py b/timm/models/feature_hooks.py index 7c3f6f4b..b489b6f5 100644 --- a/timm/models/feature_hooks.py +++ b/timm/models/feature_hooks.py @@ -7,36 +7,38 @@ Hacked together by Ross Wightman import torch from collections import defaultdict, OrderedDict -from functools import partial +from functools import partial, partialmethod from typing import List class FeatureHooks: - def __init__(self, hooks, named_modules, output_as_dict=False): + def __init__(self, hooks, named_modules, out_as_dict=False, out_map=None, default_hook_type='forward'): # setup feature hooks modules = {k: v for k, v in named_modules} - for h in hooks: + for i, h in enumerate(hooks): hook_name = h['module'] m = modules[hook_name] - hook_fn = partial(self._collect_output_hook, hook_name) - if h['hook_type'] == 'forward_pre': + hook_id = out_map[i] if out_map else hook_name + hook_fn = partial(self._collect_output_hook, hook_id) + hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type + if hook_type == 'forward_pre': m.register_forward_pre_hook(hook_fn) - elif h['hook_type'] == 'forward': + elif hook_type == 'forward': m.register_forward_hook(hook_fn) else: assert False, "Unsupported hook type" self._feature_outputs = defaultdict(OrderedDict) - self.output_as_dict = output_as_dict + self.out_as_dict = out_as_dict - def _collect_output_hook(self, name, *args): + def _collect_output_hook(self, hook_id, *args): x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre if isinstance(x, tuple): x = x[0] # unwrap input tuple - self._feature_outputs[x.device][name] = x + self._feature_outputs[x.device][hook_id] = x def get_output(self, device) -> List[torch.tensor]: - if self.output_as_dict: + if self.out_as_dict: output = self._feature_outputs[device] else: output = list(self._feature_outputs[device].values()) diff --git a/timm/models/features.py b/timm/models/features.py index e4c19755..2c210734 100644 --- a/timm/models/features.py +++ b/timm/models/features.py @@ -11,7 +11,8 @@ from copy import deepcopy import torch import torch.nn as nn -import torch.nn.functional as F + +from .feature_hooks import FeatureHooks class FeatureInfo: @@ -24,11 +25,11 @@ class FeatureInfo: assert 'reduction' in fi and fi['reduction'] >= prev_reduction prev_reduction = fi['reduction'] assert 'module' in fi - self._out_indices = out_indices - self._info = feature_info + self.out_indices = out_indices + self.info = feature_info def from_other(self, out_indices: Tuple[int]): - return FeatureInfo(deepcopy(self._info), out_indices) + return FeatureInfo(deepcopy(self.info), out_indices) def channels(self, idx=None): """ feature channels accessor @@ -36,8 +37,8 @@ class FeatureInfo: if idx is an integer, return feature channel count for that feature module index """ if isinstance(idx, int): - return self._info[idx]['num_chs'] - return [self._info[i]['num_chs'] for i in self._out_indices] + return self.info[idx]['num_chs'] + return [self.info[i]['num_chs'] for i in self.out_indices] def reduction(self, idx=None): """ feature reduction (output stride) accessor @@ -45,8 +46,8 @@ class FeatureInfo: if idx is an integer, return feature channel count at that feature module index """ if isinstance(idx, int): - return self._info[idx]['reduction'] - return [self._info[i]['reduction'] for i in self._out_indices] + return self.info[idx]['reduction'] + return [self.info[i]['reduction'] for i in self.out_indices] def module_name(self, idx=None): """ feature module name accessor @@ -54,24 +55,24 @@ class FeatureInfo: if idx is an integer, return feature module name at that feature module index """ if isinstance(idx, int): - return self._info[idx]['module'] - return [self._info[i]['module'] for i in self._out_indices] + return self.info[idx]['module'] + return [self.info[i]['module'] for i in self.out_indices] def get_by_key(self, idx=None, keys=None): """ return info dicts for specified keys (or all if None) at specified idx (or out_indices if None) """ if isinstance(idx, int): - return self._info[idx] if keys is None else {k: self._info[idx][k] for k in keys} + return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} if keys is None: - return [self._info[i] for i in self._out_indices] + return [self.info[i] for i in self.out_indices] else: - return [{k: self._info[i][k] for k in keys} for i in self._out_indices] + return [{k: self.info[i][k] for k in keys} for i in self.out_indices] def __getitem__(self, item): - return self._info[item] + return self.info[item] def __len__(self): - return len(self._info) + return len(self.info) def _module_list(module, flatten_sequential=False): @@ -81,30 +82,47 @@ def _module_list(module, flatten_sequential=False): if flatten_sequential and isinstance(module, nn.Sequential): # first level of Sequential containers is flattened into containing model for child_name, child_module in module.named_children(): - ml.append(('_'.join([name, child_name]), child_module)) + combined = [name, child_name] + ml.append(('_'.join(combined), '.'.join(combined), child_module)) else: - ml.append((name, module)) + ml.append((name, name, module)) return ml -def _check_return_layers(input_return_layers, modules): - return_layers = {} - for k, v in input_return_layers.items(): - ks = k.split('.') - assert 0 < len(ks) <= 2 - return_layers['_'.join(ks)] = v - return_set = set(return_layers.keys()) - sdiff = return_set - {name for name, _ in modules} - if sdiff: - raise ValueError(f'return_layers {sdiff} are not present in model') - return return_layers, return_set +class LayerGetterHooks(nn.ModuleDict): + """ LayerGetterHooks + TODO + """ + + def __init__(self, model, feature_info, flatten_sequential=False, out_as_dict=False, out_map=None, + default_hook_type='forward'): + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type for f in feature_info} + layers = OrderedDict() + hooks = [] + for new_name, old_name, module in modules: + layers[new_name] = module + for fn, fm in module.named_modules(prefix=old_name): + if fn in remaining: + hooks.append(dict(module=fn, hook_type=remaining[fn])) + del remaining[fn] + if not remaining: + break + assert not remaining, f'Return layers ({remaining}) are not present in model' + super(LayerGetterHooks, self).__init__(layers) + self.hooks = FeatureHooks(hooks, model.named_modules(), out_as_dict=out_as_dict, out_map=out_map) + + def forward(self, x) -> Dict[Any, torch.Tensor]: + for name, module in self.items(): + x = module(x) + return self.hooks.get_output(x.device) class LayerGetterDict(nn.ModuleDict): """ Module wrapper that returns intermediate layers from a model as a dictionary - Originally based on IntermediateLayerGetter at + Originally based on concepts from IntermediateLayerGetter at https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py It has a strong assumption that the modules have been registered into the model in the same @@ -131,16 +149,20 @@ class LayerGetterDict(nn.ModuleDict): """ def __init__(self, model, return_layers, concat=False, flatten_sequential=False): - modules = _module_list(model, flatten_sequential=flatten_sequential) - self.return_layers, remaining = _check_return_layers(return_layers, modules) - layers = OrderedDict() + self.return_layers = {} self.concat = concat - for name, module in modules: - layers[name] = module - if name in remaining: - remaining.remove(name) + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = set(return_layers.keys()) + layers = OrderedDict() + for new_name, old_name, module in modules: + layers[new_name] = module + if old_name in remaining: + self.return_layers[new_name] = return_layers[old_name] + remaining.remove(old_name) if not remaining: break + assert not remaining and len(self.return_layers) == len(return_layers), \ + f'Return layers ({remaining}) are not present in model' super(LayerGetterDict, self).__init__(layers) def forward(self, x) -> Dict[Any, torch.Tensor]: @@ -162,7 +184,7 @@ class LayerGetterList(nn.Sequential): """ Module wrapper that returns intermediate layers from a model as a list - Originally based on IntermediateLayerGetter at + Originally based on concepts from IntermediateLayerGetter at https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py It has a strong assumption that the modules have been registered into the model in the same @@ -190,15 +212,19 @@ class LayerGetterList(nn.Sequential): def __init__(self, model, return_layers, concat=False, flatten_sequential=False): super(LayerGetterList, self).__init__() - modules = _module_list(model, flatten_sequential=flatten_sequential) - self.return_layers, remaining = _check_return_layers(return_layers, modules) + self.return_layers = {} self.concat = concat - for name, module in modules: - self.add_module(name, module) - if name in remaining: - remaining.remove(name) + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = set(return_layers.keys()) + for new_name, orig_name, module in modules: + self.add_module(new_name, module) + if orig_name in remaining: + self.return_layers[new_name] = return_layers[orig_name] + remaining.remove(orig_name) if not remaining: break + assert not remaining and len(self.return_layers) == len(return_layers), \ + f'Return layers ({remaining}) are not present in model' def forward(self, x) -> List[torch.Tensor]: out = [] @@ -225,6 +251,14 @@ def _resolve_feature_info(net, out_indices, feature_info=None): assert False, "Provided feature_info is not valid" +def _get_return_layers(feature_info, out_map): + module_names = feature_info.module_name() + return_layers = {} + for i, name in enumerate(module_names): + return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i] + return return_layers + + class FeatureNet(nn.Module): """ FeatureNet @@ -235,17 +269,41 @@ class FeatureNet(nn.Module): """ def __init__( self, net, - out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, + out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, use_hooks=False, feature_info=None, feature_concat=False, flatten_sequential=False): super(FeatureNet, self).__init__() self.feature_info = _resolve_feature_info(net, out_indices, feature_info) - module_names = self.feature_info.module_name() - return_layers = {} - for i in range(len(out_indices)): - return_layers[module_names[i]] = out_map[i] if out_map is not None else out_indices[i] - lg_args = dict(return_layers=return_layers, concat=feature_concat, flatten_sequential=flatten_sequential) - self.body = LayerGetterDict(net, **lg_args) if out_as_dict else LayerGetterList(net, **lg_args) + if use_hooks: + self.body = LayerGetterHooks(net, self.feature_info, out_as_dict=out_as_dict, out_map=out_map) + else: + return_layers = _get_return_layers(self.feature_info, out_map) + lg_args = dict(return_layers=return_layers, concat=feature_concat, flatten_sequential=flatten_sequential) + self.body = LayerGetterDict(net, **lg_args) if out_as_dict else LayerGetterList(net, **lg_args) def forward(self, x): output = self.body(x) return output + + +class FeatureHookNet(nn.Module): + """ FeatureHookNet + + Wrap a model and extract features specified by the out indices. + + Features are extracted via hooks without modifying the underlying network in any way. If only + part of the model is used it is up to the caller to remove unneeded layers as this wrapper + does not rewrite and remove unused top-level modules like FeatureNet with LayerGetter. + """ + def __init__( + self, net, + out_indices=(0, 1, 2, 3, 4), out_as_dict=False, out_map=None, + feature_info=None, feature_concat=False): + super(FeatureHookNet, self).__init__() + self.feature_info = _resolve_feature_info(net, out_indices, feature_info) + self.body = net + self.hooks = FeatureHooks( + self.feature_info, self.body.named_modules(), out_as_dict=out_as_dict, out_map=out_map) + + def forward(self, x): + self.body(x) + return self.hooks.get_output(x.device) diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index 2c1b0fe9..25385c32 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -5,9 +5,10 @@ by Ross Wightman """ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg from .layers import SEModule from .registry import register_model -from .resnet import _create_resnet_with_cfg, Bottleneck, BasicBlock +from .resnet import ResNet, Bottleneck, BasicBlock def _cfg(url='', **kwargs): @@ -47,8 +48,7 @@ default_cfgs = { def _create_resnet(variant, pretrained=False, **kwargs): - default_cfg = default_cfgs[variant] - return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs) + return build_model_with_cfg(ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs) @register_model diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 48c254d4..b27dceb6 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -1,11 +1,15 @@ +import logging +import os +from collections import OrderedDict +from copy import deepcopy +from typing import Callable + import torch import torch.nn as nn -from copy import deepcopy import torch.utils.model_zoo as model_zoo -import os -import logging -from collections import OrderedDict -from timm.models.layers.conv2d_same import Conv2dSame + +from .features import FeatureNet +from .layers import Conv2dSame def load_state_dict(checkpoint_path, use_ema=False): @@ -194,3 +198,42 @@ def adapt_model_from_file(parent_module, model_variant): adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt') with open(adapt_file, 'r') as f: return adapt_model_from_string(parent_module, f.read().strip()) + + +def build_model_with_cfg( + model_cls: Callable, + variant: str, + pretrained: bool, + default_cfg: dict, + model_cfg: dict = None, + feature_cfg: dict = None, + pretrained_filter_fn: Callable = None, + **kwargs): + pruned = kwargs.pop('pruned', False) + features = False + feature_cfg = feature_cfg or {} + + if kwargs.pop('features_only', False): + features = True + feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) + if 'out_indices' in kwargs: + feature_cfg['out_indices'] = kwargs.pop('out_indices') + + model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) + model.default_cfg = deepcopy(default_cfg) + + if pruned: + model = adapt_model_from_file(model, variant) + + if pretrained: + load_pretrained( + model, + num_classes=kwargs.get('num_classes', 0), + in_chans=kwargs.get('in_chans', 3), + filter_fn=pretrained_filter_fn) + + if features: + feature_cls = feature_cfg.pop('feature_cls', FeatureNet) + model = feature_cls(model, **feature_cfg) + + return model diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index ac4824bb..23836d3b 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -19,7 +19,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import load_pretrained +from .helpers import build_model_with_cfg from .layers import SelectAdaptivePool2d from .registry import register_model from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE @@ -734,67 +734,52 @@ class HighResolutionNet(nn.Module): return x -def _create_model(variant, pretrained, model_kwargs): - if model_kwargs.pop('features_only', False): - assert False, 'Not Implemented' # TODO - load_strict = False - model_kwargs.pop('num_classes', 0) - model_class = HighResolutionNet - else: - load_strict = True - model_class = HighResolutionNet - - model = model_class(cfg_cls[variant], **model_kwargs) - model.default_cfg = default_cfgs[variant] - if pretrained: - load_pretrained( - model, - num_classes=model_kwargs.get('num_classes', 0), - in_chans=model_kwargs.get('in_chans', 3), - strict=load_strict) - return model +def _create_hrnet(variant, pretrained, **model_kwargs): + return build_model_with_cfg( + HighResolutionNet, variant, pretrained, default_cfg=default_cfgs[variant], + model_cfg=cfg_cls[variant], **model_kwargs) @register_model def hrnet_w18_small(pretrained=True, **kwargs): - return _create_model('hrnet_w18_small', pretrained, kwargs) + return _create_hrnet('hrnet_w18_small', pretrained, **kwargs) @register_model def hrnet_w18_small_v2(pretrained=True, **kwargs): - return _create_model('hrnet_w18_small_v2', pretrained, kwargs) + return _create_hrnet('hrnet_w18_small_v2', pretrained, **kwargs) @register_model def hrnet_w18(pretrained=True, **kwargs): - return _create_model('hrnet_w18', pretrained, kwargs) + return _create_hrnet('hrnet_w18', pretrained, **kwargs) @register_model def hrnet_w30(pretrained=True, **kwargs): - return _create_model('hrnet_w30', pretrained, kwargs) + return _create_hrnet('hrnet_w30', pretrained, **kwargs) @register_model def hrnet_w32(pretrained=True, **kwargs): - return _create_model('hrnet_w32', pretrained, kwargs) + return _create_hrnet('hrnet_w32', pretrained, **kwargs) @register_model def hrnet_w40(pretrained=True, **kwargs): - return _create_model('hrnet_w40', pretrained, kwargs) + return _create_hrnet('hrnet_w40', pretrained, **kwargs) @register_model def hrnet_w44(pretrained=True, **kwargs): - return _create_model('hrnet_w44', pretrained, kwargs) + return _create_hrnet('hrnet_w44', pretrained, **kwargs) @register_model def hrnet_w48(pretrained=True, **kwargs): - return _create_model('hrnet_w48', pretrained, kwargs) + return _create_hrnet('hrnet_w48', pretrained, **kwargs) @register_model def hrnet_w64(pretrained=True, **kwargs): - return _create_model('hrnet_w64', pretrained, kwargs) + return _create_hrnet('hrnet_w64', pretrained, **kwargs) diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index 85c00486..c438898e 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -7,8 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .features import FeatureNet -from .helpers import load_pretrained +from .helpers import build_model_with_cfg from .layers import SelectAdaptivePool2d from .registry import register_model @@ -340,20 +339,9 @@ class InceptionResnetV2(nn.Module): return x -def _inception_resnet_v2(variant, pretrained=False, **kwargs): - features, out_indices = False, None - if kwargs.pop('features_only', False): - features = True - out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) - model = InceptionResnetV2(**kwargs) - model.default_cfg = default_cfgs[variant] - if pretrained: - load_pretrained( - model, - num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features) - if features: - model = FeatureNet(model, out_indices) - return model +def _create_inception_resnet_v2(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + InceptionResnetV2, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs) @register_model @@ -361,7 +349,7 @@ def inception_resnet_v2(pretrained=False, **kwargs): r"""InceptionResnetV2 model architecture from the `"InceptionV4, Inception-ResNet..." ` paper. """ - return _inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs) + return _create_inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs) @register_model @@ -370,4 +358,4 @@ def ens_adv_inception_resnet_v2(pretrained=False, **kwargs): As per https://arxiv.org/abs/1705.07204 and https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models. """ - return _inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs) + return _create_inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs) diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index 67d33155..8a425f4c 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -504,21 +504,16 @@ class BasicConv2d(nn.Module): return F.relu(x, inplace=True) -def _inception_v3(variant, pretrained=False, **kwargs): +def _create_inception_v3(variant, pretrained=False, **kwargs): + assert not kwargs.pop('features_only', False) default_cfg = default_cfgs[variant] - if kwargs.pop('features_only', False): - assert False, 'Not Implemented' # TODO - load_strict = False - model_kwargs.pop('num_classes', 0) - model_class = InceptionV3 + aux_logits = kwargs.pop('aux_logits', False) + if aux_logits: + model_class = InceptionV3Aux + load_strict = default_cfg['has_aux'] else: - aux_logits = kwargs.pop('aux_logits', False) - if aux_logits: - model_class = InceptionV3Aux - load_strict = default_cfg['has_aux'] - else: - model_class = InceptionV3 - load_strict = not default_cfg['has_aux'] + model_class = InceptionV3 + load_strict = not default_cfg['has_aux'] model = model_class(**kwargs) model.default_cfg = default_cfg @@ -534,14 +529,14 @@ def _inception_v3(variant, pretrained=False, **kwargs): @register_model def inception_v3(pretrained=False, **kwargs): # original PyTorch weights, ported from Tensorflow but modified - model = _inception_v3('inception_v3', pretrained=pretrained, **kwargs) + model = _create_inception_v3('inception_v3', pretrained=pretrained, **kwargs) return model @register_model def tf_inception_v3(pretrained=False, **kwargs): # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) - model = _inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs) + model = _create_inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs) return model @@ -549,7 +544,7 @@ def tf_inception_v3(pretrained=False, **kwargs): def adv_inception_v3(pretrained=False, **kwargs): # my port of Tensorflow adversarially trained Inception V3 from # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz - model = _inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs) + model = _create_inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs) return model @@ -557,5 +552,5 @@ def adv_inception_v3(pretrained=False, **kwargs): def gluon_inception_v3(pretrained=False, **kwargs): # from gluon pretrained models, best performing in terms of accuracy/loss metrics # https://gluon-cv.mxnet.io/model_zoo/classification.html - model = _inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs) + model = _create_inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs) return model diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 1ebc4be0..e8efa3dc 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -3,6 +3,7 @@ from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d from .anti_aliasing import AntiAliasDownsampleLayer from .blur_pool import BlurPool2d +from .classifier import ClassifierHead from .cond_conv2d import CondConv2d, get_condconv_initializer from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ set_layer_config @@ -24,6 +25,7 @@ from .se import SEModule from .selective_kernel import SelectiveKernelConv from .separable_conv import SeparableConv2d, SeparableConvBnAct from .space_to_depth import SpaceToDepthModule +from .split_attn import SplitAttnConv2d from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .weight_init import trunc_normal_ diff --git a/timm/models/layers/activations.py b/timm/models/layers/activations.py index 71904935..c1066b7b 100644 --- a/timm/models/layers/activations.py +++ b/timm/models/layers/activations.py @@ -82,7 +82,7 @@ class HardSwish(nn.Module): self.inplace = inplace def forward(self, x): - return hard_swish(x, self.inplace) + return F.hardswish(x) #hard_swish(x, self.inplace) def hard_sigmoid(x, inplace: bool = False): diff --git a/timm/models/layers/classifier.py b/timm/models/layers/classifier.py new file mode 100644 index 00000000..29960c44 --- /dev/null +++ b/timm/models/layers/classifier.py @@ -0,0 +1,24 @@ +from torch import nn as nn +from torch.nn import functional as F + +from .adaptive_avgmax_pool import SelectAdaptivePool2d + + +class ClassifierHead(nn.Module): + """Classifier Head w/ configurable global pooling and dropout.""" + + def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.): + super(ClassifierHead, self).__init__() + self.drop_rate = drop_rate + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) + if num_classes > 0: + self.fc = nn.Linear(in_chs * self.global_pool.feat_mult(), num_classes, bias=True) + else: + self.fc = nn.Identity() + + def forward(self, x): + x = self.global_pool(x).flatten(1) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + x = self.fc(x) + return x diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index 6404d62f..bf4ad119 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -39,7 +39,7 @@ _ACT_FN_ME = dict( ) _ACT_LAYER_DEFAULT = dict( - swish=Swish, + swish=Swish, #nn.SiLU, # mish=Mish, relu=nn.ReLU, relu6=nn.ReLU6, @@ -56,7 +56,7 @@ _ACT_LAYER_DEFAULT = dict( ) _ACT_LAYER_JIT = dict( - swish=SwishJit, + #swish=SwishJit, mish=MishJit, hard_sigmoid=HardSigmoidJit, hard_swish=HardSwishJit, diff --git a/timm/models/layers/helpers.py b/timm/models/layers/helpers.py index 967c2f4c..d86f7bec 100644 --- a/timm/models/layers/helpers.py +++ b/timm/models/layers/helpers.py @@ -19,7 +19,7 @@ tup_single = _ntuple(1) tup_pair = _ntuple(2) tup_triple = _ntuple(3) tup_quadruple = _ntuple(4) - +ntup = _ntuple diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 28af4a10..48288223 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -215,7 +215,7 @@ class MobileNetV3Features(nn.Module): return self.feature_hooks.get_output(x.device) -def _create_model(model_kwargs, default_cfg, pretrained=False): +def _create_mnv3(model_kwargs, default_cfg, pretrained=False): if model_kwargs.pop('features_only', False): load_strict = False model_kwargs.pop('num_classes', 0) @@ -272,7 +272,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1), **kwargs, ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + model = _create_mnv3(model_kwargs, default_cfgs[variant], pretrained) return model @@ -368,7 +368,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8), **kwargs, ) - model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + model = _create_mnv3(model_kwargs, default_cfgs[variant], pretrained) return model diff --git a/timm/models/regnet.py b/timm/models/regnet.py index d934c2a5..c0926554 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -14,12 +14,10 @@ Weights from original impl have been modified """ import numpy as np import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .features import FeatureNet -from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d, AvgPool2dSame, ConvBnAct, SEModule +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule from .registry import register_model @@ -222,26 +220,6 @@ class RegStage(nn.Module): return x -class ClassifierHead(nn.Module): - """Head.""" - - def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.): - super(ClassifierHead, self).__init__() - self.drop_rate = drop_rate - self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) - if num_classes > 0: - self.fc = nn.Linear(in_chs, num_classes, bias=True) - else: - self.fc = nn.Identity() - - def forward(self, x): - x = self.global_pool(x).flatten(1) - if self.drop_rate: - x = F.dropout(x, p=float(self.drop_rate), training=self.training) - x = self.fc(x) - return x - - class RegNet(nn.Module): """RegNet model. @@ -343,163 +321,150 @@ class RegNet(nn.Module): return x -def _regnet(variant, pretrained, **kwargs): - features = False - out_indices = None - if kwargs.pop('features_only', False): - features = True - out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) - model_cfg = model_cfgs[variant] - model = RegNet(model_cfg, **kwargs) - model.default_cfg = default_cfgs[variant] - if pretrained: - load_pretrained( - model, - num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features) - if features: - model = FeatureNet(model, out_indices=out_indices) - return model +def _create_regnet(variant, pretrained, **kwargs): + return build_model_with_cfg( + RegNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant], **kwargs) @register_model def regnetx_002(pretrained=False, **kwargs): """RegNetX-200MF""" - return _regnet('regnetx_002', pretrained, **kwargs) + return _create_regnet('regnetx_002', pretrained, **kwargs) @register_model def regnetx_004(pretrained=False, **kwargs): """RegNetX-400MF""" - return _regnet('regnetx_004', pretrained, **kwargs) + return _create_regnet('regnetx_004', pretrained, **kwargs) @register_model def regnetx_006(pretrained=False, **kwargs): """RegNetX-600MF""" - return _regnet('regnetx_006', pretrained, **kwargs) + return _create_regnet('regnetx_006', pretrained, **kwargs) @register_model def regnetx_008(pretrained=False, **kwargs): """RegNetX-800MF""" - return _regnet('regnetx_008', pretrained, **kwargs) + return _create_regnet('regnetx_008', pretrained, **kwargs) @register_model def regnetx_016(pretrained=False, **kwargs): """RegNetX-1.6GF""" - return _regnet('regnetx_016', pretrained, **kwargs) + return _create_regnet('regnetx_016', pretrained, **kwargs) @register_model def regnetx_032(pretrained=False, **kwargs): """RegNetX-3.2GF""" - return _regnet('regnetx_032', pretrained, **kwargs) + return _create_regnet('regnetx_032', pretrained, **kwargs) @register_model def regnetx_040(pretrained=False, **kwargs): """RegNetX-4.0GF""" - return _regnet('regnetx_040', pretrained, **kwargs) + return _create_regnet('regnetx_040', pretrained, **kwargs) @register_model def regnetx_064(pretrained=False, **kwargs): """RegNetX-6.4GF""" - return _regnet('regnetx_064', pretrained, **kwargs) + return _create_regnet('regnetx_064', pretrained, **kwargs) @register_model def regnetx_080(pretrained=False, **kwargs): """RegNetX-8.0GF""" - return _regnet('regnetx_080', pretrained, **kwargs) + return _create_regnet('regnetx_080', pretrained, **kwargs) @register_model def regnetx_120(pretrained=False, **kwargs): """RegNetX-12GF""" - return _regnet('regnetx_120', pretrained, **kwargs) + return _create_regnet('regnetx_120', pretrained, **kwargs) @register_model def regnetx_160(pretrained=False, **kwargs): """RegNetX-16GF""" - return _regnet('regnetx_160', pretrained, **kwargs) + return _create_regnet('regnetx_160', pretrained, **kwargs) @register_model def regnetx_320(pretrained=False, **kwargs): """RegNetX-32GF""" - return _regnet('regnetx_320', pretrained, **kwargs) + return _create_regnet('regnetx_320', pretrained, **kwargs) @register_model def regnety_002(pretrained=False, **kwargs): """RegNetY-200MF""" - return _regnet('regnety_002', pretrained, **kwargs) + return _create_regnet('regnety_002', pretrained, **kwargs) @register_model def regnety_004(pretrained=False, **kwargs): """RegNetY-400MF""" - return _regnet('regnety_004', pretrained, **kwargs) + return _create_regnet('regnety_004', pretrained, **kwargs) @register_model def regnety_006(pretrained=False, **kwargs): """RegNetY-600MF""" - return _regnet('regnety_006', pretrained, **kwargs) + return _create_regnet('regnety_006', pretrained, **kwargs) @register_model def regnety_008(pretrained=False, **kwargs): """RegNetY-800MF""" - return _regnet('regnety_008', pretrained, **kwargs) + return _create_regnet('regnety_008', pretrained, **kwargs) @register_model def regnety_016(pretrained=False, **kwargs): """RegNetY-1.6GF""" - return _regnet('regnety_016', pretrained, **kwargs) + return _create_regnet('regnety_016', pretrained, **kwargs) @register_model def regnety_032(pretrained=False, **kwargs): """RegNetY-3.2GF""" - return _regnet('regnety_032', pretrained, **kwargs) + return _create_regnet('regnety_032', pretrained, **kwargs) @register_model def regnety_040(pretrained=False, **kwargs): """RegNetY-4.0GF""" - return _regnet('regnety_040', pretrained, **kwargs) + return _create_regnet('regnety_040', pretrained, **kwargs) @register_model def regnety_064(pretrained=False, **kwargs): """RegNetY-6.4GF""" - return _regnet('regnety_064', pretrained, **kwargs) + return _create_regnet('regnety_064', pretrained, **kwargs) @register_model def regnety_080(pretrained=False, **kwargs): """RegNetY-8.0GF""" - return _regnet('regnety_080', pretrained, **kwargs) + return _create_regnet('regnety_080', pretrained, **kwargs) @register_model def regnety_120(pretrained=False, **kwargs): """RegNetY-12GF""" - return _regnet('regnety_120', pretrained, **kwargs) + return _create_regnet('regnety_120', pretrained, **kwargs) @register_model def regnety_160(pretrained=False, **kwargs): """RegNetY-16GF""" - return _regnet('regnety_160', pretrained, **kwargs) + return _create_regnet('regnety_160', pretrained, **kwargs) @register_model def regnety_320(pretrained=False, **kwargs): """RegNetY-32GF""" - return _regnet('regnety_320', pretrained, **kwargs) + return _create_regnet('regnety_320', pretrained, **kwargs) diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 536fd49a..ae753d12 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -8,9 +8,9 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import load_pretrained +from .helpers import build_model_with_cfg from .registry import register_model -from .resnet import _create_resnet_with_cfg +from .resnet import ResNet __all__ = [] @@ -133,8 +133,8 @@ class Bottle2neck(nn.Module): def _create_res2net(variant, pretrained=False, **kwargs): - default_cfg = default_cfgs[variant] - return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs) + return build_model_with_cfg( + ResNet, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs) @register_model diff --git a/timm/models/resnest.py b/timm/models/resnest.py index cf207faa..f5f4d71c 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -10,10 +10,10 @@ import torch from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.layers import DropBlock2d -from .layers.split_attn import SplitAttnConv2d +from .helpers import build_model_with_cfg +from .layers import SplitAttnConv2d from .registry import register_model -from .resnet import _create_resnet_with_cfg +from .resnet import ResNet def _cfg(url='', **kwargs): @@ -140,8 +140,8 @@ class ResNestBottleneck(nn.Module): def _create_resnest(variant, pretrained=False, **kwargs): - default_cfg = default_cfgs[variant] - return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs) + return build_model_with_cfg( + ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs) @register_model diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 7c243297..4fbc9564 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -13,8 +13,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .features import FeatureNet -from .helpers import load_pretrained, adapt_model_from_file +from .helpers import build_model_with_cfg from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d from .registry import register_model @@ -590,32 +589,9 @@ class ResNet(nn.Module): return x -def _create_resnet_with_cfg(variant, default_cfg, pretrained=False, **kwargs): - assert isinstance(default_cfg, dict) - features = False - out_indices = None - if kwargs.pop('features_only', False): - features = True - out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) - pruned = kwargs.pop('pruned', False) - - model = ResNet(**kwargs) - model.default_cfg = copy.deepcopy(default_cfg) - - if pruned: - model = adapt_model_from_file(model, variant) - if pretrained: - load_pretrained( - model, - num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features) - if features: - model = FeatureNet(model, out_indices=out_indices) - return model - - def _create_resnet(variant, pretrained=False, **kwargs): - default_cfg = default_cfgs[variant] - return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs) + return build_model_with_cfg( + ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs) @register_model diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 5dddedb5..7161f723 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -16,8 +16,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .features import FeatureNet -from .helpers import load_pretrained +from .helpers import build_model_with_cfg from .layers import SelectAdaptivePool2d from .registry import register_model @@ -178,7 +177,7 @@ class SelecSLS(nn.Module): return x -def _create_model(variant, pretrained, model_kwargs): +def _create_selecsls(variant, pretrained, model_kwargs): cfg = {} feature_info = [dict(num_chs=32, reduction=2, module='stem.2')] if variant.startswith('selecsls42'): @@ -299,61 +298,42 @@ def _create_model(variant, pretrained, model_kwargs): else: raise ValueError('Invalid net configuration ' + variant + ' !!!') - load_strict = True - features = False - out_indices = None - if model_kwargs.pop('features_only', False): - load_strict = False - features = True - # this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises? - out_indices = model_kwargs.pop('out_indices', (0, 1, 2, 3, 4)) - model_kwargs.pop('num_classes', 0) - - model = SelecSLS(cfg, **model_kwargs) - model.default_cfg = default_cfgs[variant] - model.feature_info = feature_info - if pretrained: - load_pretrained( - model, - num_classes=model_kwargs.get('num_classes', 0), - in_chans=model_kwargs.get('in_chans', 3), - strict=load_strict) - - if features: - model = FeatureNet(model, out_indices, flatten_sequential=True) - return model + # this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises? + return build_model_with_cfg( + SelecSLS, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=cfg, + feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), **model_kwargs) @register_model def selecsls42(pretrained=False, **kwargs): """Constructs a SelecSLS42 model. """ - return _create_model('selecsls42', pretrained, kwargs) + return _create_selecsls('selecsls42', pretrained, kwargs) @register_model def selecsls42b(pretrained=False, **kwargs): """Constructs a SelecSLS42_B model. """ - return _create_model('selecsls42b', pretrained, kwargs) + return _create_selecsls('selecsls42b', pretrained, kwargs) @register_model def selecsls60(pretrained=False, **kwargs): """Constructs a SelecSLS60 model. """ - return _create_model('selecsls60', pretrained, kwargs) + return _create_selecsls('selecsls60', pretrained, kwargs) @register_model def selecsls60b(pretrained=False, **kwargs): """Constructs a SelecSLS60_B model. """ - return _create_model('selecsls60b', pretrained, kwargs) + return _create_selecsls('selecsls60b', pretrained, kwargs) @register_model def selecsls84(pretrained=False, **kwargs): """Constructs a SelecSLS84 model. """ - return _create_model('selecsls84', pretrained, kwargs) + return _create_selecsls('selecsls84', pretrained, kwargs) diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 2bbf9786..fa00eb5f 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -13,10 +13,10 @@ import math from torch import nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import load_pretrained +from .helpers import build_model_with_cfg from .layers import SelectiveKernelConv, ConvBnAct, create_attn from .registry import register_model -from .resnet import _create_resnet_with_cfg +from .resnet import ResNet def _cfg(url='', **kwargs): @@ -139,8 +139,8 @@ class SelectiveKernelBottleneck(nn.Module): def _create_skresnet(variant, pretrained=False, **kwargs): - default_cfg = default_cfgs[variant] - return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs) + return build_model_with_cfg( + ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs) @register_model diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index a4274b2f..27e604b8 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -5,6 +5,7 @@ https://arxiv.org/pdf/2003.13630.pdf Original model: https://github.com/mrT23/TResNet """ +import copy from collections import OrderedDict from functools import partial @@ -12,8 +13,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .helpers import load_pretrained -from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, SelectAdaptivePool2d, InplaceAbn +from .helpers import build_model_with_cfg +from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead from .registry import register_model __all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] @@ -220,11 +221,17 @@ class TResNet(nn.Module): ('layer3', layer3), ('layer4', layer4)])) + self.feature_info = [ + dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D? + dict(num_chs=self.planes, reduction=4, module='body.layer1'), + dict(num_chs=self.planes * 2, reduction=8, module='body.layer2'), + dict(num_chs=self.planes * 4, reduction=16, module='body.layer3'), + dict(num_chs=self.planes * 8, reduction=32, module='body.layer4'), + ] + # head self.num_features = (self.planes * 8) * Bottleneck.expansion - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) - self.head = nn.Sequential(OrderedDict([ - ('fc', nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes))])) + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) # model initilization for m in self.modules(): @@ -240,7 +247,8 @@ class TResNet(nn.Module): m.conv2[1].weight = nn.Parameter(torch.zeros_like(m.conv2[1].weight)) # BN to zero if isinstance(m, Bottleneck): m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero - if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01) + if isinstance(m, nn.Linear): + m.weight.data.normal_(0, 0.01) def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None): downsample = None @@ -266,86 +274,55 @@ class TResNet(nn.Module): return self.head.fc def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) - self.num_classes = num_classes - self.head = None - if num_classes: - num_features = self.num_features * self.global_pool.feat_mult() - self.head = nn.Sequential(OrderedDict([('fc', nn.Linear(num_features, num_classes))])) - else: - self.head = nn.Sequential(OrderedDict([('fc', nn.Identity())])) + self.head = ClassifierHead( + self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) def forward_features(self, x): return self.body(x) def forward(self, x): x = self.forward_features(x) - x = self.global_pool(x) - if self.drop_rate: - x = F.dropout(x, p=float(self.drop_rate), training=self.training) x = self.head(x) return x -@register_model -def tresnet_m(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['tresnet_m'] - model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def _create_tresnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + TResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, + feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True), **kwargs) @register_model -def tresnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['tresnet_l'] - model = TResNet( - layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def tresnet_m(pretrained=False, **kwargs): + model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs) + return _create_tresnet('tresnet_m', pretrained=pretrained, **model_kwargs) @register_model -def tresnet_xl(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['tresnet_xl'] - model = TResNet( - layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def tresnet_l(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs) + return _create_tresnet('tresnet_l', pretrained=pretrained, **model_kwargs) @register_model -def tresnet_m_448(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['tresnet_m_448'] - model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def tresnet_xl(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs) + return _create_tresnet('tresnet_xl', pretrained=pretrained, **model_kwargs) @register_model -def tresnet_l_448(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['tresnet_l_448'] - model = TResNet( - layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def tresnet_m_448(pretrained=False, **kwargs): + model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs) + return _create_tresnet('tresnet_m_448', pretrained=pretrained, **model_kwargs) @register_model -def tresnet_xl_448(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - default_cfg = default_cfgs['tresnet_xl_448'] - model = TResNet( - layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model +def tresnet_l_448(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs) + return _create_tresnet('tresnet_l_448', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_xl_448(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs) + return _create_tresnet('tresnet_xl_448', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 788a1f89..87e68537 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -19,9 +19,8 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .registry import register_model -from .helpers import load_pretrained -from .features import FeatureNet -from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, SelectAdaptivePool2d, \ +from .helpers import build_model_with_cfg +from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, ClassifierHead, \ create_attn, create_norm_act, get_norm_act_layer @@ -253,26 +252,6 @@ class OsaStage(nn.Module): return x -class ClassifierHead(nn.Module): - """Head.""" - - def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.): - super(ClassifierHead, self).__init__() - self.drop_rate = drop_rate - self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) - if num_classes > 0: - self.fc = nn.Linear(in_chs, num_classes, bias=True) - else: - self.fc = nn.Identity() - - def forward(self, x): - x = self.global_pool(x).flatten(1) - if self.drop_rate: - x = F.dropout(x, p=float(self.drop_rate), training=self.training) - x = self.fc(x) - return x - - class VovNet(nn.Module): def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4, @@ -346,67 +325,55 @@ class VovNet(nn.Module): return self.head(x) -def _vovnet(variant, pretrained=False, **kwargs): - features = False - out_indices = None - if kwargs.pop('features_only', False): - features = True - out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) - model_cfg = model_cfgs[variant] - model = VovNet(model_cfg, **kwargs) - model.default_cfg = default_cfgs[variant] - if pretrained: - load_pretrained( - model, - num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features) - if features: - model = FeatureNet(model, out_indices, flatten_sequential=True) - return model +def _create_vovnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + VovNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), **kwargs) @register_model def vovnet39a(pretrained=False, **kwargs): - return _vovnet('vovnet39a', pretrained=pretrained, **kwargs) + return _create_vovnet('vovnet39a', pretrained=pretrained, **kwargs) @register_model def vovnet57a(pretrained=False, **kwargs): - return _vovnet('vovnet57a', pretrained=pretrained, **kwargs) + return _create_vovnet('vovnet57a', pretrained=pretrained, **kwargs) @register_model def ese_vovnet19b_slim_dw(pretrained=False, **kwargs): - return _vovnet('ese_vovnet19b_slim_dw', pretrained=pretrained, **kwargs) + return _create_vovnet('ese_vovnet19b_slim_dw', pretrained=pretrained, **kwargs) @register_model def ese_vovnet19b_dw(pretrained=False, **kwargs): - return _vovnet('ese_vovnet19b_dw', pretrained=pretrained, **kwargs) + return _create_vovnet('ese_vovnet19b_dw', pretrained=pretrained, **kwargs) @register_model def ese_vovnet19b_slim(pretrained=False, **kwargs): - return _vovnet('ese_vovnet19b_slim', pretrained=pretrained, **kwargs) + return _create_vovnet('ese_vovnet19b_slim', pretrained=pretrained, **kwargs) @register_model def ese_vovnet39b(pretrained=False, **kwargs): - return _vovnet('ese_vovnet39b', pretrained=pretrained, **kwargs) + return _create_vovnet('ese_vovnet39b', pretrained=pretrained, **kwargs) @register_model def ese_vovnet57b(pretrained=False, **kwargs): - return _vovnet('ese_vovnet57b', pretrained=pretrained, **kwargs) + return _create_vovnet('ese_vovnet57b', pretrained=pretrained, **kwargs) @register_model def ese_vovnet99b(pretrained=False, **kwargs): - return _vovnet('ese_vovnet99b', pretrained=pretrained, **kwargs) + return _create_vovnet('ese_vovnet99b', pretrained=pretrained, **kwargs) @register_model def eca_vovnet39b(pretrained=False, **kwargs): - return _vovnet('eca_vovnet39b', pretrained=pretrained, **kwargs) + return _create_vovnet('eca_vovnet39b', pretrained=pretrained, **kwargs) # Experimental Models @@ -415,11 +382,11 @@ def eca_vovnet39b(pretrained=False, **kwargs): def ese_vovnet39b_evos(pretrained=False, **kwargs): def norm_act_fn(num_features, **nkwargs): return create_norm_act('EvoNormSample', num_features, jit=False, **nkwargs) - return _vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs) + return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs) @register_model def ese_vovnet99b_iabn(pretrained=False, **kwargs): norm_layer = get_norm_act_layer('iabn') - return _vovnet( + return _create_vovnet( 'ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, act_layer=nn.LeakyReLU, **kwargs) diff --git a/timm/models/xception.py b/timm/models/xception.py index 60241f29..8bf62624 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -25,8 +25,7 @@ The resize parameter of the validation transform should be 333, and make sure to import torch.nn as nn import torch.nn.functional as F -from .helpers import load_pretrained -from .features import FeatureNet +from .helpers import build_model_with_cfg from .layers import SelectAdaptivePool2d from .registry import register_model @@ -220,25 +219,9 @@ class Xception(nn.Module): def _xception(variant, pretrained=False, **kwargs): - load_strict = True - features = False - out_indices = None - if kwargs.pop('features_only', False): - load_strict = False - features = True - kwargs.pop('num_classes', 0) - out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) - model = Xception(**kwargs) - model.default_cfg = default_cfgs[variant] - if pretrained: - load_pretrained( - model, - num_classes=kwargs.get('num_classes', 0), - in_chans=kwargs.get('in_chans', 3), - strict=load_strict) - if features: - model = FeatureNet(model, out_indices) - return model + return build_model_with_cfg( + Xception, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(), **kwargs) @register_model diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index c2006173..81334027 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -10,9 +10,9 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .features import FeatureNet -from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, ConvBnAct, create_conv2d +from .layers.helpers import tup_triple from .registry import register_model __all__ = ['XceptionAligned'] @@ -81,10 +81,7 @@ class XceptionModule(nn.Module): start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None, norm_kwargs=None): super(XceptionModule, self).__init__() norm_kwargs = norm_kwargs if norm_kwargs is not None else {} - if isinstance(out_chs, (list, tuple)): - assert len(out_chs) == 3 - else: - out_chs = (out_chs,) * 3 + out_chs = tup_triple(out_chs) self.in_channels = in_chs self.out_channels = out_chs[-1] self.no_skip = no_skip @@ -115,26 +112,6 @@ class XceptionModule(nn.Module): return x -class ClassifierHead(nn.Module): - """Head.""" - - def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.): - super(ClassifierHead, self).__init__() - self.drop_rate = drop_rate - self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) - if num_classes > 0: - self.fc = nn.Linear(in_chs, num_classes, bias=True) - else: - self.fc = nn.Identity() - - def forward(self, x): - x = self.global_pool(x).flatten(1) - if self.drop_rate: - x = F.dropout(x, p=float(self.drop_rate), training=self.training) - x = self.fc(x) - return x - - class XceptionAligned(nn.Module): """Modified Aligned Xception """ @@ -147,32 +124,29 @@ class XceptionAligned(nn.Module): assert output_stride in (8, 16, 32) norm_kwargs = norm_kwargs if norm_kwargs is not None else {} - xtra_args = dict(act_layer=act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs) self.stem = nn.Sequential(*[ - ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **xtra_args), - ConvBnAct(32, 64, kernel_size=3, stride=1, **xtra_args) + ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **layer_args), + ConvBnAct(32, 64, kernel_size=3, stride=1, **layer_args) ]) + curr_dilation = 1 curr_stride = 2 - self.feature_info = [dict(num_chs=64, reduction=curr_stride, module='stem.1')] - + self.feature_info = [] self.blocks = nn.Sequential() for i, b in enumerate(block_cfg): - feature_extract = False b['dilation'] = curr_dilation if b['stride'] > 1: - feature_extract = True + self.feature_info += [dict( + num_chs=tup_triple(b['out_chs'])[-2], reduction=curr_stride, module=f'blocks.{i}.stack.act3')] next_stride = curr_stride * b['stride'] if next_stride > output_stride: curr_dilation *= b['stride'] b['stride'] = 1 else: curr_stride = next_stride - self.blocks.add_module(str(i), XceptionModule(**b, **xtra_args)) + self.blocks.add_module(str(i), XceptionModule(**b, **layer_args)) self.num_features = self.blocks[-1].out_channels - if feature_extract: - self.feature_info += [dict( - num_chs=self.num_features, reduction=curr_stride, module=f'blocks.{i}.stack.act2')] self.feature_info += [dict( num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] @@ -198,24 +172,9 @@ class XceptionAligned(nn.Module): def _xception(variant, pretrained=False, **kwargs): - features = False - out_indices = None - if kwargs.pop('features_only', False): - features = True - kwargs.pop('num_classes', 0) - out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) - model = XceptionAligned(**kwargs) - model.default_cfg = default_cfgs[variant] - if pretrained: - load_pretrained( - model, - num_classes=kwargs.get('num_classes', 0), - in_chans=kwargs.get('in_chans', 3), - strict=not features) - if features: - model = FeatureNet(model, out_indices) - return model - + return build_model_with_cfg( + XceptionAligned, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True, use_hooks=True), **kwargs) @register_model