mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add regex matching support to AttentionExtract. Add return_dict support to graph extractors and use returned output in AttentionExtractor
This commit is contained in:
parent
44f72c04b3
commit
e748805be3
@ -118,6 +118,7 @@ class FeatureGraphNet(nn.Module):
|
||||
out_indices: Tuple[int, ...],
|
||||
out_map: Optional[Dict] = None,
|
||||
output_fmt: str = 'NCHW',
|
||||
return_dict: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
@ -127,9 +128,13 @@ class FeatureGraphNet(nn.Module):
|
||||
self.output_fmt = Format(output_fmt)
|
||||
return_nodes = _get_return_layers(self.feature_info, out_map)
|
||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||
self.return_dict = return_dict
|
||||
|
||||
def forward(self, x):
|
||||
return list(self.graph_module(x).values())
|
||||
out = self.graph_module(x)
|
||||
if self.return_dict:
|
||||
return out
|
||||
return list(out.values())
|
||||
|
||||
|
||||
class GraphExtractNet(nn.Module):
|
||||
@ -144,19 +149,23 @@ class GraphExtractNet(nn.Module):
|
||||
model: model to extract features from
|
||||
return_nodes: node names to return features from (dict or list)
|
||||
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
||||
return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
return_nodes: Union[Dict[str, str], List[str]],
|
||||
squeeze_out: bool = True,
|
||||
return_dict: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.squeeze_out = squeeze_out
|
||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||
self.return_dict = return_dict
|
||||
|
||||
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
out = list(self.graph_module(x).values())
|
||||
if self.squeeze_out and len(out) == 1:
|
||||
return out[0]
|
||||
return out
|
||||
out = self.graph_module(x)
|
||||
if self.return_dict:
|
||||
return out
|
||||
out = list(out.values())
|
||||
return out[0] if self.squeeze_out and len(out) == 1 else out
|
||||
|
@ -1,4 +1,5 @@
|
||||
import fnmatch
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from typing import Union, Optional, List
|
||||
|
||||
@ -17,6 +18,7 @@ class AttentionExtract(torch.nn.Module):
|
||||
mode: str = 'eval',
|
||||
method: str = 'fx',
|
||||
hook_type: str = 'forward',
|
||||
use_regex: bool = False,
|
||||
):
|
||||
""" Extract attention maps (or other activations) from a model by name.
|
||||
|
||||
@ -26,6 +28,7 @@ class AttentionExtract(torch.nn.Module):
|
||||
mode: 'train' or 'eval' model mode.
|
||||
method: 'fx' or 'hook' extraction method.
|
||||
hook_type: 'forward' or 'forward_pre' hooks used.
|
||||
use_regex: Use regex instead of fnmatch
|
||||
"""
|
||||
super().__init__()
|
||||
assert mode in ('train', 'eval')
|
||||
@ -41,11 +44,15 @@ class AttentionExtract(torch.nn.Module):
|
||||
|
||||
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
|
||||
names = names or self.default_node_names
|
||||
matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])]
|
||||
if use_regex:
|
||||
regexes = [re.compile(r) for r in names]
|
||||
matched = [g for g in node_names if any([r.match(g) for r in regexes])]
|
||||
else:
|
||||
matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])]
|
||||
if not matched:
|
||||
raise RuntimeError(f'No node names found matching {names}.')
|
||||
|
||||
self.model = GraphExtractNet(model, matched)
|
||||
self.model = GraphExtractNet(model, matched, return_dict=True)
|
||||
self.hooks = None
|
||||
else:
|
||||
# names are module names
|
||||
@ -54,7 +61,11 @@ class AttentionExtract(torch.nn.Module):
|
||||
|
||||
module_names = [n for n, m in model.named_modules()]
|
||||
names = names or self.default_module_names
|
||||
matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])]
|
||||
if use_regex:
|
||||
regexes = [re.compile(r) for r in names]
|
||||
matched = [m for m in module_names if any([r.match(m) for r in regexes])]
|
||||
else:
|
||||
matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])]
|
||||
if not matched:
|
||||
raise RuntimeError(f'No module names found matching {names}.')
|
||||
|
||||
@ -71,5 +82,4 @@ class AttentionExtract(torch.nn.Module):
|
||||
output = self.hooks.get_output(device=x.device)
|
||||
else:
|
||||
output = self.model(x)
|
||||
output = OrderedDict(zip(self.names, output))
|
||||
return output
|
||||
|
Loading…
x
Reference in New Issue
Block a user