checkpoint filter fns with consistent name, add mobileclip-b pretrained cfgs

This commit is contained in:
Ross Wightman 2024-06-06 12:38:52 -07:00
parent 7d4ada6d16
commit 88a1006e02
5 changed files with 95 additions and 30 deletions

View File

@ -591,7 +591,7 @@ default_cfgs = generate_default_cfgs({
})
def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
def checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('module', state_dict)
# beit v2 didn't strip module
@ -637,7 +637,7 @@ def _create_beit(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', 3)
model = build_model_with_cfg(
Beit, variant, pretrained,
pretrained_filter_fn=_beit_checkpoint_filter_fn,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)

View File

@ -556,7 +556,7 @@ class EfficientFormer(nn.Module):
return x
def _checkpoint_filter_fn(state_dict, model):
def checkpoint_filter_fn(state_dict, model):
""" Remap original checkpoints -> timm """
if 'stem.0.weight' in state_dict:
return state_dict # non-original checkpoint, no remapping needed
@ -611,7 +611,7 @@ def _create_efficientformer(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', 4)
model = build_model_with_cfg(
EfficientFormer, variant, pretrained,
pretrained_filter_fn=_checkpoint_filter_fn,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)

View File

@ -1414,7 +1414,7 @@ default_cfgs = generate_default_cfgs({
})
def _checkpoint_filter_fn(state_dict, model):
def checkpoint_filter_fn(state_dict, model):
""" Remap original checkpoints -> timm """
if 'stem.0.conv_kxk.0.conv.weight' in state_dict:
return state_dict # non-original checkpoint, no remapping needed
@ -1493,7 +1493,7 @@ def _create_fastvit(variant, pretrained=False, **kwargs):
FastVit,
variant,
pretrained,
pretrained_filter_fn=_checkpoint_filter_fn,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs
)

View File

@ -403,7 +403,7 @@ class PyramidVisionTransformerV2(nn.Module):
return x
def _checkpoint_filter_fn(state_dict, model):
def checkpoint_filter_fn(state_dict, model):
""" Remap original checkpoints -> timm """
if 'patch_embed.proj.weight' in state_dict:
return state_dict # non-original checkpoint, no remapping needed
@ -430,7 +430,7 @@ def _create_pvt2(variant, pretrained=False, **kwargs):
PyramidVisionTransformerV2,
variant,
pretrained,
pretrained_filter_fn=_checkpoint_filter_fn,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs,
)

View File

@ -15,7 +15,7 @@ Hacked together by / Copyright 2020, Ross Wightman
"""
import math
from functools import partial
from typing import List, Optional, Tuple, Type, Union
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
@ -24,10 +24,11 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to
from ._builder import build_model_with_cfg
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, create_resnetv2_stem
from .vision_transformer import _create_vision_transformer, VisionTransformer
from .vision_transformer import VisionTransformer
class HybridEmbed(nn.Module):
@ -159,22 +160,26 @@ class HybridEmbedWithSize(nn.Module):
"""
def __init__(
self,
backbone,
img_size=224,
patch_size=1,
feature_size=None,
in_chans=3,
embed_dim=768,
backbone: nn.Module,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 1,
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
in_chans: int = 3,
embed_dim: int = 768,
bias=True,
proj=True,
):
super().__init__(
backbone=backbone,
img_size=img_size,
patch_size=patch_size,
feature_size=feature_size,
feature_ratio=feature_ratio,
in_chans=in_chans,
embed_dim=embed_dim,
bias=bias,
proj=proj,
)
@torch.jit.ignore
@ -206,12 +211,8 @@ class ConvStem(nn.Sequential):
):
super().__init__()
if isinstance(channels, int):
if depth == 4:
channels = (channels // 8, channels // 4, channels // 2, channels)
elif depth == 3:
channels = (channels // 4, channels // 2, channels)
else:
channels = to_ntuple(depth)(channels)
# a default tiered channel strategy
channels = tuple([channels // 2**i for i in range(depth)][::-1])
kernel_size = to_ntuple(depth)(kernel_size)
padding = to_ntuple(depth)(padding)
@ -235,13 +236,6 @@ class ConvStem(nn.Sequential):
in_chs = channels[i]
def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs):
embed_args = embed_args or {}
embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
def _resnetv2(layers=(3, 4, 9), **kwargs):
""" ResNet-V2 backbone helper"""
padding_same = kwargs.get('padding_same', True)
@ -257,6 +251,66 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
return backbone
def _convert_mobileclip(state_dict, model, prefix='image_encoder.model.'):
out = {}
for k, v in state_dict.items():
if not k.startswith(prefix):
continue
k = k.replace(prefix, '')
k = k.replace('patch_emb.', 'patch_embed.backbone.')
k = k.replace('block.conv', 'conv')
k = k.replace('block.norm', 'bn')
k = k.replace('post_transformer_norm.', 'norm.')
k = k.replace('pre_norm_mha.0', 'norm1')
k = k.replace('pre_norm_mha.1', 'attn')
k = k.replace('pre_norm_ffn.0', 'norm2')
k = k.replace('pre_norm_ffn.1', 'mlp.fc1')
k = k.replace('pre_norm_ffn.4', 'mlp.fc2')
k = k.replace('qkv_proj.', 'qkv.')
k = k.replace('out_proj.', 'proj.')
k = k.replace('transformer.', 'blocks.')
if k == 'pos_embed.pos_embed.pos_embed':
k = 'pos_embed'
v = v.squeeze(0)
if 'classifier.proj' in k:
bias_k = k.replace('classifier.proj', 'head.bias')
k = k.replace('classifier.proj', 'head.weight')
v = v.T
out[bias_k] = torch.zeros(v.shape[0])
out[k] = v
return out
def checkpoint_filter_fn(
state_dict: Dict[str, torch.Tensor],
model: VisionTransformer,
interpolation: str = 'bicubic',
antialias: bool = True,
) -> Dict[str, torch.Tensor]:
from .vision_transformer import checkpoint_filter_fn as _filter_fn
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
state_dict = _convert_mobileclip(state_dict, model)
return _filter_fn(state_dict, model, interpolation=interpolation, antialias=antialias)
def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', 3)
embed_args = embed_args or {}
embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args)
kwargs.setdefault('embed_layer', embed_layer)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
return build_model_with_cfg(
VisionTransformer,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,
)
def _cfg(url='', **kwargs):
return {
'url': url,
@ -331,6 +385,17 @@ default_cfgs = generate_default_cfgs({
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet50d_224.untrained': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_mci_224.apple_mclip': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt',
num_classes=512,
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.conv1.0',
),
'vit_base_mci_224.apple_mclip_lt': _cfg(
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt',
num_classes=512,
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.conv1.0',
),
})
@ -491,7 +556,7 @@ def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer:
)
model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
model = _create_vision_transformer_hybrid(
'vit_base_resnet50d_224', backbone=backbone, embed_args=dict(proj=False),
'vit_base_mci_224', backbone=backbone, embed_args=dict(proj=False),
pretrained=pretrained, **dict(model_args, **kwargs)
)
return model