From 4d5c395160a6610d6c4f09f12aa08dac34f3405a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 21 Nov 2022 22:14:12 -0800 Subject: [PATCH] MaxVit, ViT, ConvNeXt, and EfficientNet-v2 updates * Add support for TF weights and modelling specifics to MaxVit (testing ported weights) * More fine-tuned CLIP ViT configs * ConvNeXt and MaxVit updated to new pretrained cfgs use * EfficientNetV2, MaxVit and ConvNeXt high res models use squash crop/resize --- timm/models/convnext.py | 303 ++-- timm/models/efficientnet.py | 16 +- timm/models/layers/activations.py | 14 + timm/models/layers/create_act.py | 2 + timm/models/maxxvit.py | 2663 ++++++++++++++++------------- timm/models/vision_transformer.py | 22 +- 6 files changed, 1658 insertions(+), 1362 deletions(-) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 15000b40..e64bd0ef 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -21,111 +21,13 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ create_conv2d, get_act_layer, make_divisible, to_ntuple +from ._pretrained import generate_defaults from .registry import register_model __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.0', 'classifier': 'head.fc', - **kwargs - } - - -default_cfgs = dict( - # timm specific variants - convnext_atto=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth', - test_input_size=(3, 288, 288), test_crop_pct=0.95), - convnext_atto_ols=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth', - test_input_size=(3, 288, 288), test_crop_pct=0.95), - convnext_femto=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth', - test_input_size=(3, 288, 288), test_crop_pct=0.95), - convnext_femto_ols=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth', - test_input_size=(3, 288, 288), test_crop_pct=0.95), - convnext_pico=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth', - test_input_size=(3, 288, 288), test_crop_pct=0.95), - convnext_pico_ols=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth', - crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_nano=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth', - crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_nano_ols=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth', - crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_tiny_hnf=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', - crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), - - convnext_tiny=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", - test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_small=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", - test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_base=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", - test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_large=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", - test_input_size=(3, 288, 288), test_crop_pct=1.0), - - convnext_tiny_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth', - test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_small_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth', - test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_base_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth', - test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_large_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth', - test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_xlarge_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth', - test_input_size=(3, 288, 288), test_crop_pct=1.0), - - convnext_tiny_384_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth', - input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), - convnext_small_384_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth', - input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), - convnext_base_384_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth', - input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), - convnext_large_384_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth', - input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), - convnext_xlarge_384_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth', - input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), - - convnext_tiny_in22k=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841), - convnext_small_in22k=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841), - convnext_base_in22k=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841), - convnext_large_in22k=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841), - convnext_xlarge_in22k=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841), -) - - class ConvNeXtBlock(nn.Module): """ ConvNeXt Block There are two equivalent implementations: @@ -459,6 +361,107 @@ def _create_convnext(variant, pretrained=False, **kwargs): return model + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = generate_defaults({ + # timm specific variants + 'convnext_atto.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_atto_ols.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_femto.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_femto_ols.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_pico.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_pico_ols.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_nano.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_nano_ols.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_tiny_hnf.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + + 'convnext_tiny.fb_in1k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_small.fb_in1k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_base.fb_in1k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_large.fb_in1k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_xlarge.untrained': _cfg(), + + 'convnext_tiny.fb_in22k_ft_in1k': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_small.fb_in22k_ft_in1k': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_base.fb_in22k_ft_in1k': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_large.fb_in22k_ft_in1k': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_xlarge.fb_in22k_ft_in1k': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + + 'convnext_tiny.fb_in22k_ft_in1k_384': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_small..fb_in22k_ft_in1k_384': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_base.fb_in22k_ft_in1k_384': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_large.fb_in22k_ft_in1k_384': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + 'convnext_tiny_in22k.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841), + 'convnext_small_in22k.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841), + 'convnext_base_in22k.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841), + 'convnext_large_in22k.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841), + 'convnext_xlarge_in22k.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841), +}) + + @register_model def convnext_atto(pretrained=False, **kwargs): # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M @@ -569,105 +572,7 @@ def convnext_large(pretrained=False, **kwargs): @register_model -def convnext_tiny_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) - model = _create_convnext('convnext_tiny_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_small_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) - model = _create_convnext('convnext_small_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_base_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) - model = _create_convnext('convnext_base_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_large_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) - model = _create_convnext('convnext_large_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_xlarge_in22ft1k(pretrained=False, **kwargs): +def convnext_xlarge(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) - model = _create_convnext('convnext_xlarge_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_tiny_384_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) - model = _create_convnext('convnext_tiny_384_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_small_384_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) - model = _create_convnext('convnext_small_384_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_base_384_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) - model = _create_convnext('convnext_base_384_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_large_384_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) - model = _create_convnext('convnext_large_384_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) - model = _create_convnext('convnext_xlarge_384_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_tiny_in22k(pretrained=False, **kwargs): - model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) - model = _create_convnext('convnext_tiny_in22k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_small_in22k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) - model = _create_convnext('convnext_small_in22k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_base_in22k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) - model = _create_convnext('convnext_base_in22k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_large_in22k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) - model = _create_convnext('convnext_large_in22k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_xlarge_in22k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) - model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args) + model = _create_convnext('convnext_xlarge', pretrained=pretrained, **model_args) return model diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 51c683c0..3c0efc96 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -366,11 +366,11 @@ default_cfgs = { 'tf_efficientnetv2_m': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_l': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_s_in21ft1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', @@ -379,15 +379,15 @@ default_cfgs = { 'tf_efficientnetv2_m_in21ft1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21ft1k-bf41664a.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_l_in21ft1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21ft1k-60127a9d.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_xl_in21ft1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21ft1k-06c35c48.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_s_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth', @@ -396,15 +396,15 @@ default_cfgs = { 'tf_efficientnetv2_m_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21k-361418a2.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_l_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21k-91a19ec9.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_xl_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21k-fd7e8abf.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, - input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_b0': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b0-c7cc451f.pth', diff --git a/timm/models/layers/activations.py b/timm/models/layers/activations.py index e16b3bd3..2f5476c0 100644 --- a/timm/models/layers/activations.py +++ b/timm/models/layers/activations.py @@ -143,3 +143,17 @@ class GELU(nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: return F.gelu(input) + + +def gelu_tanh(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: + return F.gelu(x, approximate='tanh') + + +class GELUTanh(nn.Module): + """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) + """ + def __init__(self, inplace: bool = False): + super(GELUTanh, self).__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.gelu(input, approximate='tanh') diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index a3044a3d..0b02398d 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -28,6 +28,7 @@ _ACT_FN_DEFAULT = dict( celu=F.celu, selu=F.selu, gelu=gelu, + gelu_tanh=gelu_tanh, sigmoid=sigmoid, tanh=tanh, hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, @@ -71,6 +72,7 @@ _ACT_LAYER_DEFAULT = dict( celu=nn.CELU, selu=nn.SELU, gelu=GELU, + gelu_tanh=GELUTanh, sigmoid=Sigmoid, tanh=Tanh, hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index bd529245..13fd7abf 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -52,114 +52,19 @@ from .helpers import build_model_with_cfg, checkpoint_seq, named_apply from .fx_features import register_notrace_function from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm2d, LayerNorm from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d +from .layers import SelectAdaptivePool2d, create_pool2d from .layers import to_2tuple, extend_tuple, make_divisible, _assert +from ._pretrained import generate_defaults from .registry import register_model from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location __all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit'] -def _cfg(url='', **kwargs): - return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.95, 'interpolation': 'bicubic', - 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), - 'first_conv': 'stem.conv1', 'classifier': 'head.fc', - 'fixed_input_size': True, - **kwargs - } - - -default_cfgs = { - # Fiddling with configs / defaults / still pretraining - 'coatnet_pico_rw_224': _cfg(url=''), - 'coatnet_nano_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth', - crop_pct=0.9), - 'coatnet_0_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'), - 'coatnet_1_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' - ), - 'coatnet_2_rw_224': _cfg(url=''), - 'coatnet_3_rw_224': _cfg(url=''), - - # Highly experimental configs - 'coatnet_bn_0_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - crop_pct=0.95), - 'coatnet_rmlp_nano_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth', - crop_pct=0.9), - 'coatnet_rmlp_0_rw_224': _cfg(url=''), - 'coatnet_rmlp_1_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), - 'coatnet_rmlp_2_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'), - 'coatnet_rmlp_3_rw_224': _cfg(url=''), - 'coatnet_nano_cc_224': _cfg(url=''), - 'coatnext_nano_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth', - crop_pct=0.9), - - # Trying to be like the CoAtNet paper configs - 'coatnet_0_224': _cfg(url=''), - 'coatnet_1_224': _cfg(url=''), - 'coatnet_2_224': _cfg(url=''), - 'coatnet_3_224': _cfg(url=''), - 'coatnet_4_224': _cfg(url=''), - 'coatnet_5_224': _cfg(url=''), - - # Experimental configs - 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_nano_rw_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', - input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_tiny_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'), - 'maxvit_tiny_rw_256': _cfg( - url='', - input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_pico_rw_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth', - input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_nano_rw_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', - input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_tiny_rw_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth', - input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_small_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth', - crop_pct=0.9, - ), - 'maxvit_rmlp_small_rw_256': _cfg( - url='', - input_size=(3, 256, 256), pool_size=(8, 8)), - - 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - - 'maxxvit_rmlp_nano_rw_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth', - input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_rmlp_small_rw_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth', - input_size=(3, 256, 256), pool_size=(8, 8)), - - # Trying to be like the MaxViT paper configs - 'maxvit_tiny_224': _cfg(url=''), - 'maxvit_small_224': _cfg(url=''), - 'maxvit_base_224': _cfg(url=''), - 'maxvit_large_224': _cfg(url=''), - 'maxvit_xlarge_224': _cfg(url=''), -} - - @dataclass class MaxxVitTransformerCfg: dim_head: int = 32 + head_first: bool = True # head ordering in qkv channel dim expand_ratio: float = 4.0 expand_first: bool = True shortcut_bias: bool = True @@ -199,6 +104,7 @@ class MaxxVitConvCfg: stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw' pool_type: str = 'avg2' downsample_pool_type: str = 'avg2' + padding: str = '' attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2 attn_layer: str = 'se' attn_act_layer: str = 'silu' @@ -228,12 +134,1209 @@ class MaxxVitCfg: depths: Tuple[int, ...] = (2, 3, 5, 2) block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T') stem_width: Union[int, Tuple[int, int]] = 64 - stem_bias: bool = True + stem_bias: bool = False conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg() transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg() + head_hidden_size: int = None weight_init: str = 'vit_eff' +class Attention2d(nn.Module): + """ multi-head attention for 2D NCHW tensors""" + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dim_head: int = 32, + bias: bool = True, + expand_first: bool = True, + head_first: bool = True, + rel_pos_cls: Callable = None, + attn_drop: float = 0., + proj_drop: float = 0. + ): + super().__init__() + dim_out = dim_out or dim + dim_attn = dim_out if expand_first else dim + self.num_heads = dim_attn // dim_head + self.dim_head = dim_head + self.head_first = head_first + self.scale = dim_head ** -0.5 + + self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) + self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + B, C, H, W = x.shape + + if self.head_first: + q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2) + else: + q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1) + + attn = (q.transpose(-2, -1) @ k) * self.scale + if self.rel_pos is not None: + attn = self.rel_pos(attn) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class AttentionCl(nn.Module): + """ Channels-last multi-head attention (B, ..., C) """ + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dim_head: int = 32, + bias: bool = True, + expand_first: bool = True, + head_first: bool = True, + rel_pos_cls: Callable = None, + attn_drop: float = 0., + proj_drop: float = 0. + ): + super().__init__() + dim_out = dim_out or dim + dim_attn = dim_out if expand_first and dim_out > dim else dim + assert dim_attn % dim_head == 0, 'attn dim should be divisible by head_dim' + self.num_heads = dim_attn // dim_head + self.dim_head = dim_head + self.head_first = head_first + self.scale = dim_head ** -0.5 + + self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias) + self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim_attn, dim_out, bias=bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + B = x.shape[0] + restore_shape = x.shape[:-1] + + if self.head_first: + q, k, v = self.qkv(x).view(B, -1, self.num_heads, self.dim_head * 3).transpose(1, 2).chunk(3, dim=3) + else: + q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2) + + attn = (q @ k.transpose(-2, -1)) * self.scale + if self.rel_pos is not None: + attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(restore_shape + (-1,)) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma + return x.mul_(gamma) if self.inplace else x * gamma + + +class LayerScale2d(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma.view(1, -1, 1, 1) + return x.mul_(gamma) if self.inplace else x * gamma + + +class Downsample2d(nn.Module): + """ A downsample pooling module supporting several maxpool and avgpool modes + * 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1 + * 'max2' - MaxPool2d w/ kernel_size = stride = 2 + * 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1 + * 'avg2' - AvgPool2d w/ kernel_size = stride = 2 + """ + + def __init__( + self, + dim: int, + dim_out: int, + pool_type: str = 'avg2', + padding: str = '', + bias: bool = True, + ): + super().__init__() + assert pool_type in ('max', 'max2', 'avg', 'avg2') + if pool_type == 'max': + self.pool = create_pool2d('max', kernel_size=3, stride=2, padding=padding or 1) + elif pool_type == 'max2': + self.pool = create_pool2d('max', 2, padding=padding or 0) # kernel_size == stride == 2 + elif pool_type == 'avg': + self.pool = create_pool2d( + 'avg', kernel_size=3, stride=2, count_include_pad=False, padding=padding or 1) + else: + self.pool = create_pool2d('avg', 2, padding=padding or 0) + + if dim != dim_out: + self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) + else: + self.expand = nn.Identity() + + def forward(self, x): + x = self.pool(x) # spatial downsample + x = self.expand(x) # expand chs + return x + + +def _init_transformer(module, name, scheme=''): + if isinstance(module, (nn.Conv2d, nn.Linear)): + if scheme == 'normal': + nn.init.normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'trunc_normal': + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'xavier_normal': + nn.init.xavier_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + # vit like + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + + +class TransformerBlock2d(nn.Module): + """ Transformer block with 2D downsampling + '2D' NCHW tensor layout + + Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW + for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs. + + This impl was faster on TPU w/ PT XLA than the 1D experiment. + """ + + def __init__( + self, + dim: int, + dim_out: int, + stride: int = 1, + rel_pos_cls: Callable = None, + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) + act_layer = get_act_layer(cfg.act_layer) + + if stride == 2: + self.shortcut = Downsample2d(dim, dim_out, pool_type=cfg.pool_type, bias=cfg.shortcut_bias) + self.norm1 = nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('down', Downsample2d(dim, dim, pool_type=cfg.pool_type)), + ])) + else: + assert dim == dim_out + self.shortcut = nn.Identity() + self.norm1 = norm_layer(dim) + + self.attn = Attention2d( + dim, + dim_out, + dim_head=cfg.dim_head, + expand_first=cfg.expand_first, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop + ) + self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = ConvMlp( + in_features=dim_out, + hidden_features=int(dim_out * cfg.expand_ratio), + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def init_weights(self, scheme=''): + named_apply(partial(_init_transformer, scheme=scheme), self) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + x = self.shortcut(x) + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +def _init_conv(module, name, scheme=''): + if isinstance(module, nn.Conv2d): + if scheme == 'normal': + nn.init.normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'trunc_normal': + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'xavier_normal': + nn.init.xavier_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + # efficientnet like + fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels + fan_out //= module.groups + nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out)) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def num_groups(group_size, channels): + if not group_size: # 0 or None + return 1 # normal conv with 1 group + else: + # NOTE group_size == 1 -> depthwise conv + assert channels % group_size == 0 + return channels // group_size + + +class MbConvBlock(nn.Module): + """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand) + """ + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 1, + dilation: Tuple[int, int] = (1, 1), + cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + drop_path: float = 0. + ): + super(MbConvBlock, self).__init__() + norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps) + mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio) + groups = num_groups(cfg.group_size, mid_chs) + + if stride == 2: + self.shortcut = Downsample2d( + in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias, padding=cfg.padding) + else: + self.shortcut = nn.Identity() + + assert cfg.stride_mode in ('pool', '1x1', 'dw') + stride_pool, stride_1, stride_2 = 1, 1, 1 + if cfg.stride_mode == 'pool': + # NOTE this is not described in paper, experiment to find faster option that doesn't stride in 1x1 + stride_pool, dilation_2 = stride, dilation[1] + # FIXME handle dilation of avg pool + elif cfg.stride_mode == '1x1': + # NOTE I don't like this option described in paper, 1x1 w/ stride throws info away + stride_1, dilation_2 = stride, dilation[1] + else: + stride_2, dilation_2 = stride, dilation[0] + + self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act) + if stride_pool > 1: + self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type, padding=cfg.padding) + else: + self.down = nn.Identity() + self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1) + self.norm1 = norm_act_layer(mid_chs) + + self.conv2_kxk = create_conv2d( + mid_chs, mid_chs, cfg.kernel_size, + stride=stride_2, dilation=dilation_2, groups=groups, padding=cfg.padding) + + attn_kwargs = {} + if isinstance(cfg.attn_layer, str): + if cfg.attn_layer == 'se' or cfg.attn_layer == 'eca': + attn_kwargs['act_layer'] = cfg.attn_act_layer + attn_kwargs['rd_channels'] = int(cfg.attn_ratio * (out_chs if cfg.expand_output else mid_chs)) + + # two different orderings for SE and norm2 (due to some weights and trials using SE before norm2) + if cfg.attn_early: + self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) + self.norm2 = norm_act_layer(mid_chs) + self.se = None + else: + self.se_early = None + self.norm2 = norm_act_layer(mid_chs) + self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) + + self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def init_weights(self, scheme=''): + named_apply(partial(_init_conv, scheme=scheme), self) + + def forward(self, x): + shortcut = self.shortcut(x) + x = self.pre_norm(x) + x = self.down(x) + + # 1x1 expansion conv & norm-act + x = self.conv1_1x1(x) + x = self.norm1(x) + + # depthwise / grouped 3x3 conv w/ SE (or other) channel attention & norm-act + x = self.conv2_kxk(x) + if self.se_early is not None: + x = self.se_early(x) + x = self.norm2(x) + if self.se is not None: + x = self.se(x) + + # 1x1 linear projection to output width + x = self.conv3_1x1(x) + x = self.drop_path(x) + shortcut + return x + + +class ConvNeXtBlock(nn.Module): + """ ConvNeXt Block + """ + + def __init__( + self, + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 7, + stride: int = 1, + dilation: Tuple[int, int] = (1, 1), + cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + conv_mlp: bool = True, + drop_path: float = 0. + ): + super().__init__() + out_chs = out_chs or in_chs + act_layer = get_act_layer(cfg.act_layer) + if conv_mlp: + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) + mlp_layer = ConvMlp + else: + assert 'layernorm' in cfg.norm_layer + norm_layer = LayerNorm + mlp_layer = Mlp + self.use_conv_mlp = conv_mlp + + if stride == 2: + self.shortcut = Downsample2d(in_chs, out_chs) + elif in_chs != out_chs: + self.shortcut = nn.Conv2d(in_chs, out_chs, kernel_size=1, bias=cfg.output_bias) + else: + self.shortcut = nn.Identity() + + assert cfg.stride_mode in ('pool', 'dw') + stride_pool, stride_dw = 1, 1 + # FIXME handle dilation? + if cfg.stride_mode == 'pool': + stride_pool = stride + else: + stride_dw = stride + + if stride_pool == 2: + self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type) + else: + self.down = nn.Identity() + + self.conv_dw = create_conv2d( + in_chs, out_chs, kernel_size=kernel_size, stride=stride_dw, dilation=dilation[1], + depthwise=True, bias=cfg.output_bias) + self.norm = norm_layer(out_chs) + self.mlp = mlp_layer(out_chs, int(cfg.expand_ratio * out_chs), bias=cfg.output_bias, act_layer=act_layer) + if conv_mlp: + self.ls = LayerScale2d(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() + else: + self.ls = LayerScale(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = self.shortcut(x) + x = self.down(x) + x = self.conv_dw(x) + if self.use_conv_mlp: + x = self.norm(x) + x = self.mlp(x) + x = self.ls(x) + else: + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = self.mlp(x) + x = self.ls(x) + x = x.permute(0, 3, 1, 2) + + x = self.drop_path(x) + shortcut + return x + + +def window_partition(x, window_size: List[int]): + B, H, W, C = x.shape + _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') + _assert(W % window_size[1] == 0, '') + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows, window_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[-1] + x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) + return x + + +def grid_partition(x, grid_size: List[int]): + B, H, W, C = x.shape + _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') + _assert(W % grid_size[1] == 0, '') + x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C) + windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def grid_reverse(windows, grid_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[-1] + x = windows.view(-1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C) + x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C) + return x + + +def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size): + rel_pos_cls = None + if cfg.rel_pos_type == 'mlp': + rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim) + elif cfg.rel_pos_type == 'bias': + rel_pos_cls = partial(RelPosBias, window_size=window_size) + elif cfg.rel_pos_type == 'bias_tf': + rel_pos_cls = partial(RelPosBiasTf, window_size=window_size) + return rel_pos_cls + + +class PartitionAttentionCl(nn.Module): + """ Grid or Block partition + Attn + FFN. + NxC 'channels last' tensor layout. + """ + + def __init__( + self, + dim: int, + partition_type: str = 'block', + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last + act_layer = get_act_layer(cfg.act_layer) + + self.partition_block = partition_type == 'block' + self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) + rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) + + self.norm1 = norm_layer(dim) + self.attn = AttentionCl( + dim, + dim, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + head_first=cfg.head_first, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * cfg.expand_ratio), + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _partition_attn(self, x): + img_size = x.shape[1:3] + if self.partition_block: + partitioned = window_partition(x, self.partition_size) + else: + partitioned = grid_partition(x, self.partition_size) + + partitioned = self.attn(partitioned) + + if self.partition_block: + x = window_reverse(partitioned, self.partition_size, img_size) + else: + x = grid_reverse(partitioned, self.partition_size, img_size) + return x + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class ParallelPartitionAttention(nn.Module): + """ Experimental. Grid and Block partition + single FFN + NxC tensor layout. + """ + + def __init__( + self, + dim: int, + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + assert dim % 2 == 0 + norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last + act_layer = get_act_layer(cfg.act_layer) + + assert cfg.window_size == cfg.grid_size + self.partition_size = to_2tuple(cfg.window_size) + rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) + + self.norm1 = norm_layer(dim) + self.attn_block = AttentionCl( + dim, + dim // 2, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + head_first=cfg.head_first, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.attn_grid = AttentionCl( + dim, + dim // 2, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + head_first=cfg.head_first, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * cfg.expand_ratio), + out_features=dim, + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _partition_attn(self, x): + img_size = x.shape[1:3] + + partitioned_block = window_partition(x, self.partition_size) + partitioned_block = self.attn_block(partitioned_block) + x_window = window_reverse(partitioned_block, self.partition_size, img_size) + + partitioned_grid = grid_partition(x, self.partition_size) + partitioned_grid = self.attn_grid(partitioned_grid) + x_grid = grid_reverse(partitioned_grid, self.partition_size, img_size) + + return torch.cat([x_window, x_grid], dim=-1) + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +def window_partition_nchw(x, window_size: List[int]): + B, C, H, W = x.shape + _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') + _assert(W % window_size[1] == 0, '') + x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1]) + windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1]) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse_nchw(windows, window_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[1] + x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1]) + x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W) + return x + + +def grid_partition_nchw(x, grid_size: List[int]): + B, C, H, W = x.shape + _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') + _assert(W % grid_size[1] == 0, '') + x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1]) + windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1]) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]): + H, W = img_size + C = windows.shape[1] + x = windows.view(-1, H // grid_size[0], W // grid_size[1], C, grid_size[0], grid_size[1]) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous().view(-1, C, H, W) + return x + + +class PartitionAttention2d(nn.Module): + """ Grid or Block partition + Attn + FFN + + '2D' NCHW tensor layout. + """ + + def __init__( + self, + dim: int, + partition_type: str = 'block', + cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path: float = 0., + ): + super().__init__() + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) # NOTE this block is channels-last + act_layer = get_act_layer(cfg.act_layer) + + self.partition_block = partition_type == 'block' + self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) + rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) + + self.norm1 = norm_layer(dim) + self.attn = Attention2d( + dim, + dim, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + head_first=cfg.head_first, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.ls1 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = ConvMlp( + in_features=dim, + hidden_features=int(dim * cfg.expand_ratio), + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _partition_attn(self, x): + img_size = x.shape[-2:] + if self.partition_block: + partitioned = window_partition_nchw(x, self.partition_size) + else: + partitioned = grid_partition_nchw(x, self.partition_size) + + partitioned = self.attn(partitioned) + + if self.partition_block: + x = window_reverse_nchw(partitioned, self.partition_size, img_size) + else: + x = grid_reverse_nchw(partitioned, self.partition_size, img_size) + return x + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class MaxxVitBlock(nn.Module): + """ MaxVit conv, window partition + FFN , grid partition + FFN + """ + + def __init__( + self, + dim: int, + dim_out: int, + stride: int = 1, + conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU + use_block_attn: bool = True, # FIXME for testing ConvNeXt conv w/o block attention + drop_path: float = 0., + ): + super().__init__() + + conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock + self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) + + attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) + partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl + self.nchw_attn = use_nchw_attn + self.attn_block = partition_layer(**attn_kwargs) if use_block_attn else None + self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) + + def init_weights(self, scheme=''): + if self.attn_block is not None: + named_apply(partial(_init_transformer, scheme=scheme), self.attn_block) + named_apply(partial(_init_transformer, scheme=scheme), self.attn_grid) + named_apply(partial(_init_conv, scheme=scheme), self.conv) + + def forward(self, x): + # NCHW format + x = self.conv(x) + + if not self.nchw_attn: + x = x.permute(0, 2, 3, 1) # to NHWC (channels-last) + if self.attn_block is not None: + x = self.attn_block(x) + x = self.attn_grid(x) + if not self.nchw_attn: + x = x.permute(0, 3, 1, 2) # back to NCHW + return x + + +class ParallelMaxxVitBlock(nn.Module): + """ MaxVit block with parallel cat(window + grid), one FF + Experimental timm block. + """ + + def __init__( + self, + dim, + dim_out, + stride=1, + num_conv=2, + conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + drop_path=0., + ): + super().__init__() + + conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock + if num_conv > 1: + convs = [conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)] + convs += [conv_cls(dim_out, dim_out, cfg=conv_cfg, drop_path=drop_path)] * (num_conv - 1) + self.conv = nn.Sequential(*convs) + else: + self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) + self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) + + def init_weights(self, scheme=''): + named_apply(partial(_init_transformer, scheme=scheme), self.attn) + named_apply(partial(_init_conv, scheme=scheme), self.conv) + + def forward(self, x): + x = self.conv(x) + x = x.permute(0, 2, 3, 1) + x = self.attn(x) + x = x.permute(0, 3, 1, 2) + return x + + +class MaxxVitStage(nn.Module): + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 2, + depth: int = 4, + feat_size: Tuple[int, int] = (14, 14), + block_types: Union[str, Tuple[str]] = 'C', + transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), + conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), + drop_path: Union[float, List[float]] = 0., + ): + super().__init__() + self.grad_checkpointing = False + + block_types = extend_tuple(block_types, depth) + blocks = [] + for i, t in enumerate(block_types): + block_stride = stride if i == 0 else 1 + assert t in ('C', 'T', 'M', 'PM') + if t == 'C': + conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock + blocks += [conv_cls( + in_chs, + out_chs, + stride=block_stride, + cfg=conv_cfg, + drop_path=drop_path[i], + )] + elif t == 'T': + rel_pos_cls = get_rel_pos_cls(transformer_cfg, feat_size) + blocks += [TransformerBlock2d( + in_chs, + out_chs, + stride=block_stride, + rel_pos_cls=rel_pos_cls, + cfg=transformer_cfg, + drop_path=drop_path[i], + )] + elif t == 'M': + blocks += [MaxxVitBlock( + in_chs, + out_chs, + stride=block_stride, + conv_cfg=conv_cfg, + transformer_cfg=transformer_cfg, + drop_path=drop_path[i], + )] + elif t == 'PM': + blocks += [ParallelMaxxVitBlock( + in_chs, + out_chs, + stride=block_stride, + conv_cfg=conv_cfg, + transformer_cfg=transformer_cfg, + drop_path=drop_path[i], + )] + in_chs = out_chs + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class Stem(nn.Module): + + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int = 3, + padding: str = '', + bias: bool = False, + act_layer: str = 'gelu', + norm_layer: str = 'batchnorm2d', + norm_eps: float = 1e-5, + ): + super().__init__() + if not isinstance(out_chs, (list, tuple)): + out_chs = to_2tuple(out_chs) + + norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) + self.out_chs = out_chs[-1] + self.stride = 2 + + self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2, padding=padding, bias=bias) + self.norm1 = norm_act_layer(out_chs[0]) + self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1, padding=padding, bias=bias) + + def init_weights(self, scheme=''): + named_apply(partial(_init_conv, scheme=scheme), self) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) + return x + + +def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]): + if cfg.window_size is not None: + assert cfg.grid_size + return cfg + partition_size = img_size[0] // cfg.partition_ratio, img_size[1] // cfg.partition_ratio + cfg = replace(cfg, window_size=partition_size, grid_size=partition_size) + return cfg + + +def generate_lookup_tensor( + length: int, + max_relative_position: Optional[int] = None, +): + """Generate a one_hot lookup tensor to reindex embeddings along one dimension. + Args: + length: the length to reindex to. + max_relative_position: the maximum relative position to consider. + Relative position embeddings for distances above this threshold + are zeroed out. + Returns: + a lookup Tensor of size [length, length, vocab_size] that satisfies + ret[n,m,v] = 1{m - n + max_relative_position = v}. + """ + if max_relative_position is None: + max_relative_position = length - 1 + # Return the cached lookup tensor, otherwise compute it and cache it. + vocab_size = 2 * max_relative_position + 1 + ret = torch.zeros(length, length, vocab_size) + for i in range(length): + for x in range(length): + v = x - i + max_relative_position + if abs(x - i) > max_relative_position: + continue + ret[i, x, v] = 1 + return ret + + +def reindex_2d_einsum_lookup( + relative_position_tensor, + height: int, + width: int, + height_lookup: torch.Tensor, + width_lookup: torch.Tensor, +) -> torch.Tensor: + """Reindex 2d relative position bias with 2 independent einsum lookups. + Args: + relative_position_tensor: tensor of shape + [..., vocab_height, vocab_width, ...]. + height: height to reindex to. + width: width to reindex to. + height_lookup: one-hot height lookup + width_lookup: one-hot width lookup + Returns: + reindexed_tensor: a Tensor of shape + [..., height * width, height * width, ...] + """ + reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup) + reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup) + area = height * width + return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area) + + +class RelPosBiasTf(nn.Module): + + def __init__(self, window_size, num_heads, prefix_tokens=0): + super().__init__() + assert prefix_tokens <= 1 + self.window_size = window_size + self.window_area = window_size[0] * window_size[1] + self.num_heads = num_heads + + vocab_height = 2 * window_size[0] - 1 + vocab_width = 2 * window_size[1] - 1 + self.bias_shape = (self.num_heads, vocab_height, vocab_width) + self.relative_position_bias_table = nn.Parameter(torch.zeros(self.bias_shape)) + self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0]), persistent=False) + self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1]), persistent=False) + self.init_weights() + + def init_weights(self): + nn.init.normal_(self.relative_position_bias_table, std=.02) + + def get_bias(self) -> torch.Tensor: + # FIXME change to not use one-hot/einsum? + return reindex_2d_einsum_lookup( + self.relative_position_bias_table, + self.window_size[0], + self.window_size[1], + self.height_lookup, + self.width_lookup + ) + + def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): + return attn + self.get_bias() + + +class NormMlpHead(nn.Module): + + def __init__( + self, + in_features, + num_classes, + hidden_size=None, + pool_type='avg', + drop_rate=0., + norm_layer=nn.LayerNorm, + act_layer=nn.Tanh, + ): + super().__init__() + self.drop_rate = drop_rate + self.num_features = in_features + + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) + self.norm = norm_layer(in_features) + self.flatten = nn.Flatten(1) if pool_type else nn.Identity() + if hidden_size: + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(in_features, hidden_size)), + ('act', act_layer()), + ])) + self.num_features = hidden_size + else: + self.pre_logits = nn.Identity() + self.drop = nn.Dropout(self.drop_rate) + self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward(self, x, pre_logits: bool = False): + x = self.global_pool(x) + x = self.norm(x) + x = self.flatten(x) + x = self.pre_logits(x) + if pre_logits: + return x + x = self.fc(x) + return x + + +class MaxxVit(nn.Module): + """ CoaTNet + MaxVit base model. + + Highly configurable for different block compositions, tensor layouts, pooling types. + """ + + def __init__( + self, + cfg: MaxxVitCfg, + img_size: Union[int, Tuple[int, int]] = 224, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + drop_rate: float = 0., + drop_path_rate: float = 0. + ): + super().__init__() + img_size = to_2tuple(img_size) + transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size) + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = cfg.embed_dim[-1] + self.drop_rate = drop_rate + self.grad_checkpointing = False + + self.stem = Stem( + in_chs=in_chans, + out_chs=cfg.stem_width, + padding=cfg.conv_cfg.padding, + bias=cfg.stem_bias, + act_layer=cfg.conv_cfg.act_layer, + norm_layer=cfg.conv_cfg.norm_layer, + norm_eps=cfg.conv_cfg.norm_eps, + ) + + stride = self.stem.stride + feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))]) + + num_stages = len(cfg.embed_dim) + assert len(cfg.depths) == num_stages + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + in_chs = self.stem.out_chs + stages = [] + for i in range(num_stages): + stage_stride = 2 + out_chs = cfg.embed_dim[i] + feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size]) + stages += [MaxxVitStage( + in_chs, + out_chs, + depth=cfg.depths[i], + block_types=cfg.block_type[i], + conv_cfg=cfg.conv_cfg, + transformer_cfg=transformer_cfg, + feat_size=feat_size, + drop_path=dpr[i], + )] + stride *= stage_stride + in_chs = out_chs + self.stages = nn.Sequential(*stages) + + final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps) + if cfg.head_hidden_size: + self.norm = nn.Identity() + self.head = NormMlpHead( + self.num_features, + num_classes, + hidden_size=cfg.head_hidden_size, + pool_type=global_pool, + drop_rate=drop_rate, + norm_layer=final_norm_layer, + ) + else: + # standard classifier head w/ norm, pooling, fc classifier + self.norm = final_norm_layer(self.num_features) + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + # Weight init (default PyTorch init works well for AdamW if scheme not set) + assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff') + if cfg.weight_init: + named_apply(partial(self._init_weights, scheme=cfg.weight_init), self) + + def _init_weights(self, module, name, scheme=''): + if hasattr(module, 'init_weights'): + try: + module.init_weights(scheme=scheme) + except TypeError: + module.init_weights() + + @torch.jit.ignore + def no_weight_decay(self): + return { + k for k, _ in self.named_parameters() + if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', # stem and embed + blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is None: + global_pool = self.head.global_pool.pool_type + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + def _rw_coat_cfg( stride_mode='pool', pool_type='avg2', @@ -367,6 +1470,22 @@ def _next_cfg( ) +def _tf_cfg(): + return dict( + conv_cfg=MaxxVitConvCfg( + norm_eps=1e-3, + act_layer='gelu_tanh', + padding='same', + ), + transformer_cfg=MaxxVitTransformerCfg( + norm_eps=1e-5, + act_layer='gelu_tanh', + head_first=False, # heads are interleaved (q_nh, q_hdim, k_nh, q_hdim, ....) + rel_pos_type='bias_tf', + ), + ) + + model_cfgs = dict( # Fiddling with configs / defaults / still pretraining coatnet_pico_rw_224=MaxxVitCfg( @@ -414,7 +1533,7 @@ model_cfgs = dict( **_rw_coat_cfg( stride_mode='dw', conv_attn_act_layer='silu', - init_values=1e-6, + #init_values=1e-6, ), ), coatnet_3_rw_224=MaxxVitCfg( @@ -472,6 +1591,16 @@ model_cfgs = dict( rel_pos_dim=384, # was supposed to be 512, woops ), ), + coatnet_rmlp_1_rw2_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + rel_pos_type='mlp', + rel_pos_dim=512, # was supposed to be 512, woops + ), + ), coatnet_rmlp_2_rw_224=MaxxVitCfg( embed_dim=(128, 256, 512, 1024), depths=(2, 6, 14, 2), @@ -647,1074 +1776,70 @@ model_cfgs = dict( stem_width=(48, 96), **_next_cfg(), ), + maxxvit_rmlp_base_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=(48, 96), + **_next_cfg(), + ), + maxxvit_rmlp_large_rw_224=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 12, 2), + block_type=('M',) * 4, + stem_width=(64, 128), + **_next_cfg(), + ), # Trying to be like the MaxViT paper configs - maxvit_tiny_224=MaxxVitCfg( + maxvit_tiny_tf=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=64, + stem_bias=True, + head_hidden_size=512, + **_tf_cfg(), ), - maxvit_small_224=MaxxVitCfg( + maxvit_small_tf=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=64, + stem_bias=True, + head_hidden_size=768, + **_tf_cfg(), ), - maxvit_base_224=MaxxVitCfg( + maxvit_base_tf=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 6, 14, 2), block_type=('M',) * 4, stem_width=64, + stem_bias=True, + head_hidden_size=768, + **_tf_cfg(), ), - maxvit_large_224=MaxxVitCfg( + maxvit_large_tf=MaxxVitCfg( embed_dim=(128, 256, 512, 1024), depths=(2, 6, 14, 2), block_type=('M',) * 4, stem_width=128, + stem_bias=True, + head_hidden_size=1024, + **_tf_cfg(), ), - maxvit_xlarge_224=MaxxVitCfg( + maxvit_xlarge_tf=MaxxVitCfg( embed_dim=(192, 384, 768, 1536), depths=(2, 6, 14, 2), block_type=('M',) * 4, stem_width=192, + stem_bias=True, + head_hidden_size=1536, + **_tf_cfg(), ), - ) -class Attention2d(nn.Module): - """ multi-head attention for 2D NCHW tensors""" - def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - dim_head: int = 32, - bias: bool = True, - expand_first: bool = True, - rel_pos_cls: Callable = None, - attn_drop: float = 0., - proj_drop: float = 0. - ): - super().__init__() - dim_out = dim_out or dim - dim_attn = dim_out if expand_first else dim - self.num_heads = dim_attn // dim_head - self.dim_head = dim_head - self.scale = dim_head ** -0.5 - - self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) - self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): - B, C, H, W = x.shape - - q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2) - - attn = (q.transpose(-2, -1) @ k) * self.scale - if self.rel_pos is not None: - attn = self.rel_pos(attn) - elif shared_rel_pos is not None: - attn = attn + shared_rel_pos - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class AttentionCl(nn.Module): - """ Channels-last multi-head attention (B, ..., C) """ - def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - dim_head: int = 32, - bias: bool = True, - expand_first: bool = True, - rel_pos_cls: Callable = None, - attn_drop: float = 0., - proj_drop: float = 0. - ): - super().__init__() - dim_out = dim_out or dim - dim_attn = dim_out if expand_first and dim_out > dim else dim - assert dim_attn % dim_head == 0, 'attn dim should be divisible by head_dim' - self.num_heads = dim_attn // dim_head - self.dim_head = dim_head - self.scale = dim_head ** -0.5 - - self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias) - self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim_attn, dim_out, bias=bias) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): - B = x.shape[0] - restore_shape = x.shape[:-1] - - q, k, v = self.qkv(x).view(B, -1, self.num_heads, self.dim_head * 3).transpose(1, 2).chunk(3, dim=3) - - attn = (q @ k.transpose(-2, -1)) * self.scale - if self.rel_pos is not None: - attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) - elif shared_rel_pos is not None: - attn = attn + shared_rel_pos - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(restore_shape + (-1,)) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class LayerScale(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x): - gamma = self.gamma - return x.mul_(gamma) if self.inplace else x * gamma - - -class LayerScale2d(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x): - gamma = self.gamma.view(1, -1, 1, 1) - return x.mul_(gamma) if self.inplace else x * gamma - - -class Downsample2d(nn.Module): - """ A downsample pooling module supporting several maxpool and avgpool modes - * 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1 - * 'max2' - MaxPool2d w/ kernel_size = stride = 2 - * 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1 - * 'avg2' - AvgPool2d w/ kernel_size = stride = 2 - """ - - def __init__( - self, - dim: int, - dim_out: int, - pool_type: str = 'avg2', - bias: bool = True, - ): - super().__init__() - assert pool_type in ('max', 'max2', 'avg', 'avg2') - if pool_type == 'max': - self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - elif pool_type == 'max2': - self.pool = nn.MaxPool2d(2) # kernel_size == stride == 2 - elif pool_type == 'avg': - self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) - else: - self.pool = nn.AvgPool2d(2) # kernel_size == stride == 2 - - if dim != dim_out: - self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) - else: - self.expand = nn.Identity() - - def forward(self, x): - x = self.pool(x) # spatial downsample - x = self.expand(x) # expand chs - return x - - -def _init_transformer(module, name, scheme=''): - if isinstance(module, (nn.Conv2d, nn.Linear)): - if scheme == 'normal': - nn.init.normal_(module.weight, std=.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif scheme == 'trunc_normal': - trunc_normal_tf_(module.weight, std=.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif scheme == 'xavier_normal': - nn.init.xavier_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - else: - # vit like - nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - if 'mlp' in name: - nn.init.normal_(module.bias, std=1e-6) - else: - nn.init.zeros_(module.bias) - - -class TransformerBlock2d(nn.Module): - """ Transformer block with 2D downsampling - '2D' NCHW tensor layout - - Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW - for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs. - - This impl was faster on TPU w/ PT XLA than the 1D experiment. - """ - - def __init__( - self, - dim: int, - dim_out: int, - stride: int = 1, - rel_pos_cls: Callable = None, - cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), - drop_path: float = 0., - ): - super().__init__() - norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) - act_layer = get_act_layer(cfg.act_layer) - - if stride == 2: - self.shortcut = Downsample2d(dim, dim_out, pool_type=cfg.pool_type, bias=cfg.shortcut_bias) - self.norm1 = nn.Sequential(OrderedDict([ - ('norm', norm_layer(dim)), - ('down', Downsample2d(dim, dim, pool_type=cfg.pool_type)), - ])) - else: - assert dim == dim_out - self.shortcut = nn.Identity() - self.norm1 = norm_layer(dim) - - self.attn = Attention2d( - dim, - dim_out, - dim_head=cfg.dim_head, - expand_first=cfg.expand_first, - bias=cfg.attn_bias, - rel_pos_cls=rel_pos_cls, - attn_drop=cfg.attn_drop, - proj_drop=cfg.proj_drop - ) - self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() - self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - self.norm2 = norm_layer(dim_out) - self.mlp = ConvMlp( - in_features=dim_out, - hidden_features=int(dim_out * cfg.expand_ratio), - act_layer=act_layer, - drop=cfg.proj_drop) - self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() - self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - def init_weights(self, scheme=''): - named_apply(partial(_init_transformer, scheme=scheme), self) - - def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): - x = self.shortcut(x) + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos))) - x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) - return x - - -def _init_conv(module, name, scheme=''): - if isinstance(module, nn.Conv2d): - if scheme == 'normal': - nn.init.normal_(module.weight, std=.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif scheme == 'trunc_normal': - trunc_normal_tf_(module.weight, std=.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif scheme == 'xavier_normal': - nn.init.xavier_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - else: - # efficientnet like - fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels - fan_out //= module.groups - nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out)) - if module.bias is not None: - nn.init.zeros_(module.bias) - - -def num_groups(group_size, channels): - if not group_size: # 0 or None - return 1 # normal conv with 1 group - else: - # NOTE group_size == 1 -> depthwise conv - assert channels % group_size == 0 - return channels // group_size - - -class MbConvBlock(nn.Module): - """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand) - """ - def __init__( - self, - in_chs: int, - out_chs: int, - stride: int = 1, - dilation: Tuple[int, int] = (1, 1), - cfg: MaxxVitConvCfg = MaxxVitConvCfg(), - drop_path: float = 0. - ): - super(MbConvBlock, self).__init__() - norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps) - mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio) - groups = num_groups(cfg.group_size, mid_chs) - - if stride == 2: - self.shortcut = Downsample2d(in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias) - else: - self.shortcut = nn.Identity() - - assert cfg.stride_mode in ('pool', '1x1', 'dw') - stride_pool, stride_1, stride_2 = 1, 1, 1 - if cfg.stride_mode == 'pool': - # NOTE this is not described in paper, experiment to find faster option that doesn't stride in 1x1 - stride_pool, dilation_2 = stride, dilation[1] - # FIXME handle dilation of avg pool - elif cfg.stride_mode == '1x1': - # NOTE I don't like this option described in paper, 1x1 w/ stride throws info away - stride_1, dilation_2 = stride, dilation[1] - else: - stride_2, dilation_2 = stride, dilation[0] - - self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act) - if stride_pool > 1: - self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type) - else: - self.down = nn.Identity() - self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1) - self.norm1 = norm_act_layer(mid_chs) - - self.conv2_kxk = create_conv2d( - mid_chs, mid_chs, cfg.kernel_size, stride=stride_2, dilation=dilation_2, groups=groups) - - attn_kwargs = {} - if isinstance(cfg.attn_layer, str): - if cfg.attn_layer == 'se' or cfg.attn_layer == 'eca': - attn_kwargs['act_layer'] = cfg.attn_act_layer - attn_kwargs['rd_channels'] = int(cfg.attn_ratio * (out_chs if cfg.expand_output else mid_chs)) - - # two different orderings for SE and norm2 (due to some weights and trials using SE before norm2) - if cfg.attn_early: - self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) - self.norm2 = norm_act_layer(mid_chs) - self.se = None - else: - self.se_early = None - self.norm2 = norm_act_layer(mid_chs) - self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) - - self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - def init_weights(self, scheme=''): - named_apply(partial(_init_conv, scheme=scheme), self) - - def forward(self, x): - shortcut = self.shortcut(x) - x = self.pre_norm(x) - x = self.down(x) - - # 1x1 expansion conv & norm-act - x = self.conv1_1x1(x) - x = self.norm1(x) - - # depthwise / grouped 3x3 conv w/ SE (or other) channel attention & norm-act - x = self.conv2_kxk(x) - if self.se_early is not None: - x = self.se_early(x) - x = self.norm2(x) - if self.se is not None: - x = self.se(x) - - # 1x1 linear projection to output width - x = self.conv3_1x1(x) - x = self.drop_path(x) + shortcut - return x - - -class ConvNeXtBlock(nn.Module): - """ ConvNeXt Block - """ - - def __init__( - self, - in_chs: int, - out_chs: Optional[int] = None, - kernel_size: int = 7, - stride: int = 1, - dilation: Tuple[int, int] = (1, 1), - cfg: MaxxVitConvCfg = MaxxVitConvCfg(), - conv_mlp: bool = True, - drop_path: float = 0. - ): - super().__init__() - out_chs = out_chs or in_chs - act_layer = get_act_layer(cfg.act_layer) - if conv_mlp: - norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) - mlp_layer = ConvMlp - else: - assert 'layernorm' in cfg.norm_layer - norm_layer = LayerNorm - mlp_layer = Mlp - self.use_conv_mlp = conv_mlp - - if stride == 2: - self.shortcut = Downsample2d(in_chs, out_chs) - elif in_chs != out_chs: - self.shortcut = nn.Conv2d(in_chs, out_chs, kernel_size=1, bias=cfg.output_bias) - else: - self.shortcut = nn.Identity() - - assert cfg.stride_mode in ('pool', 'dw') - stride_pool, stride_dw = 1, 1 - # FIXME handle dilation? - if cfg.stride_mode == 'pool': - stride_pool = stride - else: - stride_dw = stride - - if stride_pool == 2: - self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type) - else: - self.down = nn.Identity() - - self.conv_dw = create_conv2d( - in_chs, out_chs, kernel_size=kernel_size, stride=stride_dw, dilation=dilation[1], - depthwise=True, bias=cfg.output_bias) - self.norm = norm_layer(out_chs) - self.mlp = mlp_layer(out_chs, int(cfg.expand_ratio * out_chs), bias=cfg.output_bias, act_layer=act_layer) - if conv_mlp: - self.ls = LayerScale2d(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() - else: - self.ls = LayerScale(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - def forward(self, x): - shortcut = self.shortcut(x) - x = self.down(x) - x = self.conv_dw(x) - if self.use_conv_mlp: - x = self.norm(x) - x = self.mlp(x) - x = self.ls(x) - else: - x = x.permute(0, 2, 3, 1) - x = self.norm(x) - x = self.mlp(x) - x = self.ls(x) - x = x.permute(0, 3, 1, 2) - - x = self.drop_path(x) + shortcut - return x - - -def window_partition(x, window_size: List[int]): - B, H, W, C = x.shape - _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') - _assert(W % window_size[1] == 0, '') - x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) - return windows - - -@register_notrace_function # reason: int argument is a Proxy -def window_reverse(windows, window_size: List[int], img_size: List[int]): - H, W = img_size - C = windows.shape[-1] - x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) - return x - - -def grid_partition(x, grid_size: List[int]): - B, H, W, C = x.shape - _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') - _assert(W % grid_size[1] == 0, '') - x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C) - windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C) - return windows - - -@register_notrace_function # reason: int argument is a Proxy -def grid_reverse(windows, grid_size: List[int], img_size: List[int]): - H, W = img_size - C = windows.shape[-1] - x = windows.view(-1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C) - x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C) - return x - - -def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size): - rel_pos_cls = None - if cfg.rel_pos_type == 'mlp': - rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim) - elif cfg.rel_pos_type == 'bias': - rel_pos_cls = partial(RelPosBias, window_size=window_size) - return rel_pos_cls - - -class PartitionAttentionCl(nn.Module): - """ Grid or Block partition + Attn + FFN. - NxC 'channels last' tensor layout. - """ - - def __init__( - self, - dim: int, - partition_type: str = 'block', - cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), - drop_path: float = 0., - ): - super().__init__() - norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last - act_layer = get_act_layer(cfg.act_layer) - - self.partition_block = partition_type == 'block' - self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) - rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) - - self.norm1 = norm_layer(dim) - self.attn = AttentionCl( - dim, - dim, - dim_head=cfg.dim_head, - bias=cfg.attn_bias, - rel_pos_cls=rel_pos_cls, - attn_drop=cfg.attn_drop, - proj_drop=cfg.proj_drop, - ) - self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() - self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - self.norm2 = norm_layer(dim) - self.mlp = Mlp( - in_features=dim, - hidden_features=int(dim * cfg.expand_ratio), - act_layer=act_layer, - drop=cfg.proj_drop) - self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() - self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - def _partition_attn(self, x): - img_size = x.shape[1:3] - if self.partition_block: - partitioned = window_partition(x, self.partition_size) - else: - partitioned = grid_partition(x, self.partition_size) - - partitioned = self.attn(partitioned) - - if self.partition_block: - x = window_reverse(partitioned, self.partition_size, img_size) - else: - x = grid_reverse(partitioned, self.partition_size, img_size) - return x - - def forward(self, x): - x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) - x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) - return x - - -class ParallelPartitionAttention(nn.Module): - """ Experimental. Grid and Block partition + single FFN - NxC tensor layout. - """ - - def __init__( - self, - dim: int, - cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), - drop_path: float = 0., - ): - super().__init__() - assert dim % 2 == 0 - norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last - act_layer = get_act_layer(cfg.act_layer) - - assert cfg.window_size == cfg.grid_size - self.partition_size = to_2tuple(cfg.window_size) - rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) - - self.norm1 = norm_layer(dim) - self.attn_block = AttentionCl( - dim, - dim // 2, - dim_head=cfg.dim_head, - bias=cfg.attn_bias, - rel_pos_cls=rel_pos_cls, - attn_drop=cfg.attn_drop, - proj_drop=cfg.proj_drop, - ) - self.attn_grid = AttentionCl( - dim, - dim // 2, - dim_head=cfg.dim_head, - bias=cfg.attn_bias, - rel_pos_cls=rel_pos_cls, - attn_drop=cfg.attn_drop, - proj_drop=cfg.proj_drop, - ) - self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() - self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - self.norm2 = norm_layer(dim) - self.mlp = Mlp( - in_features=dim, - hidden_features=int(dim * cfg.expand_ratio), - out_features=dim, - act_layer=act_layer, - drop=cfg.proj_drop) - self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() - self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - def _partition_attn(self, x): - img_size = x.shape[1:3] - - partitioned_block = window_partition(x, self.partition_size) - partitioned_block = self.attn_block(partitioned_block) - x_window = window_reverse(partitioned_block, self.partition_size, img_size) - - partitioned_grid = grid_partition(x, self.partition_size) - partitioned_grid = self.attn_grid(partitioned_grid) - x_grid = grid_reverse(partitioned_grid, self.partition_size, img_size) - - return torch.cat([x_window, x_grid], dim=-1) - - def forward(self, x): - x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) - x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) - return x - - -def window_partition_nchw(x, window_size: List[int]): - B, C, H, W = x.shape - _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') - _assert(W % window_size[1] == 0, '') - x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1]) - windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1]) - return windows - - -@register_notrace_function # reason: int argument is a Proxy -def window_reverse_nchw(windows, window_size: List[int], img_size: List[int]): - H, W = img_size - C = windows.shape[1] - x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1]) - x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W) - return x - - -def grid_partition_nchw(x, grid_size: List[int]): - B, C, H, W = x.shape - _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') - _assert(W % grid_size[1] == 0, '') - x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1]) - windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1]) - return windows - - -@register_notrace_function # reason: int argument is a Proxy -def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]): - H, W = img_size - C = windows.shape[1] - x = windows.view(-1, H // grid_size[0], W // grid_size[1], C, grid_size[0], grid_size[1]) - x = x.permute(0, 3, 4, 1, 5, 2).contiguous().view(-1, C, H, W) - return x - - -class PartitionAttention2d(nn.Module): - """ Grid or Block partition + Attn + FFN - - '2D' NCHW tensor layout. - """ - - def __init__( - self, - dim: int, - partition_type: str = 'block', - cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), - drop_path: float = 0., - ): - super().__init__() - norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) # NOTE this block is channels-last - act_layer = get_act_layer(cfg.act_layer) - - self.partition_block = partition_type == 'block' - self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) - rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) - - self.norm1 = norm_layer(dim) - self.attn = Attention2d( - dim, - dim, - dim_head=cfg.dim_head, - bias=cfg.attn_bias, - rel_pos_cls=rel_pos_cls, - attn_drop=cfg.attn_drop, - proj_drop=cfg.proj_drop, - ) - self.ls1 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() - self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - self.norm2 = norm_layer(dim) - self.mlp = ConvMlp( - in_features=dim, - hidden_features=int(dim * cfg.expand_ratio), - act_layer=act_layer, - drop=cfg.proj_drop) - self.ls2 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() - self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - def _partition_attn(self, x): - img_size = x.shape[-2:] - if self.partition_block: - partitioned = window_partition_nchw(x, self.partition_size) - else: - partitioned = grid_partition_nchw(x, self.partition_size) - - partitioned = self.attn(partitioned) - - if self.partition_block: - x = window_reverse_nchw(partitioned, self.partition_size, img_size) - else: - x = grid_reverse_nchw(partitioned, self.partition_size, img_size) - return x - - def forward(self, x): - x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) - x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) - return x - - -class MaxxVitBlock(nn.Module): - """ MaxVit conv, window partition + FFN , grid partition + FFN - """ - - def __init__( - self, - dim: int, - dim_out: int, - stride: int = 1, - conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), - transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), - use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU - drop_path: float = 0., - ): - super().__init__() - - conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock - self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) - - attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) - partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl - self.nchw_attn = use_nchw_attn - self.attn_block = partition_layer(**attn_kwargs) - self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) - - def init_weights(self, scheme=''): - named_apply(partial(_init_transformer, scheme=scheme), self.attn_block) - named_apply(partial(_init_transformer, scheme=scheme), self.attn_grid) - named_apply(partial(_init_conv, scheme=scheme), self.conv) - - def forward(self, x): - # NCHW format - x = self.conv(x) - - if not self.nchw_attn: - x = x.permute(0, 2, 3, 1) # to NHWC (channels-last) - x = self.attn_block(x) - x = self.attn_grid(x) - if not self.nchw_attn: - x = x.permute(0, 3, 1, 2) # back to NCHW - return x - - -class ParallelMaxxVitBlock(nn.Module): - """ MaxVit block with parallel cat(window + grid), one FF - Experimental timm block. - """ - - def __init__( - self, - dim, - dim_out, - stride=1, - num_conv=2, - conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), - transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), - drop_path=0., - ): - super().__init__() - - conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock - if num_conv > 1: - convs = [conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)] - convs += [conv_cls(dim_out, dim_out, cfg=conv_cfg, drop_path=drop_path)] * (num_conv - 1) - self.conv = nn.Sequential(*convs) - else: - self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) - self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) - - def init_weights(self, scheme=''): - named_apply(partial(_init_transformer, scheme=scheme), self.attn) - named_apply(partial(_init_conv, scheme=scheme), self.conv) - - def forward(self, x): - x = self.conv(x) - x = x.permute(0, 2, 3, 1) - x = self.attn(x) - x = x.permute(0, 3, 1, 2) - return x - - -class MaxxVitStage(nn.Module): - def __init__( - self, - in_chs: int, - out_chs: int, - stride: int = 2, - depth: int = 4, - feat_size: Tuple[int, int] = (14, 14), - block_types: Union[str, Tuple[str]] = 'C', - transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), - conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), - drop_path: Union[float, List[float]] = 0., - ): - super().__init__() - self.grad_checkpointing = False - - block_types = extend_tuple(block_types, depth) - blocks = [] - for i, t in enumerate(block_types): - block_stride = stride if i == 0 else 1 - assert t in ('C', 'T', 'M', 'PM') - if t == 'C': - conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock - blocks += [conv_cls( - in_chs, - out_chs, - stride=block_stride, - cfg=conv_cfg, - drop_path=drop_path[i], - )] - elif t == 'T': - rel_pos_cls = get_rel_pos_cls(transformer_cfg, feat_size) - blocks += [TransformerBlock2d( - in_chs, - out_chs, - stride=block_stride, - rel_pos_cls=rel_pos_cls, - cfg=transformer_cfg, - drop_path=drop_path[i], - )] - elif t == 'M': - blocks += [MaxxVitBlock( - in_chs, - out_chs, - stride=block_stride, - conv_cfg=conv_cfg, - transformer_cfg=transformer_cfg, - drop_path=drop_path[i], - )] - elif t == 'PM': - blocks += [ParallelMaxxVitBlock( - in_chs, - out_chs, - stride=block_stride, - conv_cfg=conv_cfg, - transformer_cfg=transformer_cfg, - drop_path=drop_path[i], - )] - in_chs = out_chs - self.blocks = nn.Sequential(*blocks) - - def forward(self, x): - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x) - else: - x = self.blocks(x) - return x - - -class Stem(nn.Module): - - def __init__( - self, - in_chs: int, - out_chs: int, - kernel_size: int = 3, - act_layer: str = 'gelu', - norm_layer: str = 'batchnorm2d', - norm_eps: float = 1e-5, - ): - super().__init__() - if not isinstance(out_chs, (list, tuple)): - out_chs = to_2tuple(out_chs) - - norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) - self.out_chs = out_chs[-1] - self.stride = 2 - - self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2) - self.norm1 = norm_act_layer(out_chs[0]) - self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1) - - def init_weights(self, scheme=''): - named_apply(partial(_init_conv, scheme=scheme), self) - - def forward(self, x): - x = self.conv1(x) - x = self.norm1(x) - x = self.conv2(x) - return x - - -def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]): - if cfg.window_size is not None: - assert cfg.grid_size - return cfg - partition_size = img_size[0] // cfg.partition_ratio, img_size[1] // cfg.partition_ratio - cfg = replace(cfg, window_size=partition_size, grid_size=partition_size) - return cfg - - -class MaxxVit(nn.Module): - """ CoaTNet + MaxVit base model. - - Highly configurable for different block compositions, tensor layouts, pooling types. - """ - - def __init__( - self, - cfg: MaxxVitCfg, - img_size: Union[int, Tuple[int, int]] = 224, - in_chans: int = 3, - num_classes: int = 1000, - global_pool: str = 'avg', - drop_rate: float = 0., - drop_path_rate: float = 0. - ): - super().__init__() - img_size = to_2tuple(img_size) - transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size) - self.num_classes = num_classes - self.global_pool = global_pool - self.num_features = cfg.embed_dim[-1] - self.embed_dim = cfg.embed_dim - self.drop_rate = drop_rate - self.grad_checkpointing = False - - self.stem = Stem( - in_chs=in_chans, - out_chs=cfg.stem_width, - act_layer=cfg.conv_cfg.act_layer, - norm_layer=cfg.conv_cfg.norm_layer, - norm_eps=cfg.conv_cfg.norm_eps, - ) - - stride = self.stem.stride - feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))]) - - num_stages = len(cfg.embed_dim) - assert len(cfg.depths) == num_stages - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] - in_chs = self.stem.out_chs - stages = [] - for i in range(num_stages): - stage_stride = 2 - out_chs = cfg.embed_dim[i] - feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size]) - stages += [MaxxVitStage( - in_chs, - out_chs, - depth=cfg.depths[i], - block_types=cfg.block_type[i], - conv_cfg=cfg.conv_cfg, - transformer_cfg=transformer_cfg, - feat_size=feat_size, - drop_path=dpr[i], - )] - stride *= stage_stride - in_chs = out_chs - self.stages = nn.Sequential(*stages) - - final_norm_layer = get_norm_layer(cfg.transformer_cfg.norm_layer) - self.norm = final_norm_layer(self.num_features, eps=cfg.transformer_cfg.norm_eps) - - # Classifier head - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) - - # Weight init (default PyTorch init works well for AdamW if scheme not set) - assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff') - if cfg.weight_init: - named_apply(partial(self._init_weights, scheme=cfg.weight_init), self) - - def _init_weights(self, module, name, scheme=''): - if hasattr(module, 'init_weights'): - try: - module.init_weights(scheme=scheme) - except TypeError: - module.init_weights() - - @torch.jit.ignore - def no_weight_decay(self): - return { - k for k, _ in self.named_parameters() - if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])} - - @torch.jit.ignore - def group_matcher(self, coarse=False): - matcher = dict( - stem=r'^stem', # stem and embed - blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] - ) - return matcher - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - for s in self.stages: - s.grad_checkpointing = enable - - @torch.jit.ignore - def get_classifier(self): - return self.head.fc - - def reset_classifier(self, num_classes, global_pool=None): - self.num_classes = num_classes - if global_pool is None: - global_pool = self.head.global_pool.pool_type - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - - def forward_features(self, x): - x = self.stem(x) - x = self.stages(x) - x = self.norm(x) - return x - - def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) - - def forward(self, x): - x = self.forward_features(x) - x = self.forward_head(x) - return x - - def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs): return build_model_with_cfg( MaxxVit, variant, pretrained, @@ -1723,6 +1848,183 @@ def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs): **kwargs) +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.95, 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'stem.conv1', 'classifier': 'head.fc', + 'fixed_input_size': True, + **kwargs + } + + +default_cfgs = generate_defaults({ + # Fiddling with configs / defaults / still pretraining + 'coatnet_pico_rw_224': _cfg(url=''), + 'coatnet_nano_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth', + crop_pct=0.9), + 'coatnet_0_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'), + 'coatnet_1_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' + ), + 'coatnet_2_rw_224': _cfg(url=''), + 'coatnet_3_rw_224': _cfg(url=''), + + # Highly experimental configs + 'coatnet_bn_0_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=0.95), + 'coatnet_rmlp_nano_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth', + crop_pct=0.9), + 'coatnet_rmlp_0_rw_224': _cfg(url=''), + 'coatnet_rmlp_1_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), + 'coatnet_rmlp_1_rw2_224': _cfg(url=''), + 'coatnet_rmlp_2_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'), + 'coatnet_rmlp_3_rw_224': _cfg(url=''), + 'coatnet_nano_cc_224': _cfg(url=''), + 'coatnext_nano_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth', + crop_pct=0.9), + + # Trying to be like the CoAtNet paper configs + 'coatnet_0_224': _cfg(url=''), + 'coatnet_1_224': _cfg(url=''), + 'coatnet_2_224': _cfg(url=''), + 'coatnet_3_224': _cfg(url=''), + 'coatnet_4_224': _cfg(url=''), + 'coatnet_5_224': _cfg(url=''), + + # Experimental configs + 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_nano_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_tiny_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'), + 'maxvit_tiny_rw_256': _cfg( + url='', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_pico_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_nano_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_tiny_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_small_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth', + crop_pct=0.9, + ), + 'maxvit_rmlp_small_rw_256': _cfg( + url='', + input_size=(3, 256, 256), pool_size=(8, 8)), + + 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + + 'maxxvit_rmlp_nano_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_rmlp_small_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_rmlp_base_rw_224': _cfg(url=''), + 'maxxvit_rmlp_large_rw_224': _cfg(url=''), + + + # Trying to be like the MaxViT paper configs + 'maxvit_tiny_tf_224.in1k': _cfg( + url='', + #file='maxvit_tiny_tf_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'maxvit_tiny_tf_384.in1k': _cfg( + url='', + #file='maxvit_tiny_tf_384_in1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + 'maxvit_tiny_tf_512.in1k': _cfg( + url='', + #file='maxvit_tiny_tf_512_in1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + 'maxvit_small_tf_224.in1k': _cfg( + url='', + #file='maxvit_small_tf_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'maxvit_small_tf_384.in1k': _cfg( + url='', + #file='maxvit_small_tf_384_in1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + 'maxvit_small_tf_512.in1k': _cfg( + url='', + #file='maxvit_small_tf_512_in1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + 'maxvit_base_tf_224.in1k': _cfg( + url='', + #file='maxvit_base_tf_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'maxvit_base_tf_384.in1k': _cfg( + url='', + #file='maxvit_base_tf_384_in1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + 'maxvit_base_tf_512.in1k': _cfg( + url='', + #file='maxvit_base_tf_512_in1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + 'maxvit_large_tf_224.in1k': _cfg( + url='', + #file='maxvit_large_tf_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'maxvit_large_tf_384.in1k': _cfg( + url='', + #file='maxvit_large_tf_384_in1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + 'maxvit_large_tf_512.in1k': _cfg( + url='', + #file='maxvit_large_tf_512_in1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + + 'maxvit_base_tf_224.in21k': _cfg( + url='', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'maxvit_base_tf_384.in21k_ft1k': _cfg( + url='', + #file='maxvit_base_tf_384_in21k_ft_in1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + 'maxvit_base_tf_512.in21k_ft1k': _cfg( + url='', + #file='maxvit_base_tf_512_in21k_ft_in1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + 'maxvit_large_tf_224.in21k': _cfg( + url=''), + 'maxvit_large_tf_384.in21k_ft1k': _cfg( + url='', + #file='maxvit_large_tf_384_in21k_ft_in1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + 'maxvit_large_tf_512.in21k_ft1k': _cfg( + url='', + #file='maxvit_large_tf_512_in21k_ft_in1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + 'maxvit_xlarge_tf_224.in21k': _cfg( + url=''), + 'maxvit_xlarge_tf_384.in21k_ft1k': _cfg( + url='', + #file='maxvit_xlarge_tf_384_in21k_ft_in1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + 'maxvit_xlarge_tf_512.in21k_ft1k': _cfg( + url='', + #file='maxvit_xlarge_tf_512_in21k_ft_in1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), +}) + + @register_model def coatnet_pico_rw_224(pretrained=False, **kwargs): return _create_maxxvit('coatnet_pico_rw_224', pretrained=pretrained, **kwargs) @@ -1773,6 +2075,11 @@ def coatnet_rmlp_1_rw_224(pretrained=False, **kwargs): return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs) +@register_model +def coatnet_rmlp_1_rw2_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_rmlp_1_rw2_224', pretrained=pretrained, **kwargs) + + @register_model def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs): return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs) @@ -1889,25 +2196,85 @@ def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs): @register_model -def maxvit_tiny_224(pretrained=False, **kwargs): - return _create_maxxvit('maxvit_tiny_224', pretrained=pretrained, **kwargs) +def maxxvit_rmlp_base_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('maxxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs) @register_model -def maxvit_small_224(pretrained=False, **kwargs): - return _create_maxxvit('maxvit_small_224', pretrained=pretrained, **kwargs) +def maxxvit_rmlp_large_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('maxxvit_rmlp_large_rw_224', pretrained=pretrained, **kwargs) @register_model -def maxvit_base_224(pretrained=False, **kwargs): - return _create_maxxvit('maxvit_base_224', pretrained=pretrained, **kwargs) +def maxvit_tiny_tf_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_tiny_tf_224', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_large_224(pretrained=False, **kwargs): - return _create_maxxvit('maxvit_large_224', pretrained=pretrained, **kwargs) +def maxvit_tiny_tf_384(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_tiny_tf_384', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_xlarge_224(pretrained=False, **kwargs): - return _create_maxxvit('maxvit_xlarge_224', pretrained=pretrained, **kwargs) \ No newline at end of file +def maxvit_tiny_tf_512(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_tiny_tf_512', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_small_tf_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_small_tf_224', 'maxvit_small_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_small_tf_384(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_small_tf_384', 'maxvit_small_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_small_tf_512(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_small_tf_512', 'maxvit_small_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_base_tf_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_base_tf_224', 'maxvit_base_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_base_tf_384(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_base_tf_384', 'maxvit_base_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_base_tf_512(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_base_tf_512', 'maxvit_base_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_large_tf_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_large_tf_224', 'maxvit_large_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_large_tf_384(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_large_tf_384', 'maxvit_large_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_large_tf_512(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_large_tf_512', 'maxvit_large_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_xlarge_tf_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_xlarge_tf_224', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_xlarge_tf_384(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_xlarge_tf_384', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_xlarge_tf_512(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_xlarge_tf_512', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs) \ No newline at end of file diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index f29216c9..cde0018b 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -32,7 +32,7 @@ import torch.utils.checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,\ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq +from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from ._pretrained import generate_defaults from .registry import register_model @@ -795,13 +795,15 @@ default_cfgs = generate_defaults({ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), 'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_336.laion2b_ft_in1k', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, input_size=(3, 336, 336)), + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg( hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg( hf_hub_id='', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336)), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k_in1k', @@ -823,13 +825,15 @@ default_cfgs = generate_defaults({ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), 'vit_large_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_336.laion2b_ft_in12k_in1k', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, input_size=(3, 336, 336)), + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336)), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_base_patch32_clip_224.laion2b_ft_in12k': _cfg( hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k', @@ -879,12 +883,16 @@ default_cfgs = generate_defaults({ 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/vit_large_patch14_clip_336.openai_ft_in12k_in1k', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_base_patch32_clip_224.openai_ft_in12k': _cfg( - #hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k', + hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_base_patch16_clip_224.openai_ft_in12k': _cfg( - #hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k', + hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_large_patch14_clip_224.openai_ft_in12k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k',