mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Unbreak gamma remap impacting beit checkpoint load, version bump to 0.6.4
This commit is contained in:
parent
1ccce50d48
commit
a8e34051c1
@ -10,6 +10,8 @@ Modifications copyright 2021, Ross Wightman
|
|||||||
"""
|
"""
|
||||||
# Copyright (c) 2015-present, Facebook, Inc.
|
# Copyright (c) 2015-present, Facebook, Inc.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
@ -177,7 +179,7 @@ def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
|
|||||||
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
|
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
|
||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
model_cls, variant, pretrained,
|
model_cls, variant, pretrained,
|
||||||
pretrained_filter_fn=checkpoint_filter_fn,
|
pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
|
||||||
**kwargs)
|
**kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -626,7 +626,7 @@ def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
|
|||||||
return posemb
|
return posemb
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_filter_fn(state_dict, model):
|
def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
|
||||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||||
import re
|
import re
|
||||||
out_dict = {}
|
out_dict = {}
|
||||||
@ -647,7 +647,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|||||||
getattr(model, 'num_prefix_tokens', 1),
|
getattr(model, 'num_prefix_tokens', 1),
|
||||||
model.patch_embed.grid_size
|
model.patch_embed.grid_size
|
||||||
)
|
)
|
||||||
elif 'gamma_' in k:
|
elif adapt_layer_scale and 'gamma_' in k:
|
||||||
# remap layer-scale gamma into sub-module (deit3 models)
|
# remap layer-scale gamma into sub-module (deit3 models)
|
||||||
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
|
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
|
||||||
elif 'pre_logits' in k:
|
elif 'pre_logits' in k:
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = '0.6.3.dev0'
|
__version__ = '0.6.4'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user