Change node/module name matching for AttentionExtract so it keeps outputs in order. #1232

This commit is contained in:
Ross Wightman 2024-05-22 13:45:25 -07:00
parent 84cb225ecb
commit 44f72c04b3

View File

@ -40,10 +40,8 @@ class AttentionExtract(torch.nn.Module):
from timm.models._features_fx import get_graph_node_names, GraphExtractNet
node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
matched = []
names = names or self.default_node_names
for n in names:
matched.extend(fnmatch.filter(node_names, n))
matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])]
if not matched:
raise RuntimeError(f'No node names found matching {names}.')
@ -55,10 +53,8 @@ class AttentionExtract(torch.nn.Module):
from timm.models._features import FeatureHooks
module_names = [n for n, m in model.named_modules()]
matched = []
names = names or self.default_module_names
for n in names:
matched.extend(fnmatch.filter(module_names, n))
matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])]
if not matched:
raise RuntimeError(f'No module names found matching {names}.')