Exploring vit features_only using get_intermediate_layers() as per #2131

This commit is contained in:
Ross Wightman 2024-04-07 11:24:45 -07:00
parent 59b3d86c1d
commit 5fdc0b4e93
4 changed files with 94 additions and 21 deletions

View File

@ -7,7 +7,7 @@ from typing import Optional, Dict, Callable, Any, Tuple
from torch import nn as nn
from torch.hub import load_state_dict_from_url
from timm.models._features import FeatureListNet, FeatureHookNet
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
@ -428,8 +428,12 @@ def build_model_with_cfg(
feature_cls = feature_cls.lower()
if 'hook' in feature_cls:
feature_cls = FeatureHookNet
elif feature_cls == 'dict':
feature_cls = FeatureDictNet
elif feature_cls == 'fx':
feature_cls = FeatureGraphNet
elif feature_cls == 'getter':
feature_cls = FeatureGetterNet
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)

View File

@ -11,7 +11,7 @@ Hacked together by / Copyright 2020 Ross Wightman
from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
from typing import Dict, List, Sequence, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
@ -23,9 +23,24 @@ from timm.layers import Format
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
if isinstance(x, int):
# if indices is an int, take last N features
return tuple(range(-x, 0))
return tuple(x)
OutIndicesT = Union[int, Tuple[int, ...]]
class FeatureInfo:
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
def __init__(
self,
feature_info: List[Dict],
out_indices: OutIndicesT,
):
out_indices = _out_indices_as_tuple(out_indices)
prev_reduction = 1
for i, fi in enumerate(feature_info):
# sanity check the mandatory fields, there may be additional fields depending on the model
@ -37,14 +52,15 @@ class FeatureInfo:
self.out_indices = out_indices
self.info = feature_info
def from_other(self, out_indices: Tuple[int]):
def from_other(self, out_indices: OutIndicesT):
out_indices = _out_indices_as_tuple(out_indices)
return FeatureInfo(deepcopy(self.info), out_indices)
def get(self, key, idx=None):
def get(self, key: str, idx: Optional[Union[int, List[int]]] = None):
""" Get value by key at specified index (indices)
if idx == None, returns value for key at each output index
if idx is an integer, return value for that feature module index (ignoring output indices)
if idx is a list/tupple, return value for each module index (ignoring output indices)
if idx is a list/tuple, return value for each module index (ignoring output indices)
"""
if idx is None:
return [self.info[i][key] for i in self.out_indices]
@ -53,7 +69,7 @@ class FeatureInfo:
else:
return self.info[idx][key]
def get_dicts(self, keys=None, idx=None):
def get_dicts(self, keys: Optional[List[str]] = None, idx: Optional[Union[int, List[int]]] = None):
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
"""
if idx is None:
@ -66,17 +82,17 @@ class FeatureInfo:
else:
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
def channels(self, idx=None):
def channels(self, idx: Optional[Union[int, List[int]]] = None):
""" feature channels accessor
"""
return self.get('num_chs', idx)
def reduction(self, idx=None):
def reduction(self, idx: Optional[Union[int, List[int]]] = None):
""" feature reduction (output stride) accessor
"""
return self.get('reduction', idx)
def module_name(self, idx=None):
def module_name(self, idx: Optional[Union[int, List[int]]] = None):
""" feature module name accessor
"""
return self.get('module', idx)
@ -146,7 +162,7 @@ def _module_list(module, flatten_sequential=False):
return ml
def _get_feature_info(net, out_indices):
def _get_feature_info(net, out_indices: OutIndicesT):
feature_info = getattr(net, 'feature_info')
if isinstance(feature_info, FeatureInfo):
return feature_info.from_other(out_indices)
@ -182,7 +198,7 @@ class FeatureDictNet(nn.ModuleDict):
def __init__(
self,
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
out_indices: OutIndicesT = (0, 1, 2, 3, 4),
out_map: Sequence[Union[int, str]] = None,
output_fmt: str = 'NCHW',
feature_concat: bool = False,
@ -257,7 +273,7 @@ class FeatureListNet(FeatureDictNet):
def __init__(
self,
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
out_indices: OutIndicesT = (0, 1, 2, 3, 4),
output_fmt: str = 'NCHW',
feature_concat: bool = False,
flatten_sequential: bool = False,
@ -298,8 +314,8 @@ class FeatureHookNet(nn.ModuleDict):
def __init__(
self,
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
out_map: Sequence[Union[int, str]] = None,
out_indices: OutIndicesT = (0, 1, 2, 3, 4),
out_map: Optional[Sequence[Union[int, str]]] = None,
return_dict: bool = False,
output_fmt: str = 'NCHW',
no_rewrite: bool = False,
@ -366,3 +382,44 @@ class FeatureHookNet(nn.ModuleDict):
x = module(x)
out = self.hooks.get_output(x.device)
return out if self.return_dict else list(out.values())
class FeatureGetterNet(nn.ModuleDict):
""" FeatureGetterNet
Wrap models with a feature getter method, like 'get_intermediate_layers'
"""
def __init__(
self,
model: nn.Module,
out_indices: OutIndicesT = 4,
out_map: Optional[Sequence[Union[int, str]]] = None,
return_dict: bool = False,
output_fmt: str = 'NCHW',
):
super().__init__()
self.model = model
self.feature_info = _get_feature_info(model, out_indices)
self.out_indices = out_indices
self.out_map = out_map
self.return_dict = return_dict
self.output_fmt = output_fmt
def forward(self, *args, **kwargs):
"""
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
reshape: bool = False,
return_prefix_tokens: bool = False,
norm: bool = False,
"""
out = self.model.get_intermediate_layers(
*args,
n=self.out_indices,
reshape=True,
**kwargs,
)
return out

View File

@ -1,7 +1,7 @@
""" PyTorch FX Based Feature Extraction Helpers
Using https://pytorch.org/vision/stable/feature_extraction.html
"""
from typing import Callable, List, Dict, Union, Type
from typing import Callable, Dict, List, Optional, Union, Tuple, Type
import torch
from torch import nn
@ -103,7 +103,12 @@ def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str
class FeatureGraphNet(nn.Module):
""" A FX Graph based feature extractor that works with the model feature_info metadata
"""
def __init__(self, model, out_indices, out_map=None):
def __init__(
self,
model: nn.Module,
out_indices: Tuple[int, ...],
out_map: Optional[Dict] = None,
):
super().__init__()
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
self.feature_info = _get_feature_info(model, out_indices)
@ -129,7 +134,12 @@ class GraphExtractNet(nn.Module):
return_nodes: node names to return features from (dict or list)
squeeze_out: if only one output, and output in list format, flatten to single tensor
"""
def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
def __init__(
self,
model: nn.Module,
return_nodes: Union[Dict[str, str], List[str]],
squeeze_out: bool = True,
):
super().__init__()
self.squeeze_out = squeeze_out
self.graph_module = create_feature_extractor(model, return_nodes)

View File

@ -473,6 +473,7 @@ class VisionTransformer(nn.Module):
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False
self.feature_info = []
embed_args = {}
if dynamic_img_size:
@ -520,6 +521,8 @@ class VisionTransformer(nn.Module):
mlp_layer=mlp_layer,
)
for i in range(depth)])
self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)]
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
# Classifier Head
@ -1770,9 +1773,7 @@ default_cfgs = generate_default_cfgs(default_cfgs)
def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) -> VisionTransformer:
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
out_indices = kwargs.pop('out_indices', 3)
if 'flexi' in variant:
# FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
# interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
@ -1791,6 +1792,7 @@ def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs)
pretrained,
pretrained_filter_fn=_filter_fn,
pretrained_strict=strict,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)