mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
checkpoint filter fns with consistent name, add mobileclip-b pretrained cfgs
This commit is contained in:
parent
7d4ada6d16
commit
88a1006e02
@ -591,7 +591,7 @@ default_cfgs = generate_default_cfgs({
|
||||
})
|
||||
|
||||
|
||||
def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
|
||||
def checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('module', state_dict)
|
||||
# beit v2 didn't strip module
|
||||
@ -637,7 +637,7 @@ def _create_beit(variant, pretrained=False, **kwargs):
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
model = build_model_with_cfg(
|
||||
Beit, variant, pretrained,
|
||||
pretrained_filter_fn=_beit_checkpoint_filter_fn,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -556,7 +556,7 @@ class EfficientFormer(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _checkpoint_filter_fn(state_dict, model):
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" Remap original checkpoints -> timm """
|
||||
if 'stem.0.weight' in state_dict:
|
||||
return state_dict # non-original checkpoint, no remapping needed
|
||||
@ -611,7 +611,7 @@ def _create_efficientformer(variant, pretrained=False, **kwargs):
|
||||
out_indices = kwargs.pop('out_indices', 4)
|
||||
model = build_model_with_cfg(
|
||||
EfficientFormer, variant, pretrained,
|
||||
pretrained_filter_fn=_checkpoint_filter_fn,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -1414,7 +1414,7 @@ default_cfgs = generate_default_cfgs({
|
||||
})
|
||||
|
||||
|
||||
def _checkpoint_filter_fn(state_dict, model):
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" Remap original checkpoints -> timm """
|
||||
if 'stem.0.conv_kxk.0.conv.weight' in state_dict:
|
||||
return state_dict # non-original checkpoint, no remapping needed
|
||||
@ -1493,7 +1493,7 @@ def _create_fastvit(variant, pretrained=False, **kwargs):
|
||||
FastVit,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=_checkpoint_filter_fn,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs
|
||||
)
|
||||
|
@ -403,7 +403,7 @@ class PyramidVisionTransformerV2(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _checkpoint_filter_fn(state_dict, model):
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" Remap original checkpoints -> timm """
|
||||
if 'patch_embed.proj.weight' in state_dict:
|
||||
return state_dict # non-original checkpoint, no remapping needed
|
||||
@ -430,7 +430,7 @@ def _create_pvt2(variant, pretrained=False, **kwargs):
|
||||
PyramidVisionTransformerV2,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=_checkpoint_filter_fn,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -15,7 +15,7 @@ Hacked together by / Copyright 2020, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -24,10 +24,11 @@ import torch.nn.functional as F
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
from .resnet import resnet26d, resnet50d
|
||||
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
||||
from .vision_transformer import _create_vision_transformer, VisionTransformer
|
||||
from .vision_transformer import VisionTransformer
|
||||
|
||||
|
||||
class HybridEmbed(nn.Module):
|
||||
@ -159,22 +160,26 @@ class HybridEmbedWithSize(nn.Module):
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
backbone,
|
||||
img_size=224,
|
||||
patch_size=1,
|
||||
feature_size=None,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
backbone: nn.Module,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 1,
|
||||
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
bias=True,
|
||||
proj=True,
|
||||
):
|
||||
super().__init__(
|
||||
backbone=backbone,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
feature_size=feature_size,
|
||||
feature_ratio=feature_ratio,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
bias=bias,
|
||||
proj=proj,
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
@ -206,12 +211,8 @@ class ConvStem(nn.Sequential):
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(channels, int):
|
||||
if depth == 4:
|
||||
channels = (channels // 8, channels // 4, channels // 2, channels)
|
||||
elif depth == 3:
|
||||
channels = (channels // 4, channels // 2, channels)
|
||||
else:
|
||||
channels = to_ntuple(depth)(channels)
|
||||
# a default tiered channel strategy
|
||||
channels = tuple([channels // 2**i for i in range(depth)][::-1])
|
||||
|
||||
kernel_size = to_ntuple(depth)(kernel_size)
|
||||
padding = to_ntuple(depth)(padding)
|
||||
@ -235,13 +236,6 @@ class ConvStem(nn.Sequential):
|
||||
in_chs = channels[i]
|
||||
|
||||
|
||||
def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs):
|
||||
embed_args = embed_args or {}
|
||||
embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args)
|
||||
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
||||
return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
|
||||
|
||||
|
||||
def _resnetv2(layers=(3, 4, 9), **kwargs):
|
||||
""" ResNet-V2 backbone helper"""
|
||||
padding_same = kwargs.get('padding_same', True)
|
||||
@ -257,6 +251,66 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
|
||||
return backbone
|
||||
|
||||
|
||||
def _convert_mobileclip(state_dict, model, prefix='image_encoder.model.'):
|
||||
out = {}
|
||||
for k, v in state_dict.items():
|
||||
if not k.startswith(prefix):
|
||||
continue
|
||||
k = k.replace(prefix, '')
|
||||
k = k.replace('patch_emb.', 'patch_embed.backbone.')
|
||||
k = k.replace('block.conv', 'conv')
|
||||
k = k.replace('block.norm', 'bn')
|
||||
k = k.replace('post_transformer_norm.', 'norm.')
|
||||
k = k.replace('pre_norm_mha.0', 'norm1')
|
||||
k = k.replace('pre_norm_mha.1', 'attn')
|
||||
k = k.replace('pre_norm_ffn.0', 'norm2')
|
||||
k = k.replace('pre_norm_ffn.1', 'mlp.fc1')
|
||||
k = k.replace('pre_norm_ffn.4', 'mlp.fc2')
|
||||
k = k.replace('qkv_proj.', 'qkv.')
|
||||
k = k.replace('out_proj.', 'proj.')
|
||||
k = k.replace('transformer.', 'blocks.')
|
||||
if k == 'pos_embed.pos_embed.pos_embed':
|
||||
k = 'pos_embed'
|
||||
v = v.squeeze(0)
|
||||
if 'classifier.proj' in k:
|
||||
bias_k = k.replace('classifier.proj', 'head.bias')
|
||||
k = k.replace('classifier.proj', 'head.weight')
|
||||
v = v.T
|
||||
out[bias_k] = torch.zeros(v.shape[0])
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
def checkpoint_filter_fn(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
model: VisionTransformer,
|
||||
interpolation: str = 'bicubic',
|
||||
antialias: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
from .vision_transformer import checkpoint_filter_fn as _filter_fn
|
||||
|
||||
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
|
||||
state_dict = _convert_mobileclip(state_dict, model)
|
||||
|
||||
return _filter_fn(state_dict, model, interpolation=interpolation, antialias=antialias)
|
||||
|
||||
|
||||
def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs):
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
embed_args = embed_args or {}
|
||||
embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args)
|
||||
kwargs.setdefault('embed_layer', embed_layer)
|
||||
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
||||
return build_model_with_cfg(
|
||||
VisionTransformer,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
@ -331,6 +385,17 @@ default_cfgs = generate_default_cfgs({
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||
'vit_base_resnet50d_224.untrained': _cfg(
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||
|
||||
'vit_base_mci_224.apple_mclip': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt',
|
||||
num_classes=512,
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.conv1.0',
|
||||
),
|
||||
'vit_base_mci_224.apple_mclip_lt': _cfg(
|
||||
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt',
|
||||
num_classes=512,
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.conv1.0',
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
@ -491,7 +556,7 @@ def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
)
|
||||
model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_base_resnet50d_224', backbone=backbone, embed_args=dict(proj=False),
|
||||
'vit_base_mci_224', backbone=backbone, embed_args=dict(proj=False),
|
||||
pretrained=pretrained, **dict(model_args, **kwargs)
|
||||
)
|
||||
return model
|
||||
|
Loading…
x
Reference in New Issue
Block a user