Add AttentionExtract helper module

This commit is contained in:
Ross Wightman 2024-05-04 14:10:00 -07:00
parent 45b7ae8029
commit 07535f408a
5 changed files with 96 additions and 5 deletions

View File

@ -80,7 +80,7 @@ from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrai
set_pretrained_download_progress, set_pretrained_check_hash
from ._factory import create_model, parse_model_name, safe_model_name
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, get_graph_node_names, \
register_notrace_module, is_notrace_module, get_notrace_modules, \
register_notrace_function, is_notrace_function, get_notrace_functions
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint

View File

@ -158,7 +158,7 @@ class FeatureHooks:
def __init__(
self,
hooks: Sequence[str],
hooks: Sequence[Union[str, Dict]],
named_modules: dict,
out_map: Sequence[Union[int, str]] = None,
default_hook_type: str = 'forward',
@ -168,11 +168,13 @@ class FeatureHooks:
self._handles = []
modules = {k: v for k, v in named_modules}
for i, h in enumerate(hooks):
hook_name = h['module']
hook_name = h if isinstance(h, str) else h['module']
m = modules[hook_name]
hook_id = out_map[i] if out_map else hook_name
hook_fn = partial(self._collect_output_hook, hook_id)
hook_type = h.get('hook_type', default_hook_type)
hook_type = default_hook_type
if isinstance(h, dict):
hook_type = h.get('hook_type', default_hook_type)
if hook_type == 'forward_pre':
handle = m.register_forward_pre_hook(hook_fn)
elif hook_type == 'forward':

View File

@ -9,7 +9,9 @@ from torch import nn
from ._features import _get_feature_info, _get_return_layers
try:
# NOTE we wrap torchvision fns to use timm leaf / no trace definitions
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names
has_fx_feature_extraction = True
except ImportError:
has_fx_feature_extraction = False
@ -30,7 +32,7 @@ from timm.layers.norm_act import (
__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']
'create_feature_extractor', 'get_graph_node_names', 'FeatureGraphNet', 'GraphExtractNet']
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
@ -92,6 +94,13 @@ def get_notrace_functions():
return list(_autowrap_functions)
def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]:
return _get_graph_node_names(
model,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
)
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
return _create_feature_extractor(

View File

@ -1,4 +1,5 @@
from .agc import adaptive_clip_grad
from .attention_extract import AttentionExtract
from .checkpoint_saver import CheckpointSaver
from .clip_grad import dispatch_clip_grad
from .cuda import ApexScaler, NativeScaler

View File

@ -0,0 +1,79 @@
import fnmatch
from collections import OrderedDict
from typing import Union, Optional, List
import torch
class AttentionExtract(torch.nn.Module):
# defaults should cover a significant number of timm models with attention maps.
default_node_names = ['*attn.softmax']
default_module_names = ['*attn_drop']
def __init__(
self,
model: Union[torch.nn.Module],
names: Optional[List[str]] = None,
mode: str = 'eval',
method: str = 'fx',
hook_type: str = 'forward',
):
""" Extract attention maps (or other activations) from a model by name.
Args:
model: Instantiated model to extract from.
names: List of concrete or wildcard names to extract. Names are nodes for fx and modules for hooks.
mode: 'train' or 'eval' model mode.
method: 'fx' or 'hook' extraction method.
hook_type: 'forward' or 'forward_pre' hooks used.
"""
super().__init__()
assert mode in ('train', 'eval')
if mode == 'train':
model = model.train()
else:
model = model.eval()
assert method in ('fx', 'hook')
if method == 'fx':
# names are activation node names
from timm.models._features_fx import get_graph_node_names, GraphExtractNet
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
matched = []
names = names or self.default_node_names
for n in names:
matched.extend(fnmatch.filter(node_names, n))
if not matched:
raise RuntimeError(f'No node names found matching {names}.')
self.model = GraphExtractNet(model, matched)
self.hooks = None
else:
# names are module names
assert hook_type in ('forward', 'forward_pre')
from timm.models._features import FeatureHooks
module_names = [n for n, m in model.named_modules()]
matched = []
names = names or self.default_module_names
for n in names:
matched.extend(fnmatch.filter(module_names, n))
if not matched:
raise RuntimeError(f'No module names found matching {names}.')
self.model = model
self.hooks = FeatureHooks(matched, model.named_modules(), default_hook_type=hook_type)
self.names = matched
self.mode = mode
self.method = method
def forward(self, x):
if self.hooks is not None:
self.model(x)
output = self.hooks.get_output(device=x.device)
else:
output = self.model(x)
output = OrderedDict(zip(self.names, output))
return output