Add quickgelu vit clip variants, simplify get_norm_layer and allow string args in vit norm/act. Add metaclip CLIP weights
parent
c55bc41a42
commit
a2e4a4c148
|
@ -157,3 +157,17 @@ class GELUTanh(nn.Module):
|
|||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.gelu(input, approximate='tanh')
|
||||
|
||||
|
||||
def quick_gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
|
||||
"""
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(QuickGELU, self).__init__()
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return quick_gelu(input)
|
||||
|
|
|
@ -29,6 +29,7 @@ _ACT_FN_DEFAULT = dict(
|
|||
selu=F.selu,
|
||||
gelu=gelu,
|
||||
gelu_tanh=gelu_tanh,
|
||||
quick_gelu=quick_gelu,
|
||||
sigmoid=sigmoid,
|
||||
tanh=tanh,
|
||||
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,
|
||||
|
@ -42,7 +43,7 @@ _ACT_FN_JIT = dict(
|
|||
mish=F.mish if _has_mish else mish_jit,
|
||||
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
|
||||
hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
|
||||
hard_mish=hard_mish_jit
|
||||
hard_mish=hard_mish_jit,
|
||||
)
|
||||
|
||||
_ACT_FN_ME = dict(
|
||||
|
@ -73,6 +74,7 @@ _ACT_LAYER_DEFAULT = dict(
|
|||
selu=nn.SELU,
|
||||
gelu=GELU,
|
||||
gelu_tanh=GELUTanh,
|
||||
quick_gelu=QuickGELU,
|
||||
sigmoid=Sigmoid,
|
||||
tanh=Tanh,
|
||||
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
|
||||
|
@ -87,7 +89,7 @@ _ACT_LAYER_JIT = dict(
|
|||
mish=nn.Mish if _has_mish else MishJit,
|
||||
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
|
||||
hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
|
||||
hard_mish=HardMishJit
|
||||
hard_mish=HardMishJit,
|
||||
)
|
||||
|
||||
_ACT_LAYER_ME = dict(
|
||||
|
|
|
@ -4,12 +4,14 @@ Create norm modules by string (to mirror create_act and creat_norm-act fns)
|
|||
|
||||
Copyright 2022 Ross Wightman
|
||||
"""
|
||||
import types
|
||||
import functools
|
||||
import types
|
||||
from typing import Type
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
|
||||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
|
||||
from torchvision.ops import FrozenBatchNorm2d
|
||||
|
||||
_NORM_MAP = dict(
|
||||
batchnorm=nn.BatchNorm2d,
|
||||
|
@ -19,6 +21,8 @@ _NORM_MAP = dict(
|
|||
groupnorm1=GroupNorm1,
|
||||
layernorm=LayerNorm,
|
||||
layernorm2d=LayerNorm2d,
|
||||
rmsnorm=RmsNorm,
|
||||
frozenbatchnorm2d=FrozenBatchNorm2d,
|
||||
)
|
||||
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
|
||||
|
||||
|
@ -30,7 +34,10 @@ def create_norm_layer(layer_name, num_features, **kwargs):
|
|||
|
||||
|
||||
def get_norm_layer(norm_layer):
|
||||
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
|
||||
if not norm_layer:
|
||||
# None or '' should return None
|
||||
return None
|
||||
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
|
||||
norm_kwargs = {}
|
||||
|
||||
# unbind partial fn, so args can be rebound later
|
||||
|
@ -40,16 +47,9 @@ def get_norm_layer(norm_layer):
|
|||
|
||||
if isinstance(norm_layer, str):
|
||||
layer_name = norm_layer.replace('_', '')
|
||||
norm_layer = _NORM_MAP.get(layer_name, None)
|
||||
elif norm_layer in _NORM_TYPES:
|
||||
norm_layer = norm_layer
|
||||
elif isinstance(norm_layer, types.FunctionType):
|
||||
# if function type, assume it is a lambda/fn that creates a norm layer
|
||||
norm_layer = norm_layer
|
||||
norm_layer = _NORM_MAP[layer_name]
|
||||
else:
|
||||
type_name = norm_layer.__name__.lower().replace('_', '')
|
||||
norm_layer = _NORM_MAP.get(type_name, None)
|
||||
assert norm_layer is not None, f"No equivalent norm layer for {type_name}"
|
||||
norm_layer = norm_layer
|
||||
|
||||
if norm_kwargs:
|
||||
norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args
|
||||
|
|
|
@ -27,7 +27,7 @@ import logging
|
|||
import math
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Callable, List, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -35,10 +35,12 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint
|
||||
from torch.jit import Final
|
||||
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
|
||||
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn
|
||||
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
||||
get_act_layer, get_norm_layer, LayerType
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
@ -414,10 +416,10 @@ class VisionTransformer(nn.Module):
|
|||
drop_path_rate: float = 0.,
|
||||
weight_init: str = '',
|
||||
embed_layer: Callable = PatchEmbed,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
act_layer: Optional[Callable] = None,
|
||||
block_fn: Callable = Block,
|
||||
mlp_layer: Callable = Mlp,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
act_layer: Optional[LayerType] = None,
|
||||
block_fn: Type[nn.Module] = Block,
|
||||
mlp_layer: Type[nn.Module] = Mlp,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -450,8 +452,8 @@ class VisionTransformer(nn.Module):
|
|||
assert global_pool in ('', 'avg', 'token', 'map')
|
||||
assert class_token or global_pool != 'token'
|
||||
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
act_layer = act_layer or nn.GELU
|
||||
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
||||
act_layer = get_act_layer(act_layer) or nn.GELU
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
|
@ -1415,38 +1417,14 @@ default_cfgs = generate_default_cfgs({
|
|||
hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
||||
'vit_base_patch16_clip_224.datacompxl': _cfg(
|
||||
hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
||||
'vit_base_patch16_clip_224.dfn2b': _cfg(
|
||||
hf_hub_id='apple/DFN2B-CLIP-ViT-B-16',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
||||
'vit_large_patch14_clip_224.laion2b': _cfg(
|
||||
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768),
|
||||
'vit_large_patch14_clip_224.datacompxl': _cfg(
|
||||
hf_hub_id='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
||||
'vit_large_patch14_clip_224.dfn2b': _cfg(
|
||||
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
||||
'vit_huge_patch14_clip_224.laion2b': _cfg(
|
||||
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
|
||||
'vit_huge_patch14_clip_224.dfn5b': _cfg(
|
||||
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
|
||||
'vit_huge_patch14_clip_378.dfn5b': _cfg(
|
||||
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14-378',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
|
||||
'vit_giant_patch14_clip_224.laion2b': _cfg(
|
||||
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
|
@ -1456,6 +1434,59 @@ default_cfgs = generate_default_cfgs({
|
|||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
|
||||
|
||||
'vit_base_patch32_clip_224.datacompxl': _cfg(
|
||||
hf_hub_id='laion/',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
||||
'vit_base_patch32_clip_256.datacompxl': _cfg(
|
||||
hf_hub_id='laion/',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
crop_pct=1.0, input_size=(3, 256, 256), num_classes=512),
|
||||
'vit_base_patch16_clip_224.datacompxl': _cfg(
|
||||
hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
||||
'vit_large_patch14_clip_224.datacompxl': _cfg(
|
||||
hf_hub_id='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
||||
|
||||
'vit_base_patch16_clip_224.dfn2b': _cfg(
|
||||
hf_hub_id='apple/DFN2B-CLIP-ViT-B-16',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
||||
'vit_large_patch14_clip_224.dfn2b': _cfg(
|
||||
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
||||
'vit_huge_patch14_clip_224.dfn5b': _cfg(
|
||||
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
|
||||
'vit_huge_patch14_clip_378.dfn5b': _cfg(
|
||||
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14-378',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024),
|
||||
|
||||
'vit_base_patch32_clip_224.metaclip_2pt5b': _cfg(
|
||||
hf_hub_id='facebook/metaclip-b32-fullcc2.5b',
|
||||
hf_hub_filename='metaclip_b32_fullcc2.5b.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
||||
'vit_base_patch16_clip_224.metaclip_2pt5b': _cfg(
|
||||
hf_hub_id='facebook/metaclip-b16-fullcc2.5b',
|
||||
hf_hub_filename='metaclip_b16_fullcc2.5b.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
||||
'vit_large_patch14_clip_224.metaclip_2pt5b': _cfg(
|
||||
hf_hub_id='facebook/metaclip-l14-fullcc2.5b',
|
||||
hf_hub_filename='metaclip_l14_fullcc2.5b.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
||||
'vit_huge_patch14_clip_224.metaclip_2pt5b': _cfg(
|
||||
hf_hub_id='facebook/metaclip-h14-fullcc2.5b',
|
||||
hf_hub_filename='metaclip_h14_fullcc2.5b.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
|
||||
|
||||
'vit_base_patch32_clip_224.openai': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
||||
|
@ -2078,6 +2109,80 @@ def vit_gigantic_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransform
|
|||
'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch32_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-B/32 CLIP image tower @ 224x224
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
|
||||
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-B/16 CLIP image tower w/ QuickGELU act
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
|
||||
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act
|
||||
"""
|
||||
from timm.layers import get_act_layer
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
|
||||
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
|
||||
model = _create_vision_transformer(
|
||||
'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch14_clip_quickgelu_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 w/ QuickGELU act
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
|
||||
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
|
||||
model = _create_vision_transformer(
|
||||
'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Huge model (ViT-H/14) CLIP image tower w/ QuickGELU act.
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
|
||||
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
|
||||
model = _create_vision_transformer(
|
||||
'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch14_clip_quickgelu_378(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 w/ QuickGELU act
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
|
||||
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
|
||||
model = _create_vision_transformer(
|
||||
'vit_huge_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
# Experimental models below
|
||||
|
||||
@register_model
|
||||
|
|
Loading…
Reference in New Issue