From 44f72c04b365660358c281775e6195d8394cf1e8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 22 May 2024 13:45:25 -0700 Subject: [PATCH] Change node/module name matching for AttentionExtract so it keeps outputs in order. #1232 --- timm/utils/attention_extract.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/timm/utils/attention_extract.py b/timm/utils/attention_extract.py index 90021018..da0913f0 100644 --- a/timm/utils/attention_extract.py +++ b/timm/utils/attention_extract.py @@ -40,10 +40,8 @@ class AttentionExtract(torch.nn.Module): from timm.models._features_fx import get_graph_node_names, GraphExtractNet node_names = get_graph_node_names(model)[0 if mode == 'train' else 1] - matched = [] names = names or self.default_node_names - for n in names: - matched.extend(fnmatch.filter(node_names, n)) + matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])] if not matched: raise RuntimeError(f'No node names found matching {names}.') @@ -55,10 +53,8 @@ class AttentionExtract(torch.nn.Module): from timm.models._features import FeatureHooks module_names = [n for n, m in model.named_modules()] - matched = [] names = names or self.default_module_names - for n in names: - matched.extend(fnmatch.filter(module_names, n)) + matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])] if not matched: raise RuntimeError(f'No module names found matching {names}.')