From 88a1006e025c1a4e39fb2b4db7f8ad8cb85ae88f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 6 Jun 2024 12:38:52 -0700 Subject: [PATCH] checkpoint filter fns with consistent name, add mobileclip-b pretrained cfgs --- timm/models/beit.py | 4 +- timm/models/efficientformer.py | 4 +- timm/models/fastvit.py | 4 +- timm/models/pvt_v2.py | 4 +- timm/models/vision_transformer_hybrid.py | 109 ++++++++++++++++++----- 5 files changed, 95 insertions(+), 30 deletions(-) diff --git a/timm/models/beit.py b/timm/models/beit.py index 63b6db54..922d15e7 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -591,7 +591,7 @@ default_cfgs = generate_default_cfgs({ }) -def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True): +def checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True): state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('module', state_dict) # beit v2 didn't strip module @@ -637,7 +637,7 @@ def _create_beit(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( Beit, variant, pretrained, - pretrained_filter_fn=_beit_checkpoint_filter_fn, + pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index c28538bc..32630683 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -556,7 +556,7 @@ class EfficientFormer(nn.Module): return x -def _checkpoint_filter_fn(state_dict, model): +def checkpoint_filter_fn(state_dict, model): """ Remap original checkpoints -> timm """ if 'stem.0.weight' in state_dict: return state_dict # non-original checkpoint, no remapping needed @@ -611,7 +611,7 @@ def _create_efficientformer(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', 4) model = build_model_with_cfg( EfficientFormer, variant, pretrained, - pretrained_filter_fn=_checkpoint_filter_fn, + pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs, ) diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index 66142105..ef7ec3c9 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -1414,7 +1414,7 @@ default_cfgs = generate_default_cfgs({ }) -def _checkpoint_filter_fn(state_dict, model): +def checkpoint_filter_fn(state_dict, model): """ Remap original checkpoints -> timm """ if 'stem.0.conv_kxk.0.conv.weight' in state_dict: return state_dict # non-original checkpoint, no remapping needed @@ -1493,7 +1493,7 @@ def _create_fastvit(variant, pretrained=False, **kwargs): FastVit, variant, pretrained, - pretrained_filter_fn=_checkpoint_filter_fn, + pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs ) diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 1d9c6842..90ebfe7a 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -403,7 +403,7 @@ class PyramidVisionTransformerV2(nn.Module): return x -def _checkpoint_filter_fn(state_dict, model): +def checkpoint_filter_fn(state_dict, model): """ Remap original checkpoints -> timm """ if 'patch_embed.proj.weight' in state_dict: return state_dict # non-original checkpoint, no remapping needed @@ -430,7 +430,7 @@ def _create_pvt2(variant, pretrained=False, **kwargs): PyramidVisionTransformerV2, variant, pretrained, - pretrained_filter_fn=_checkpoint_filter_fn, + pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs, ) diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index af51fa98..3501565c 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -15,7 +15,7 @@ Hacked together by / Copyright 2020, Ross Wightman """ import math from functools import partial -from typing import List, Optional, Tuple, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -24,10 +24,11 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to +from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model, register_model_deprecations from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, create_resnetv2_stem -from .vision_transformer import _create_vision_transformer, VisionTransformer +from .vision_transformer import VisionTransformer class HybridEmbed(nn.Module): @@ -159,22 +160,26 @@ class HybridEmbedWithSize(nn.Module): """ def __init__( self, - backbone, - img_size=224, - patch_size=1, - feature_size=None, - in_chans=3, - embed_dim=768, + backbone: nn.Module, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 1, + feature_size: Optional[Union[int, Tuple[int, int]]] = None, + feature_ratio: Optional[Union[int, Tuple[int, int]]] = None, + in_chans: int = 3, + embed_dim: int = 768, bias=True, + proj=True, ): super().__init__( backbone=backbone, img_size=img_size, patch_size=patch_size, feature_size=feature_size, + feature_ratio=feature_ratio, in_chans=in_chans, embed_dim=embed_dim, bias=bias, + proj=proj, ) @torch.jit.ignore @@ -206,12 +211,8 @@ class ConvStem(nn.Sequential): ): super().__init__() if isinstance(channels, int): - if depth == 4: - channels = (channels // 8, channels // 4, channels // 2, channels) - elif depth == 3: - channels = (channels // 4, channels // 2, channels) - else: - channels = to_ntuple(depth)(channels) + # a default tiered channel strategy + channels = tuple([channels // 2**i for i in range(depth)][::-1]) kernel_size = to_ntuple(depth)(kernel_size) padding = to_ntuple(depth)(padding) @@ -235,13 +236,6 @@ class ConvStem(nn.Sequential): in_chs = channels[i] -def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs): - embed_args = embed_args or {} - embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args) - kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set - return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs) - - def _resnetv2(layers=(3, 4, 9), **kwargs): """ ResNet-V2 backbone helper""" padding_same = kwargs.get('padding_same', True) @@ -257,6 +251,66 @@ def _resnetv2(layers=(3, 4, 9), **kwargs): return backbone +def _convert_mobileclip(state_dict, model, prefix='image_encoder.model.'): + out = {} + for k, v in state_dict.items(): + if not k.startswith(prefix): + continue + k = k.replace(prefix, '') + k = k.replace('patch_emb.', 'patch_embed.backbone.') + k = k.replace('block.conv', 'conv') + k = k.replace('block.norm', 'bn') + k = k.replace('post_transformer_norm.', 'norm.') + k = k.replace('pre_norm_mha.0', 'norm1') + k = k.replace('pre_norm_mha.1', 'attn') + k = k.replace('pre_norm_ffn.0', 'norm2') + k = k.replace('pre_norm_ffn.1', 'mlp.fc1') + k = k.replace('pre_norm_ffn.4', 'mlp.fc2') + k = k.replace('qkv_proj.', 'qkv.') + k = k.replace('out_proj.', 'proj.') + k = k.replace('transformer.', 'blocks.') + if k == 'pos_embed.pos_embed.pos_embed': + k = 'pos_embed' + v = v.squeeze(0) + if 'classifier.proj' in k: + bias_k = k.replace('classifier.proj', 'head.bias') + k = k.replace('classifier.proj', 'head.weight') + v = v.T + out[bias_k] = torch.zeros(v.shape[0]) + out[k] = v + return out + + +def checkpoint_filter_fn( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, + interpolation: str = 'bicubic', + antialias: bool = True, +) -> Dict[str, torch.Tensor]: + from .vision_transformer import checkpoint_filter_fn as _filter_fn + + if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict: + state_dict = _convert_mobileclip(state_dict, model) + + return _filter_fn(state_dict, model, interpolation=interpolation, antialias=antialias) + + +def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', 3) + embed_args = embed_args or {} + embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args) + kwargs.setdefault('embed_layer', embed_layer) + kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set + return build_model_with_cfg( + VisionTransformer, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) + + def _cfg(url='', **kwargs): return { 'url': url, @@ -331,6 +385,17 @@ default_cfgs = generate_default_cfgs({ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), 'vit_base_resnet50d_224.untrained': _cfg( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + + 'vit_base_mci_224.apple_mclip': _cfg( + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt', + num_classes=512, + mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.conv1.0', + ), + 'vit_base_mci_224.apple_mclip_lt': _cfg( + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt', + num_classes=512, + mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.conv1.0', + ), }) @@ -491,7 +556,7 @@ def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer: ) model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True) model = _create_vision_transformer_hybrid( - 'vit_base_resnet50d_224', backbone=backbone, embed_args=dict(proj=False), + 'vit_base_mci_224', backbone=backbone, embed_args=dict(proj=False), pretrained=pretrained, **dict(model_args, **kwargs) ) return model