diff --git a/tests/test_models.py b/tests/test_models.py index df21d039..6a467597 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,8 +4,14 @@ import platform import os import fnmatch +import timm from timm import list_models, create_model, set_scriptable +if hasattr(torch._C, '_jit_set_profiling_executor'): + # legacy executor is too slow to compile large models for unit tests + # no need for the fusion performance here + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(False) if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models @@ -78,10 +84,28 @@ def test_model_default_cfgs(model_name, batch_size): if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and \ not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]): - # pool size only checked if default res <= 448 * 448 to keep resource down + # output sizes only checked if default res <= 448 * 448 to keep resource down input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size]) - outputs = model.forward_features(torch.randn((batch_size, *input_size))) + input_tensor = torch.randn((batch_size, *input_size)) + + # test forward_features (always unpooled) + outputs = model.forward_features(input_tensor) assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] + + # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features + model.reset_classifier(0) + outputs = model.forward(input_tensor) + assert len(outputs.shape) == 2 + assert outputs.shape[-1] == model.num_features + + # test model forward without pooling and classifier + if not isinstance(model, timm.models.MobileNetV3): + model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through + outputs = model.forward(input_tensor) + assert len(outputs.shape) == 4 + assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] + + # check classifier and first convolution names match those in default_cfg assert any([k.startswith(classifier) for k in state_dict.keys()]), f'{classifier} not in model params' assert any([k.startswith(first_conv) for k in state_dict.keys()]), f'{first_conv} not in model params' diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 5c8d6af8..e4e20564 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -14,7 +14,7 @@ from torch.jit.annotations import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, BlurPool2d +from .layers import BatchNormAct2d, create_norm_act, BlurPool2d, create_classifier from .registry import register_model __all__ = ['DenseNet'] @@ -236,8 +236,8 @@ class DenseNet(nn.Module): self.num_features = num_features # Linear layer - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) # Official init from torch repo. for m in self.modules(): @@ -254,19 +254,15 @@ class DenseNet(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - if num_classes: - num_features = self.num_features * self.global_pool.feat_mult() - self.classifier = nn.Linear(num_features, num_classes) - else: - self.classifier = nn.Identity() + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) def forward_features(self, x): return self.features(x) def forward(self, x): x = self.forward_features(x) - x = self.global_pool(x).flatten(1) + x = self.global_pool(x) # both classifier and block drop? # if self.drop_rate > 0.: # x = F.dropout(x, p=self.drop_rate, training=self.training) diff --git a/timm/models/dla.py b/timm/models/dla.py index 212150e6..a41ec326 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d +from .layers import create_classifier from .registry import register_model __all__ = ['DLA'] @@ -286,9 +286,8 @@ class DLA(nn.Module): ] self.num_features = channels[-1] - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.fc = nn.Conv2d(self.num_features * self.global_pool.feat_mult(), num_classes, 1, bias=True) - + self.global_pool, self.fc = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels @@ -313,12 +312,8 @@ class DLA(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - if num_classes: - num_features = self.num_features * self.global_pool.feat_mult() - self.fc = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True) - else: - self.fc = nn.Identity() + self.global_pool, self.fc = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) def forward_features(self, x): x = self.base_layer(x) @@ -336,7 +331,9 @@ class DLA(nn.Module): if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.fc(x) - return x.flatten(1) + if not self.global_pool.is_identity(): + x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled) + return x def _create_dla(variant, pretrained=False, **kwargs): diff --git a/timm/models/dpn.py b/timm/models/dpn.py index a0a77ab5..61ce6a0e 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_conv2d, ConvBnAct +from .layers import BatchNormAct2d, ConvBnAct, create_conv2d, create_classifier from .registry import register_model __all__ = ['DPN'] @@ -237,21 +237,16 @@ class DPN(nn.Module): self.features = nn.Sequential(blocks) # Using 1x1 conv for the FC layer to allow the extra pooling scheme - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - num_features = self.num_features * self.global_pool.feat_mult() - self.classifier = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True) + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) def get_classifier(self): return self.classifier def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - if num_classes: - num_features = self.num_features * self.global_pool.feat_mult() - self.classifier = nn.Conv2d(num_features, num_classes, kernel_size=1, bias=True) - else: - self.classifier = nn.Identity() + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) def forward_features(self, x): return self.features(x) @@ -261,8 +256,10 @@ class DPN(nn.Module): x = self.global_pool(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) - out = self.classifier(x) - return out.flatten(1) + x = self.classifier(x) + if not self.global_pool.is_identity(): + x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled) + return x def _create_dpn(variant, pretrained=False, **kwargs): diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 02e4797f..2e64d7e1 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -35,7 +35,7 @@ from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_la from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .features import FeatureInfo, FeatureHooks from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d, create_conv2d +from .layers import create_conv2d, create_classifier from .registry import register_model __all__ = ['EfficientNet'] @@ -336,32 +336,28 @@ class EfficientNet(nn.Module): self.num_classes = num_classes self.num_features = num_features self.drop_rate = drop_rate - self._in_chs = in_chans # Stem if not fix_stem: stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) - self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) - self._in_chs = stem_size # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG) - self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features - self._in_chs = builder.in_chs + head_chs = builder.in_chs # Head + Pooling - self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type) + self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type) self.bn2 = norm_layer(self.num_features, **norm_kwargs) self.act2 = act_layer(inplace=True) - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - - # Classifier - self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes) + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) efficientnet_init_weights(self) @@ -369,7 +365,7 @@ class EfficientNet(nn.Module): layers = [self.conv_stem, self.bn1, self.act1] layers.extend(self.blocks) layers.extend([self.conv_head, self.bn2, self.act2, self.global_pool]) - layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + layers.extend([nn.Dropout(self.drop_rate), self.classifier]) return nn.Sequential(*layers) def get_classifier(self): @@ -377,12 +373,8 @@ class EfficientNet(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - if num_classes: - num_features = self.num_features * self.global_pool.feat_mult() - self.classifier = nn.Linear(num_features, num_classes) - else: - self.classifier = nn.Identity() + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) def forward_features(self, x): x = self.conv_stem(x) @@ -397,7 +389,6 @@ class EfficientNet(nn.Module): def forward(self, x): x = self.forward_features(x) x = self.global_pool(x) - x = x.flatten(1) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) return self.classifier(x) @@ -417,24 +408,21 @@ class EfficientNetFeatures(nn.Module): super(EfficientNetFeatures, self).__init__() norm_kwargs = norm_kwargs or {} self.drop_rate = drop_rate - self._in_chs = in_chans # Stem if not fix_stem: stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) - self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) - self._in_chs = stem_size # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) - self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = FeatureInfo(builder.features, out_indices) self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} - self._in_chs = builder.in_chs efficientnet_init_weights(self) diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index 8e7eb99f..3782c500 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d, get_padding +from .layers import create_classifier, get_padding from .registry import register_model __all__ = ['Xception65'] @@ -192,16 +192,14 @@ class Xception65(nn.Module): dict(num_chs=2048, reduction=32, module='act5'), ] - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) def get_classifier(self): return self.fc def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) def forward_features(self, x): # Entry flow @@ -242,7 +240,7 @@ class Xception65(nn.Module): def forward(self, x): x = self.forward_features(x) - x = self.global_pool(x).flatten(1) + x = self.global_pool(x) if self.drop_rate: F.dropout(x, self.drop_rate, training=self.training) x = self.fc(x) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 20247f49..60888205 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -187,10 +187,13 @@ def adapt_model_from_string(parent_module, model_string): affine=old_module.affine, track_running_stats=True) set_layer(new_module, n, new_bn) if isinstance(old_module, nn.Linear): + # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? + num_features = state_dict[n + '.weight'][1] new_fc = nn.Linear( - in_features=state_dict[n + '.weight'][1], out_features=old_module.out_features, - bias=old_module.bias is not None) + in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) set_layer(new_module, n, new_fc) + if hasattr(new_module, 'num_features'): + new_module.num_features = num_features new_module.eval() parent_module.eval() diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 61c051b1..ad865887 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .features import FeatureInfo from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d +from .layers import create_classifier from .registry import register_model from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE @@ -553,8 +553,8 @@ class HighResolutionNet(nn.Module): # Classification Head self.num_features = 2048 self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head(pre_stage_channels) - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) elif head == 'incre': self.num_features = 2048 self.incre_modules, _, _ = self._make_head(pre_stage_channels, True) @@ -685,12 +685,8 @@ class HighResolutionNet(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - num_features = self.num_features * self.global_pool.feat_mult() - if num_classes: - self.classifier = nn.Linear(num_features, num_classes) - else: - self.classifier = nn.Identity() + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) def stages(self, x) -> List[torch.Tensor]: x = self.layer1(x) @@ -726,7 +722,7 @@ class HighResolutionNet(nn.Module): def forward(self, x): x = self.forward_features(x) - x = self.global_pool(x).flatten(1) + x = self.global_pool(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.classifier(x) diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index c438898e..a5efa330 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d +from .layers import create_classifier from .registry import register_model __all__ = ['InceptionResnetV2'] @@ -296,21 +296,14 @@ class InceptionResnetV2(nn.Module): self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1) self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')] - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - # NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC - self.classif = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) def get_classifier(self): return self.classif def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - if num_classes: - num_features = self.num_features * self.global_pool.feat_mult() - self.classif = nn.Linear(num_features, num_classes) - else: - self.classif = nn.Identity() + self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) def forward_features(self, x): x = self.conv2d_1a(x) @@ -332,7 +325,7 @@ class InceptionResnetV2(nn.Module): def forward(self, x): x = self.forward_features(x) - x = self.global_pool(x).flatten(1) + x = self.global_pool(x) if self.drop_rate > 0: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.classif(x) diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index fd7852bd..6634d4b3 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import build_model_with_cfg from .registry import register_model -from .layers import trunc_normal_, SelectAdaptivePool2d +from .layers import trunc_normal_, create_classifier def _cfg(url='', **kwargs): @@ -326,8 +326,7 @@ class InceptionV3(nn.Module): ] self.num_features = 2048 - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.fc = nn.Linear(2048, num_classes) + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): @@ -389,16 +388,12 @@ class InceptionV3(nn.Module): return self.fc def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - if self.num_classes > 0: - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) - else: - self.fc = nn.Identity() + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) def forward(self, x): x = self.forward_features(x) - x = self.global_pool(x).flatten(1) + x = self.global_pool(x) if self.drop_rate > 0: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.fc(x) @@ -421,7 +416,7 @@ class InceptionV3Aux(InceptionV3): def forward(self, x): x, aux = self.forward_features(x) - x = self.global_pool(x).flatten(1) + x = self.global_pool(x) if self.drop_rate > 0: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.fc(x) diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index d74354bd..40a0f291 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d +from .layers import create_classifier from .registry import register_model __all__ = ['InceptionV4'] @@ -279,27 +279,23 @@ class InceptionV4(nn.Module): dict(num_chs=1024, reduction=16, module='features.17'), dict(num_chs=1536, reduction=32, module='features.21'), ] - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) def get_classifier(self): return self.last_linear def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - if num_classes: - num_features = self.num_features * self.global_pool.feat_mult() - self.last_linear = nn.Linear(num_features, num_classes) - else: - self.last_linear = nn.Identity() + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) def forward_features(self, x): return self.features(x) def forward(self, x): x = self.forward_features(x) - x = self.global_pool(x).flatten(1) + x = self.global_pool(x) if self.drop_rate > 0: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.last_linear(x) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index e8efa3dc..4d5f8a69 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -3,7 +3,7 @@ from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d from .anti_aliasing import AntiAliasDownsampleLayer from .blur_pool import BlurPool2d -from .classifier import ClassifierHead +from .classifier import ClassifierHead, create_classifier from .cond_conv2d import CondConv2d, get_condconv_initializer from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ set_layer_config diff --git a/timm/models/layers/adaptive_avgmax_pool.py b/timm/models/layers/adaptive_avgmax_pool.py index c3d823e1..482c0c01 100644 --- a/timm/models/layers/adaptive_avgmax_pool.py +++ b/timm/models/layers/adaptive_avgmax_pool.py @@ -72,19 +72,23 @@ class SelectAdaptivePool2d(nn.Module): """ def __init__(self, output_size=1, pool_type='avg', flatten=False): super(SelectAdaptivePool2d, self).__init__() - self.output_size = output_size - self.pool_type = pool_type + self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing self.flatten = flatten - if pool_type == 'avgmax': + if pool_type == '': + self.pool = nn.Identity() # pass through + elif pool_type == 'avg': + self.pool = nn.AdaptiveAvgPool2d(output_size) + elif pool_type == 'avgmax': self.pool = AdaptiveAvgMaxPool2d(output_size) elif pool_type == 'catavgmax': self.pool = AdaptiveCatAvgMaxPool2d(output_size) elif pool_type == 'max': self.pool = nn.AdaptiveMaxPool2d(output_size) else: - if pool_type != 'avg': - assert False, 'Invalid pool type: %s' % pool_type - self.pool = nn.AdaptiveAvgPool2d(output_size) + assert False, 'Invalid pool type: %s' % pool_type + + def is_identity(self): + return self.pool_type == '' def forward(self, x): x = self.pool(x) @@ -97,5 +101,6 @@ class SelectAdaptivePool2d(nn.Module): def __repr__(self): return self.__class__.__name__ + ' (' \ - + 'output_size=' + str(self.output_size) \ - + ', pool_type=' + self.pool_type + ')' + + 'pool_type=' + self.pool_type \ + + ', flatten=' + str(self.flatten) + ')' + diff --git a/timm/models/layers/classifier.py b/timm/models/layers/classifier.py index 29960c44..e9194f05 100644 --- a/timm/models/layers/classifier.py +++ b/timm/models/layers/classifier.py @@ -1,23 +1,40 @@ +""" Classifier head and layer factory + +Hacked together by / Copyright 2020 Ross Wightman +""" from torch import nn as nn from torch.nn import functional as F from .adaptive_avgmax_pool import SelectAdaptivePool2d +def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): + flatten = not use_conv # flatten when we use a Linear layer after pooling + if not pool_type: + assert num_classes == 0 or use_conv,\ + 'Pooling can only be disabled if classifier is also removed or conv classifier is used' + flatten = False # disable flattening if pooling is pass-through (no pooling) + global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten) + num_pooled_features = num_features * global_pool.feat_mult() + if num_classes <= 0: + fc = nn.Identity() # pass-through (no classifier) + elif use_conv: + fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True) + else: + fc = nn.Linear(num_pooled_features, num_classes, bias=True) + return global_pool, fc + + class ClassifierHead(nn.Module): - """Classifier Head w/ configurable global pooling and dropout.""" + """Classifier head w/ configurable global pooling and dropout.""" def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.): super(ClassifierHead, self).__init__() self.drop_rate = drop_rate - self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) - if num_classes > 0: - self.fc = nn.Linear(in_chs * self.global_pool.feat_mult(), num_classes, bias=True) - else: - self.fc = nn.Identity() + self.global_pool, self.fc = create_classifier(in_chs, num_classes, pool_type=pool_type) def forward(self, x): - x = self.global_pool(x).flatten(1) + x = self.global_pool(x) if self.drop_rate: x = F.dropout(x, p=float(self.drop_rate), training=self.training) x = self.fc(x) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 258046bd..e0ad7c95 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -85,30 +85,27 @@ class MobileNetV3(nn.Module): self.num_classes = num_classes self.num_features = num_features self.drop_rate = drop_rate - self._in_chs = in_chans # Stem stem_size = round_channels(stem_size, channel_multiplier) - self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) - self._in_chs = stem_size # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG) - self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features - self._in_chs = builder.in_chs + head_chs = builder.in_chs # Head + Pooling - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if global_pool else nn.Identity() + num_pooled_chs = head_chs * self.global_pool.feat_mult() + self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) self.act2 = act_layer(inplace=True) - - # Classifier - self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes) + self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() efficientnet_init_weights(self) @@ -123,13 +120,10 @@ class MobileNetV3(nn.Module): return self.classifier def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - if num_classes: - num_features = self.num_features * self.global_pool.feat_mult() - self.classifier = nn.Linear(num_features, num_classes) - else: - self.classifier = nn.Identity() + # cannot meaningfully change pooling of efficient head after creation + assert global_pool == self.global_pool.pool_type + self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.conv_stem(x) @@ -142,8 +136,7 @@ class MobileNetV3(nn.Module): return x def forward(self, x): - x = self.forward_features(x) - x = x.flatten(1) + x = self.forward_features(x).flatten(1) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) return self.classifier(x) @@ -163,23 +156,20 @@ class MobileNetV3Features(nn.Module): super(MobileNetV3Features, self).__init__() norm_kwargs = norm_kwargs or {} self.drop_rate = drop_rate - self._in_chs = in_chans # Stem stem_size = round_channels(stem_size, channel_multiplier) - self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) - self._in_chs = stem_size # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) - self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = FeatureInfo(builder.features, out_indices) self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} - self._in_chs = builder.in_chs efficientnet_init_weights(self) diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index d682b46b..18b3725f 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d +from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier from .registry import register_model __all__ = ['NASNetALarge'] @@ -496,20 +496,16 @@ class NASNetALarge(nn.Module): dict(num_chs=4032, reduction=32, module='act'), ] - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) def get_classifier(self): return self.last_linear def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - if num_classes: - num_features = self.num_features * self.global_pool.feat_mult() - self.last_linear = nn.Linear(num_features, num_classes) - else: - self.last_linear = nn.Identity() + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) def forward_features(self, x): x_conv0 = self.conv0(x) @@ -544,7 +540,7 @@ class NASNetALarge(nn.Module): def forward(self, x): x = self.forward_features(x) - x = self.global_pool(x).flatten(1) + x = self.global_pool(x) if self.drop_rate > 0: x = F.dropout(x, self.drop_rate, training=self.training) x = self.last_linear(x) diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 5a283ba9..5f1e177f 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -12,7 +12,7 @@ import torch.nn as nn import torch.nn.functional as F from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d +from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier from .registry import register_model __all__ = ['PNASNet5Large'] @@ -291,20 +291,16 @@ class PNASNet5Large(nn.Module): dict(num_chs=4320, reduction=32, module='act'), ] - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) def get_classifier(self): return self.last_linear def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - if num_classes: - num_features = self.num_features * self.global_pool.feat_mult() - self.last_linear = nn.Linear(num_features, num_classes) - else: - self.last_linear = nn.Identity() + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) def forward_features(self, x): x_conv_0 = self.conv_0(x) @@ -327,7 +323,7 @@ class PNASNet5Large(nn.Module): def forward(self, x): x = self.forward_features(x) - x = self.global_pool(x).flatten(1) + x = self.global_pool(x) if self.drop_rate > 0: x = F.dropout(x, self.drop_rate, training=self.training) x = self.last_linear(x) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 37b5e82e..e2799d82 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -15,7 +15,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d +from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, create_attn, create_classifier from .registry import register_model __all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this @@ -542,9 +542,8 @@ class ResNet(nn.Module): self.feature_info.extend(stage_feature_info) # Head (Pooling and Classifier) - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_features = 512 * block.expansion - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): @@ -561,13 +560,8 @@ class ResNet(nn.Module): return self.fc def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - if num_classes: - num_features = self.num_features * self.global_pool.feat_mult() - self.fc = nn.Linear(num_features, num_classes) - else: - self.fc = nn.Identity() + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) def forward_features(self, x): x = self.conv1(x) @@ -583,7 +577,7 @@ class ResNet(nn.Module): def forward(self, x): x = self.forward_features(x) - x = self.global_pool(x).flatten(1) + x = self.global_pool(x) if self.drop_rate: x = F.dropout(x, p=float(self.drop_rate), training=self.training) x = self.fc(x) diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 71b8cb62..03e6ee02 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -168,6 +168,7 @@ class ReXNetV1(nn.Module): initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, use_se=True, se_rd=12, ch_div=1, drop_rate=0.2, feature_location='bottleneck'): super(ReXNetV1, self).__init__() + self.drop_rate = drop_rate assert output_stride == 32 # FIXME support dilation stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32 diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 6b541e95..815aec06 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d +from .layers import create_classifier from .registry import register_model __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this @@ -165,8 +165,7 @@ class SelecSLS(nn.Module): self.num_features = cfg['num_features'] self.feature_info = cfg['feature_info'] - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): @@ -179,13 +178,8 @@ class SelecSLS(nn.Module): return self.fc def reset_classifier(self, num_classes, global_pool='avg'): - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes - if num_classes: - num_features = self.num_features * self.global_pool.feat_mult() - self.fc = nn.Linear(num_features, num_classes) - else: - self.fc = nn.Identity() + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) def forward_features(self, x): x = self.stem(x) @@ -195,7 +189,7 @@ class SelecSLS(nn.Module): def forward(self, x): x = self.forward_features(x) - x = self.global_pool(x).flatten(1) + x = self.global_pool(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.fc(x) diff --git a/timm/models/senet.py b/timm/models/senet.py index b0cf8de2..96228224 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d +from .layers import create_classifier from .registry import register_model __all__ = ['SENet'] @@ -345,8 +345,8 @@ class SENet(nn.Module): ) self.feature_info += [dict(num_chs=512 * block.expansion, reduction=32, module='layer4')] self.num_features = 512 * block.expansion - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.last_linear = nn.Linear(self.num_features, num_classes) + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) for m in self.modules(): _weight_init(m) @@ -374,12 +374,8 @@ class SENet(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - if num_classes: - num_features = self.num_features * self.global_pool.feat_mult() - self.last_linear = nn.Linear(num_features, num_classes) - else: - self.last_linear = nn.Identity() + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) def forward_features(self, x): x = self.layer0(x) @@ -391,7 +387,7 @@ class SENet(nn.Module): return x def logits(self, x): - x = self.global_pool(x).flatten(1) + x = self.global_pool(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.last_linear(x) diff --git a/timm/models/xception.py b/timm/models/xception.py index db506828..a61548dc 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -26,7 +26,7 @@ import torch.nn as nn import torch.nn.functional as F from .helpers import build_model_with_cfg -from .layers import SelectAdaptivePool2d +from .layers import create_classifier from .registry import register_model __all__ = ['Xception'] @@ -162,8 +162,7 @@ class Xception(nn.Module): dict(num_chs=2048, reduction=32, module='act4'), ] - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) # #------- init weights -------- for m in self.modules(): @@ -178,12 +177,7 @@ class Xception(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - if num_classes: - num_features = self.num_features * self.global_pool.feat_mult() - self.fc = nn.Linear(num_features, num_classes) - else: - self.fc = nn.Identity() + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) def forward_features(self, x): x = self.conv1(x) @@ -218,7 +212,7 @@ class Xception(nn.Module): def forward(self, x): x = self.forward_features(x) - x = self.global_pool(x).flatten(1) + x = self.global_pool(x) if self.drop_rate: F.dropout(x, self.drop_rate, training=self.training) x = self.fc(x)