From df304ffbf24114c2faf62eb0e6faae4c18320256 Mon Sep 17 00:00:00 2001 From: Beckschen Date: Tue, 14 May 2024 15:10:05 -0400 Subject: [PATCH] the dataclass init needs to use the default factory pattern, according to Ross --- timm/models/vitamin.py | 62 ++++++++++++------------------------------ 1 file changed, 17 insertions(+), 45 deletions(-) diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index 3eecb8db..ad1b6883 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -21,7 +21,7 @@ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision from functools import partial from typing import List, Tuple -from dataclasses import dataclass, replace +from dataclasses import dataclass, replace, field from typing import Callable, Optional, Union, Tuple, List, Sequence import math, time from torch.jit import Final @@ -29,16 +29,17 @@ import torch import torch.nn as nn import torch.nn.functional as F import timm -from timm.layers import to_2tuple + from torch.utils.checkpoint import checkpoint from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_ -from timm.models._registry import register_model +from timm.layers import to_2tuple from timm.layers import DropPath from timm.layers.norm_act import _create_act from timm.models._manipulate import named_apply, checkpoint_seq from timm.models._builder import build_model_with_cfg +from timm.models._registry import register_model from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn from timm.models.vision_transformer_hybrid import HybridEmbed @@ -54,37 +55,19 @@ class VitConvCfg: pool_type: str = 'avg2' downsample_pool_type: str = 'avg2' act_layer: str = 'gelu' # stem & stage 1234 - act_layer1: str = 'gelu' # stage 1234 - act_layer2: str = 'gelu' # stage 1234 norm_layer: str = '' - norm_layer_cl: str = '' - norm_eps: Optional[float] = None + norm_eps: float = 1e-5 down_shortcut: Optional[bool] = True mlp: str = 'mlp' - def __post_init__(self): - # mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args - use_mbconv = True - if not self.norm_layer: - self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d' - if not self.norm_layer_cl and not use_mbconv: - self.norm_layer_cl = 'layernorm' - if self.norm_eps is None: - self.norm_eps = 1e-5 if use_mbconv else 1e-6 - self.downsample_pool_type = self.downsample_pool_type or self.pool_type @dataclass class VitCfg: - # embed_dim: Tuple[int, ...] = (96, 192, 384, 768) embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768) depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2) stem_width: int = 64 - conv_cfg: VitConvCfg = VitConvCfg() - weight_init: str = 'vit_eff' + conv_cfg: VitConvCfg = field(default_factory=VitConvCfg) head_type: str = "" - stem_type: str = "stem" - ln2d_permute: bool = True - # memory_format: str="" def _init_conv(module, name, scheme=''): @@ -95,6 +78,7 @@ def _init_conv(module, name, scheme=''): if module.bias is not None: nn.init.zeros_(module.bias) + class Stem(nn.Module): def __init__( self, @@ -126,6 +110,7 @@ class Stem(nn.Module): return x + class Downsample2d(nn.Module): def __init__( self, @@ -158,12 +143,10 @@ class StridedConv(nn.Module): stride=2, padding=1, in_chans=3, - embed_dim=768, - ln2d_permute=True + embed_dim=768 ): super().__init__() self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) - self.permute = ln2d_permute # TODO: disable norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6) self.norm = norm_layer(in_chans) # affine over C @@ -354,6 +337,7 @@ class HybridEmbed(nn.Module): x = x.flatten(2).transpose(1, 2) return x + def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') @@ -434,6 +418,7 @@ def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) return model + @register_model def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -452,6 +437,7 @@ def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -470,6 +456,7 @@ def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -488,6 +475,7 @@ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -506,6 +494,7 @@ def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -524,6 +513,7 @@ def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: backbone = MbConvStages(cfg=VitCfg( @@ -540,22 +530,4 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg') model = _create_vision_transformer_hybrid( 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) - return model - - -def count_params(model: nn.Module): - return sum([m.numel() for m in model.parameters()]) - -def count_stage_params(model: nn.Module, prefix='none'): - collections = [] - for name, m in model.named_parameters(): - print(name) - if name.startswith(prefix): - collections.append(m.numel()) - return sum(collections) - - -if __name__ == "__main__": - model = timm.create_model('vitamin_large', num_classes=10).cuda() - # x = torch.rand([2,3,224,224]).cuda() - check_keys(model) + return model \ No newline at end of file