diff --git a/timm/models/_builder.py b/timm/models/_builder.py index e6150b9a..c1ad5c2d 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -7,7 +7,7 @@ from typing import Optional, Dict, Callable, Any, Tuple from torch import nn as nn from torch.hub import load_state_dict_from_url -from timm.models._features import FeatureListNet, FeatureHookNet +from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet from timm.models._features_fx import FeatureGraphNet from timm.models._helpers import load_state_dict from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf @@ -428,8 +428,12 @@ def build_model_with_cfg( feature_cls = feature_cls.lower() if 'hook' in feature_cls: feature_cls = FeatureHookNet + elif feature_cls == 'dict': + feature_cls = FeatureDictNet elif feature_cls == 'fx': feature_cls = FeatureGraphNet + elif feature_cls == 'getter': + feature_cls = FeatureGetterNet else: assert False, f'Unknown feature class {feature_cls}' model = feature_cls(model, **feature_cfg) diff --git a/timm/models/_features.py b/timm/models/_features.py index 7ef51809..cc4068d4 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -11,7 +11,7 @@ Hacked together by / Copyright 2020 Ross Wightman from collections import OrderedDict, defaultdict from copy import deepcopy from functools import partial -from typing import Dict, List, Sequence, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -23,9 +23,24 @@ from timm.layers import Format __all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet'] +def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: + if isinstance(x, int): + # if indices is an int, take last N features + return tuple(range(-x, 0)) + return tuple(x) + + +OutIndicesT = Union[int, Tuple[int, ...]] + + class FeatureInfo: - def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): + def __init__( + self, + feature_info: List[Dict], + out_indices: OutIndicesT, + ): + out_indices = _out_indices_as_tuple(out_indices) prev_reduction = 1 for i, fi in enumerate(feature_info): # sanity check the mandatory fields, there may be additional fields depending on the model @@ -37,14 +52,15 @@ class FeatureInfo: self.out_indices = out_indices self.info = feature_info - def from_other(self, out_indices: Tuple[int]): + def from_other(self, out_indices: OutIndicesT): + out_indices = _out_indices_as_tuple(out_indices) return FeatureInfo(deepcopy(self.info), out_indices) - def get(self, key, idx=None): + def get(self, key: str, idx: Optional[Union[int, List[int]]] = None): """ Get value by key at specified index (indices) if idx == None, returns value for key at each output index if idx is an integer, return value for that feature module index (ignoring output indices) - if idx is a list/tupple, return value for each module index (ignoring output indices) + if idx is a list/tuple, return value for each module index (ignoring output indices) """ if idx is None: return [self.info[i][key] for i in self.out_indices] @@ -53,7 +69,7 @@ class FeatureInfo: else: return self.info[idx][key] - def get_dicts(self, keys=None, idx=None): + def get_dicts(self, keys: Optional[List[str]] = None, idx: Optional[Union[int, List[int]]] = None): """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None) """ if idx is None: @@ -66,17 +82,17 @@ class FeatureInfo: else: return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} - def channels(self, idx=None): + def channels(self, idx: Optional[Union[int, List[int]]] = None): """ feature channels accessor """ return self.get('num_chs', idx) - def reduction(self, idx=None): + def reduction(self, idx: Optional[Union[int, List[int]]] = None): """ feature reduction (output stride) accessor """ return self.get('reduction', idx) - def module_name(self, idx=None): + def module_name(self, idx: Optional[Union[int, List[int]]] = None): """ feature module name accessor """ return self.get('module', idx) @@ -146,7 +162,7 @@ def _module_list(module, flatten_sequential=False): return ml -def _get_feature_info(net, out_indices): +def _get_feature_info(net, out_indices: OutIndicesT): feature_info = getattr(net, 'feature_info') if isinstance(feature_info, FeatureInfo): return feature_info.from_other(out_indices) @@ -182,7 +198,7 @@ class FeatureDictNet(nn.ModuleDict): def __init__( self, model: nn.Module, - out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + out_indices: OutIndicesT = (0, 1, 2, 3, 4), out_map: Sequence[Union[int, str]] = None, output_fmt: str = 'NCHW', feature_concat: bool = False, @@ -257,7 +273,7 @@ class FeatureListNet(FeatureDictNet): def __init__( self, model: nn.Module, - out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + out_indices: OutIndicesT = (0, 1, 2, 3, 4), output_fmt: str = 'NCHW', feature_concat: bool = False, flatten_sequential: bool = False, @@ -298,8 +314,8 @@ class FeatureHookNet(nn.ModuleDict): def __init__( self, model: nn.Module, - out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), - out_map: Sequence[Union[int, str]] = None, + out_indices: OutIndicesT = (0, 1, 2, 3, 4), + out_map: Optional[Sequence[Union[int, str]]] = None, return_dict: bool = False, output_fmt: str = 'NCHW', no_rewrite: bool = False, @@ -366,3 +382,44 @@ class FeatureHookNet(nn.ModuleDict): x = module(x) out = self.hooks.get_output(x.device) return out if self.return_dict else list(out.values()) + + +class FeatureGetterNet(nn.ModuleDict): + """ FeatureGetterNet + + Wrap models with a feature getter method, like 'get_intermediate_layers' + + """ + def __init__( + self, + model: nn.Module, + out_indices: OutIndicesT = 4, + out_map: Optional[Sequence[Union[int, str]]] = None, + return_dict: bool = False, + output_fmt: str = 'NCHW', + ): + super().__init__() + self.model = model + self.feature_info = _get_feature_info(model, out_indices) + self.out_indices = out_indices + self.out_map = out_map + self.return_dict = return_dict + self.output_fmt = output_fmt + + def forward(self, *args, **kwargs): + """ + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + reshape: bool = False, + return_prefix_tokens: bool = False, + norm: bool = False, + """ + out = self.model.get_intermediate_layers( + *args, + n=self.out_indices, + reshape=True, + **kwargs, + ) + return out diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index c48c13b7..e67d1f25 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -1,7 +1,7 @@ """ PyTorch FX Based Feature Extraction Helpers Using https://pytorch.org/vision/stable/feature_extraction.html """ -from typing import Callable, List, Dict, Union, Type +from typing import Callable, Dict, List, Optional, Union, Tuple, Type import torch from torch import nn @@ -103,7 +103,12 @@ def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str class FeatureGraphNet(nn.Module): """ A FX Graph based feature extractor that works with the model feature_info metadata """ - def __init__(self, model, out_indices, out_map=None): + def __init__( + self, + model: nn.Module, + out_indices: Tuple[int, ...], + out_map: Optional[Dict] = None, + ): super().__init__() assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' self.feature_info = _get_feature_info(model, out_indices) @@ -129,7 +134,12 @@ class GraphExtractNet(nn.Module): return_nodes: node names to return features from (dict or list) squeeze_out: if only one output, and output in list format, flatten to single tensor """ - def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True): + def __init__( + self, + model: nn.Module, + return_nodes: Union[Dict[str, str], List[str]], + squeeze_out: bool = True, + ): super().__init__() self.squeeze_out = squeeze_out self.graph_module = create_feature_extractor(model, return_nodes) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index ce65ee4a..b57104ac 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -473,6 +473,7 @@ class VisionTransformer(nn.Module): self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg) self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False + self.feature_info = [] embed_args = {} if dynamic_img_size: @@ -520,6 +521,8 @@ class VisionTransformer(nn.Module): mlp_layer=mlp_layer, ) for i in range(depth)]) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)] self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() # Classifier Head @@ -1770,9 +1773,7 @@ default_cfgs = generate_default_cfgs(default_cfgs) def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) -> VisionTransformer: - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - + out_indices = kwargs.pop('out_indices', 3) if 'flexi' in variant: # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. @@ -1791,6 +1792,7 @@ def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) pretrained, pretrained_filter_fn=_filter_fn, pretrained_strict=strict, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, )