mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
the dataclass init needs to use the default factory pattern, according to Ross
This commit is contained in:
parent
99d4c7d202
commit
df304ffbf2
@ -21,7 +21,7 @@ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision
|
||||
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
from dataclasses import dataclass, replace
|
||||
from dataclasses import dataclass, replace, field
|
||||
from typing import Callable, Optional, Union, Tuple, List, Sequence
|
||||
import math, time
|
||||
from torch.jit import Final
|
||||
@ -29,16 +29,17 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import timm
|
||||
from timm.layers import to_2tuple
|
||||
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_
|
||||
|
||||
from timm.models._registry import register_model
|
||||
from timm.layers import to_2tuple
|
||||
from timm.layers import DropPath
|
||||
from timm.layers.norm_act import _create_act
|
||||
|
||||
from timm.models._manipulate import named_apply, checkpoint_seq
|
||||
from timm.models._builder import build_model_with_cfg
|
||||
from timm.models._registry import register_model
|
||||
from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn
|
||||
from timm.models.vision_transformer_hybrid import HybridEmbed
|
||||
|
||||
@ -54,37 +55,19 @@ class VitConvCfg:
|
||||
pool_type: str = 'avg2'
|
||||
downsample_pool_type: str = 'avg2'
|
||||
act_layer: str = 'gelu' # stem & stage 1234
|
||||
act_layer1: str = 'gelu' # stage 1234
|
||||
act_layer2: str = 'gelu' # stage 1234
|
||||
norm_layer: str = ''
|
||||
norm_layer_cl: str = ''
|
||||
norm_eps: Optional[float] = None
|
||||
norm_eps: float = 1e-5
|
||||
down_shortcut: Optional[bool] = True
|
||||
mlp: str = 'mlp'
|
||||
|
||||
def __post_init__(self):
|
||||
# mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args
|
||||
use_mbconv = True
|
||||
if not self.norm_layer:
|
||||
self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d'
|
||||
if not self.norm_layer_cl and not use_mbconv:
|
||||
self.norm_layer_cl = 'layernorm'
|
||||
if self.norm_eps is None:
|
||||
self.norm_eps = 1e-5 if use_mbconv else 1e-6
|
||||
self.downsample_pool_type = self.downsample_pool_type or self.pool_type
|
||||
|
||||
@dataclass
|
||||
class VitCfg:
|
||||
# embed_dim: Tuple[int, ...] = (96, 192, 384, 768)
|
||||
embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768)
|
||||
depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2)
|
||||
stem_width: int = 64
|
||||
conv_cfg: VitConvCfg = VitConvCfg()
|
||||
weight_init: str = 'vit_eff'
|
||||
conv_cfg: VitConvCfg = field(default_factory=VitConvCfg)
|
||||
head_type: str = ""
|
||||
stem_type: str = "stem"
|
||||
ln2d_permute: bool = True
|
||||
# memory_format: str=""
|
||||
|
||||
|
||||
def _init_conv(module, name, scheme=''):
|
||||
@ -95,6 +78,7 @@ def _init_conv(module, name, scheme=''):
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
|
||||
class Stem(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -126,6 +110,7 @@ class Stem(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Downsample2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -158,12 +143,10 @@ class StridedConv(nn.Module):
|
||||
stride=2,
|
||||
padding=1,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
ln2d_permute=True
|
||||
embed_dim=768
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.permute = ln2d_permute # TODO: disable
|
||||
norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
|
||||
self.norm = norm_layer(in_chans) # affine over C
|
||||
|
||||
@ -354,6 +337,7 @@ class HybridEmbed(nn.Module):
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
@ -434,6 +418,7 @@ def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
backbone = MbConvStages(cfg=VitCfg(
|
||||
@ -452,6 +437,7 @@ def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
backbone = MbConvStages(cfg=VitCfg(
|
||||
@ -470,6 +456,7 @@ def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
backbone = MbConvStages(cfg=VitCfg(
|
||||
@ -488,6 +475,7 @@ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
backbone = MbConvStages(cfg=VitCfg(
|
||||
@ -506,6 +494,7 @@ def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
backbone = MbConvStages(cfg=VitCfg(
|
||||
@ -524,6 +513,7 @@ def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
backbone = MbConvStages(cfg=VitCfg(
|
||||
@ -540,22 +530,4 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
def count_params(model: nn.Module):
|
||||
return sum([m.numel() for m in model.parameters()])
|
||||
|
||||
def count_stage_params(model: nn.Module, prefix='none'):
|
||||
collections = []
|
||||
for name, m in model.named_parameters():
|
||||
print(name)
|
||||
if name.startswith(prefix):
|
||||
collections.append(m.numel())
|
||||
return sum(collections)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = timm.create_model('vitamin_large', num_classes=10).cuda()
|
||||
# x = torch.rand([2,3,224,224]).cuda()
|
||||
check_keys(model)
|
||||
return model
|
Loading…
x
Reference in New Issue
Block a user