diff --git a/timm/utils/attention_extract.py b/timm/utils/attention_extract.py index 90021018..da0913f0 100644 --- a/timm/utils/attention_extract.py +++ b/timm/utils/attention_extract.py @@ -40,10 +40,8 @@ class AttentionExtract(torch.nn.Module): from timm.models._features_fx import get_graph_node_names, GraphExtractNet node_names = get_graph_node_names(model)[0 if mode == 'train' else 1] - matched = [] names = names or self.default_node_names - for n in names: - matched.extend(fnmatch.filter(node_names, n)) + matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])] if not matched: raise RuntimeError(f'No node names found matching {names}.') @@ -55,10 +53,8 @@ class AttentionExtract(torch.nn.Module): from timm.models._features import FeatureHooks module_names = [n for n, m in model.named_modules()] - matched = [] names = names or self.default_module_names - for n in names: - matched.extend(fnmatch.filter(module_names, n)) + matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])] if not matched: raise RuntimeError(f'No module names found matching {names}.')