Add regex matching support to AttentionExtract. Add return_dict support to graph extractors and use returned output in AttentionExtractor

This commit is contained in:
Ross Wightman 2024-05-22 14:33:39 -07:00
parent 44f72c04b3
commit e748805be3
2 changed files with 28 additions and 9 deletions

View File

@ -118,6 +118,7 @@ class FeatureGraphNet(nn.Module):
out_indices: Tuple[int, ...],
out_map: Optional[Dict] = None,
output_fmt: str = 'NCHW',
return_dict: bool = False,
):
super().__init__()
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
@ -127,9 +128,13 @@ class FeatureGraphNet(nn.Module):
self.output_fmt = Format(output_fmt)
return_nodes = _get_return_layers(self.feature_info, out_map)
self.graph_module = create_feature_extractor(model, return_nodes)
self.return_dict = return_dict
def forward(self, x):
return list(self.graph_module(x).values())
out = self.graph_module(x)
if self.return_dict:
return out
return list(out.values())
class GraphExtractNet(nn.Module):
@ -144,19 +149,23 @@ class GraphExtractNet(nn.Module):
model: model to extract features from
return_nodes: node names to return features from (dict or list)
squeeze_out: if only one output, and output in list format, flatten to single tensor
return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg
"""
def __init__(
self,
model: nn.Module,
return_nodes: Union[Dict[str, str], List[str]],
squeeze_out: bool = True,
return_dict: bool = False,
):
super().__init__()
self.squeeze_out = squeeze_out
self.graph_module = create_feature_extractor(model, return_nodes)
self.return_dict = return_dict
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
out = list(self.graph_module(x).values())
if self.squeeze_out and len(out) == 1:
return out[0]
return out
out = self.graph_module(x)
if self.return_dict:
return out
out = list(out.values())
return out[0] if self.squeeze_out and len(out) == 1 else out

View File

@ -1,4 +1,5 @@
import fnmatch
import re
from collections import OrderedDict
from typing import Union, Optional, List
@ -17,6 +18,7 @@ class AttentionExtract(torch.nn.Module):
mode: str = 'eval',
method: str = 'fx',
hook_type: str = 'forward',
use_regex: bool = False,
):
""" Extract attention maps (or other activations) from a model by name.
@ -26,6 +28,7 @@ class AttentionExtract(torch.nn.Module):
mode: 'train' or 'eval' model mode.
method: 'fx' or 'hook' extraction method.
hook_type: 'forward' or 'forward_pre' hooks used.
use_regex: Use regex instead of fnmatch
"""
super().__init__()
assert mode in ('train', 'eval')
@ -41,11 +44,15 @@ class AttentionExtract(torch.nn.Module):
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
names = names or self.default_node_names
matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])]
if use_regex:
regexes = [re.compile(r) for r in names]
matched = [g for g in node_names if any([r.match(g) for r in regexes])]
else:
matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])]
if not matched:
raise RuntimeError(f'No node names found matching {names}.')
self.model = GraphExtractNet(model, matched)
self.model = GraphExtractNet(model, matched, return_dict=True)
self.hooks = None
else:
# names are module names
@ -54,7 +61,11 @@ class AttentionExtract(torch.nn.Module):
module_names = [n for n, m in model.named_modules()]
names = names or self.default_module_names
matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])]
if use_regex:
regexes = [re.compile(r) for r in names]
matched = [m for m in module_names if any([r.match(m) for r in regexes])]
else:
matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])]
if not matched:
raise RuntimeError(f'No module names found matching {names}.')
@ -71,5 +82,4 @@ class AttentionExtract(torch.nn.Module):
output = self.hooks.get_output(device=x.device)
else:
output = self.model(x)
output = OrderedDict(zip(self.names, output))
return output