Add quickgelu vit clip variants, simplify get_norm_layer and allow string args in vit norm/act. Add metaclip CLIP weights

This commit is contained in:
Ross Wightman 2023-11-02 17:18:17 -07:00 committed by Ross Wightman
parent c55bc41a42
commit a2e4a4c148
4 changed files with 167 additions and 46 deletions

View File

@ -157,3 +157,17 @@ class GELUTanh(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.gelu(input, approximate='tanh') return F.gelu(input, approximate='tanh')
def quick_gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x)
class QuickGELU(nn.Module):
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
"""
def __init__(self, inplace: bool = False):
super(QuickGELU, self).__init__()
def forward(self, input: torch.Tensor) -> torch.Tensor:
return quick_gelu(input)

View File

@ -29,6 +29,7 @@ _ACT_FN_DEFAULT = dict(
selu=F.selu, selu=F.selu,
gelu=gelu, gelu=gelu,
gelu_tanh=gelu_tanh, gelu_tanh=gelu_tanh,
quick_gelu=quick_gelu,
sigmoid=sigmoid, sigmoid=sigmoid,
tanh=tanh, tanh=tanh,
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,
@ -42,7 +43,7 @@ _ACT_FN_JIT = dict(
mish=F.mish if _has_mish else mish_jit, mish=F.mish if _has_mish else mish_jit,
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit, hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
hard_swish=F.hardswish if _has_hardswish else hard_swish_jit, hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
hard_mish=hard_mish_jit hard_mish=hard_mish_jit,
) )
_ACT_FN_ME = dict( _ACT_FN_ME = dict(
@ -73,6 +74,7 @@ _ACT_LAYER_DEFAULT = dict(
selu=nn.SELU, selu=nn.SELU,
gelu=GELU, gelu=GELU,
gelu_tanh=GELUTanh, gelu_tanh=GELUTanh,
quick_gelu=QuickGELU,
sigmoid=Sigmoid, sigmoid=Sigmoid,
tanh=Tanh, tanh=Tanh,
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
@ -87,7 +89,7 @@ _ACT_LAYER_JIT = dict(
mish=nn.Mish if _has_mish else MishJit, mish=nn.Mish if _has_mish else MishJit,
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit, hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit, hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
hard_mish=HardMishJit hard_mish=HardMishJit,
) )
_ACT_LAYER_ME = dict( _ACT_LAYER_ME = dict(

View File

@ -4,12 +4,14 @@ Create norm modules by string (to mirror create_act and creat_norm-act fns)
Copyright 2022 Ross Wightman Copyright 2022 Ross Wightman
""" """
import types
import functools import functools
import types
from typing import Type
import torch.nn as nn import torch.nn as nn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
from torchvision.ops import FrozenBatchNorm2d
_NORM_MAP = dict( _NORM_MAP = dict(
batchnorm=nn.BatchNorm2d, batchnorm=nn.BatchNorm2d,
@ -19,6 +21,8 @@ _NORM_MAP = dict(
groupnorm1=GroupNorm1, groupnorm1=GroupNorm1,
layernorm=LayerNorm, layernorm=LayerNorm,
layernorm2d=LayerNorm2d, layernorm2d=LayerNorm2d,
rmsnorm=RmsNorm,
frozenbatchnorm2d=FrozenBatchNorm2d,
) )
_NORM_TYPES = {m for n, m in _NORM_MAP.items()} _NORM_TYPES = {m for n, m in _NORM_MAP.items()}
@ -30,7 +34,10 @@ def create_norm_layer(layer_name, num_features, **kwargs):
def get_norm_layer(norm_layer): def get_norm_layer(norm_layer):
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) if not norm_layer:
# None or '' should return None
return None
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
norm_kwargs = {} norm_kwargs = {}
# unbind partial fn, so args can be rebound later # unbind partial fn, so args can be rebound later
@ -40,16 +47,9 @@ def get_norm_layer(norm_layer):
if isinstance(norm_layer, str): if isinstance(norm_layer, str):
layer_name = norm_layer.replace('_', '') layer_name = norm_layer.replace('_', '')
norm_layer = _NORM_MAP.get(layer_name, None) norm_layer = _NORM_MAP[layer_name]
elif norm_layer in _NORM_TYPES:
norm_layer = norm_layer
elif isinstance(norm_layer, types.FunctionType):
# if function type, assume it is a lambda/fn that creates a norm layer
norm_layer = norm_layer
else: else:
type_name = norm_layer.__name__.lower().replace('_', '') norm_layer = norm_layer
norm_layer = _NORM_MAP.get(type_name, None)
assert norm_layer is not None, f"No equivalent norm layer for {type_name}"
if norm_kwargs: if norm_kwargs:
norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args

View File

@ -27,7 +27,7 @@ import logging
import math import math
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Callable, List, Optional, Sequence, Tuple, Union from typing import Callable, List, Optional, Sequence, Tuple, Type, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -35,10 +35,12 @@ import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.jit import Final from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \ from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
get_act_layer, get_norm_layer, LayerType
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._registry import generate_default_cfgs, register_model, register_model_deprecations from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -414,10 +416,10 @@ class VisionTransformer(nn.Module):
drop_path_rate: float = 0., drop_path_rate: float = 0.,
weight_init: str = '', weight_init: str = '',
embed_layer: Callable = PatchEmbed, embed_layer: Callable = PatchEmbed,
norm_layer: Optional[Callable] = None, norm_layer: Optional[LayerType] = None,
act_layer: Optional[Callable] = None, act_layer: Optional[LayerType] = None,
block_fn: Callable = Block, block_fn: Type[nn.Module] = Block,
mlp_layer: Callable = Mlp, mlp_layer: Type[nn.Module] = Mlp,
): ):
""" """
Args: Args:
@ -450,8 +452,8 @@ class VisionTransformer(nn.Module):
assert global_pool in ('', 'avg', 'token', 'map') assert global_pool in ('', 'avg', 'token', 'map')
assert class_token or global_pool != 'token' assert class_token or global_pool != 'token'
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU act_layer = get_act_layer(act_layer) or nn.GELU
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool self.global_pool = global_pool
@ -1415,38 +1417,14 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K', hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_base_patch16_clip_224.datacompxl': _cfg(
hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_base_patch16_clip_224.dfn2b': _cfg(
hf_hub_id='apple/DFN2B-CLIP-ViT-B-16',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.laion2b': _cfg( 'vit_large_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K', hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768), mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768),
'vit_large_patch14_clip_224.datacompxl': _cfg(
hf_hub_id='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_large_patch14_clip_224.dfn2b': _cfg(
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.laion2b': _cfg( 'vit_huge_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K', hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_huge_patch14_clip_224.dfn5b': _cfg(
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_huge_patch14_clip_378.dfn5b': _cfg(
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14-378',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_giant_patch14_clip_224.laion2b': _cfg( 'vit_giant_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K', hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
@ -1456,6 +1434,59 @@ default_cfgs = generate_default_cfgs({
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
'vit_base_patch32_clip_224.datacompxl': _cfg(
hf_hub_id='laion/',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_base_patch32_clip_256.datacompxl': _cfg(
hf_hub_id='laion/',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 256, 256), num_classes=512),
'vit_base_patch16_clip_224.datacompxl': _cfg(
hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.datacompxl': _cfg(
hf_hub_id='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_base_patch16_clip_224.dfn2b': _cfg(
hf_hub_id='apple/DFN2B-CLIP-ViT-B-16',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.dfn2b': _cfg(
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.dfn5b': _cfg(
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_huge_patch14_clip_378.dfn5b': _cfg(
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14-378',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024),
'vit_base_patch32_clip_224.metaclip_2pt5b': _cfg(
hf_hub_id='facebook/metaclip-b32-fullcc2.5b',
hf_hub_filename='metaclip_b32_fullcc2.5b.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_base_patch16_clip_224.metaclip_2pt5b': _cfg(
hf_hub_id='facebook/metaclip-b16-fullcc2.5b',
hf_hub_filename='metaclip_b16_fullcc2.5b.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.metaclip_2pt5b': _cfg(
hf_hub_id='facebook/metaclip-l14-fullcc2.5b',
hf_hub_filename='metaclip_l14_fullcc2.5b.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.metaclip_2pt5b': _cfg(
hf_hub_id='facebook/metaclip-h14-fullcc2.5b',
hf_hub_filename='metaclip_h14_fullcc2.5b.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_base_patch32_clip_224.openai': _cfg( 'vit_base_patch32_clip_224.openai': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
@ -2078,6 +2109,80 @@ def vit_gigantic_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransform
'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) 'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model
def vit_base_patch32_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-B/32 CLIP image tower @ 224x224
"""
model_args = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch16_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-B/16 CLIP image tower w/ QuickGELU act
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_large_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act
"""
from timm.layers import get_act_layer
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_large_patch14_clip_quickgelu_336(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 w/ QuickGELU act
"""
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_huge_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/14) CLIP image tower w/ QuickGELU act.
"""
model_args = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_huge_patch14_clip_quickgelu_378(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 w/ QuickGELU act
"""
model_args = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
norm_layer=nn.LayerNorm, act_layer='quick_gelu')
model = _create_vision_transformer(
'vit_huge_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs))
return model
# Experimental models below # Experimental models below
@register_model @register_model