the dataclass init needs to use the default factory pattern, according to Ross

This commit is contained in:
Beckschen 2024-05-14 15:10:05 -04:00
parent 99d4c7d202
commit df304ffbf2

View File

@ -21,7 +21,7 @@ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision
from functools import partial from functools import partial
from typing import List, Tuple from typing import List, Tuple
from dataclasses import dataclass, replace from dataclasses import dataclass, replace, field
from typing import Callable, Optional, Union, Tuple, List, Sequence from typing import Callable, Optional, Union, Tuple, List, Sequence
import math, time import math, time
from torch.jit import Final from torch.jit import Final
@ -29,16 +29,17 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import timm import timm
from timm.layers import to_2tuple
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_ from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_
from timm.models._registry import register_model from timm.layers import to_2tuple
from timm.layers import DropPath from timm.layers import DropPath
from timm.layers.norm_act import _create_act from timm.layers.norm_act import _create_act
from timm.models._manipulate import named_apply, checkpoint_seq from timm.models._manipulate import named_apply, checkpoint_seq
from timm.models._builder import build_model_with_cfg from timm.models._builder import build_model_with_cfg
from timm.models._registry import register_model
from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn
from timm.models.vision_transformer_hybrid import HybridEmbed from timm.models.vision_transformer_hybrid import HybridEmbed
@ -54,37 +55,19 @@ class VitConvCfg:
pool_type: str = 'avg2' pool_type: str = 'avg2'
downsample_pool_type: str = 'avg2' downsample_pool_type: str = 'avg2'
act_layer: str = 'gelu' # stem & stage 1234 act_layer: str = 'gelu' # stem & stage 1234
act_layer1: str = 'gelu' # stage 1234
act_layer2: str = 'gelu' # stage 1234
norm_layer: str = '' norm_layer: str = ''
norm_layer_cl: str = '' norm_eps: float = 1e-5
norm_eps: Optional[float] = None
down_shortcut: Optional[bool] = True down_shortcut: Optional[bool] = True
mlp: str = 'mlp' mlp: str = 'mlp'
def __post_init__(self):
# mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args
use_mbconv = True
if not self.norm_layer:
self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d'
if not self.norm_layer_cl and not use_mbconv:
self.norm_layer_cl = 'layernorm'
if self.norm_eps is None:
self.norm_eps = 1e-5 if use_mbconv else 1e-6
self.downsample_pool_type = self.downsample_pool_type or self.pool_type
@dataclass @dataclass
class VitCfg: class VitCfg:
# embed_dim: Tuple[int, ...] = (96, 192, 384, 768)
embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768) embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768)
depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2) depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2)
stem_width: int = 64 stem_width: int = 64
conv_cfg: VitConvCfg = VitConvCfg() conv_cfg: VitConvCfg = field(default_factory=VitConvCfg)
weight_init: str = 'vit_eff'
head_type: str = "" head_type: str = ""
stem_type: str = "stem"
ln2d_permute: bool = True
# memory_format: str=""
def _init_conv(module, name, scheme=''): def _init_conv(module, name, scheme=''):
@ -95,6 +78,7 @@ def _init_conv(module, name, scheme=''):
if module.bias is not None: if module.bias is not None:
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
class Stem(nn.Module): class Stem(nn.Module):
def __init__( def __init__(
self, self,
@ -126,6 +110,7 @@ class Stem(nn.Module):
return x return x
class Downsample2d(nn.Module): class Downsample2d(nn.Module):
def __init__( def __init__(
self, self,
@ -158,12 +143,10 @@ class StridedConv(nn.Module):
stride=2, stride=2,
padding=1, padding=1,
in_chans=3, in_chans=3,
embed_dim=768, embed_dim=768
ln2d_permute=True
): ):
super().__init__() super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
self.permute = ln2d_permute # TODO: disable
norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6) norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
self.norm = norm_layer(in_chans) # affine over C self.norm = norm_layer(in_chans) # affine over C
@ -354,6 +337,7 @@ class HybridEmbed(nn.Module):
x = x.flatten(2).transpose(1, 2) x = x.flatten(2).transpose(1, 2)
return x return x
def _create_vision_transformer(variant, pretrained=False, **kwargs): def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
@ -434,6 +418,7 @@ def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer:
'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs)) 'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
return model return model
@register_model @register_model
def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer: def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg( backbone = MbConvStages(cfg=VitCfg(
@ -452,6 +437,7 @@ def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) 'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer: def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg( backbone = MbConvStages(cfg=VitCfg(
@ -470,6 +456,7 @@ def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) 'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer: def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg( backbone = MbConvStages(cfg=VitCfg(
@ -488,6 +475,7 @@ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) 'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer: def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg( backbone = MbConvStages(cfg=VitCfg(
@ -506,6 +494,7 @@ def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer: def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg( backbone = MbConvStages(cfg=VitCfg(
@ -524,6 +513,7 @@ def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer: def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
backbone = MbConvStages(cfg=VitCfg( backbone = MbConvStages(cfg=VitCfg(
@ -541,21 +531,3 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
model = _create_vision_transformer_hybrid( model = _create_vision_transformer_hybrid(
'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs)) 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
def count_params(model: nn.Module):
return sum([m.numel() for m in model.parameters()])
def count_stage_params(model: nn.Module, prefix='none'):
collections = []
for name, m in model.named_parameters():
print(name)
if name.startswith(prefix):
collections.append(m.numel())
return sum(collections)
if __name__ == "__main__":
model = timm.create_model('vitamin_large', num_classes=10).cuda()
# x = torch.rand([2,3,224,224]).cuda()
check_keys(model)