the dataclass init needs to use the default factory pattern, according to Ross

This commit is contained in:
Beckschen 2024-05-14 15:10:05 -04:00
parent 99d4c7d202
commit df304ffbf2

View File

@ -21,7 +21,7 @@ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision
from functools import partial
from typing import List, Tuple
from dataclasses import dataclass, replace
from dataclasses import dataclass, replace, field
from typing import Callable, Optional, Union, Tuple, List, Sequence
import math, time
from torch.jit import Final
@ -29,16 +29,17 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from timm.layers import to_2tuple
from torch.utils.checkpoint import checkpoint
from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_
from timm.models._registry import register_model
from timm.layers import to_2tuple
from timm.layers import DropPath
from timm.layers.norm_act import _create_act
from timm.models._manipulate import named_apply, checkpoint_seq
from timm.models._builder import build_model_with_cfg
from timm.models._registry import register_model
from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn
from timm.models.vision_transformer_hybrid import HybridEmbed
@ -54,37 +55,19 @@ class VitConvCfg:
pool_type: str = 'avg2'
downsample_pool_type: str = 'avg2'
act_layer: str = 'gelu' # stem & stage 1234
act_layer1: str = 'gelu' # stage 1234
act_layer2: str = 'gelu' # stage 1234
norm_layer: str = ''
norm_layer_cl: str = ''
norm_eps: Optional[float] = None
norm_eps: float = 1e-5
down_shortcut: Optional[bool] = True
mlp: str = 'mlp'
def __post_init__(self):
# mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args
use_mbconv = True
if not self.norm_layer:
self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d'
if not self.norm_layer_cl and not use_mbconv:
self.norm_layer_cl = 'layernorm'
if self.norm_eps is None:
self.norm_eps = 1e-5 if use_mbconv else 1e-6
self.downsample_pool_type = self.downsample_pool_type or self.pool_type
@dataclass
class VitCfg:
# embed_dim: Tuple[int, ...] = (96, 192, 384, 768)
embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768)
depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2)
stem_width: int = 64
conv_cfg: VitConvCfg = VitConvCfg()
weight_init: str = 'vit_eff'
conv_cfg: VitConvCfg = field(default_factory=VitConvCfg)
head_type: str = ""
stem_type: str = "stem"
ln2d_permute: bool = True
# memory_format: str=""
def _init_conv(module, name, scheme=''):
@ -95,6 +78,7 @@ def _init_conv(module, name, scheme=''):
if module.bias is not None:
nn.init.zeros_(module.bias)
class Stem(nn.Module):
def __init__(
self,
@ -126,6 +110,7 @@ class Stem(nn.Module):
return x
class Downsample2d(nn.Module):
def __init__(
self,
@ -158,12 +143,10 @@ class StridedConv(nn.Module):
stride=2,
padding=1,
in_chans=3,
embed_dim=768,
ln2d_permute=True
embed_dim=768
):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
self.permute = ln2d_permute # TODO: disable
norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
self.norm = norm_layer(in_chans) # affine over C
@ -354,6 +337,7 @@ class HybridEmbed(nn.Module):
x = x.flatten(2).transpose(1, 2)
return x
def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
@ -434,6 +418,7 @@ def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer:
'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
return model
@register_model
def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg(
@ -452,6 +437,7 @@ def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg(
@ -470,6 +456,7 @@ def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg(
@ -488,6 +475,7 @@ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg(
@ -506,6 +494,7 @@ def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg(
@ -524,6 +513,7 @@ def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg(
@ -541,21 +531,3 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
model = _create_vision_transformer_hybrid(
'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
return model
def count_params(model: nn.Module):
return sum([m.numel() for m in model.parameters()])
def count_stage_params(model: nn.Module, prefix='none'):
collections = []
for name, m in model.named_parameters():
print(name)
if name.startswith(prefix):
collections.append(m.numel())
return sum(collections)
if __name__ == "__main__":
model = timm.create_model('vitamin_large', num_classes=10).cuda()
# x = torch.rand([2,3,224,224]).cuda()
check_keys(model)