mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
the dataclass init needs to use the default factory pattern, according to Ross
This commit is contained in:
parent
99d4c7d202
commit
df304ffbf2
@ -21,7 +21,7 @@ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision
|
|||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass, replace, field
|
||||||
from typing import Callable, Optional, Union, Tuple, List, Sequence
|
from typing import Callable, Optional, Union, Tuple, List, Sequence
|
||||||
import math, time
|
import math, time
|
||||||
from torch.jit import Final
|
from torch.jit import Final
|
||||||
@ -29,16 +29,17 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import timm
|
import timm
|
||||||
from timm.layers import to_2tuple
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_
|
from timm.models.layers import create_attn, get_norm_layer, get_norm_act_layer, create_conv2d, make_divisible, trunc_normal_tf_
|
||||||
|
|
||||||
from timm.models._registry import register_model
|
from timm.layers import to_2tuple
|
||||||
from timm.layers import DropPath
|
from timm.layers import DropPath
|
||||||
from timm.layers.norm_act import _create_act
|
from timm.layers.norm_act import _create_act
|
||||||
|
|
||||||
from timm.models._manipulate import named_apply, checkpoint_seq
|
from timm.models._manipulate import named_apply, checkpoint_seq
|
||||||
from timm.models._builder import build_model_with_cfg
|
from timm.models._builder import build_model_with_cfg
|
||||||
|
from timm.models._registry import register_model
|
||||||
from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn
|
from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn
|
||||||
from timm.models.vision_transformer_hybrid import HybridEmbed
|
from timm.models.vision_transformer_hybrid import HybridEmbed
|
||||||
|
|
||||||
@ -54,37 +55,19 @@ class VitConvCfg:
|
|||||||
pool_type: str = 'avg2'
|
pool_type: str = 'avg2'
|
||||||
downsample_pool_type: str = 'avg2'
|
downsample_pool_type: str = 'avg2'
|
||||||
act_layer: str = 'gelu' # stem & stage 1234
|
act_layer: str = 'gelu' # stem & stage 1234
|
||||||
act_layer1: str = 'gelu' # stage 1234
|
|
||||||
act_layer2: str = 'gelu' # stage 1234
|
|
||||||
norm_layer: str = ''
|
norm_layer: str = ''
|
||||||
norm_layer_cl: str = ''
|
norm_eps: float = 1e-5
|
||||||
norm_eps: Optional[float] = None
|
|
||||||
down_shortcut: Optional[bool] = True
|
down_shortcut: Optional[bool] = True
|
||||||
mlp: str = 'mlp'
|
mlp: str = 'mlp'
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args
|
|
||||||
use_mbconv = True
|
|
||||||
if not self.norm_layer:
|
|
||||||
self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d'
|
|
||||||
if not self.norm_layer_cl and not use_mbconv:
|
|
||||||
self.norm_layer_cl = 'layernorm'
|
|
||||||
if self.norm_eps is None:
|
|
||||||
self.norm_eps = 1e-5 if use_mbconv else 1e-6
|
|
||||||
self.downsample_pool_type = self.downsample_pool_type or self.pool_type
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VitCfg:
|
class VitCfg:
|
||||||
# embed_dim: Tuple[int, ...] = (96, 192, 384, 768)
|
|
||||||
embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768)
|
embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768)
|
||||||
depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2)
|
depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2)
|
||||||
stem_width: int = 64
|
stem_width: int = 64
|
||||||
conv_cfg: VitConvCfg = VitConvCfg()
|
conv_cfg: VitConvCfg = field(default_factory=VitConvCfg)
|
||||||
weight_init: str = 'vit_eff'
|
|
||||||
head_type: str = ""
|
head_type: str = ""
|
||||||
stem_type: str = "stem"
|
|
||||||
ln2d_permute: bool = True
|
|
||||||
# memory_format: str=""
|
|
||||||
|
|
||||||
|
|
||||||
def _init_conv(module, name, scheme=''):
|
def _init_conv(module, name, scheme=''):
|
||||||
@ -95,6 +78,7 @@ def _init_conv(module, name, scheme=''):
|
|||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
nn.init.zeros_(module.bias)
|
nn.init.zeros_(module.bias)
|
||||||
|
|
||||||
|
|
||||||
class Stem(nn.Module):
|
class Stem(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -126,6 +110,7 @@ class Stem(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Downsample2d(nn.Module):
|
class Downsample2d(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -158,12 +143,10 @@ class StridedConv(nn.Module):
|
|||||||
stride=2,
|
stride=2,
|
||||||
padding=1,
|
padding=1,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
embed_dim=768,
|
embed_dim=768
|
||||||
ln2d_permute=True
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||||
self.permute = ln2d_permute # TODO: disable
|
|
||||||
norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
|
norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
|
||||||
self.norm = norm_layer(in_chans) # affine over C
|
self.norm = norm_layer(in_chans) # affine over C
|
||||||
|
|
||||||
@ -354,6 +337,7 @@ class HybridEmbed(nn.Module):
|
|||||||
x = x.flatten(2).transpose(1, 2)
|
x = x.flatten(2).transpose(1, 2)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
||||||
if kwargs.get('features_only', None):
|
if kwargs.get('features_only', None):
|
||||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||||
@ -434,6 +418,7 @@ def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
|
'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
|
def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
backbone = MbConvStages(cfg=VitCfg(
|
backbone = MbConvStages(cfg=VitCfg(
|
||||||
@ -452,6 +437,7 @@ def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
|
def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
backbone = MbConvStages(cfg=VitCfg(
|
backbone = MbConvStages(cfg=VitCfg(
|
||||||
@ -470,6 +456,7 @@ def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
|
def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
backbone = MbConvStages(cfg=VitCfg(
|
backbone = MbConvStages(cfg=VitCfg(
|
||||||
@ -488,6 +475,7 @@ def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
|
def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
backbone = MbConvStages(cfg=VitCfg(
|
backbone = MbConvStages(cfg=VitCfg(
|
||||||
@ -506,6 +494,7 @@ def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
|
def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
backbone = MbConvStages(cfg=VitCfg(
|
backbone = MbConvStages(cfg=VitCfg(
|
||||||
@ -524,6 +513,7 @@ def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
|
def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
backbone = MbConvStages(cfg=VitCfg(
|
backbone = MbConvStages(cfg=VitCfg(
|
||||||
@ -541,21 +531,3 @@ def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
model = _create_vision_transformer_hybrid(
|
model = _create_vision_transformer_hybrid(
|
||||||
'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def count_params(model: nn.Module):
|
|
||||||
return sum([m.numel() for m in model.parameters()])
|
|
||||||
|
|
||||||
def count_stage_params(model: nn.Module, prefix='none'):
|
|
||||||
collections = []
|
|
||||||
for name, m in model.named_parameters():
|
|
||||||
print(name)
|
|
||||||
if name.startswith(prefix):
|
|
||||||
collections.append(m.numel())
|
|
||||||
return sum(collections)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
model = timm.create_model('vitamin_large', num_classes=10).cuda()
|
|
||||||
# x = torch.rand([2,3,224,224]).cuda()
|
|
||||||
check_keys(model)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user