mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Exploring vit features_only using get_intermediate_layers() as per #2131
This commit is contained in:
parent
59b3d86c1d
commit
5fdc0b4e93
@ -7,7 +7,7 @@ from typing import Optional, Dict, Callable, Any, Tuple
|
||||
from torch import nn as nn
|
||||
from torch.hub import load_state_dict_from_url
|
||||
|
||||
from timm.models._features import FeatureListNet, FeatureHookNet
|
||||
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
|
||||
from timm.models._features_fx import FeatureGraphNet
|
||||
from timm.models._helpers import load_state_dict
|
||||
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
|
||||
@ -428,8 +428,12 @@ def build_model_with_cfg(
|
||||
feature_cls = feature_cls.lower()
|
||||
if 'hook' in feature_cls:
|
||||
feature_cls = FeatureHookNet
|
||||
elif feature_cls == 'dict':
|
||||
feature_cls = FeatureDictNet
|
||||
elif feature_cls == 'fx':
|
||||
feature_cls = FeatureGraphNet
|
||||
elif feature_cls == 'getter':
|
||||
feature_cls = FeatureGetterNet
|
||||
else:
|
||||
assert False, f'Unknown feature class {feature_cls}'
|
||||
model = feature_cls(model, **feature_cfg)
|
||||
|
@ -11,7 +11,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Dict, List, Sequence, Tuple, Union
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -23,9 +23,24 @@ from timm.layers import Format
|
||||
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
||||
|
||||
|
||||
def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
|
||||
if isinstance(x, int):
|
||||
# if indices is an int, take last N features
|
||||
return tuple(range(-x, 0))
|
||||
return tuple(x)
|
||||
|
||||
|
||||
OutIndicesT = Union[int, Tuple[int, ...]]
|
||||
|
||||
|
||||
class FeatureInfo:
|
||||
|
||||
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
||||
def __init__(
|
||||
self,
|
||||
feature_info: List[Dict],
|
||||
out_indices: OutIndicesT,
|
||||
):
|
||||
out_indices = _out_indices_as_tuple(out_indices)
|
||||
prev_reduction = 1
|
||||
for i, fi in enumerate(feature_info):
|
||||
# sanity check the mandatory fields, there may be additional fields depending on the model
|
||||
@ -37,14 +52,15 @@ class FeatureInfo:
|
||||
self.out_indices = out_indices
|
||||
self.info = feature_info
|
||||
|
||||
def from_other(self, out_indices: Tuple[int]):
|
||||
def from_other(self, out_indices: OutIndicesT):
|
||||
out_indices = _out_indices_as_tuple(out_indices)
|
||||
return FeatureInfo(deepcopy(self.info), out_indices)
|
||||
|
||||
def get(self, key, idx=None):
|
||||
def get(self, key: str, idx: Optional[Union[int, List[int]]] = None):
|
||||
""" Get value by key at specified index (indices)
|
||||
if idx == None, returns value for key at each output index
|
||||
if idx is an integer, return value for that feature module index (ignoring output indices)
|
||||
if idx is a list/tupple, return value for each module index (ignoring output indices)
|
||||
if idx is a list/tuple, return value for each module index (ignoring output indices)
|
||||
"""
|
||||
if idx is None:
|
||||
return [self.info[i][key] for i in self.out_indices]
|
||||
@ -53,7 +69,7 @@ class FeatureInfo:
|
||||
else:
|
||||
return self.info[idx][key]
|
||||
|
||||
def get_dicts(self, keys=None, idx=None):
|
||||
def get_dicts(self, keys: Optional[List[str]] = None, idx: Optional[Union[int, List[int]]] = None):
|
||||
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
|
||||
"""
|
||||
if idx is None:
|
||||
@ -66,17 +82,17 @@ class FeatureInfo:
|
||||
else:
|
||||
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
|
||||
|
||||
def channels(self, idx=None):
|
||||
def channels(self, idx: Optional[Union[int, List[int]]] = None):
|
||||
""" feature channels accessor
|
||||
"""
|
||||
return self.get('num_chs', idx)
|
||||
|
||||
def reduction(self, idx=None):
|
||||
def reduction(self, idx: Optional[Union[int, List[int]]] = None):
|
||||
""" feature reduction (output stride) accessor
|
||||
"""
|
||||
return self.get('reduction', idx)
|
||||
|
||||
def module_name(self, idx=None):
|
||||
def module_name(self, idx: Optional[Union[int, List[int]]] = None):
|
||||
""" feature module name accessor
|
||||
"""
|
||||
return self.get('module', idx)
|
||||
@ -146,7 +162,7 @@ def _module_list(module, flatten_sequential=False):
|
||||
return ml
|
||||
|
||||
|
||||
def _get_feature_info(net, out_indices):
|
||||
def _get_feature_info(net, out_indices: OutIndicesT):
|
||||
feature_info = getattr(net, 'feature_info')
|
||||
if isinstance(feature_info, FeatureInfo):
|
||||
return feature_info.from_other(out_indices)
|
||||
@ -182,7 +198,7 @@ class FeatureDictNet(nn.ModuleDict):
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
out_indices: OutIndicesT = (0, 1, 2, 3, 4),
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
output_fmt: str = 'NCHW',
|
||||
feature_concat: bool = False,
|
||||
@ -257,7 +273,7 @@ class FeatureListNet(FeatureDictNet):
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
out_indices: OutIndicesT = (0, 1, 2, 3, 4),
|
||||
output_fmt: str = 'NCHW',
|
||||
feature_concat: bool = False,
|
||||
flatten_sequential: bool = False,
|
||||
@ -298,8 +314,8 @@ class FeatureHookNet(nn.ModuleDict):
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
out_indices: OutIndicesT = (0, 1, 2, 3, 4),
|
||||
out_map: Optional[Sequence[Union[int, str]]] = None,
|
||||
return_dict: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
no_rewrite: bool = False,
|
||||
@ -366,3 +382,44 @@ class FeatureHookNet(nn.ModuleDict):
|
||||
x = module(x)
|
||||
out = self.hooks.get_output(x.device)
|
||||
return out if self.return_dict else list(out.values())
|
||||
|
||||
|
||||
class FeatureGetterNet(nn.ModuleDict):
|
||||
""" FeatureGetterNet
|
||||
|
||||
Wrap models with a feature getter method, like 'get_intermediate_layers'
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: OutIndicesT = 4,
|
||||
out_map: Optional[Sequence[Union[int, str]]] = None,
|
||||
return_dict: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.out_indices = out_indices
|
||||
self.out_map = out_map
|
||||
self.return_dict = return_dict
|
||||
self.output_fmt = output_fmt
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
def get_intermediate_layers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1,
|
||||
reshape: bool = False,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
"""
|
||||
out = self.model.get_intermediate_layers(
|
||||
*args,
|
||||
n=self.out_indices,
|
||||
reshape=True,
|
||||
**kwargs,
|
||||
)
|
||||
return out
|
||||
|
@ -1,7 +1,7 @@
|
||||
""" PyTorch FX Based Feature Extraction Helpers
|
||||
Using https://pytorch.org/vision/stable/feature_extraction.html
|
||||
"""
|
||||
from typing import Callable, List, Dict, Union, Type
|
||||
from typing import Callable, Dict, List, Optional, Union, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -103,7 +103,12 @@ def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str
|
||||
class FeatureGraphNet(nn.Module):
|
||||
""" A FX Graph based feature extractor that works with the model feature_info metadata
|
||||
"""
|
||||
def __init__(self, model, out_indices, out_map=None):
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...],
|
||||
out_map: Optional[Dict] = None,
|
||||
):
|
||||
super().__init__()
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
@ -129,7 +134,12 @@ class GraphExtractNet(nn.Module):
|
||||
return_nodes: node names to return features from (dict or list)
|
||||
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
||||
"""
|
||||
def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
return_nodes: Union[Dict[str, str], List[str]],
|
||||
squeeze_out: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.squeeze_out = squeeze_out
|
||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||
|
@ -473,6 +473,7 @@ class VisionTransformer(nn.Module):
|
||||
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
|
||||
self.dynamic_img_size = dynamic_img_size
|
||||
self.grad_checkpointing = False
|
||||
self.feature_info = []
|
||||
|
||||
embed_args = {}
|
||||
if dynamic_img_size:
|
||||
@ -520,6 +521,8 @@ class VisionTransformer(nn.Module):
|
||||
mlp_layer=mlp_layer,
|
||||
)
|
||||
for i in range(depth)])
|
||||
self.feature_info = [
|
||||
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=patch_size) for i in range(depth)]
|
||||
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
||||
|
||||
# Classifier Head
|
||||
@ -1770,9 +1773,7 @@ default_cfgs = generate_default_cfgs(default_cfgs)
|
||||
|
||||
|
||||
def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
if 'flexi' in variant:
|
||||
# FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
|
||||
# interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
|
||||
@ -1791,6 +1792,7 @@ def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs)
|
||||
pretrained,
|
||||
pretrained_filter_fn=_filter_fn,
|
||||
pretrained_strict=strict,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user