From 99d4c7d2027b67653d8e5fb15aec210283c2551d Mon Sep 17 00:00:00 2001 From: Beckschen Date: Sun, 5 May 2024 02:50:14 -0400 Subject: [PATCH 1/4] add ViTamin models --- timm/models/__init__.py | 1 + timm/models/vitamin.py | 561 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 562 insertions(+) create mode 100644 timm/models/vitamin.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 9d09efac..9c7bee6f 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -70,6 +70,7 @@ from .vision_transformer import * from .vision_transformer_hybrid import * from .vision_transformer_relpos import * from .vision_transformer_sam import * +from .vitamin import * from .volo import * from .vovnet import * from .xception import * diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py new file mode 100644 index 00000000..3eecb8db --- /dev/null +++ b/timm/models/vitamin.py @@ -0,0 +1,561 @@ +""" ViTamin + +Paper: Designing Scalable Vison Models in the Vision-Language Era + +@misc{chen2023designing, + title={Designing Scalable Vison Models in the Vision-Language Era}, + author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen}, + year={2023}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} + +Based on Apache 2.0 licensed code at https://github.com/ViTamin/ViTamin + +Modifications and timm support by Jieneng Chen 2023 + +Reference: +https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py +https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py +""" + +from functools import partial +from typing import List, Tuple +from dataclasses import dataclass, replace +from typing import Callable, Optional, Union, Tuple, List, Sequence +import math, time +from torch.jit import Final +import torch +import torch.nn as nn +import torch.nn.functional as F +import timm +from timm.layers import to_2tuple +from torch.utils.checkpoint import checkpoint +from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_ + +from timm.models._registry import register_model +from timm.layers import DropPath +from timm.layers.norm_act import _create_act + +from timm.models._manipulate import named_apply, checkpoint_seq +from timm.models._builder import build_model_with_cfg +from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn +from timm.models.vision_transformer_hybrid import HybridEmbed + + +@dataclass +class VitConvCfg: + expand_ratio: float = 4.0 + expand_output: bool = True # calculate expansion channels from output (vs input chs) + kernel_size: int = 3 + group_size: int = 1 # 1 == depthwise + pre_norm_act: bool = False # activation after pre-norm + stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw' + pool_type: str = 'avg2' + downsample_pool_type: str = 'avg2' + act_layer: str = 'gelu' # stem & stage 1234 + act_layer1: str = 'gelu' # stage 1234 + act_layer2: str = 'gelu' # stage 1234 + norm_layer: str = '' + norm_layer_cl: str = '' + norm_eps: Optional[float] = None + down_shortcut: Optional[bool] = True + mlp: str = 'mlp' + + def __post_init__(self): + # mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args + use_mbconv = True + if not self.norm_layer: + self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d' + if not self.norm_layer_cl and not use_mbconv: + self.norm_layer_cl = 'layernorm' + if self.norm_eps is None: + self.norm_eps = 1e-5 if use_mbconv else 1e-6 + self.downsample_pool_type = self.downsample_pool_type or self.pool_type + +@dataclass +class VitCfg: + # embed_dim: Tuple[int, ...] = (96, 192, 384, 768) + embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768) + depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2) + stem_width: int = 64 + conv_cfg: VitConvCfg = VitConvCfg() + weight_init: str = 'vit_eff' + head_type: str = "" + stem_type: str = "stem" + ln2d_permute: bool = True + # memory_format: str="" + + +def _init_conv(module, name, scheme=''): + if isinstance(module, nn.Conv2d): + fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels + fan_out //= module.groups + nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out)) + if module.bias is not None: + nn.init.zeros_(module.bias) + +class Stem(nn.Module): + def __init__( + self, + in_chs: int, + out_chs: int, + act_layer: str = 'gelu', + norm_layer: str = 'layernorm2d', + norm_eps: float = 1e-6, + bias: bool = True, + ): + super().__init__() + self.grad_checkpointing=False + norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) + self.out_chs = out_chs + self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias) + self.norm1 = norm_act_layer(out_chs) + self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias) + named_apply(_init_conv, self) + + def forward(self, x): + if self.grad_checkpointing: + x = checkpoint(self.conv1, x) + x = self.norm1(x) + x = checkpoint(self.conv2, x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) + + return x + +class Downsample2d(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + pool_type: str = 'avg2', + bias: bool = True, + ): + super().__init__() + self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) + + + if dim != dim_out: + self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) # 1x1 conv + else: + self.expand = nn.Identity() + + def forward(self, x): + x = self.pool(x) # spatial downsample + x = self.expand(x) # expand chs + return x + + +class StridedConv(nn.Module): + """ downsample 2d as well + """ + def __init__( + self, + kernel_size=3, + stride=2, + padding=1, + in_chans=3, + embed_dim=768, + ln2d_permute=True + ): + super().__init__() + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) + self.permute = ln2d_permute # TODO: disable + norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6) + self.norm = norm_layer(in_chans) # affine over C + + def forward(self, x): + x = self.norm(x) + x = self.proj(x) + return x + + +class MbConvLNBlock(nn.Module): + """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand) + """ + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 1, + drop_path: float = 0., + kernel_size: int = 3, + norm_layer: str = 'layernorm2d', + norm_eps: float = 1e-6, + act_layer: str = 'gelu', + expand_ratio: float = 4.0, + ): + super(MbConvLNBlock, self).__init__() + self.stride, self.in_chs, self.out_chs = stride, in_chs, out_chs + mid_chs = make_divisible(out_chs * expand_ratio) + prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) + + if stride == 2: + self.shortcut = Downsample2d(in_chs, out_chs, pool_type='avg', bias=True) + elif in_chs != out_chs: + self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True) + else: + self.shortcut = nn.Identity() + + self.pre_norm = prenorm_act_layer(in_chs, apply_act=False) + self.down = nn.Identity() + self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True) + self.act1 = _create_act(act_layer, inplace=True) + self.act2 = _create_act(act_layer, inplace=True) + + self.conv2_kxk = create_conv2d(mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True) + self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + + def init_weights(self, scheme=''): + named_apply(partial(_init_conv, scheme=scheme), self) + + def forward(self, x): + shortcut = self.shortcut(x) + + x = self.pre_norm(x) + x = self.down(x) # nn.Identity() + + # 1x1 expansion conv & act + x = self.conv1_1x1(x) + x = self.act1(x) + + # (strided) depthwise 3x3 conv & act + x = self.conv2_kxk(x) + x = self.act2(x) + + # 1x1 linear projection to output width + x = self.conv3_1x1(x) + x = self.drop_path(x) + shortcut + + return x + + +class MbConvStages(nn.Module): + """ MobileConv for stage 1 and stage 2 of ViTamin + """ + def __init__( + self, + cfg: VitCfg, + img_size: Union[int, Tuple[int, int]] = 224, # place holder + in_chans: int = 3, + ): + super().__init__() + self.grad_checkpointing = False + self.stem = Stem( + in_chs=in_chans, + out_chs=cfg.stem_width, + ) + stages = [] + self.num_stages = len(cfg.embed_dim) + for s, dim in enumerate(cfg.embed_dim[:2]): # stage + blocks = [] + stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width + for d in range(cfg.depths[s]): + blocks += [MbConvLNBlock( + in_chs = stage_in_chs if d==0 else dim, + out_chs = dim, + stride = 2 if d == 0 else 1, + # cfg = cfg.conv_cfg, + )] + blocks = nn.Sequential(*blocks) + stages += [blocks] + + self.stages = nn.ModuleList(stages) + self.pool = StridedConv( + stride=2, + in_chans=cfg.embed_dim[1], + embed_dim=cfg.embed_dim[2] + ) + + def forward(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + for stage in self.stages: + x = checkpoint_seq(stage, x) + x = checkpoint(self.pool, x) + else: + for stage in self.stages: + x = stage(x) + x = self.pool(x) + + return x + +class GeGluMlp(nn.Module): + def __init__( + self, + in_features, + hidden_features, + act_layer = None, + drop = 0.0, + ): + super().__init__() + norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6) + self.norm = norm_layer(in_features) + self.act = nn.GELU() + self.w0 = nn.Linear(in_features, hidden_features) + self.w1 = nn.Linear(in_features, hidden_features) + self.w2 = nn.Linear(hidden_features, in_features) + + def forward(self, x): + x = self.norm(x) + x = self.act(self.w0(x)) * self.w1(x) + x = self.w2(x) + return x + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__( + self, + backbone, + img_size=224, + patch_size=1, + feature_size=None, + in_chans=3, + embed_dim=1024, + bias=True, + dynamic_img_pad=False, + ): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.backbone = backbone + with torch.no_grad(): + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + o = o[-1] # last feature if backbone outputs list/tuple of features + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + + assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 + self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.proj = nn.Identity() + + def forward(self, x): + x = self.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) + return x + +def _create_vision_transformer(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + if 'flexi' in variant: + # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed + # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. + _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False) + else: + _filter_fn = checkpoint_filter_fn + + return build_model_with_cfg( + VisionTransformer, + variant, + pretrained, + pretrained_filter_fn=_filter_fn, + **kwargs, + ) + + +def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): + embed_layer = partial(HybridEmbed, backbone=backbone) + kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set + return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs) + + +@register_model +def vitamin_small(pretrained=False, **kwargs) -> VisionTransformer: + stage_1_2 = MbConvStages(cfg=VitCfg( + embed_dim=(64, 128, 384), + depths=(2, 4, 1), + stem_width=64, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + stage3_args = dict(embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid('vitamin_small', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) + return model + + +@register_model +def vitamin_base(pretrained=False, **kwargs) -> VisionTransformer: + stage_1_2 = MbConvStages(cfg=VitCfg( + embed_dim=(128, 256, 768), + depths=(2, 4, 1), + stem_width=128, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + stage3_args = dict(embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid('vitamin_base', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) + return model + + +@register_model +def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer: + stage_1_2 = MbConvStages(cfg=VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + stage3_args = dict(embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid( + 'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) + return model + +@register_model +def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: + backbone = MbConvStages(cfg=VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + model_args = dict(img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid( + 'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: + backbone = MbConvStages(cfg=VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + model_args = dict(img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid( + 'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: + backbone = MbConvStages(cfg=VitCfg( + embed_dim=(160, 320, 1024), + depths=(2, 4, 1), + stem_width=160, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + model_args = dict(img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid( + 'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: + backbone = MbConvStages(cfg=VitCfg( + embed_dim=(192, 384, 1152), + depths=(2, 4, 1), + stem_width=192, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + model_args = dict(img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid( + 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: + backbone = MbConvStages(cfg=VitCfg( + embed_dim=(192, 384, 1152), + depths=(2, 4, 1), + stem_width=192, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + model_args = dict(img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid( + 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: + backbone = MbConvStages(cfg=VitCfg( + embed_dim=(192, 384, 1152), + depths=(2, 4, 1), + stem_width=192, + conv_cfg = VitConvCfg( + norm_layer='layernorm2d', + norm_eps=1e-6, + ), + head_type='1d', + ), + ) + model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') + model = _create_vision_transformer_hybrid( + 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +def count_params(model: nn.Module): + return sum([m.numel() for m in model.parameters()]) + +def count_stage_params(model: nn.Module, prefix='none'): + collections = [] + for name, m in model.named_parameters(): + print(name) + if name.startswith(prefix): + collections.append(m.numel()) + return sum(collections) + + +if __name__ == "__main__": + model = timm.create_model('vitamin_large', num_classes=10).cuda() + # x = torch.rand([2,3,224,224]).cuda() + check_keys(model) From df304ffbf24114c2faf62eb0e6faae4c18320256 Mon Sep 17 00:00:00 2001 From: Beckschen Date: Tue, 14 May 2024 15:10:05 -0400 Subject: [PATCH 2/4] the dataclass init needs to use the default factory pattern, according to Ross --- timm/models/vitamin.py | 62 ++++++++++++------------------------------ 1 file changed, 17 insertions(+), 45 deletions(-) diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index 3eecb8db..ad1b6883 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -21,7 +21,7 @@ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision from functools import partial from typing import List, Tuple -from dataclasses import dataclass, replace +from dataclasses import dataclass, replace, field from typing import Callable, Optional, Union, Tuple, List, Sequence import math, time from torch.jit import Final @@ -29,16 +29,17 @@ import torch import torch.nn as nn import torch.nn.functional as F import timm -from timm.layers import to_2tuple + from torch.utils.checkpoint import checkpoint from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_ -from timm.models._registry import register_model +from timm.layers import to_2tuple from timm.layers import DropPath from timm.layers.norm_act import _create_act from timm.models._manipulate import named_apply, checkpoint_seq from timm.models._builder import build_model_with_cfg +from timm.models._registry import register_model from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn from timm.models.vision_transformer_hybrid import HybridEmbed @@ -54,37 +55,19 @@ class VitConvCfg: pool_type: str = 'avg2' downsample_pool_type: str = 'avg2' act_layer: str = 'gelu' # stem & stage 1234 - act_layer1: str = 'gelu' # stage 1234 - act_layer2: str = 'gelu' # stage 1234 norm_layer: str = '' - norm_layer_cl: str = '' - norm_eps: Optional[float] = None + norm_eps: float = 1e-5 down_shortcut: Optional[bool] = True mlp: str = 'mlp' - def __post_init__(self): - # mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args - use_mbconv = True - if not self.norm_layer: - self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d' - if not self.norm_layer_cl and not use_mbconv: - self.norm_layer_cl = 'layernorm' - if self.norm_eps is None: - self.norm_eps = 1e-5 if use_mbconv else 1e-6 - self.downsample_pool_type = self.downsample_pool_type or self.pool_type @dataclass class VitCfg: - # embed_dim: Tuple[int, ...] = (96, 192, 384, 768) embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768) depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2) stem_width: int = 64 - conv_cfg: VitConvCfg = VitConvCfg() - weight_init: str = 'vit_eff' + conv_cfg: VitConvCfg = field(default_factory=VitConvCfg) head_type: str = "" - stem_type: str = "stem" - ln2d_permute: bool = True - # memory_format: str="" def _init_conv(module, name, scheme=''): @@ -95,6 +78,7 @@ def _init_conv(module, name, scheme=''): if module.bias is not None: nn.init.zeros_(module.bias) + class Stem(nn.Module): def __init__( self, @@ -126,6 +110,7 @@ class Stem(nn.Module): return x + class Downsample2d(nn.Module): def __init__( self, @@ -158,12 +143,10 @@ class StridedConv(nn.Module): stride=2, padding=1, in_chans=3, - embed_dim=768, - ln2d_permute=True + embed_dim=768 ): super().__init__() self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) - self.permute = ln2d_permute # TODO: disable norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6) self.norm = norm_layer(in_chans) # affine over C @@ -354,6 +337,7 @@ class HybridEmbed(nn.Module): x = x.flatten(2).transpose(1, 2) return x + def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') @@ -434,6 +418,7 @@ def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) return model + @register_model def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -452,6 +437,7 @@ def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -470,6 +456,7 @@ def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -488,6 +475,7 @@ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -506,6 +494,7 @@ def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -524,6 +513,7 @@ def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -540,22 +530,4 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') model = _create_vision_transformer_hybrid( 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) - return model - - -def count_params(model: nn.Module): - return sum([m.numel() for m in model.parameters()]) - -def count_stage_params(model: nn.Module, prefix='none'): - collections = [] - for name, m in model.named_parameters(): - print(name) - if name.startswith(prefix): - collections.append(m.numel()) - return sum(collections) - - -if __name__ == "__main__": - model = timm.create_model('vitamin_large', num_classes=10).cuda() - # x = torch.rand([2,3,224,224]).cuda() - check_keys(model) + return model \ No newline at end of file From 530fb49e7e96fc90d4620baf5fc7de3c6edd12c9 Mon Sep 17 00:00:00 2001 From: Beckschen Date: Fri, 17 May 2024 06:48:59 -0400 Subject: [PATCH 3/4] Add link to model weights on Hugging Face --- timm/models/vitamin.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index ad1b6883..75022c5f 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -1,18 +1,18 @@ """ ViTamin Paper: Designing Scalable Vison Models in the Vision-Language Era +Model Weights on Huggingface: https://huggingface.co/collections/jienengchen/vitamin-family-661048126b72debdaca060bf -@misc{chen2023designing, - title={Designing Scalable Vison Models in the Vision-Language Era}, - author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen}, - year={2023}, - archivePrefix={arXiv}, - primaryClass={cs.CV} +@inproceedings{chen2024vitamin, + title={ViTamin: Designing Scalable Vision Models in the Vision-language Era}, + author={Chen, Jieneng and Yu, Qihang and Shen, Xiaohui and Yuille, Alan and Chen, Liang-Chieh}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + year={2024} } Based on Apache 2.0 licensed code at https://github.com/ViTamin/ViTamin -Modifications and timm support by Jieneng Chen 2023 +Modifications and timm support by Jieneng Chen 2024 Reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py @@ -122,7 +122,6 @@ class Downsample2d(nn.Module): super().__init__() self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) - if dim != dim_out: self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) # 1x1 conv else: @@ -530,4 +529,11 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') model = _create_vision_transformer_hybrid( 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) - return model \ No newline at end of file + return model + + +if __name__ == "__main__": + model = timm.create_model('vitamin_large', num_classes=10).cuda() + x = torch.rand([2,3,224,224]).cuda() + y = model(x) + print(y.shape) \ No newline at end of file From 7a2ad6bce1a4a3e230bd5fa7d27431a644030b4c Mon Sep 17 00:00:00 2001 From: Beckschen Date: Fri, 17 May 2024 06:51:35 -0400 Subject: [PATCH 4/4] Add link to model weights on Hugging Face --- timm/models/vitamin.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index 75022c5f..f84a59d6 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -1,7 +1,7 @@ """ ViTamin Paper: Designing Scalable Vison Models in the Vision-Language Era -Model Weights on Huggingface: https://huggingface.co/collections/jienengchen/vitamin-family-661048126b72debdaca060bf +A family of model weights on Huggingface: https://huggingface.co/collections/jienengchen/vitamin-family-661048126b72debdaca060bf @inproceedings{chen2024vitamin, title={ViTamin: Designing Scalable Vision Models in the Vision-language Era}, @@ -529,11 +529,4 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') model = _create_vision_transformer_hybrid( 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) - return model - - -if __name__ == "__main__": - model = timm.create_model('vitamin_large', num_classes=10).cuda() - x = torch.rand([2,3,224,224]).cuda() - y = model(x) - print(y.shape) \ No newline at end of file + return model \ No newline at end of file