mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update swin_v2 attn_mask buffer change in #1790 to apply to updated checkpoints in hub
This commit is contained in:
parent
1a1aca0cee
commit
80b247d843
@ -614,15 +614,16 @@ class SwinTransformerV2(nn.Module):
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
if 'head.fc.weight' in state_dict:
|
||||
return state_dict
|
||||
native_checkpoint = 'head.fc.weight' in state_dict
|
||||
out_dict = {}
|
||||
import re
|
||||
for k, v in state_dict.items():
|
||||
if any([n in k for n in ('relative_position_index', 'relative_coords_table', 'attn_mask')]):
|
||||
continue # skip buffers that should not be persistent
|
||||
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
if not native_checkpoint:
|
||||
# skip layer remapping for updated checkpoints
|
||||
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
out_dict[k] = v
|
||||
|
||||
return out_dict
|
||||
|
Loading…
x
Reference in New Issue
Block a user