mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Change node/module name matching for AttentionExtract so it keeps outputs in order. #1232
This commit is contained in:
parent
84cb225ecb
commit
44f72c04b3
@ -40,10 +40,8 @@ class AttentionExtract(torch.nn.Module):
|
||||
from timm.models._features_fx import get_graph_node_names, GraphExtractNet
|
||||
|
||||
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
|
||||
matched = []
|
||||
names = names or self.default_node_names
|
||||
for n in names:
|
||||
matched.extend(fnmatch.filter(node_names, n))
|
||||
matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])]
|
||||
if not matched:
|
||||
raise RuntimeError(f'No node names found matching {names}.')
|
||||
|
||||
@ -55,10 +53,8 @@ class AttentionExtract(torch.nn.Module):
|
||||
from timm.models._features import FeatureHooks
|
||||
|
||||
module_names = [n for n, m in model.named_modules()]
|
||||
matched = []
|
||||
names = names or self.default_module_names
|
||||
for n in names:
|
||||
matched.extend(fnmatch.filter(module_names, n))
|
||||
matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])]
|
||||
if not matched:
|
||||
raise RuntimeError(f'No module names found matching {names}.')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user