diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 377bcd14..6fd03b35 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -620,6 +620,8 @@ def checkpoint_filter_fn(state_dict, model): k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed') k = k.replace('patch_embed', 'stem') + k = k.replace('post_norm', 'norm') + k = k.replace('pre_norm', 'norm') k = re.sub(r'^head', 'head.fc', k) k = re.sub(r'^norm', 'head.norm', k) out_dict[k] = v