Add quickgelu vit clip variants, simplify get_norm_layer and allow string args in vit norm/act. Add metaclip CLIP weights

pull/2020/head
Ross Wightman 2023-11-02 17:18:17 -07:00 committed by Ross Wightman
parent c55bc41a42
commit a2e4a4c148
4 changed files with 167 additions and 46 deletions

View File

@ -157,3 +157,17 @@ class GELUTanh(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.gelu(input, approximate='tanh')
def quick_gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x)
class QuickGELU(nn.Module):
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
"""
def __init__(self, inplace: bool = False):
super(QuickGELU, self).__init__()
def forward(self, input: torch.Tensor) -> torch.Tensor:
return quick_gelu(input)

View File

@ -29,6 +29,7 @@ _ACT_FN_DEFAULT = dict(
selu=F.selu,
gelu=gelu,
gelu_tanh=gelu_tanh,
quick_gelu=quick_gelu,
sigmoid=sigmoid,
tanh=tanh,
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,
@ -42,7 +43,7 @@ _ACT_FN_JIT = dict(
mish=F.mish if _has_mish else mish_jit,
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
hard_mish=hard_mish_jit
hard_mish=hard_mish_jit,
)
_ACT_FN_ME = dict(
@ -73,6 +74,7 @@ _ACT_LAYER_DEFAULT = dict(
selu=nn.SELU,
gelu=GELU,
gelu_tanh=GELUTanh,
quick_gelu=QuickGELU,
sigmoid=Sigmoid,
tanh=Tanh,
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
@ -87,7 +89,7 @@ _ACT_LAYER_JIT = dict(
mish=nn.Mish if _has_mish else MishJit,
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
hard_mish=HardMishJit
hard_mish=HardMishJit,
)
_ACT_LAYER_ME = dict(

View File

@ -4,12 +4,14 @@ Create norm modules by string (to mirror create_act and creat_norm-act fns)
Copyright 2022 Ross Wightman
"""
import types
import functools
import types
from typing import Type
import torch.nn as nn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
from torchvision.ops import FrozenBatchNorm2d
_NORM_MAP = dict(
batchnorm=nn.BatchNorm2d,
@ -19,6 +21,8 @@ _NORM_MAP = dict(
groupnorm1=GroupNorm1,
layernorm=LayerNorm,
layernorm2d=LayerNorm2d,
rmsnorm=RmsNorm,
frozenbatchnorm2d=FrozenBatchNorm2d,
)
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}
@ -30,7 +34,10 @@ def create_norm_layer(layer_name, num_features, **kwargs):
def get_norm_layer(norm_layer):
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
if not norm_layer:
# None or '' should return None
return None
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
norm_kwargs = {}
# unbind partial fn, so args can be rebound later
@ -40,16 +47,9 @@ def get_norm_layer(norm_layer):
if isinstance(norm_layer, str):
layer_name = norm_layer.replace('_', '')
norm_layer = _NORM_MAP.get(layer_name, None)
elif norm_layer in _NORM_TYPES:
norm_layer = norm_layer
elif isinstance(norm_layer, types.FunctionType):
# if function type, assume it is a lambda/fn that creates a norm layer
norm_layer = norm_layer
norm_layer = _NORM_MAP[layer_name]
else:
type_name = norm_layer.__name__.lower().replace('_', '')
norm_layer = _NORM_MAP.get(type_name, None)
assert norm_layer is not None, f"No equivalent norm layer for {type_name}"
norm_layer = norm_layer
if norm_kwargs:
norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args

View File

@ -27,7 +27,7 @@ import logging
import math
from collections import OrderedDict
from functools import partial
from typing import Callable, List, Optional, Sequence, Tuple, Union
from typing import Callable, List, Optional, Sequence, Tuple, Type, Union
import torch
import torch.nn as nn
@ -35,10 +35,12 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
get_act_layer, get_norm_layer, LayerType
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -414,10 +416,10 @@ class VisionTransformer(nn.Module):
drop_path_rate: float = 0.,
weight_init: str = '',
embed_layer: Callable = PatchEmbed,
norm_layer: Optional[Callable] = None,
act_layer: Optional[Callable] = None,
block_fn: Callable = Block,
mlp_layer: Callable = Mlp,
norm_layer: Optional[LayerType] = None,
act_layer: Optional[LayerType] = None,
block_fn: Type[nn.Module] = Block,
mlp_layer: Type[nn.Module] = Mlp,
):
"""
Args:
@ -450,8 +452,8 @@ class VisionTransformer(nn.Module):
assert global_pool in ('', 'avg', 'token', 'map')
assert class_token or global_pool != 'token'
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
act_layer = get_act_layer(act_layer) or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
@ -1415,38 +1417,14 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_base_patch16_clip_224.datacompxl': _cfg(
hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_base_patch16_clip_224.dfn2b': _cfg(
hf_hub_id='apple/DFN2B-CLIP-ViT-B-16',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768),
'vit_large_patch14_clip_224.datacompxl': _cfg(
hf_hub_id='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_large_patch14_clip_224.dfn2b': _cfg(
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_huge_patch14_clip_224.dfn5b': _cfg(
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_huge_patch14_clip_378.dfn5b': _cfg(
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14-378',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_giant_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
hf_hub_filename='open_clip_pytorch_model.bin',
@ -1456,6 +1434,59 @@ default_cfgs = generate_default_cfgs({
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
'vit_base_patch32_clip_224.datacompxl': _cfg(
hf_hub_id='laion/',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_base_patch32_clip_256.datacompxl': _cfg(
hf_hub_id='laion/',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 256, 256), num_classes=512),
'vit_base_patch16_clip_224.datacompxl': _cfg(
hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.datacompxl': _cfg(
hf_hub_id='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_base_patch16_clip_224.dfn2b': _cfg(
hf_hub_id='apple/DFN2B-CLIP-ViT-B-16',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.dfn2b': _cfg(
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.dfn5b': _cfg(
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_huge_patch14_clip_378.dfn5b': _cfg(
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14-378',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024),
'vit_base_patch32_clip_224.metaclip_2pt5b': _cfg(
hf_hub_id='facebook/metaclip-b32-fullcc2.5b',
hf_hub_filename='metaclip_b32_fullcc2.5b.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_base_patch16_clip_224.metaclip_2pt5b': _cfg(
hf_hub_id='facebook/metaclip-b16-fullcc2.5b',
hf_hub_filename='metaclip_b16_fullcc2.5b.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.metaclip_2pt5b': _cfg(
hf_hub_id='facebook/metaclip-l14-fullcc2.5b',
hf_hub_filename='metaclip_l14_fullcc2.5b.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.metaclip_2pt5b': _cfg(
hf_hub_id='facebook/metaclip-h14-fullcc2.5b',
hf_hub_filename='metaclip_h14_fullcc2.5b.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_base_patch32_clip_224.openai': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
@ -2078,6 +2109,80 @@ def vit_gigantic_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransform
'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch32_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-B/32 CLIP image tower @ 224x224
"""
model_args = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch16_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-B/16 CLIP image tower w/ QuickGELU act
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_large_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act
"""
from timm.layers import get_act_layer
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_large_patch14_clip_quickgelu_336(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 w/ QuickGELU act
"""
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_huge_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/14) CLIP image tower w/ QuickGELU act.
"""
model_args = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_huge_patch14_clip_quickgelu_378(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 w/ QuickGELU act
"""
model_args = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_huge_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs))
return model
# Experimental models below
@register_model