diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index a4acf5d2..861676e9 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -614,15 +614,16 @@ class SwinTransformerV2(nn.Module): def checkpoint_filter_fn(state_dict, model): state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('state_dict', state_dict) - if 'head.fc.weight' in state_dict: - return state_dict + native_checkpoint = 'head.fc.weight' in state_dict out_dict = {} import re for k, v in state_dict.items(): if any([n in k for n in ('relative_position_index', 'relative_coords_table', 'attn_mask')]): continue # skip buffers that should not be persistent - k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) - k = k.replace('head.', 'head.fc.') + if not native_checkpoint: + # skip layer remapping for updated checkpoints + k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) + k = k.replace('head.', 'head.fc.') out_dict[k] = v return out_dict