mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Unbreak gamma remap impacting beit checkpoint load, version bump to 0.6.4
This commit is contained in:
parent
1ccce50d48
commit
a8e34051c1
@ -10,6 +10,8 @@ Modifications copyright 2021, Ross Wightman
|
||||
"""
|
||||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
@ -177,7 +179,7 @@ def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
|
||||
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
@ -626,7 +626,7 @@ def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
|
||||
return posemb
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
import re
|
||||
out_dict = {}
|
||||
@ -647,7 +647,7 @@ def checkpoint_filter_fn(state_dict, model):
|
||||
getattr(model, 'num_prefix_tokens', 1),
|
||||
model.patch_embed.grid_size
|
||||
)
|
||||
elif 'gamma_' in k:
|
||||
elif adapt_layer_scale and 'gamma_' in k:
|
||||
# remap layer-scale gamma into sub-module (deit3 models)
|
||||
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
|
||||
elif 'pre_logits' in k:
|
||||
|
@ -1 +1 @@
|
||||
__version__ = '0.6.3.dev0'
|
||||
__version__ = '0.6.4'
|
||||
|
Loading…
x
Reference in New Issue
Block a user