mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add AttentionExtract helper module
This commit is contained in:
parent
45b7ae8029
commit
07535f408a
@ -80,7 +80,7 @@ from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrai
|
||||
set_pretrained_download_progress, set_pretrained_check_hash
|
||||
from ._factory import create_model, parse_model_name, safe_model_name
|
||||
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
|
||||
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
|
||||
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, get_graph_node_names, \
|
||||
register_notrace_module, is_notrace_module, get_notrace_modules, \
|
||||
register_notrace_function, is_notrace_function, get_notrace_functions
|
||||
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
|
||||
|
@ -158,7 +158,7 @@ class FeatureHooks:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hooks: Sequence[str],
|
||||
hooks: Sequence[Union[str, Dict]],
|
||||
named_modules: dict,
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
default_hook_type: str = 'forward',
|
||||
@ -168,11 +168,13 @@ class FeatureHooks:
|
||||
self._handles = []
|
||||
modules = {k: v for k, v in named_modules}
|
||||
for i, h in enumerate(hooks):
|
||||
hook_name = h['module']
|
||||
hook_name = h if isinstance(h, str) else h['module']
|
||||
m = modules[hook_name]
|
||||
hook_id = out_map[i] if out_map else hook_name
|
||||
hook_fn = partial(self._collect_output_hook, hook_id)
|
||||
hook_type = h.get('hook_type', default_hook_type)
|
||||
hook_type = default_hook_type
|
||||
if isinstance(h, dict):
|
||||
hook_type = h.get('hook_type', default_hook_type)
|
||||
if hook_type == 'forward_pre':
|
||||
handle = m.register_forward_pre_hook(hook_fn)
|
||||
elif hook_type == 'forward':
|
||||
|
@ -9,7 +9,9 @@ from torch import nn
|
||||
from ._features import _get_feature_info, _get_return_layers
|
||||
|
||||
try:
|
||||
# NOTE we wrap torchvision fns to use timm leaf / no trace definitions
|
||||
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
|
||||
from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names
|
||||
has_fx_feature_extraction = True
|
||||
except ImportError:
|
||||
has_fx_feature_extraction = False
|
||||
@ -30,7 +32,7 @@ from timm.layers.norm_act import (
|
||||
|
||||
__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
|
||||
'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
|
||||
'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']
|
||||
'create_feature_extractor', 'get_graph_node_names', 'FeatureGraphNet', 'GraphExtractNet']
|
||||
|
||||
|
||||
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
||||
@ -92,6 +94,13 @@ def get_notrace_functions():
|
||||
return list(_autowrap_functions)
|
||||
|
||||
|
||||
def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]:
|
||||
return _get_graph_node_names(
|
||||
model,
|
||||
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
|
||||
)
|
||||
|
||||
|
||||
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
return _create_feature_extractor(
|
||||
|
@ -1,4 +1,5 @@
|
||||
from .agc import adaptive_clip_grad
|
||||
from .attention_extract import AttentionExtract
|
||||
from .checkpoint_saver import CheckpointSaver
|
||||
from .clip_grad import dispatch_clip_grad
|
||||
from .cuda import ApexScaler, NativeScaler
|
||||
|
79
timm/utils/attention_extract.py
Normal file
79
timm/utils/attention_extract.py
Normal file
@ -0,0 +1,79 @@
|
||||
import fnmatch
|
||||
from collections import OrderedDict
|
||||
from typing import Union, Optional, List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AttentionExtract(torch.nn.Module):
|
||||
# defaults should cover a significant number of timm models with attention maps.
|
||||
default_node_names = ['*attn.softmax']
|
||||
default_module_names = ['*attn_drop']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[torch.nn.Module],
|
||||
names: Optional[List[str]] = None,
|
||||
mode: str = 'eval',
|
||||
method: str = 'fx',
|
||||
hook_type: str = 'forward',
|
||||
):
|
||||
""" Extract attention maps (or other activations) from a model by name.
|
||||
|
||||
Args:
|
||||
model: Instantiated model to extract from.
|
||||
names: List of concrete or wildcard names to extract. Names are nodes for fx and modules for hooks.
|
||||
mode: 'train' or 'eval' model mode.
|
||||
method: 'fx' or 'hook' extraction method.
|
||||
hook_type: 'forward' or 'forward_pre' hooks used.
|
||||
"""
|
||||
super().__init__()
|
||||
assert mode in ('train', 'eval')
|
||||
if mode == 'train':
|
||||
model = model.train()
|
||||
else:
|
||||
model = model.eval()
|
||||
|
||||
assert method in ('fx', 'hook')
|
||||
if method == 'fx':
|
||||
# names are activation node names
|
||||
from timm.models._features_fx import get_graph_node_names, GraphExtractNet
|
||||
|
||||
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
|
||||
matched = []
|
||||
names = names or self.default_node_names
|
||||
for n in names:
|
||||
matched.extend(fnmatch.filter(node_names, n))
|
||||
if not matched:
|
||||
raise RuntimeError(f'No node names found matching {names}.')
|
||||
|
||||
self.model = GraphExtractNet(model, matched)
|
||||
self.hooks = None
|
||||
else:
|
||||
# names are module names
|
||||
assert hook_type in ('forward', 'forward_pre')
|
||||
from timm.models._features import FeatureHooks
|
||||
|
||||
module_names = [n for n, m in model.named_modules()]
|
||||
matched = []
|
||||
names = names or self.default_module_names
|
||||
for n in names:
|
||||
matched.extend(fnmatch.filter(module_names, n))
|
||||
if not matched:
|
||||
raise RuntimeError(f'No module names found matching {names}.')
|
||||
|
||||
self.model = model
|
||||
self.hooks = FeatureHooks(matched, model.named_modules(), default_hook_type=hook_type)
|
||||
|
||||
self.names = matched
|
||||
self.mode = mode
|
||||
self.method = method
|
||||
|
||||
def forward(self, x):
|
||||
if self.hooks is not None:
|
||||
self.model(x)
|
||||
output = self.hooks.get_output(device=x.device)
|
||||
else:
|
||||
output = self.model(x)
|
||||
output = OrderedDict(zip(self.names, output))
|
||||
return output
|
Loading…
x
Reference in New Issue
Block a user