mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Create model functions
This commit is contained in:
parent
87b4d7a29a
commit
2a4f6c13dd
@ -12,6 +12,8 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Christoph Reich
|
||||
# --------------------------------------------------------
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Tuple, Optional, List, Union, Any, Type
|
||||
|
||||
import torch
|
||||
@ -19,7 +21,81 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from .layers import DropPath, Mlp
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
# from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
# from .vision_transformer import checkpoint_filter_fn
|
||||
# from .registry import register_model
|
||||
# from .layers import DropPath, Mlp
|
||||
|
||||
from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from timm.models.vision_transformer import checkpoint_filter_fn
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.layers import DropPath, Mlp
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# patch models (my experiments)
|
||||
'swin_v2_cr_tiny_patch4_window12_384': _cfg(
|
||||
url="",
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
'swin_v2_cr_tiny_patch4_window7_224': _cfg(
|
||||
url="",
|
||||
input_size=(3, 224, 224), crop_pct=1.0),
|
||||
|
||||
'swin_v2_cr_small_patch4_window12_384': _cfg(
|
||||
url="",
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
'swin_v2_cr_small_patch4_window7_224': _cfg(
|
||||
url="",
|
||||
input_size=(3, 224, 224), crop_pct=1.0),
|
||||
|
||||
'swin_v2_cr_base_patch4_window12_384': _cfg(
|
||||
url="",
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
'swin_v2_cr_base_patch4_window7_224': _cfg(
|
||||
url="",
|
||||
input_size=(3, 224, 224), crop_pct=1.0),
|
||||
|
||||
'swin_v2_cr_large_patch4_window12_384': _cfg(
|
||||
url="",
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
'swin_v2_cr_large_patch4_window7_224': _cfg(
|
||||
url="",
|
||||
input_size=(3, 224, 224), crop_pct=1.0),
|
||||
|
||||
'swin_v2_cr_huge_patch4_window12_384': _cfg(
|
||||
url="",
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
'swin_v2_cr_huge_patch4_window7_224': _cfg(
|
||||
url="",
|
||||
input_size=(3, 224, 224), crop_pct=1.0),
|
||||
|
||||
'swin_v2_cr_giant_patch4_window12_384': _cfg(
|
||||
url="",
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
'swin_v2_cr_giant_patch4_window7_224': _cfg(
|
||||
url="",
|
||||
input_size=(3, 224, 224), crop_pct=1.0),
|
||||
}
|
||||
|
||||
|
||||
def bchw_to_bhwc(input: torch.Tensor) -> torch.Tensor:
|
||||
@ -958,3 +1034,148 @@ class SwinTransformerV2CR(nn.Module):
|
||||
# Predict classification
|
||||
classification: torch.Tensor = self.head(output)
|
||||
return classification
|
||||
|
||||
|
||||
def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, **kwargs):
|
||||
if default_cfg is None:
|
||||
default_cfg = deepcopy(default_cfgs[variant])
|
||||
overlay_external_default_cfg(default_cfg, kwargs)
|
||||
default_num_classes = default_cfg['num_classes']
|
||||
default_img_size = default_cfg['input_size'][-2:]
|
||||
|
||||
num_classes = kwargs.pop('num_classes', default_num_classes)
|
||||
img_size = kwargs.pop('img_size', default_img_size)
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
model = build_model_with_cfg(
|
||||
SwinTransformerV2CR, variant, pretrained,
|
||||
default_cfg=default_cfg,
|
||||
img_size=img_size,
|
||||
num_classes=num_classes,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_v2_cr_tiny_patch4_window12_384(pretrained=False, **kwargs):
|
||||
""" Swin-T V2 CR @ 384x384, trained ImageNet-1k
|
||||
"""
|
||||
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=96, depths=(2, 2, 6, 2),
|
||||
num_heads=(3, 6, 12, 24), **kwargs)
|
||||
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_v2_cr_tiny_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-T V2 CR @ 224x224, trained ImageNet-1k
|
||||
"""
|
||||
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2),
|
||||
num_heads=(3, 6, 12, 24), **kwargs)
|
||||
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_v2_cr_small_patch4_window12_384(pretrained=False, **kwargs):
|
||||
""" Swin-S V2 CR @ 384x384, trained ImageNet-1k
|
||||
"""
|
||||
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2),
|
||||
num_heads=(3, 6, 12, 24), **kwargs)
|
||||
return _create_swin_transformer_v2_cr('swin_v2_cr_small_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_v2_cr_small_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-S V2 CR @ 224x224, trained ImageNet-1k
|
||||
"""
|
||||
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2),
|
||||
num_heads=(3, 6, 12, 24), **kwargs)
|
||||
return _create_swin_transformer_v2_cr('swin_v2_cr_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_v2_cr_base_patch4_window12_384(pretrained=False, **kwargs):
|
||||
""" Swin-B V2 CR @ 384x384, trained ImageNet-1k
|
||||
"""
|
||||
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2),
|
||||
num_heads=(4, 8, 16, 32), **kwargs)
|
||||
return _create_swin_transformer_v2_cr('swin_v2_cr_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_v2_cr_base_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-B V2 CR @ 224x224, trained ImageNet-1k
|
||||
"""
|
||||
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2),
|
||||
num_heads=(4, 8, 16, 32), **kwargs)
|
||||
return _create_swin_transformer_v2_cr('swin_v2_cr_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_v2_cr_large_patch4_window12_384(pretrained=False, **kwargs):
|
||||
""" Swin-L V2 CR @ 384x384, trained ImageNet-1k
|
||||
"""
|
||||
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2),
|
||||
num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer_v2_cr('swin_v2_cr_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_v2_cr_large_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-L V2 CR @ 224x224, trained ImageNet-1k
|
||||
"""
|
||||
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2),
|
||||
num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer_v2_cr('swin_v2_cr_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_v2_cr_huge_patch4_window12_384(pretrained=False, **kwargs):
|
||||
""" Swin-H V2 CR @ 384x384, trained ImageNet-1k
|
||||
"""
|
||||
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=352, depths=(2, 2, 18, 2),
|
||||
num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer_v2_cr('swin_v2_cr_huge_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_v2_cr_huge_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-H V2 CR @ 224x224, trained ImageNet-1k
|
||||
"""
|
||||
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=352, depths=(2, 2, 18, 2),
|
||||
num_heads=(11, 22, 44, 88), **kwargs)
|
||||
return _create_swin_transformer_v2_cr('swin_v2_cr_huge_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_v2_cr_giant_patch4_window12_384(pretrained=False, **kwargs):
|
||||
""" Swin-G V2 CR @ 384x384, trained ImageNet-1k
|
||||
"""
|
||||
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=512, depths=(2, 2, 18, 2),
|
||||
num_heads=(16, 32, 64, 128), **kwargs)
|
||||
return _create_swin_transformer_v2_cr('swin_v2_cr_giant_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_v2_cr_giant_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-G V2 CR @ 224x224, trained ImageNet-1k
|
||||
"""
|
||||
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=512, depths=(2, 2, 42, 2),
|
||||
num_heads=(16, 32, 64, 128), **kwargs)
|
||||
return _create_swin_transformer_v2_cr('swin_v2_cr_giant_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = swin_v2_cr_tiny_patch4_window12_384(pretrained=False)
|
||||
model = swin_v2_cr_tiny_patch4_window7_224(pretrained=False)
|
||||
|
||||
model = swin_v2_cr_small_patch4_window12_384(pretrained=False)
|
||||
model = swin_v2_cr_small_patch4_window7_224(pretrained=False)
|
||||
|
||||
model = swin_v2_cr_base_patch4_window12_384(pretrained=False)
|
||||
model = swin_v2_cr_base_patch4_window7_224(pretrained=False)
|
||||
|
||||
model = swin_v2_cr_large_patch4_window12_384(pretrained=False)
|
||||
model = swin_v2_cr_large_patch4_window7_224(pretrained=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user