From e748805be31318da1a0e34b61294704666f50397 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 22 May 2024 14:33:39 -0700 Subject: [PATCH] Add regex matching support to AttentionExtract. Add return_dict support to graph extractors and use returned output in AttentionExtractor --- timm/models/_features_fx.py | 19 ++++++++++++++----- timm/utils/attention_extract.py | 18 ++++++++++++++---- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index b775871c..3a276046 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -118,6 +118,7 @@ class FeatureGraphNet(nn.Module): out_indices: Tuple[int, ...], out_map: Optional[Dict] = None, output_fmt: str = 'NCHW', + return_dict: bool = False, ): super().__init__() assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' @@ -127,9 +128,13 @@ class FeatureGraphNet(nn.Module): self.output_fmt = Format(output_fmt) return_nodes = _get_return_layers(self.feature_info, out_map) self.graph_module = create_feature_extractor(model, return_nodes) + self.return_dict = return_dict def forward(self, x): - return list(self.graph_module(x).values()) + out = self.graph_module(x) + if self.return_dict: + return out + return list(out.values()) class GraphExtractNet(nn.Module): @@ -144,19 +149,23 @@ class GraphExtractNet(nn.Module): model: model to extract features from return_nodes: node names to return features from (dict or list) squeeze_out: if only one output, and output in list format, flatten to single tensor + return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg """ def __init__( self, model: nn.Module, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True, + return_dict: bool = False, ): super().__init__() self.squeeze_out = squeeze_out self.graph_module = create_feature_extractor(model, return_nodes) + self.return_dict = return_dict def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]: - out = list(self.graph_module(x).values()) - if self.squeeze_out and len(out) == 1: - return out[0] - return out + out = self.graph_module(x) + if self.return_dict: + return out + out = list(out.values()) + return out[0] if self.squeeze_out and len(out) == 1 else out diff --git a/timm/utils/attention_extract.py b/timm/utils/attention_extract.py index da0913f0..e813d42a 100644 --- a/timm/utils/attention_extract.py +++ b/timm/utils/attention_extract.py @@ -1,4 +1,5 @@ import fnmatch +import re from collections import OrderedDict from typing import Union, Optional, List @@ -17,6 +18,7 @@ class AttentionExtract(torch.nn.Module): mode: str = 'eval', method: str = 'fx', hook_type: str = 'forward', + use_regex: bool = False, ): """ Extract attention maps (or other activations) from a model by name. @@ -26,6 +28,7 @@ class AttentionExtract(torch.nn.Module): mode: 'train' or 'eval' model mode. method: 'fx' or 'hook' extraction method. hook_type: 'forward' or 'forward_pre' hooks used. + use_regex: Use regex instead of fnmatch """ super().__init__() assert mode in ('train', 'eval') @@ -41,11 +44,15 @@ class AttentionExtract(torch.nn.Module): node_names = get_graph_node_names(model)[0 if mode == 'train' else 1] names = names or self.default_node_names - matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])] + if use_regex: + regexes = [re.compile(r) for r in names] + matched = [g for g in node_names if any([r.match(g) for r in regexes])] + else: + matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])] if not matched: raise RuntimeError(f'No node names found matching {names}.') - self.model = GraphExtractNet(model, matched) + self.model = GraphExtractNet(model, matched, return_dict=True) self.hooks = None else: # names are module names @@ -54,7 +61,11 @@ class AttentionExtract(torch.nn.Module): module_names = [n for n, m in model.named_modules()] names = names or self.default_module_names - matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])] + if use_regex: + regexes = [re.compile(r) for r in names] + matched = [m for m in module_names if any([r.match(m) for r in regexes])] + else: + matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])] if not matched: raise RuntimeError(f'No module names found matching {names}.') @@ -71,5 +82,4 @@ class AttentionExtract(torch.nn.Module): output = self.hooks.get_output(device=x.device) else: output = self.model(x) - output = OrderedDict(zip(self.names, output)) return output