diff --git a/tests/test_models.py b/tests/test_models.py index 25d514ea..2158c89d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -49,8 +49,9 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): # models with forward_intermediates() and support for FeatureGetterNet features_only wrapper FEAT_INTER_FILTERS = [ - 'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*', - 'cait_*', 'xcit_*', 'volo_*', 'swin*', 'max*vit_*', 'coatne*t_*' + 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos', + 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', + 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3' ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. @@ -388,7 +389,7 @@ def test_model_forward_features(model_name, batch_size): @pytest.mark.features @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*'])) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_intermediates_features(model_name, batch_size): """Run a single forward pass with each model in feature extraction mode""" @@ -419,7 +420,7 @@ def test_model_forward_intermediates_features(model_name, batch_size): @pytest.mark.features @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*'])) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_intermediates(model_name, batch_size): """Run a single forward pass with each model in feature extraction mode""" diff --git a/timm/models/_registry.py b/timm/models/_registry.py index a129e1af..fde8bac7 100644 --- a/timm/models/_registry.py +++ b/timm/models/_registry.py @@ -184,7 +184,7 @@ def _expand_filter(filter: str): def list_models( filter: Union[str, List[str]] = '', - module: str = '', + module: Union[str, List[str]] = '', pretrained: bool = False, exclude_filters: Union[str, List[str]] = '', name_matches_cfg: bool = False, @@ -217,7 +217,16 @@ def list_models( # FIXME should this be default behaviour? or default to include_tags=True? include_tags = pretrained - all_models: Set[str] = _module_to_models[module] if module else set(_model_entrypoints.keys()) + if not module: + all_models: Set[str] = set(_model_entrypoints.keys()) + else: + if isinstance(module, str): + all_models: Set[str] = _module_to_models[module] + else: + assert isinstance(module, Sequence) + all_models: Set[str] = set() + for m in module: + all_models.update(_module_to_models[m]) all_models = all_models - _deprecated_models.keys() # remove deprecated models from listings if include_tags: diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 6e61d1bf..25f3607a 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -36,7 +36,7 @@ the models and weights open source! Hacked together by / Copyright 2019, Ross Wightman """ from functools import partial -from typing import List +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -49,7 +49,7 @@ from ._builder import build_model_with_cfg, pretrained_cfg_for_features from ._efficientnet_blocks import SqueezeExcite from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT -from ._features import FeatureInfo, FeatureHooks +from ._features import FeatureInfo, FeatureHooks, feature_take_indices from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -118,6 +118,7 @@ class EfficientNet(nn.Module): ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features + self.stage_ends = [f['stage'] for f in self.feature_info] head_chs = builder.in_chs # Head + Pooling @@ -158,6 +159,86 @@ class EfficientNet(nn.Module): self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) + def forward_intermediates( + self, + x: torch.Tensor, + *, + indices: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + extra_blocks: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + if stop_early: + assert intermediates_only, 'Must use intermediates_only for early stopping.' + intermediates = [] + if extra_blocks: + take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) + else: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + take_indices = [self.stage_ends[i] for i in take_indices] + max_index = self.stage_ends[max_index] + # forward pass + feat_idx = 0 # stem is index 0 + x = self.conv_stem(x) + x = self.bn1(x) + if feat_idx in take_indices: + intermediates.append(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index] + for blk in blocks: + feat_idx += 1 + x = blk(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + x = self.conv_head(x) + x = self.bn2(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + extra_blocks: bool = False, + ): + """ Prune layers not required for specified intermediates. + """ + if extra_blocks: + take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) + else: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + max_index = self.stage_ends[max_index] + self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0 + if prune_norm or max_index < len(self.blocks): + self.conv_head = nn.Identity() + self.bn2 = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x): x = self.conv_stem(x) x = self.bn1(x) @@ -272,7 +353,7 @@ def _create_effnet(variant, pretrained=False, **kwargs): model_cls = EfficientNet kwargs_filter = None if kwargs.pop('features_only', False): - if 'feature_cfg' in kwargs: + if 'feature_cfg' in kwargs or 'feature_cls' in kwargs: features_mode = 'cfg' else: kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool') diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 2d197a9d..aba4d354 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -7,7 +7,7 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 Hacked together by / Copyright 2019, Ross Wightman """ from functools import partial -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -20,7 +20,7 @@ from ._builder import build_model_with_cfg, pretrained_cfg_for_features from ._efficientnet_blocks import SqueezeExcite from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT -from ._features import FeatureInfo, FeatureHooks +from ._features import FeatureInfo, FeatureHooks, feature_take_indices from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -109,6 +109,7 @@ class MobileNetV3(nn.Module): ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features + self.stage_ends = [f['stage'] for f in self.feature_info] head_chs = builder.in_chs # Head + Pooling @@ -150,6 +151,84 @@ class MobileNetV3(nn.Module): self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + def forward_intermediates( + self, + x: torch.Tensor, + *, + indices: Union[int, List[int], Tuple[int]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + extra_blocks: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + if stop_early: + assert intermediates_only, 'Must use intermediates_only for early stopping.' + intermediates = [] + if extra_blocks: + take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) + else: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + print(take_indices, self.stage_ends) + take_indices = [self.stage_ends[i] for i in take_indices] + max_index = self.stage_ends[max_index] + # forward pass + feat_idx = 0 # stem is index 0 + x = self.conv_stem(x) + x = self.bn1(x) + if feat_idx in take_indices: + intermediates.append(x) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index] + for blk in blocks: + feat_idx += 1 + x = blk(x) + if feat_idx in take_indices: + intermediates.append(x) + + if intermediates_only: + return intermediates + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int], Tuple[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + extra_blocks: bool = False, + ): + """ Prune layers not required for specified intermediates. + """ + if extra_blocks: + take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) + else: + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) + max_index = self.stage_ends[max_index] + self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0 + if max_index < len(self.blocks): + self.conv_head = nn.Identity() + if prune_head: + self.conv_head = nn.Identity() + self.reset_classifier(0, '') + return take_indices + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.conv_stem(x) x = self.bn1(x) @@ -288,7 +367,7 @@ def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV model_cls = MobileNetV3 kwargs_filter = None if kwargs.pop('features_only', False): - if 'feature_cfg' in kwargs: + if 'feature_cfg' in kwargs or 'feature_cls' in kwargs: features_mode = 'cfg' else: kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')