mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add regex matching support to AttentionExtract. Add return_dict support to graph extractors and use returned output in AttentionExtractor
This commit is contained in:
parent
44f72c04b3
commit
e748805be3
@ -118,6 +118,7 @@ class FeatureGraphNet(nn.Module):
|
|||||||
out_indices: Tuple[int, ...],
|
out_indices: Tuple[int, ...],
|
||||||
out_map: Optional[Dict] = None,
|
out_map: Optional[Dict] = None,
|
||||||
output_fmt: str = 'NCHW',
|
output_fmt: str = 'NCHW',
|
||||||
|
return_dict: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||||
@ -127,9 +128,13 @@ class FeatureGraphNet(nn.Module):
|
|||||||
self.output_fmt = Format(output_fmt)
|
self.output_fmt = Format(output_fmt)
|
||||||
return_nodes = _get_return_layers(self.feature_info, out_map)
|
return_nodes = _get_return_layers(self.feature_info, out_map)
|
||||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||||
|
self.return_dict = return_dict
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return list(self.graph_module(x).values())
|
out = self.graph_module(x)
|
||||||
|
if self.return_dict:
|
||||||
|
return out
|
||||||
|
return list(out.values())
|
||||||
|
|
||||||
|
|
||||||
class GraphExtractNet(nn.Module):
|
class GraphExtractNet(nn.Module):
|
||||||
@ -144,19 +149,23 @@ class GraphExtractNet(nn.Module):
|
|||||||
model: model to extract features from
|
model: model to extract features from
|
||||||
return_nodes: node names to return features from (dict or list)
|
return_nodes: node names to return features from (dict or list)
|
||||||
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
||||||
|
return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
return_nodes: Union[Dict[str, str], List[str]],
|
return_nodes: Union[Dict[str, str], List[str]],
|
||||||
squeeze_out: bool = True,
|
squeeze_out: bool = True,
|
||||||
|
return_dict: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.squeeze_out = squeeze_out
|
self.squeeze_out = squeeze_out
|
||||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||||
|
self.return_dict = return_dict
|
||||||
|
|
||||||
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
|
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||||
out = list(self.graph_module(x).values())
|
out = self.graph_module(x)
|
||||||
if self.squeeze_out and len(out) == 1:
|
if self.return_dict:
|
||||||
return out[0]
|
return out
|
||||||
return out
|
out = list(out.values())
|
||||||
|
return out[0] if self.squeeze_out and len(out) == 1 else out
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import fnmatch
|
import fnmatch
|
||||||
|
import re
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Union, Optional, List
|
from typing import Union, Optional, List
|
||||||
|
|
||||||
@ -17,6 +18,7 @@ class AttentionExtract(torch.nn.Module):
|
|||||||
mode: str = 'eval',
|
mode: str = 'eval',
|
||||||
method: str = 'fx',
|
method: str = 'fx',
|
||||||
hook_type: str = 'forward',
|
hook_type: str = 'forward',
|
||||||
|
use_regex: bool = False,
|
||||||
):
|
):
|
||||||
""" Extract attention maps (or other activations) from a model by name.
|
""" Extract attention maps (or other activations) from a model by name.
|
||||||
|
|
||||||
@ -26,6 +28,7 @@ class AttentionExtract(torch.nn.Module):
|
|||||||
mode: 'train' or 'eval' model mode.
|
mode: 'train' or 'eval' model mode.
|
||||||
method: 'fx' or 'hook' extraction method.
|
method: 'fx' or 'hook' extraction method.
|
||||||
hook_type: 'forward' or 'forward_pre' hooks used.
|
hook_type: 'forward' or 'forward_pre' hooks used.
|
||||||
|
use_regex: Use regex instead of fnmatch
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert mode in ('train', 'eval')
|
assert mode in ('train', 'eval')
|
||||||
@ -41,11 +44,15 @@ class AttentionExtract(torch.nn.Module):
|
|||||||
|
|
||||||
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
|
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
|
||||||
names = names or self.default_node_names
|
names = names or self.default_node_names
|
||||||
matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])]
|
if use_regex:
|
||||||
|
regexes = [re.compile(r) for r in names]
|
||||||
|
matched = [g for g in node_names if any([r.match(g) for r in regexes])]
|
||||||
|
else:
|
||||||
|
matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])]
|
||||||
if not matched:
|
if not matched:
|
||||||
raise RuntimeError(f'No node names found matching {names}.')
|
raise RuntimeError(f'No node names found matching {names}.')
|
||||||
|
|
||||||
self.model = GraphExtractNet(model, matched)
|
self.model = GraphExtractNet(model, matched, return_dict=True)
|
||||||
self.hooks = None
|
self.hooks = None
|
||||||
else:
|
else:
|
||||||
# names are module names
|
# names are module names
|
||||||
@ -54,7 +61,11 @@ class AttentionExtract(torch.nn.Module):
|
|||||||
|
|
||||||
module_names = [n for n, m in model.named_modules()]
|
module_names = [n for n, m in model.named_modules()]
|
||||||
names = names or self.default_module_names
|
names = names or self.default_module_names
|
||||||
matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])]
|
if use_regex:
|
||||||
|
regexes = [re.compile(r) for r in names]
|
||||||
|
matched = [m for m in module_names if any([r.match(m) for r in regexes])]
|
||||||
|
else:
|
||||||
|
matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])]
|
||||||
if not matched:
|
if not matched:
|
||||||
raise RuntimeError(f'No module names found matching {names}.')
|
raise RuntimeError(f'No module names found matching {names}.')
|
||||||
|
|
||||||
@ -71,5 +82,4 @@ class AttentionExtract(torch.nn.Module):
|
|||||||
output = self.hooks.get_output(device=x.device)
|
output = self.hooks.get_output(device=x.device)
|
||||||
else:
|
else:
|
||||||
output = self.model(x)
|
output = self.model(x)
|
||||||
output = OrderedDict(zip(self.names, output))
|
|
||||||
return output
|
return output
|
||||||
|
Loading…
x
Reference in New Issue
Block a user