mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
checkpoint filter fns with consistent name, add mobileclip-b pretrained cfgs
This commit is contained in:
parent
7d4ada6d16
commit
88a1006e02
@ -591,7 +591,7 @@ default_cfgs = generate_default_cfgs({
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
|
def checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
|
||||||
state_dict = state_dict.get('model', state_dict)
|
state_dict = state_dict.get('model', state_dict)
|
||||||
state_dict = state_dict.get('module', state_dict)
|
state_dict = state_dict.get('module', state_dict)
|
||||||
# beit v2 didn't strip module
|
# beit v2 didn't strip module
|
||||||
@ -637,7 +637,7 @@ def _create_beit(variant, pretrained=False, **kwargs):
|
|||||||
out_indices = kwargs.pop('out_indices', 3)
|
out_indices = kwargs.pop('out_indices', 3)
|
||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
Beit, variant, pretrained,
|
Beit, variant, pretrained,
|
||||||
pretrained_filter_fn=_beit_checkpoint_filter_fn,
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -556,7 +556,7 @@ class EfficientFormer(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _checkpoint_filter_fn(state_dict, model):
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
""" Remap original checkpoints -> timm """
|
""" Remap original checkpoints -> timm """
|
||||||
if 'stem.0.weight' in state_dict:
|
if 'stem.0.weight' in state_dict:
|
||||||
return state_dict # non-original checkpoint, no remapping needed
|
return state_dict # non-original checkpoint, no remapping needed
|
||||||
@ -611,7 +611,7 @@ def _create_efficientformer(variant, pretrained=False, **kwargs):
|
|||||||
out_indices = kwargs.pop('out_indices', 4)
|
out_indices = kwargs.pop('out_indices', 4)
|
||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
EfficientFormer, variant, pretrained,
|
EfficientFormer, variant, pretrained,
|
||||||
pretrained_filter_fn=_checkpoint_filter_fn,
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -1414,7 +1414,7 @@ default_cfgs = generate_default_cfgs({
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def _checkpoint_filter_fn(state_dict, model):
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
""" Remap original checkpoints -> timm """
|
""" Remap original checkpoints -> timm """
|
||||||
if 'stem.0.conv_kxk.0.conv.weight' in state_dict:
|
if 'stem.0.conv_kxk.0.conv.weight' in state_dict:
|
||||||
return state_dict # non-original checkpoint, no remapping needed
|
return state_dict # non-original checkpoint, no remapping needed
|
||||||
@ -1493,7 +1493,7 @@ def _create_fastvit(variant, pretrained=False, **kwargs):
|
|||||||
FastVit,
|
FastVit,
|
||||||
variant,
|
variant,
|
||||||
pretrained,
|
pretrained,
|
||||||
pretrained_filter_fn=_checkpoint_filter_fn,
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
@ -403,7 +403,7 @@ class PyramidVisionTransformerV2(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _checkpoint_filter_fn(state_dict, model):
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
""" Remap original checkpoints -> timm """
|
""" Remap original checkpoints -> timm """
|
||||||
if 'patch_embed.proj.weight' in state_dict:
|
if 'patch_embed.proj.weight' in state_dict:
|
||||||
return state_dict # non-original checkpoint, no remapping needed
|
return state_dict # non-original checkpoint, no remapping needed
|
||||||
@ -430,7 +430,7 @@ def _create_pvt2(variant, pretrained=False, **kwargs):
|
|||||||
PyramidVisionTransformerV2,
|
PyramidVisionTransformerV2,
|
||||||
variant,
|
variant,
|
||||||
pretrained,
|
pretrained,
|
||||||
pretrained_filter_fn=_checkpoint_filter_fn,
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -15,7 +15,7 @@ Hacked together by / Copyright 2020, Ross Wightman
|
|||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Optional, Tuple, Type, Union
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -24,10 +24,11 @@ import torch.nn.functional as F
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to
|
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to
|
||||||
|
|
||||||
|
from ._builder import build_model_with_cfg
|
||||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||||
from .resnet import resnet26d, resnet50d
|
from .resnet import resnet26d, resnet50d
|
||||||
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
||||||
from .vision_transformer import _create_vision_transformer, VisionTransformer
|
from .vision_transformer import VisionTransformer
|
||||||
|
|
||||||
|
|
||||||
class HybridEmbed(nn.Module):
|
class HybridEmbed(nn.Module):
|
||||||
@ -159,22 +160,26 @@ class HybridEmbedWithSize(nn.Module):
|
|||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
backbone,
|
backbone: nn.Module,
|
||||||
img_size=224,
|
img_size: Union[int, Tuple[int, int]] = 224,
|
||||||
patch_size=1,
|
patch_size: Union[int, Tuple[int, int]] = 1,
|
||||||
feature_size=None,
|
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
in_chans=3,
|
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
embed_dim=768,
|
in_chans: int = 3,
|
||||||
|
embed_dim: int = 768,
|
||||||
bias=True,
|
bias=True,
|
||||||
|
proj=True,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
backbone=backbone,
|
backbone=backbone,
|
||||||
img_size=img_size,
|
img_size=img_size,
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
feature_size=feature_size,
|
feature_size=feature_size,
|
||||||
|
feature_ratio=feature_ratio,
|
||||||
in_chans=in_chans,
|
in_chans=in_chans,
|
||||||
embed_dim=embed_dim,
|
embed_dim=embed_dim,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
proj=proj,
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
@ -206,12 +211,8 @@ class ConvStem(nn.Sequential):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if isinstance(channels, int):
|
if isinstance(channels, int):
|
||||||
if depth == 4:
|
# a default tiered channel strategy
|
||||||
channels = (channels // 8, channels // 4, channels // 2, channels)
|
channels = tuple([channels // 2**i for i in range(depth)][::-1])
|
||||||
elif depth == 3:
|
|
||||||
channels = (channels // 4, channels // 2, channels)
|
|
||||||
else:
|
|
||||||
channels = to_ntuple(depth)(channels)
|
|
||||||
|
|
||||||
kernel_size = to_ntuple(depth)(kernel_size)
|
kernel_size = to_ntuple(depth)(kernel_size)
|
||||||
padding = to_ntuple(depth)(padding)
|
padding = to_ntuple(depth)(padding)
|
||||||
@ -235,13 +236,6 @@ class ConvStem(nn.Sequential):
|
|||||||
in_chs = channels[i]
|
in_chs = channels[i]
|
||||||
|
|
||||||
|
|
||||||
def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs):
|
|
||||||
embed_args = embed_args or {}
|
|
||||||
embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args)
|
|
||||||
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
|
||||||
return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _resnetv2(layers=(3, 4, 9), **kwargs):
|
def _resnetv2(layers=(3, 4, 9), **kwargs):
|
||||||
""" ResNet-V2 backbone helper"""
|
""" ResNet-V2 backbone helper"""
|
||||||
padding_same = kwargs.get('padding_same', True)
|
padding_same = kwargs.get('padding_same', True)
|
||||||
@ -257,6 +251,66 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
|
|||||||
return backbone
|
return backbone
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_mobileclip(state_dict, model, prefix='image_encoder.model.'):
|
||||||
|
out = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if not k.startswith(prefix):
|
||||||
|
continue
|
||||||
|
k = k.replace(prefix, '')
|
||||||
|
k = k.replace('patch_emb.', 'patch_embed.backbone.')
|
||||||
|
k = k.replace('block.conv', 'conv')
|
||||||
|
k = k.replace('block.norm', 'bn')
|
||||||
|
k = k.replace('post_transformer_norm.', 'norm.')
|
||||||
|
k = k.replace('pre_norm_mha.0', 'norm1')
|
||||||
|
k = k.replace('pre_norm_mha.1', 'attn')
|
||||||
|
k = k.replace('pre_norm_ffn.0', 'norm2')
|
||||||
|
k = k.replace('pre_norm_ffn.1', 'mlp.fc1')
|
||||||
|
k = k.replace('pre_norm_ffn.4', 'mlp.fc2')
|
||||||
|
k = k.replace('qkv_proj.', 'qkv.')
|
||||||
|
k = k.replace('out_proj.', 'proj.')
|
||||||
|
k = k.replace('transformer.', 'blocks.')
|
||||||
|
if k == 'pos_embed.pos_embed.pos_embed':
|
||||||
|
k = 'pos_embed'
|
||||||
|
v = v.squeeze(0)
|
||||||
|
if 'classifier.proj' in k:
|
||||||
|
bias_k = k.replace('classifier.proj', 'head.bias')
|
||||||
|
k = k.replace('classifier.proj', 'head.weight')
|
||||||
|
v = v.T
|
||||||
|
out[bias_k] = torch.zeros(v.shape[0])
|
||||||
|
out[k] = v
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_filter_fn(
|
||||||
|
state_dict: Dict[str, torch.Tensor],
|
||||||
|
model: VisionTransformer,
|
||||||
|
interpolation: str = 'bicubic',
|
||||||
|
antialias: bool = True,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
from .vision_transformer import checkpoint_filter_fn as _filter_fn
|
||||||
|
|
||||||
|
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
|
||||||
|
state_dict = _convert_mobileclip(state_dict, model)
|
||||||
|
|
||||||
|
return _filter_fn(state_dict, model, interpolation=interpolation, antialias=antialias)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs):
|
||||||
|
out_indices = kwargs.pop('out_indices', 3)
|
||||||
|
embed_args = embed_args or {}
|
||||||
|
embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args)
|
||||||
|
kwargs.setdefault('embed_layer', embed_layer)
|
||||||
|
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
||||||
|
return build_model_with_cfg(
|
||||||
|
VisionTransformer,
|
||||||
|
variant,
|
||||||
|
pretrained,
|
||||||
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
|
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
return {
|
return {
|
||||||
'url': url,
|
'url': url,
|
||||||
@ -331,6 +385,17 @@ default_cfgs = generate_default_cfgs({
|
|||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||||
'vit_base_resnet50d_224.untrained': _cfg(
|
'vit_base_resnet50d_224.untrained': _cfg(
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||||
|
|
||||||
|
'vit_base_mci_224.apple_mclip': _cfg(
|
||||||
|
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt',
|
||||||
|
num_classes=512,
|
||||||
|
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.conv1.0',
|
||||||
|
),
|
||||||
|
'vit_base_mci_224.apple_mclip_lt': _cfg(
|
||||||
|
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt',
|
||||||
|
num_classes=512,
|
||||||
|
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.conv1.0',
|
||||||
|
),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@ -491,7 +556,7 @@ def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
)
|
)
|
||||||
model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
|
model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
|
||||||
model = _create_vision_transformer_hybrid(
|
model = _create_vision_transformer_hybrid(
|
||||||
'vit_base_resnet50d_224', backbone=backbone, embed_args=dict(proj=False),
|
'vit_base_mci_224', backbone=backbone, embed_args=dict(proj=False),
|
||||||
pretrained=pretrained, **dict(model_args, **kwargs)
|
pretrained=pretrained, **dict(model_args, **kwargs)
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user