Merge pull request #637 from rwightman/levit_visformer_rednet

LeVit, Visformer, RedNet/Involution models and layers
This commit is contained in:
Ross Wightman 2021-05-25 14:27:06 -07:00 committed by GitHub
commit 07d952c7a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1511 additions and 245 deletions

View File

@ -17,8 +17,8 @@ jobs:
matrix: matrix:
os: [ubuntu-latest, macOS-latest] os: [ubuntu-latest, macOS-latest]
python: ['3.8'] python: ['3.8']
torch: ['1.8.0'] torch: ['1.8.1']
torchvision: ['0.9.0'] torchvision: ['0.9.1']
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:

View File

@ -23,6 +23,14 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New ## What's New
### May 25, 2021
* Add LeViT, Visformer, ConViT (PR by Aman Arora), Twins (PR by paper authors) transformer models
* Add ResMLP and gMLP MLP vision models to the existing MLP Mixer impl
* Fix a number of torchscript issues with various vision transformer models
* Cleanup input_size/img_size override handling and improve testing / test coverage for all vision transformer and MLP models
* More flexible pos embedding resize (non-square) for ViT and TnT. Thanks [Alexander Soare](https://github.com/alexander-soare)
* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params.
### May 14, 2021 ### May 14, 2021
* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl. * Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl.
* 1k trained variants: `tf_efficientnetv2_s/m/l` * 1k trained variants: `tf_efficientnetv2_s/m/l`
@ -166,30 +174,6 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
* Misc fixes for SiLU ONNX export, default_cfg missing from Feature extraction models, Linear layer w/ AMP + torchscript * Misc fixes for SiLU ONNX export, default_cfg missing from Feature extraction models, Linear layer w/ AMP + torchscript
* PyPi release @ 0.3.2 (needed by EfficientDet) * PyPi release @ 0.3.2 (needed by EfficientDet)
### Oct 30, 2020
* Test with PyTorch 1.7 and fix a small top-n metric view vs reshape issue.
* Convert newly added 224x224 Vision Transformer weights from official JAX repo. 81.8 top-1 for B/16, 83.1 L/16.
* Support PyTorch 1.7 optimized, native SiLU (aka Swish) activation. Add mapping to 'silu' name, custom swish will eventually be deprecated.
* Fix regression for loading pretrained classifier via direct model entrypoint functions. Didn't impact create_model() factory usage.
* PyPi release @ 0.3.0 version!
### Oct 26, 2020
* Update Vision Transformer models to be compatible with official code release at https://github.com/google-research/vision_transformer
* Add Vision Transformer weights (ImageNet-21k pretrain) for 384x384 base and large models converted from official jax impl
* ViT-B/16 - 84.2
* ViT-B/32 - 81.7
* ViT-L/16 - 85.2
* ViT-L/32 - 81.5
### Oct 21, 2020
* Weights added for Vision Transformer (ViT) models. 77.86 top-1 for 'small' and 79.35 for 'base'. Thanks to [Christof](https://www.kaggle.com/christofhenkel) for training the base model w/ lots of GPUs.
### Oct 13, 2020
* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train...
* Adafactor and AdaHessian (FP32 only, no AMP) optimizers
* EdgeTPU-M (`efficientnet_em`) model trained in PyTorch, 79.3 top-1
* Pip release, doc updates pending a few more changes...
## Introduction ## Introduction
@ -207,6 +191,7 @@ A full version of the list below with source links can be found in the [document
* Bottleneck Transformers - https://arxiv.org/abs/2101.11605 * Bottleneck Transformers - https://arxiv.org/abs/2101.11605
* CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239 * CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239
* CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399 * CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399
* ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697
* CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929 * CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
* DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877 * DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877
* DenseNet - https://arxiv.org/abs/1608.06993 * DenseNet - https://arxiv.org/abs/1608.06993
@ -224,6 +209,7 @@ A full version of the list below with source links can be found in the [document
* MobileNet-V2 - https://arxiv.org/abs/1801.04381 * MobileNet-V2 - https://arxiv.org/abs/1801.04381
* Single-Path NAS - https://arxiv.org/abs/1904.02877 * Single-Path NAS - https://arxiv.org/abs/1904.02877
* GhostNet - https://arxiv.org/abs/1911.11907 * GhostNet - https://arxiv.org/abs/1911.11907
* gMLP - https://arxiv.org/abs/2105.08050
* GPU-Efficient Networks - https://arxiv.org/abs/2006.14090 * GPU-Efficient Networks - https://arxiv.org/abs/2006.14090
* Halo Nets - https://arxiv.org/abs/2103.12731 * Halo Nets - https://arxiv.org/abs/2103.12731
* HardCoRe-NAS - https://arxiv.org/abs/2102.11646 * HardCoRe-NAS - https://arxiv.org/abs/2102.11646
@ -231,6 +217,7 @@ A full version of the list below with source links can be found in the [document
* Inception-V3 - https://arxiv.org/abs/1512.00567 * Inception-V3 - https://arxiv.org/abs/1512.00567
* Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261 * Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
* Lambda Networks - https://arxiv.org/abs/2102.08602 * Lambda Networks - https://arxiv.org/abs/2102.08602
* LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136
* MLP-Mixer - https://arxiv.org/abs/2105.01601 * MLP-Mixer - https://arxiv.org/abs/2105.01601
* MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244 * MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
* NASNet-A - https://arxiv.org/abs/1707.07012 * NASNet-A - https://arxiv.org/abs/1707.07012
@ -240,6 +227,7 @@ A full version of the list below with source links can be found in the [document
* Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302 * Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302
* RegNet - https://arxiv.org/abs/2003.13678 * RegNet - https://arxiv.org/abs/2003.13678
* RepVGG - https://arxiv.org/abs/2101.03697 * RepVGG - https://arxiv.org/abs/2101.03697
* ResMLP - https://arxiv.org/abs/2105.03404
* ResNet/ResNeXt * ResNet/ResNeXt
* ResNet (v1b/v1.5) - https://arxiv.org/abs/1512.03385 * ResNet (v1b/v1.5) - https://arxiv.org/abs/1512.03385
* ResNeXt - https://arxiv.org/abs/1611.05431 * ResNeXt - https://arxiv.org/abs/1611.05431
@ -257,6 +245,7 @@ A full version of the list below with source links can be found in the [document
* Swin Transformer - https://arxiv.org/abs/2103.14030 * Swin Transformer - https://arxiv.org/abs/2103.14030
* Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112 * Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112
* TResNet - https://arxiv.org/abs/2003.13630 * TResNet - https://arxiv.org/abs/2003.13630
* Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf
* Vision Transformer - https://arxiv.org/abs/2010.11929 * Vision Transformer - https://arxiv.org/abs/2010.11929
* VovNet V2 and V1 - https://arxiv.org/abs/1911.06667 * VovNet V2 and V1 - https://arxiv.org/abs/1911.06667
* Xception - https://arxiv.org/abs/1610.02357 * Xception - https://arxiv.org/abs/1610.02357

View File

@ -1,5 +1,29 @@
# Archived Changes # Archived Changes
### Oct 30, 2020
* Test with PyTorch 1.7 and fix a small top-n metric view vs reshape issue.
* Convert newly added 224x224 Vision Transformer weights from official JAX repo. 81.8 top-1 for B/16, 83.1 L/16.
* Support PyTorch 1.7 optimized, native SiLU (aka Swish) activation. Add mapping to 'silu' name, custom swish will eventually be deprecated.
* Fix regression for loading pretrained classifier via direct model entrypoint functions. Didn't impact create_model() factory usage.
* PyPi release @ 0.3.0 version!
### Oct 26, 2020
* Update Vision Transformer models to be compatible with official code release at https://github.com/google-research/vision_transformer
* Add Vision Transformer weights (ImageNet-21k pretrain) for 384x384 base and large models converted from official jax impl
* ViT-B/16 - 84.2
* ViT-B/32 - 81.7
* ViT-L/16 - 85.2
* ViT-L/32 - 81.5
### Oct 21, 2020
* Weights added for Vision Transformer (ViT) models. 77.86 top-1 for 'small' and 79.35 for 'base'. Thanks to [Christof](https://www.kaggle.com/christofhenkel) for training the base model w/ lots of GPUs.
### Oct 13, 2020
* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train...
* Adafactor and AdaHessian (FP32 only, no AMP) optimizers
* EdgeTPU-M (`efficientnet_em`) model trained in PyTorch, 79.3 top-1
* Pip release, doc updates pending a few more changes...
### Sept 18, 2020 ### Sept 18, 2020
* New ResNet 'D' weights. 72.7 (top-1) ResNet-18-D, 77.1 ResNet-34-D, 80.5 ResNet-50-D * New ResNet 'D' weights. 72.7 (top-1) ResNet-18-D, 77.1 ResNet-34-D, 80.5 ResNet-50-D
* Added a few untrained defs for other ResNet models (66D, 101D, 152D, 200/200D) * Added a few untrained defs for other ResNet models (66D, 101D, 152D, 200/200D)

View File

@ -1,5 +1,33 @@
# Recent Changes # Recent Changes
### May 25, 2021
* Add LeViT, Visformer, Convit (PR by Aman Arora), Twins (PR by paper authors) transformer models
* Cleanup input_size/img_size override handling and testing for all vision transformer models
* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params.
### May 14, 2021
* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl.
* 1k trained variants: `tf_efficientnetv2_s/m/l`
* 21k trained variants: `tf_efficientnetv2_s/m/l_in21k`
* 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_in21ft1k`
* v2 models w/ v1 scaling: `tf_efficientnetv2_b0` through `b3`
* Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s`
* Some blank `efficientnetv2_*` models in-place for future native PyTorch training
### May 5, 2021
* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen)
* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit)
* Add ResNet-RS models and weights from [TF](https://github.com/tensorflow/tpu/tree/master/models/official/resnet/resnet_rs). Thanks [Aman Arora](https://github.com/amaarora)
* Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin)
* Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23)
* Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai)
* Update ByoaNet attention modles
* Improve SA module inits
* Hack together experimental stand-alone Swin based attn module and `swinnet`
* Consistent '26t' model defs for experiments.
* Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1.
* WandB logging support
### April 13, 2021 ### April 13, 2021
* Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer * Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer

View File

@ -16,7 +16,8 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
# transformer models don't support many of the spatial / feature based model functionalities # transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = [ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*'] 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*']
NUM_NON_STD = len(NON_STD_FILTERS) NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures # exclude models that cause specific test failures
@ -25,29 +26,56 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
EXCLUDE_FILTERS = [ EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*',
'*resnetrs350*', '*resnetrs420*'] + NON_STD_FILTERS '*resnetrs350*', '*resnetrs420*']
else: else:
EXCLUDE_FILTERS = NON_STD_FILTERS EXCLUDE_FILTERS = []
MAX_FWD_SIZE = 384 TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
MAX_BWD_SIZE = 128 TARGET_BWD_SIZE = 128
MAX_FWD_FEAT_SIZE = 448 MAX_BWD_SIZE = 320
MAX_FWD_OUT_SIZE = 448
TARGET_JIT_SIZE = 128
MAX_JIT_SIZE = 320
TARGET_FFEAT_SIZE = 96
MAX_FFEAT_SIZE = 256
def _get_input_size(model=None, model_name='', target=None):
if model is None:
assert model_name, "One of model or model_name must be provided"
input_size = get_model_default_value(model_name, 'input_size')
fixed_input_size = get_model_default_value(model_name, 'fixed_input_size')
min_input_size = get_model_default_value(model_name, 'min_input_size')
else:
default_cfg = model.default_cfg
input_size = default_cfg['input_size']
fixed_input_size = default_cfg.get('fixed_input_size', None)
min_input_size = default_cfg.get('min_input_size', None)
assert input_size is not None
if fixed_input_size:
return input_size
if min_input_size:
if target and max(input_size) > target:
input_size = min_input_size
else:
if target and max(input_size) > target:
input_size = tuple([min(x, target) for x in input_size])
return input_size
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-NUM_NON_STD])) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_forward(model_name, batch_size): def test_model_forward(model_name, batch_size):
"""Run a single forward pass with each model""" """Run a single forward pass with each model"""
model = create_model(model_name, pretrained=False) model = create_model(model_name, pretrained=False)
model.eval() model.eval()
input_size = model.default_cfg['input_size'] input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
if any([x > MAX_FWD_SIZE for x in input_size]): if max(input_size) > MAX_FWD_SIZE:
if is_model_default_key(model_name, 'fixed_input_size'): pytest.skip("Fixed input size model > limit.")
pytest.skip("Fixed input size model > limit.")
# cap forward test at max res 384 * 384 to keep resource down
input_size = tuple([min(x, MAX_FWD_SIZE) for x in input_size])
inputs = torch.randn((batch_size, *input_size)) inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs) outputs = model(inputs)
@ -56,26 +84,22 @@ def test_model_forward(model_name, batch_size):
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2]) @pytest.mark.parametrize('batch_size', [2])
def test_model_backward(model_name, batch_size): def test_model_backward(model_name, batch_size):
"""Run a single forward pass with each model""" """Run a single forward pass with each model"""
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
if max(input_size) > MAX_BWD_SIZE:
pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42) model = create_model(model_name, pretrained=False, num_classes=42)
num_params = sum([x.numel() for x in model.parameters()]) num_params = sum([x.numel() for x in model.parameters()])
model.eval() model.train()
input_size = model.default_cfg['input_size']
if not is_model_default_key(model_name, 'fixed_input_size'):
min_input_size = get_model_default_value(model_name, 'min_input_size')
if min_input_size is not None:
input_size = min_input_size
else:
if any([x > MAX_BWD_SIZE for x in input_size]):
# cap backward test at 128 * 128 to keep resource usage down
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
inputs = torch.randn((batch_size, *input_size)) inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs) outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward() outputs.mean().backward()
for n, x in model.named_parameters(): for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}' assert x.grad is not None, f'No gradient for {n}'
@ -100,10 +124,10 @@ def test_model_default_cfgs(model_name, batch_size):
pool_size = cfg['pool_size'] pool_size = cfg['pool_size']
input_size = model.default_cfg['input_size'] input_size = model.default_cfg['input_size']
if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and \ if all([x <= MAX_FWD_OUT_SIZE for x in input_size]) and \
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]): not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
# output sizes only checked if default res <= 448 * 448 to keep resource down # output sizes only checked if default res <= 448 * 448 to keep resource down
input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size]) input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size])
input_tensor = torch.randn((batch_size, *input_size)) input_tensor = torch.randn((batch_size, *input_size))
# test forward_features (always unpooled) # test forward_features (always unpooled)
@ -154,26 +178,25 @@ if 'GITHUB_ACTIONS' not in os.environ:
EXCLUDE_JIT_FILTERS = [ EXCLUDE_JIT_FILTERS = [
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point 'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point
'vit_large_*', 'vit_huge_*',
] ]
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS)) @pytest.mark.parametrize(
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_forward_torchscript(model_name, batch_size): def test_model_forward_torchscript(model_name, batch_size):
"""Run a single forward pass with each model""" """Run a single forward pass with each model"""
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
with set_scriptable(True): with set_scriptable(True):
model = create_model(model_name, pretrained=False) model = create_model(model_name, pretrained=False)
model.eval() model.eval()
if has_model_default_key(model_name, 'fixed_input_size'):
input_size = get_model_default_value(model_name, 'input_size')
elif has_model_default_key(model_name, 'min_input_size'):
input_size = get_model_default_value(model_name, 'min_input_size')
else:
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
model = torch.jit.script(model) model = torch.jit.script(model)
outputs = model(torch.randn((batch_size, *input_size))) outputs = model(torch.randn((batch_size, *input_size)))
@ -183,7 +206,7 @@ def test_model_forward_torchscript(model_name, batch_size):
EXCLUDE_FEAT_FILTERS = [ EXCLUDE_FEAT_FILTERS = [
'*pruned*', # hopefully fix at some point '*pruned*', # hopefully fix at some point
] ] + NON_STD_FILTERS
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d'] EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d']
@ -199,12 +222,9 @@ def test_model_forward_features(model_name, batch_size):
expected_channels = model.feature_info.channels() expected_channels = model.feature_info.channels()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
if has_model_default_key(model_name, 'fixed_input_size'): input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
input_size = get_model_default_value(model_name, 'input_size') if max(input_size) > MAX_FFEAT_SIZE:
elif has_model_default_key(model_name, 'min_input_size'): pytest.skip("Fixed input size model > limit.")
input_size = get_model_default_value(model_name, 'min_input_size')
else:
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
outputs = model(torch.randn((batch_size, *input_size))) outputs = model(torch.randn((batch_size, *input_size)))
assert len(expected_channels) == len(outputs) assert len(expected_channels) == len(outputs)

View File

@ -25,8 +25,8 @@ from .parser import Parser
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue
PREFETCH_SIZE = 4096 # samples to prefetch PREFETCH_SIZE = 2048 # samples to prefetch
def even_split_indices(split, n, num_samples): def even_split_indices(split, n, num_samples):
@ -144,14 +144,16 @@ class ParserTfds(Parser):
ds = self.builder.as_dataset( ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config)
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) options = tf.data.Options()
ds.options().experimental_threading.max_intra_op_parallelism = 1 options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
options.experimental_threading.max_intra_op_parallelism = 1
ds = ds.with_options(options)
if self.is_training or self.repeats > 1: if self.is_training or self.repeats > 1:
# to prevent excessive drop_last batch behaviour w/ IterableDatasets # to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually ds = ds.repeat() # allow wrap around and break iteration manually
if self.shuffle: if self.shuffle:
ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0) ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0)
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
self.ds = tfds.as_numpy(ds) self.ds = tfds.as_numpy(ds)

View File

@ -16,6 +16,8 @@ from .hrnet import *
from .inception_resnet_v2 import * from .inception_resnet_v2 import *
from .inception_v3 import * from .inception_v3 import *
from .inception_v4 import * from .inception_v4 import *
from .levit import *
#from .levit import *
from .mlp_mixer import * from .mlp_mixer import *
from .mobilenetv3 import * from .mobilenetv3 import *
from .nasnet import * from .nasnet import *
@ -35,6 +37,7 @@ from .swin_transformer import *
from .tnt import * from .tnt import *
from .tresnet import * from .tresnet import *
from .vgg import * from .vgg import *
from .visformer import *
from .vision_transformer import * from .vision_transformer import *
from .vision_transformer_hybrid import * from .vision_transformer_hybrid import *
from .vovnet import * from .vovnet import *

View File

@ -47,17 +47,24 @@ default_cfgs = {
# GPU-Efficient (ResNet) weights # GPU-Efficient (ResNet) weights
'botnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'botnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'botnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'botnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'eca_botnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet26t': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet26t': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'eca_halonext26ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), 'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)), 'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)),
'eca_lambda_resnext26ts': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
'swinnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'swinnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'swinnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'swinnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'eca_swinnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'rednet26t': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)),
'rednet50ts': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)),
} }
@ -126,6 +133,23 @@ model_cfgs = dict(
self_attn_fixed_size=True, self_attn_fixed_size=True,
self_attn_kwargs=dict() self_attn_kwargs=dict()
), ),
eca_botnext26ts=ByoaCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
act_layer='silu',
attn_layer='eca',
self_attn_layer='bottleneck',
self_attn_fixed_size=True,
self_attn_kwargs=dict()
),
halonet_h1=ByoaCfg( halonet_h1=ByoaCfg(
blocks=( blocks=(
@ -184,6 +208,22 @@ model_cfgs = dict(
self_attn_layer='halo', self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=2) self_attn_kwargs=dict(block_size=8, halo_size=2)
), ),
eca_halonext26ts=ByoaCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
act_layer='silu',
attn_layer='eca',
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
),
lambda_resnet26t=ByoaCfg( lambda_resnet26t=ByoaCfg(
blocks=( blocks=(
@ -213,6 +253,22 @@ model_cfgs = dict(
self_attn_layer='lambda', self_attn_layer='lambda',
self_attn_kwargs=dict() self_attn_kwargs=dict()
), ),
eca_lambda_resnext26ts=ByoaCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
act_layer='silu',
attn_layer='eca',
self_attn_layer='lambda',
self_attn_kwargs=dict()
),
swinnet26t=ByoaCfg( swinnet26t=ByoaCfg(
blocks=( blocks=(
@ -245,6 +301,56 @@ model_cfgs = dict(
self_attn_fixed_size=True, self_attn_fixed_size=True,
self_attn_kwargs=dict(win_size=8) self_attn_kwargs=dict(win_size=8)
), ),
eca_swinnext26ts=ByoaCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
act_layer='silu',
attn_layer='eca',
self_attn_layer='swin',
self_attn_fixed_size=True,
self_attn_kwargs=dict(win_size=8)
),
rednet26t=ByoaCfg(
blocks=(
ByoaBlocksCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered', # FIXME RedNet uses involution in middle of stem
stem_pool='maxpool',
num_features=0,
self_attn_layer='involution',
self_attn_fixed_size=False,
self_attn_kwargs=dict()
),
rednet50ts=ByoaCfg(
blocks=(
ByoaBlocksCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
act_layer='silu',
self_attn_layer='involution',
self_attn_fixed_size=False,
self_attn_kwargs=dict()
),
) )
@ -419,6 +525,14 @@ def botnet50ts_256(pretrained=False, **kwargs):
return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs) return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs)
@register_model
def eca_botnext26ts_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage.
"""
kwargs.setdefault('img_size', 256)
return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs)
@register_model @register_model
def halonet_h1(pretrained=False, **kwargs): def halonet_h1(pretrained=False, **kwargs):
""" HaloNet-H1. Halo attention in all stages as per the paper. """ HaloNet-H1. Halo attention in all stages as per the paper.
@ -449,6 +563,13 @@ def halonet50ts(pretrained=False, **kwargs):
return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs) return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs)
@register_model
def eca_halonext26ts(pretrained=False, **kwargs):
""" HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage
"""
return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs)
@register_model @register_model
def lambda_resnet26t(pretrained=False, **kwargs): def lambda_resnet26t(pretrained=False, **kwargs):
""" Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5. """ Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5.
@ -463,6 +584,13 @@ def lambda_resnet50t(pretrained=False, **kwargs):
return _create_byoanet('lambda_resnet50t', pretrained=pretrained, **kwargs) return _create_byoanet('lambda_resnet50t', pretrained=pretrained, **kwargs)
@register_model
def eca_lambda_resnext26ts(pretrained=False, **kwargs):
""" Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5.
"""
return _create_byoanet('eca_lambda_resnext26ts', pretrained=pretrained, **kwargs)
@register_model @register_model
def swinnet26t_256(pretrained=False, **kwargs): def swinnet26t_256(pretrained=False, **kwargs):
""" """
@ -477,3 +605,25 @@ def swinnet50ts_256(pretrained=False, **kwargs):
""" """
kwargs.setdefault('img_size', 256) kwargs.setdefault('img_size', 256)
return _create_byoanet('swinnet50ts_256', 'swinnet50ts', pretrained=pretrained, **kwargs) return _create_byoanet('swinnet50ts_256', 'swinnet50ts', pretrained=pretrained, **kwargs)
@register_model
def eca_swinnext26ts_256(pretrained=False, **kwargs):
"""
"""
kwargs.setdefault('img_size', 256)
return _create_byoanet('eca_swinnext26ts_256', 'eca_swinnext26ts', pretrained=pretrained, **kwargs)
@register_model
def rednet26t(pretrained=False, **kwargs):
"""
"""
return _create_byoanet('rednet26t', pretrained=pretrained, **kwargs)
@register_model
def rednet50ts(pretrained=False, **kwargs):
"""
"""
return _create_byoanet('rednet50ts', pretrained=pretrained, **kwargs)

View File

@ -98,7 +98,7 @@ class BlocksCfg:
s: int = 2 # stride of stage (first block) s: int = 2 # stride of stage (first block)
gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1 gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1
br: float = 1. # bottleneck-ratio of blocks in stage br: float = 1. # bottleneck-ratio of blocks in stage
no_attn: bool = True # disable channel attn (ie SE) when layer is set for model no_attn: bool = False # disable channel attn (ie SE) when layer is set for model
@dataclass @dataclass

View File

@ -306,26 +306,15 @@ def checkpoint_filter_fn(state_dict, model=None):
return checkpoint_no_module return checkpoint_no_module
def _create_cait(variant, pretrained=False, default_cfg=None, **kwargs): def _create_cait(variant, pretrained=False, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg( model = build_model_with_cfg(
Cait, variant, pretrained, Cait, variant, pretrained,
default_cfg=default_cfg, default_cfg=default_cfgs[variant],
img_size=img_size,
num_classes=num_classes,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)
return model return model

View File

@ -7,19 +7,19 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
Modified from timm/models/vision_transformer.py Modified from timm/models/vision_transformer.py
""" """
from typing import Tuple, Dict, Any, Optional from copy import deepcopy
from functools import partial
from typing import Tuple, List
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained from .helpers import build_model_with_cfg, overlay_external_default_cfg
from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model from .registry import register_model
from functools import partial
from torch import nn
__all__ = [ __all__ = [
"coat_tiny", "coat_tiny",
@ -34,7 +34,7 @@ def _cfg_coat(url='', **kwargs):
return { return {
'url': url, 'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed1.proj', 'classifier': 'head', 'first_conv': 'patch_embed1.proj', 'classifier': 'head',
**kwargs **kwargs
@ -42,15 +42,21 @@ def _cfg_coat(url='', **kwargs):
default_cfgs = { default_cfgs = {
'coat_tiny': _cfg_coat(), 'coat_tiny': _cfg_coat(
'coat_mini': _cfg_coat(), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_tiny-473c2a20.pth'
),
'coat_mini': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_mini-2c6baf49.pth'
),
'coat_lite_tiny': _cfg_coat( 'coat_lite_tiny': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth' url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth'
), ),
'coat_lite_mini': _cfg_coat( 'coat_lite_mini': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth' url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth'
), ),
'coat_lite_small': _cfg_coat(), 'coat_lite_small': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_small-fea1d5a1.pth'
),
} }
@ -120,11 +126,11 @@ class ConvRelPosEnc(nn.Module):
class FactorAtt_ConvRelPosEnc(nn.Module): class FactorAtt_ConvRelPosEnc(nn.Module):
""" Factorized attention with convolutional relative position encoding class. """ """ Factorized attention with convolutional relative position encoding class. """
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., shared_crpe=None): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., shared_crpe=None):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5 self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
@ -190,9 +196,8 @@ class ConvPosEnc(nn.Module):
class SerialBlock(nn.Module): class SerialBlock(nn.Module):
""" Serial block class. """ Serial block class.
Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """ Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpe=None, shared_crpe=None):
shared_cpe=None, shared_crpe=None):
super().__init__() super().__init__()
# Conv-Attention. # Conv-Attention.
@ -200,8 +205,7 @@ class SerialBlock(nn.Module):
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.factoratt_crpe = FactorAtt_ConvRelPosEnc( self.factoratt_crpe = FactorAtt_ConvRelPosEnc(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe)
shared_crpe=shared_crpe)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# MLP. # MLP.
@ -226,27 +230,24 @@ class SerialBlock(nn.Module):
class ParallelBlock(nn.Module): class ParallelBlock(nn.Module):
""" Parallel block class. """ """ Parallel block class. """
def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_crpes=None):
shared_cpes=None, shared_crpes=None):
super().__init__() super().__init__()
# Conv-Attention. # Conv-Attention.
self.cpes = shared_cpes
self.norm12 = norm_layer(dims[1]) self.norm12 = norm_layer(dims[1])
self.norm13 = norm_layer(dims[2]) self.norm13 = norm_layer(dims[2])
self.norm14 = norm_layer(dims[3]) self.norm14 = norm_layer(dims[3])
self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc( self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc(
dims[1], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, dims[1], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[1] shared_crpe=shared_crpes[1]
) )
self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc( self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc(
dims[2], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, dims[2], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[2] shared_crpe=shared_crpes[2]
) )
self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc( self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc(
dims[3], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, dims[3], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[3] shared_crpe=shared_crpes[3]
) )
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -262,15 +263,15 @@ class ParallelBlock(nn.Module):
self.mlp2 = self.mlp3 = self.mlp4 = Mlp( self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def upsample(self, x, factor, size): def upsample(self, x, factor: float, size: Tuple[int, int]):
""" Feature map up-sampling. """ """ Feature map up-sampling. """
return self.interpolate(x, scale_factor=factor, size=size) return self.interpolate(x, scale_factor=factor, size=size)
def downsample(self, x, factor, size): def downsample(self, x, factor: float, size: Tuple[int, int]):
""" Feature map down-sampling. """ """ Feature map down-sampling. """
return self.interpolate(x, scale_factor=1.0/factor, size=size) return self.interpolate(x, scale_factor=1.0/factor, size=size)
def interpolate(self, x, scale_factor, size): def interpolate(self, x, scale_factor: float, size: Tuple[int, int]):
""" Feature map interpolation. """ """ Feature map interpolation. """
B, N, C = x.shape B, N, C = x.shape
H, W = size H, W = size
@ -280,33 +281,28 @@ class ParallelBlock(nn.Module):
img_tokens = x[:, 1:, :] img_tokens = x[:, 1:, :]
img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W) img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
img_tokens = F.interpolate(img_tokens, scale_factor=scale_factor, mode='bilinear') img_tokens = F.interpolate(
img_tokens, scale_factor=scale_factor, recompute_scale_factor=False, mode='bilinear', align_corners=False)
img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2) img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2)
out = torch.cat((cls_token, img_tokens), dim=1) out = torch.cat((cls_token, img_tokens), dim=1)
return out return out
def forward(self, x1, x2, x3, x4, sizes): def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]):
_, (H2, W2), (H3, W3), (H4, W4) = sizes _, S2, S3, S4 = sizes
# Conv-Attention.
x2 = self.cpes[1](x2, size=(H2, W2)) # Note: x1 is ignored.
x3 = self.cpes[2](x3, size=(H3, W3))
x4 = self.cpes[3](x4, size=(H4, W4))
cur2 = self.norm12(x2) cur2 = self.norm12(x2)
cur3 = self.norm13(x3) cur3 = self.norm13(x3)
cur4 = self.norm14(x4) cur4 = self.norm14(x4)
cur2 = self.factoratt_crpe2(cur2, size=(H2, W2)) cur2 = self.factoratt_crpe2(cur2, size=S2)
cur3 = self.factoratt_crpe3(cur3, size=(H3, W3)) cur3 = self.factoratt_crpe3(cur3, size=S3)
cur4 = self.factoratt_crpe4(cur4, size=(H4, W4)) cur4 = self.factoratt_crpe4(cur4, size=S4)
upsample3_2 = self.upsample(cur3, factor=2, size=(H3, W3)) upsample3_2 = self.upsample(cur3, factor=2., size=S3)
upsample4_3 = self.upsample(cur4, factor=2, size=(H4, W4)) upsample4_3 = self.upsample(cur4, factor=2., size=S4)
upsample4_2 = self.upsample(cur4, factor=4, size=(H4, W4)) upsample4_2 = self.upsample(cur4, factor=4., size=S4)
downsample2_3 = self.downsample(cur2, factor=2, size=(H2, W2)) downsample2_3 = self.downsample(cur2, factor=2., size=S2)
downsample3_4 = self.downsample(cur3, factor=2, size=(H3, W3)) downsample3_4 = self.downsample(cur3, factor=2., size=S3)
downsample2_4 = self.downsample(cur2, factor=4, size=(H2, W2)) downsample2_4 = self.downsample(cur2, factor=4., size=S2)
cur2 = cur2 + upsample3_2 + upsample4_2 cur2 = cur2 + upsample3_2 + upsample4_2
cur3 = cur3 + upsample4_3 + downsample2_3 cur3 = cur3 + upsample4_3 + downsample2_3
cur4 = cur4 + downsample3_4 + downsample2_4 cur4 = cur4 + downsample3_4 + downsample2_4
@ -330,11 +326,11 @@ class ParallelBlock(nn.Module):
class CoaT(nn.Module): class CoaT(nn.Module):
""" CoaT class. """ """ CoaT class. """
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[0, 0, 0, 0], def __init__(
serial_depths=[0, 0, 0, 0], parallel_depth=0, self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0),
num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., serial_depths=(0, 0, 0, 0), parallel_depth=0, num_heads=0, mlp_ratios=(0, 0, 0, 0), qkv_bias=True,
drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
return_interm_layers=False, out_features = None, crpe_window=None, **kwargs): return_interm_layers=False, out_features=None, crpe_window=None, **kwargs):
super().__init__() super().__init__()
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3} crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
self.return_interm_layers = return_interm_layers self.return_interm_layers = return_interm_layers
@ -342,17 +338,18 @@ class CoaT(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
# Patch embeddings. # Patch embeddings.
img_size = to_2tuple(img_size)
self.patch_embed1 = PatchEmbed( self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dims[0], norm_layer=nn.LayerNorm) embed_dim=embed_dims[0], norm_layer=nn.LayerNorm)
self.patch_embed2 = PatchEmbed( self.patch_embed2 = PatchEmbed(
img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1], norm_layer=nn.LayerNorm) embed_dim=embed_dims[1], norm_layer=nn.LayerNorm)
self.patch_embed3 = PatchEmbed( self.patch_embed3 = PatchEmbed(
img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2], norm_layer=nn.LayerNorm) embed_dim=embed_dims[2], norm_layer=nn.LayerNorm)
self.patch_embed4 = PatchEmbed( self.patch_embed4 = PatchEmbed(
img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2],
embed_dim=embed_dims[3], norm_layer=nn.LayerNorm) embed_dim=embed_dims[3], norm_layer=nn.LayerNorm)
# Class tokens. # Class tokens.
@ -380,7 +377,7 @@ class CoaT(nn.Module):
# Serial blocks 1. # Serial blocks 1.
self.serial_blocks1 = nn.ModuleList([ self.serial_blocks1 = nn.ModuleList([
SerialBlock( SerialBlock(
dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe1, shared_crpe=self.crpe1 shared_cpe=self.cpe1, shared_crpe=self.crpe1
) )
@ -390,7 +387,7 @@ class CoaT(nn.Module):
# Serial blocks 2. # Serial blocks 2.
self.serial_blocks2 = nn.ModuleList([ self.serial_blocks2 = nn.ModuleList([
SerialBlock( SerialBlock(
dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe2, shared_crpe=self.crpe2 shared_cpe=self.cpe2, shared_crpe=self.crpe2
) )
@ -400,7 +397,7 @@ class CoaT(nn.Module):
# Serial blocks 3. # Serial blocks 3.
self.serial_blocks3 = nn.ModuleList([ self.serial_blocks3 = nn.ModuleList([
SerialBlock( SerialBlock(
dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe3, shared_crpe=self.crpe3 shared_cpe=self.cpe3, shared_crpe=self.crpe3
) )
@ -410,7 +407,7 @@ class CoaT(nn.Module):
# Serial blocks 4. # Serial blocks 4.
self.serial_blocks4 = nn.ModuleList([ self.serial_blocks4 = nn.ModuleList([
SerialBlock( SerialBlock(
dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe4, shared_crpe=self.crpe4 shared_cpe=self.cpe4, shared_crpe=self.crpe4
) )
@ -422,10 +419,9 @@ class CoaT(nn.Module):
if self.parallel_depth > 0: if self.parallel_depth > 0:
self.parallel_blocks = nn.ModuleList([ self.parallel_blocks = nn.ModuleList([
ParallelBlock( ParallelBlock(
dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale, dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpes=[self.cpe1, self.cpe2, self.cpe3, self.cpe4], shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4)
shared_crpes=[self.crpe1, self.crpe2, self.crpe3, self.crpe4]
) )
for _ in range(parallel_depth)] for _ in range(parallel_depth)]
) )
@ -434,9 +430,11 @@ class CoaT(nn.Module):
# Classification head(s). # Classification head(s).
if not self.return_interm_layers: if not self.return_interm_layers:
self.norm1 = norm_layer(embed_dims[0]) if self.parallel_blocks is not None:
self.norm2 = norm_layer(embed_dims[1]) self.norm2 = norm_layer(embed_dims[1])
self.norm3 = norm_layer(embed_dims[2]) self.norm3 = norm_layer(embed_dims[2])
else:
self.norm2 = self.norm3 = None
self.norm4 = norm_layer(embed_dims[3]) self.norm4 = norm_layer(embed_dims[3])
if self.parallel_depth > 0: if self.parallel_depth > 0:
@ -546,6 +544,7 @@ class CoaT(nn.Module):
# Parallel blocks. # Parallel blocks.
for blk in self.parallel_blocks: for blk in self.parallel_blocks:
x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4))
x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)]) x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)])
if not torch.jit.is_scripting() and self.return_interm_layers: if not torch.jit.is_scripting() and self.return_interm_layers:
@ -590,52 +589,70 @@ class CoaT(nn.Module):
return x return x
def checkpoint_filter_fn(state_dict, model):
out_dict = {}
for k, v in state_dict.items():
# original model had unused norm layers, removing them requires filtering pretrained checkpoints
if k.startswith('norm1') or \
(model.norm2 is None and k.startswith('norm2')) or \
(model.norm3 is None and k.startswith('norm3')):
continue
out_dict[k] = v
return out_dict
def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
CoaT, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
@register_model @register_model
def coat_tiny(pretrained=False, **kwargs): def coat_tiny(pretrained=False, **kwargs):
model = CoaT( model_cfg = dict(
patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6,
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_tiny'] model = _create_coat('coat_tiny', pretrained=pretrained, **model_cfg)
return model return model
@register_model @register_model
def coat_mini(pretrained=False, **kwargs): def coat_mini(pretrained=False, **kwargs):
model = CoaT( model_cfg = dict(
patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6,
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_mini'] model = _create_coat('coat_mini', pretrained=pretrained, **model_cfg)
return model return model
@register_model @register_model
def coat_lite_tiny(pretrained=False, **kwargs): def coat_lite_tiny(pretrained=False, **kwargs):
model = CoaT( model_cfg = dict(
patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
# FIXME use builder model = _create_coat('coat_lite_tiny', pretrained=pretrained, **model_cfg)
model.default_cfg = default_cfgs['coat_lite_tiny']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model return model
@register_model @register_model
def coat_lite_mini(pretrained=False, **kwargs): def coat_lite_mini(pretrained=False, **kwargs):
model = CoaT( model_cfg = dict(
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
# FIXME use builder model = _create_coat('coat_lite_mini', pretrained=pretrained, **model_cfg)
model.default_cfg = default_cfgs['coat_lite_mini']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model return model
@register_model @register_model
def coat_lite_small(pretrained=False, **kwargs): def coat_lite_small(pretrained=False, **kwargs):
model = CoaT( model_cfg = dict(
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_lite_small'] model = _create_coat('coat_lite_small', pretrained=pretrained, **model_cfg)
return model return model

View File

@ -39,7 +39,7 @@ def _cfg(url='', **kwargs):
return { return {
'url': url, 'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs **kwargs
} }
@ -317,6 +317,9 @@ class ConViT(nn.Module):
def _create_convit(variant, pretrained=False, **kwargs): def _create_convit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
return build_model_with_cfg( return build_model_with_cfg(
ConViT, variant, pretrained, ConViT, variant, pretrained,
default_cfg=default_cfgs[variant], default_cfg=default_cfgs[variant],

View File

@ -44,7 +44,7 @@ def load_state_dict(checkpoint_path, use_ema=False):
raise FileNotFoundError() raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): def load_checkpoint(model, checkpoint_path, use_ema=False, strict=False):
state_dict = load_state_dict(checkpoint_path, use_ema) state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict, strict=strict) model.load_state_dict(state_dict, strict=strict)
@ -378,7 +378,11 @@ def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs # Overlay default cfg values from `external_default_cfg` if it exists in kwargs
overlay_external_default_cfg(default_cfg, kwargs) overlay_external_default_cfg(default_cfg, kwargs)
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg) default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
if default_cfg.get('fixed_input_size', False):
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
default_kwarg_names += ('img_size',)
set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg)
# Filter keyword args for task specific model variants (some 'features only' models, etc.) # Filter keyword args for task specific model variants (some 'features only' models, etc.)
filter_kwargs(kwargs, names=kwargs_filter) filter_kwargs(kwargs, names=kwargs_filter)

View File

@ -18,6 +18,7 @@ from .eca import EcaModule, CecaModule
from .evo_norm import EvoNormBatch2d, EvoNormSample2d from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
from .inplace_abn import InplaceAbn from .inplace_abn import InplaceAbn
from .involution import Involution
from .linear import Linear from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp from .mlp import Mlp, GluMlp, GatedMlp

View File

@ -1,5 +1,6 @@
from .bottleneck_attn import BottleneckAttn from .bottleneck_attn import BottleneckAttn
from .halo_attn import HaloAttn from .halo_attn import HaloAttn
from .involution import Involution
from .lambda_layer import LambdaLayer from .lambda_layer import LambdaLayer
from .swin_attn import WindowAttention from .swin_attn import WindowAttention
@ -13,6 +14,8 @@ def get_self_attn(attn_type):
return LambdaLayer return LambdaLayer
elif attn_type == 'swin': elif attn_type == 'swin':
return WindowAttention return WindowAttention
elif attn_type == 'involution':
return Involution
else: else:
assert False, f"Unknown attn type ({attn_type})" assert False, f"Unknown attn type ({attn_type})"

View File

@ -0,0 +1,50 @@
""" PyTorch Involution Layer
Official impl: https://github.com/d-li14/involution/blob/main/cls/mmcls/models/utils/involution_naive.py
Paper: `Involution: Inverting the Inherence of Convolution for Visual Recognition` - https://arxiv.org/abs/2103.06255
"""
import torch.nn as nn
from .conv_bn_act import ConvBnAct
from .create_conv2d import create_conv2d
class Involution(nn.Module):
def __init__(
self,
channels,
kernel_size=3,
stride=1,
group_size=16,
reduction_ratio=4,
norm_layer=nn.BatchNorm2d,
act_layer=nn.ReLU,
):
super(Involution, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.channels = channels
self.group_size = group_size
self.groups = self.channels // self.group_size
self.conv1 = ConvBnAct(
in_channels=channels,
out_channels=channels // reduction_ratio,
kernel_size=1,
norm_layer=norm_layer,
act_layer=act_layer)
self.conv2 = self.conv = create_conv2d(
in_channels=channels // reduction_ratio,
out_channels=kernel_size**2 * self.groups,
kernel_size=1,
stride=1)
self.avgpool = nn.AvgPool2d(stride, stride) if stride == 2 else nn.Identity()
self.unfold = nn.Unfold(kernel_size, 1, (kernel_size-1)//2, stride)
def forward(self, x):
weight = self.conv2(self.conv1(self.avgpool(x)))
B, C, H, W = weight.shape
KK = int(self.kernel_size ** 2)
weight = weight.view(B, self.groups, KK, H, W).unsqueeze(2)
out = self.unfold(x).view(B, self.groups, self.group_size, KK, H, W)
out = (weight * out).sum(dim=3).view(B, self.channels, H, W)
return out

View File

@ -15,7 +15,7 @@ from .helpers import to_2tuple
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding """ 2D Image to Patch Embedding
""" """
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
super().__init__() super().__init__()
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size) patch_size = to_2tuple(patch_size)
@ -23,6 +23,7 @@ class PatchEmbed(nn.Module):
self.patch_size = patch_size self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
@ -31,6 +32,8 @@ class PatchEmbed(nn.Module):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \ assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x) x = self.norm(x)
return x return x

568
timm/models/levit.py Normal file
View File

@ -0,0 +1,568 @@
""" LeViT
Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference`
- https://arxiv.org/abs/2104.01136
@article{graham2021levit,
title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze},
journal={arXiv preprint arXiv:22104.01136},
year={2021}
}
Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow.
This version combines both conv/linear models and fixes torchscript compatibility.
Modifications by/coyright Copyright 2021 Ross Wightman
"""
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
# Modified from
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Copyright 2020 Ross Wightman, Apache-2.0 License
import itertools
from copy import deepcopy
from functools import partial
from typing import Dict
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import to_ntuple
from .vision_transformer import trunc_normal_
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'),
**kwargs
}
default_cfgs = dict(
levit_128s=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'
),
levit_128=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'
),
levit_192=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'
),
levit_256=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'
),
levit_384=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'
),
)
model_cfgs = dict(
levit_128s=dict(
embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)),
levit_128=dict(
embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)),
levit_192=dict(
embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)),
levit_256=dict(
embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)),
levit_384=dict(
embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)),
)
__all__ = ['Levit']
@register_model
def levit_128s(pretrained=False, fuse=False,distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_128(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_192(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_256(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_384(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_128s(pretrained=False, fuse=False, distillation=True, use_conv=True,**kwargs):
return create_levit(
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_128(pretrained=False, fuse=False,distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_192(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_256(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_384(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
class ConvNorm(nn.Sequential):
def __init__(
self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000):
super().__init__()
self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = nn.BatchNorm2d(b)
nn.init.constant_(bn.weight, bn_weight_init)
nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = nn.Conv2d(
w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class LinearNorm(nn.Sequential):
def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
super().__init__()
self.add_module('c', nn.Linear(a, b, bias=False))
bn = nn.BatchNorm1d(b)
nn.init.constant_(bn.weight, bn_weight_init)
nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
def fuse(self):
l, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = l.weight * w[:, None]
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
def forward(self, x):
x = self.c(x)
return self.bn(x.flatten(0, 1)).reshape_as(x)
class NormLinear(nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
self.add_module('bn', nn.BatchNorm1d(a))
l = nn.Linear(a, b, bias=bias)
trunc_normal_(l.weight, std=std)
if bias:
nn.init.constant_(l.bias, 0)
self.add_module('l', l)
@torch.no_grad()
def fuse(self):
bn, l = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
w = l.weight * w[None, :]
if l.bias is None:
b = b @ self.l.weight.T
else:
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
m = nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
def stem_b16(in_chs, out_chs, activation, resolution=224):
return nn.Sequential(
ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution),
activation(),
ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2),
activation(),
ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4),
activation(),
ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8))
class Residual(nn.Module):
def __init__(self, m, drop):
super().__init__()
self.m = m
self.drop = drop
def forward(self, x):
if self.training and self.drop > 0:
return x + self.m(x) * torch.rand(
x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach()
else:
return x + self.m(x)
class Subsample(nn.Module):
def __init__(self, stride, resolution):
super().__init__()
self.stride = stride
self.resolution = resolution
def forward(self, x):
B, N, C = x.shape
x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride]
return x.reshape(B, -1, C)
class Attention(nn.Module):
ab: Dict[str, torch.Tensor]
def __init__(
self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
self.use_conv = use_conv
ln_layer = ConvNorm if self.use_conv else LinearNorm
h = self.dh + nh_kd * 2
self.qkv = ln_layer(dim, h, resolution=resolution)
self.proj = nn.Sequential(
act_layer(),
ln_layer(self.dh, dim, bn_weight_init=0, resolution=resolution))
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
self.ab = {}
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.ab:
self.ab = {} # clear ab cache
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
if self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
device_key = str(device)
if device_key not in self.ab:
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.ab[device_key]
def forward(self, x): # x (B,C,H,W)
if self.use_conv:
B, C, H, W = x.shape
q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2)
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
else:
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x
class AttentionSubsample(nn.Module):
ab: Dict[str, torch.Tensor]
def __init__(
self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = self.d * self.num_heads
self.attn_ratio = attn_ratio
self.resolution_ = resolution_
self.resolution_2 = resolution_ ** 2
self.use_conv = use_conv
if self.use_conv:
ln_layer = ConvNorm
sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0)
else:
ln_layer = LinearNorm
sub_layer = partial(Subsample, resolution=resolution)
h = self.dh + nh_kd
self.kv = ln_layer(in_dim, h, resolution=resolution)
self.q = nn.Sequential(
sub_layer(stride=stride),
ln_layer(in_dim, nh_kd, resolution=resolution_))
self.proj = nn.Sequential(
act_layer(),
ln_layer(self.dh, out_dim, resolution=resolution_))
self.stride = stride
self.resolution = resolution
points = list(itertools.product(range(resolution), range(resolution)))
points_ = list(itertools.product(range(resolution_), range(resolution_)))
N = len(points)
N_ = len(points_)
attention_offsets = {}
idxs = []
for p1 in points_:
for p2 in points:
size = 1
offset = (
abs(p1[0] * stride - p2[0] + (size - 1) / 2),
abs(p1[1] * stride - p2[1] + (size - 1) / 2))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
self.ab = {} # per-device attention_biases cache
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.ab:
self.ab = {} # clear ab cache
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
if self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
device_key = str(device)
if device_key not in self.ab:
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.ab[device_key]
def forward(self, x):
if self.use_conv:
B, C, H, W = x.shape
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2)
q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2)
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_)
else:
B, N, C = x.shape
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3)
k = k.permute(0, 2, 1, 3) # BHNC
v = v.permute(0, 2, 1, 3) # BHNC
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
x = self.proj(x)
return x
class Levit(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=(192,),
key_dim=64,
depth=(12,),
num_heads=(3,),
attn_ratio=2,
mlp_ratio=2,
hybrid_backbone=None,
down_ops=None,
act_layer=nn.Hardswish,
attn_act_layer=nn.Hardswish,
distillation=True,
use_conv=False,
drop_path=0):
super().__init__()
if isinstance(img_size, tuple):
# FIXME origin impl passes single img/res dim through whole hierarchy,
# not sure this model will be used enough to spend time fixing it.
assert img_size[0] == img_size[1]
img_size = img_size[0]
self.num_classes = num_classes
self.num_features = embed_dim[-1]
self.embed_dim = embed_dim
N = len(embed_dim)
assert len(depth) == len(num_heads) == N
key_dim = to_ntuple(N)(key_dim)
attn_ratio = to_ntuple(N)(attn_ratio)
mlp_ratio = to_ntuple(N)(mlp_ratio)
down_ops = down_ops or (
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2),
('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2),
('',)
)
self.distillation = distillation
self.use_conv = use_conv
ln_layer = ConvNorm if self.use_conv else LinearNorm
self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer)
self.blocks = []
resolution = img_size // patch_size
for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)):
for _ in range(dpth):
self.blocks.append(
Residual(
Attention(
ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer,
resolution=resolution, use_conv=use_conv),
drop_path))
if mr > 0:
h = int(ed * mr)
self.blocks.append(
Residual(nn.Sequential(
ln_layer(ed, h, resolution=resolution),
act_layer(),
ln_layer(h, ed, bn_weight_init=0, resolution=resolution),
), drop_path))
if do[0] == 'Subsample':
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
resolution_ = (resolution - 1) // do[5] + 1
self.blocks.append(
AttentionSubsample(
*embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2],
attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5],
resolution=resolution, resolution_=resolution_, use_conv=use_conv))
resolution = resolution_
if do[4] > 0: # mlp_ratio
h = int(embed_dim[i + 1] * do[4])
self.blocks.append(
Residual(nn.Sequential(
ln_layer(embed_dim[i + 1], h, resolution=resolution),
act_layer(),
ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
), drop_path))
self.blocks = nn.Sequential(*self.blocks)
# Classifier head
self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
if distillation:
self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
else:
self.head_dist = None
@torch.jit.ignore
def no_weight_decay(self):
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
def forward(self, x):
x = self.patch_embed(x)
if not self.use_conv:
x = x.flatten(2).transpose(1, 2)
x = self.blocks(x)
x = x.mean((-2, -1)) if self.use_conv else x.mean(1)
if self.head_dist is not None:
x, x_dist = self.head(x), self.head_dist(x)
if self.training and not torch.jit.is_scripting():
return x, x_dist
else:
# during inference, return the average of both classifier predictions
return (x + x_dist) / 2
else:
x = self.head(x)
return x
def checkpoint_filter_fn(state_dict, model):
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
D = model.state_dict()
for k in state_dict.keys():
if D[k].ndim == 4 and state_dict[k].ndim == 2:
state_dict[k] = state_dict[k][:, :, None, None]
return state_dict
def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model_cfg = dict(**model_cfgs[variant], **kwargs)
model = build_model_with_cfg(
Levit, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**model_cfg)
#if fuse:
# utils.replace_batchnorm(model)
return model

View File

@ -273,25 +273,14 @@ def _init_weights(m, n: str, head_bias: float = 0.):
nn.init.ones_(m.weight) nn.init.ones_(m.weight)
def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs): def _create_mixer(variant, pretrained=False, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for MLP-Mixer models.') raise RuntimeError('features_only not implemented for MLP-Mixer models.')
model = build_model_with_cfg( model = build_model_with_cfg(
MlpMixer, variant, pretrained, MlpMixer, variant, pretrained,
default_cfg=default_cfg, default_cfg=default_cfgs[variant],
img_size=img_size,
num_classes=num_classes,
**kwargs) **kwargs)
return model return model

View File

@ -110,6 +110,12 @@ default_cfgs = dict(
eca_nfnet_l1=_dcfg( eca_nfnet_l1=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth',
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0), pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0),
eca_nfnet_l2=_dcfg(
url='',
pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 352, 352), crop_pct=1.0),
eca_nfnet_l3=_dcfg(
url='',
pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0),
nf_regnet_b0=_dcfg( nf_regnet_b0=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'), url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
@ -244,6 +250,12 @@ model_cfgs = dict(
eca_nfnet_l1=_nfnet_cfg( eca_nfnet_l1=_nfnet_cfg(
depths=(2, 4, 12, 6), feat_mult=2, group_size=64, bottle_ratio=0.25, depths=(2, 4, 12, 6), feat_mult=2, group_size=64, bottle_ratio=0.25,
attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
eca_nfnet_l2=_nfnet_cfg(
depths=(3, 6, 18, 9), feat_mult=2, group_size=64, bottle_ratio=0.25,
attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
eca_nfnet_l3=_nfnet_cfg(
depths=(4, 8, 24, 12), feat_mult=2, group_size=64, bottle_ratio=0.25,
attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
# EffNet influenced RegNet defs. # EffNet influenced RegNet defs.
# NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8. # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8.
@ -814,6 +826,22 @@ def eca_nfnet_l1(pretrained=False, **kwargs):
return _create_normfreenet('eca_nfnet_l1', pretrained=pretrained, **kwargs) return _create_normfreenet('eca_nfnet_l1', pretrained=pretrained, **kwargs)
@register_model
def eca_nfnet_l2(pretrained=False, **kwargs):
""" ECA-NFNet-L2 w/ SiLU
My experimental 'light' model w/ F2 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
"""
return _create_normfreenet('eca_nfnet_l2', pretrained=pretrained, **kwargs)
@register_model
def eca_nfnet_l3(pretrained=False, **kwargs):
""" ECA-NFNet-L3 w/ SiLU
My experimental 'light' model w/ F3 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
"""
return _create_normfreenet('eca_nfnet_l3', pretrained=pretrained, **kwargs)
@register_model @register_model
def nf_regnet_b0(pretrained=False, **kwargs): def nf_regnet_b0(pretrained=False, **kwargs):
""" Normalization-Free RegNet-B0 """ Normalization-Free RegNet-B0

View File

@ -251,24 +251,14 @@ def checkpoint_filter_fn(state_dict, model):
def _create_pit(variant, pretrained=False, **kwargs): def _create_pit(variant, pretrained=False, **kwargs):
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
img_size = kwargs.pop('img_size', default_img_size)
num_classes = kwargs.pop('num_classes', default_num_classes)
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg( model = build_model_with_cfg(
PoolingVisionTransformer, variant, pretrained, PoolingVisionTransformer, variant, pretrained,
default_cfg=default_cfg, default_cfg=default_cfgs[variant],
img_size=img_size,
num_classes=num_classes,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)
return model return model

View File

@ -50,7 +50,7 @@ def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(filter='', module='', pretrained=False, exclude_filters=''): def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False):
""" Return list of available model names, sorted alphabetically """ Return list of available model names, sorted alphabetically
Args: Args:
@ -58,6 +58,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''):
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
pretrained (bool) - Include only models with pretrained weights if True pretrained (bool) - Include only models with pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
Example: Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
@ -70,7 +71,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''):
if filter: if filter:
models = fnmatch.filter(models, filter) # include these models models = fnmatch.filter(models, filter) # include these models
if exclude_filters: if exclude_filters:
if not isinstance(exclude_filters, list): if not isinstance(exclude_filters, (tuple, list)):
exclude_filters = [exclude_filters] exclude_filters = [exclude_filters]
for xf in exclude_filters: for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models exclude_models = fnmatch.filter(models, xf) # exclude these models
@ -78,6 +79,8 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''):
models = set(models).difference(exclude_models) models = set(models).difference(exclude_models)
if pretrained: if pretrained:
models = _model_has_pretrained.intersection(models) models = _model_has_pretrained.intersection(models)
if name_matches_cfg:
models = set(_model_default_cfgs).intersection(models)
return list(sorted(models, key=_natural_key)) return list(sorted(models, key=_natural_key))

View File

@ -12,7 +12,7 @@ import torch.nn as nn
from functools import partial from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained from timm.models.helpers import build_model_with_cfg
from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers import Mlp, DropPath, trunc_normal_
from timm.models.layers.helpers import to_2tuple from timm.models.layers.helpers import to_2tuple
from timm.models.registry import register_model from timm.models.registry import register_model
@ -238,24 +238,31 @@ def checkpoint_filter_fn(state_dict, model):
return state_dict return state_dict
def _create_tnt(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
TNT, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
@register_model @register_model
def tnt_s_patch16_224(pretrained=False, **kwargs): def tnt_s_patch16_224(pretrained=False, **kwargs):
model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, model_cfg = dict(
patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4,
qkv_bias=False, **kwargs) qkv_bias=False, **kwargs)
model.default_cfg = default_cfgs['tnt_s_patch16_224'] model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **model_cfg)
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3),
filter_fn=checkpoint_filter_fn)
return model return model
@register_model @register_model
def tnt_b_patch16_224(pretrained=False, **kwargs): def tnt_b_patch16_224(pretrained=False, **kwargs):
model = TNT(patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, model_cfg = dict(
patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4,
qkv_bias=False, **kwargs) qkv_bias=False, **kwargs)
model.default_cfg = default_cfgs['tnt_b_patch16_224'] model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **model_cfg)
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model return model

View File

@ -33,7 +33,7 @@ def _cfg(url='', **kwargs):
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head',
**kwargs **kwargs
} }
@ -361,25 +361,14 @@ class Twins(nn.Module):
return x return x
def _create_twins(variant, pretrained=False, default_cfg=None, **kwargs): def _create_twins(variant, pretrained=False, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg( model = build_model_with_cfg(
Twins, variant, pretrained, Twins, variant, pretrained,
default_cfg=default_cfg, default_cfg=default_cfgs[variant],
img_size=img_size,
num_classes=num_classes,
**kwargs) **kwargs)
return model return model

414
timm/models/visformer.py Normal file
View File

@ -0,0 +1,414 @@
""" Visformer
Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.12533
From original at https://github.com/danczs/Visformer
"""
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed
from .registry import register_model
__all__ = ['Visformer']
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0', 'classifier': 'head',
**kwargs
}
default_cfgs = dict(
visformer_tiny=_cfg(),
visformer_small=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/visformer_small-839e1f5b.pth'
),
)
class LayerNormBHWC(nn.LayerNorm):
def __init__(self, dim):
super().__init__(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
class SpatialMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None,
act_layer=nn.GELU, drop=0., group=8, spatial_conv=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.in_features = in_features
self.out_features = out_features
self.spatial_conv = spatial_conv
if self.spatial_conv:
if group < 2: # net setting
hidden_features = in_features * 5 // 6
else:
hidden_features = in_features * 2
self.hidden_features = hidden_features
self.group = group
self.drop = nn.Dropout(drop)
self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False)
self.act1 = act_layer()
if self.spatial_conv:
self.conv2 = nn.Conv2d(
hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False)
self.act2 = act_layer()
else:
self.conv2 = None
self.act2 = None
self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.act1(x)
x = self.drop(x)
if self.conv2 is not None:
x = self.conv2(x)
x = self.act2(x)
x = self.conv3(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = round(dim // num_heads * head_dim_ratio)
self.head_dim = head_dim
self.scale = head_dim ** -0.5
self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, C, H, W = x.shape
x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
q, k, v = x[0], x[1], x[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4.,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormBHWC,
group=8, attn_disabled=False, spatial_conv=False):
super().__init__()
self.spatial_conv = spatial_conv
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if attn_disabled:
self.norm1 = None
self.attn = None
else:
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = SpatialMlp(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
group=group, spatial_conv=spatial_conv) # new setting
def forward(self, x):
if self.attn is not None:
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Visformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=LayerNormBHWC, attn_stage='111', pos_embed=True, spatial_conv='111',
vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.init_channels = init_channels
self.img_size = img_size
self.vit_stem = vit_stem
self.pool = pool
self.conv_init = conv_init
if isinstance(depth, (list, tuple)):
self.stage_num1, self.stage_num2, self.stage_num3 = depth
depth = sum(depth)
else:
self.stage_num1 = self.stage_num3 = depth // 3
self.stage_num2 = depth - self.stage_num1 - self.stage_num3
self.pos_embed = pos_embed
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
# stage 1
if self.vit_stem:
self.stem = None
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
img_size //= 16
else:
if self.init_channels is None:
self.stem = None
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans,
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
img_size //= 8
else:
self.stem = nn.Sequential(
nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(self.init_channels),
nn.ReLU(inplace=True)
)
img_size //= 2
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels,
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
img_size //= 4
if self.pos_embed:
if self.vit_stem:
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size))
else:
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, img_size, img_size))
self.pos_drop = nn.Dropout(p=drop_rate)
self.stage1 = nn.ModuleList([
Block(
dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
group=group, attn_disabled=(attn_stage[0] == '0'), spatial_conv=(spatial_conv[0] == '1')
)
for i in range(self.stage_num1)
])
#stage2
if not self.vit_stem:
self.patch_embed2 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2,
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
img_size //= 2
if self.pos_embed:
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size))
self.stage2 = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
group=group, attn_disabled=(attn_stage[1] == '0'), spatial_conv=(spatial_conv[1] == '1')
)
for i in range(self.stage_num1, self.stage_num1+self.stage_num2)
])
# stage 3
if not self.vit_stem:
self.patch_embed3 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim,
embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False)
img_size //= 2
if self.pos_embed:
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, img_size, img_size))
self.stage3 = nn.ModuleList([
Block(
dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
group=group, attn_disabled=(attn_stage[2] == '0'), spatial_conv=(spatial_conv[2] == '1')
)
for i in range(self.stage_num1+self.stage_num2, depth)
])
# head
if self.pool:
self.global_pooling = nn.AdaptiveAvgPool2d(1)
head_dim = embed_dim if self.vit_stem else embed_dim * 2
self.norm = norm_layer(head_dim)
self.head = nn.Linear(head_dim, num_classes)
# weights init
if self.pos_embed:
trunc_normal_(self.pos_embed1, std=0.02)
if not self.vit_stem:
trunc_normal_(self.pos_embed2, std=0.02)
trunc_normal_(self.pos_embed3, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
if self.conv_init:
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
else:
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0.)
def forward(self, x):
if self.stem is not None:
x = self.stem(x)
# stage 1
x = self.patch_embed1(x)
if self.pos_embed:
x = x + self.pos_embed1
x = self.pos_drop(x)
for b in self.stage1:
x = b(x)
# stage 2
if not self.vit_stem:
x = self.patch_embed2(x)
if self.pos_embed:
x = x + self.pos_embed2
x = self.pos_drop(x)
for b in self.stage2:
x = b(x)
# stage3
if not self.vit_stem:
x = self.patch_embed3(x)
if self.pos_embed:
x = x + self.pos_embed3
x = self.pos_drop(x)
for b in self.stage3:
x = b(x)
# head
x = self.norm(x)
if self.pool:
x = self.global_pooling(x)
else:
x = x[:, :, 0, 0]
x = self.head(x.view(x.size(0), -1))
return x
def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
Visformer, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
return model
@register_model
def visformer_tiny(pretrained=False, **kwargs):
model_cfg = dict(
img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs)
model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg)
return model
@register_model
def visformer_small(pretrained=False, **kwargs):
model_cfg = dict(
img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs)
model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg)
return model
# @register_model
# def visformer_net1(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
# spatial_conv='000', vit_stem=True, conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net2(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net3(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net4(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net5(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
# spatial_conv='111', vit_stem=False, conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net6(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net7(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000',
# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model

View File

@ -387,21 +387,20 @@ def checkpoint_filter_fn(state_dict, model):
v = v.reshape(O, -1, H, W) v = v.reshape(O, -1, H, W)
elif k == 'pos_embed' and v.shape != model.pos_embed.shape: elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
# To resize pos embedding when using model at different size from pretrained weights # To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1), v = resize_pos_embed(
model.patch_embed.grid_size) v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
out_dict[k] = v out_dict[k] = v
return out_dict return out_dict
def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None: default_cfg = default_cfg or default_cfgs[variant]
default_cfg = deepcopy(default_cfgs[variant]) if kwargs.get('features_only', None):
overlay_external_default_cfg(default_cfg, kwargs) raise RuntimeError('features_only not implemented for Vision Transformer models.')
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes) # NOTE this extra code to support handling of repr size for in21k pretrained models
img_size = kwargs.pop('img_size', default_img_size) default_num_classes = default_cfg['num_classes']
num_classes = kwargs.get('num_classes', default_num_classes)
repr_size = kwargs.pop('representation_size', None) repr_size = kwargs.pop('representation_size', None)
if repr_size is not None and num_classes != default_num_classes: if repr_size is not None and num_classes != default_num_classes:
# Remove representation layer if fine-tuning. This may not always be the desired action, # Remove representation layer if fine-tuning. This may not always be the desired action,
@ -409,18 +408,12 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw
_logger.warning("Removing representation layer for fine-tuning.") _logger.warning("Removing representation layer for fine-tuning.")
repr_size = None repr_size = None
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg( model = build_model_with_cfg(
VisionTransformer, variant, pretrained, VisionTransformer, variant, pretrained,
default_cfg=default_cfg, default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
representation_size=repr_size, representation_size=repr_size,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)
return model return model

View File

@ -27,7 +27,7 @@ def _cfg(url='', **kwargs):
return { return {
'url': url, 'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head', 'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
**kwargs **kwargs
@ -107,11 +107,10 @@ class HybridEmbed(nn.Module):
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
default_cfg = deepcopy(default_cfgs[variant])
embed_layer = partial(HybridEmbed, backbone=backbone) embed_layer = partial(HybridEmbed, backbone=backbone)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
return _create_vision_transformer( return _create_vision_transformer(
variant, pretrained=pretrained, default_cfg=default_cfg, embed_layer=embed_layer, **kwargs) variant, pretrained=pretrained, embed_layer=embed_layer, default_cfg=default_cfgs[variant], **kwargs)
def _resnetv2(layers=(3, 4, 9), **kwargs): def _resnetv2(layers=(3, 4, 9), **kwargs):

View File

@ -1 +1 @@
__version__ = '0.4.9' __version__ = '0.4.10'