mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add AttentionExtract helper module
This commit is contained in:
parent
45b7ae8029
commit
07535f408a
@ -80,7 +80,7 @@ from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrai
|
|||||||
set_pretrained_download_progress, set_pretrained_check_hash
|
set_pretrained_download_progress, set_pretrained_check_hash
|
||||||
from ._factory import create_model, parse_model_name, safe_model_name
|
from ._factory import create_model, parse_model_name, safe_model_name
|
||||||
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
|
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
|
||||||
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
|
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, get_graph_node_names, \
|
||||||
register_notrace_module, is_notrace_module, get_notrace_modules, \
|
register_notrace_module, is_notrace_module, get_notrace_modules, \
|
||||||
register_notrace_function, is_notrace_function, get_notrace_functions
|
register_notrace_function, is_notrace_function, get_notrace_functions
|
||||||
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
|
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
|
||||||
|
@ -158,7 +158,7 @@ class FeatureHooks:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hooks: Sequence[str],
|
hooks: Sequence[Union[str, Dict]],
|
||||||
named_modules: dict,
|
named_modules: dict,
|
||||||
out_map: Sequence[Union[int, str]] = None,
|
out_map: Sequence[Union[int, str]] = None,
|
||||||
default_hook_type: str = 'forward',
|
default_hook_type: str = 'forward',
|
||||||
@ -168,11 +168,13 @@ class FeatureHooks:
|
|||||||
self._handles = []
|
self._handles = []
|
||||||
modules = {k: v for k, v in named_modules}
|
modules = {k: v for k, v in named_modules}
|
||||||
for i, h in enumerate(hooks):
|
for i, h in enumerate(hooks):
|
||||||
hook_name = h['module']
|
hook_name = h if isinstance(h, str) else h['module']
|
||||||
m = modules[hook_name]
|
m = modules[hook_name]
|
||||||
hook_id = out_map[i] if out_map else hook_name
|
hook_id = out_map[i] if out_map else hook_name
|
||||||
hook_fn = partial(self._collect_output_hook, hook_id)
|
hook_fn = partial(self._collect_output_hook, hook_id)
|
||||||
hook_type = h.get('hook_type', default_hook_type)
|
hook_type = default_hook_type
|
||||||
|
if isinstance(h, dict):
|
||||||
|
hook_type = h.get('hook_type', default_hook_type)
|
||||||
if hook_type == 'forward_pre':
|
if hook_type == 'forward_pre':
|
||||||
handle = m.register_forward_pre_hook(hook_fn)
|
handle = m.register_forward_pre_hook(hook_fn)
|
||||||
elif hook_type == 'forward':
|
elif hook_type == 'forward':
|
||||||
|
@ -9,7 +9,9 @@ from torch import nn
|
|||||||
from ._features import _get_feature_info, _get_return_layers
|
from ._features import _get_feature_info, _get_return_layers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# NOTE we wrap torchvision fns to use timm leaf / no trace definitions
|
||||||
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
|
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
|
||||||
|
from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names
|
||||||
has_fx_feature_extraction = True
|
has_fx_feature_extraction = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_fx_feature_extraction = False
|
has_fx_feature_extraction = False
|
||||||
@ -30,7 +32,7 @@ from timm.layers.norm_act import (
|
|||||||
|
|
||||||
__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
|
__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
|
||||||
'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
|
'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
|
||||||
'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']
|
'create_feature_extractor', 'get_graph_node_names', 'FeatureGraphNet', 'GraphExtractNet']
|
||||||
|
|
||||||
|
|
||||||
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
||||||
@ -92,6 +94,13 @@ def get_notrace_functions():
|
|||||||
return list(_autowrap_functions)
|
return list(_autowrap_functions)
|
||||||
|
|
||||||
|
|
||||||
|
def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]:
|
||||||
|
return _get_graph_node_names(
|
||||||
|
model,
|
||||||
|
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
|
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
|
||||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||||
return _create_feature_extractor(
|
return _create_feature_extractor(
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from .agc import adaptive_clip_grad
|
from .agc import adaptive_clip_grad
|
||||||
|
from .attention_extract import AttentionExtract
|
||||||
from .checkpoint_saver import CheckpointSaver
|
from .checkpoint_saver import CheckpointSaver
|
||||||
from .clip_grad import dispatch_clip_grad
|
from .clip_grad import dispatch_clip_grad
|
||||||
from .cuda import ApexScaler, NativeScaler
|
from .cuda import ApexScaler, NativeScaler
|
||||||
|
79
timm/utils/attention_extract.py
Normal file
79
timm/utils/attention_extract.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import fnmatch
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Union, Optional, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionExtract(torch.nn.Module):
|
||||||
|
# defaults should cover a significant number of timm models with attention maps.
|
||||||
|
default_node_names = ['*attn.softmax']
|
||||||
|
default_module_names = ['*attn_drop']
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[torch.nn.Module],
|
||||||
|
names: Optional[List[str]] = None,
|
||||||
|
mode: str = 'eval',
|
||||||
|
method: str = 'fx',
|
||||||
|
hook_type: str = 'forward',
|
||||||
|
):
|
||||||
|
""" Extract attention maps (or other activations) from a model by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Instantiated model to extract from.
|
||||||
|
names: List of concrete or wildcard names to extract. Names are nodes for fx and modules for hooks.
|
||||||
|
mode: 'train' or 'eval' model mode.
|
||||||
|
method: 'fx' or 'hook' extraction method.
|
||||||
|
hook_type: 'forward' or 'forward_pre' hooks used.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert mode in ('train', 'eval')
|
||||||
|
if mode == 'train':
|
||||||
|
model = model.train()
|
||||||
|
else:
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
assert method in ('fx', 'hook')
|
||||||
|
if method == 'fx':
|
||||||
|
# names are activation node names
|
||||||
|
from timm.models._features_fx import get_graph_node_names, GraphExtractNet
|
||||||
|
|
||||||
|
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
|
||||||
|
matched = []
|
||||||
|
names = names or self.default_node_names
|
||||||
|
for n in names:
|
||||||
|
matched.extend(fnmatch.filter(node_names, n))
|
||||||
|
if not matched:
|
||||||
|
raise RuntimeError(f'No node names found matching {names}.')
|
||||||
|
|
||||||
|
self.model = GraphExtractNet(model, matched)
|
||||||
|
self.hooks = None
|
||||||
|
else:
|
||||||
|
# names are module names
|
||||||
|
assert hook_type in ('forward', 'forward_pre')
|
||||||
|
from timm.models._features import FeatureHooks
|
||||||
|
|
||||||
|
module_names = [n for n, m in model.named_modules()]
|
||||||
|
matched = []
|
||||||
|
names = names or self.default_module_names
|
||||||
|
for n in names:
|
||||||
|
matched.extend(fnmatch.filter(module_names, n))
|
||||||
|
if not matched:
|
||||||
|
raise RuntimeError(f'No module names found matching {names}.')
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.hooks = FeatureHooks(matched, model.named_modules(), default_hook_type=hook_type)
|
||||||
|
|
||||||
|
self.names = matched
|
||||||
|
self.mode = mode
|
||||||
|
self.method = method
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.hooks is not None:
|
||||||
|
self.model(x)
|
||||||
|
output = self.hooks.get_output(device=x.device)
|
||||||
|
else:
|
||||||
|
output = self.model(x)
|
||||||
|
output = OrderedDict(zip(self.names, output))
|
||||||
|
return output
|
Loading…
x
Reference in New Issue
Block a user