mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #1641 from rwightman/maxxvit_hub
MaxxViT weights on hub, new 12k FT 1k weights, convnext 384x384 12k FT 1k, and more
This commit is contained in:
commit
3aa31f537d
53
README.md
53
README.md
@ -24,6 +24,59 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
|
||||
* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗
|
||||
* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch.
|
||||
|
||||
### Jan 20, 2023
|
||||
* Add two convnext 12k -> 1k fine-tunes at 384x384
|
||||
* `convnext_tiny.in12k_ft_in1k_384` - 85.1 @ 384
|
||||
* `convnext_small.in12k_ft_in1k_384` - 86.2 @ 384
|
||||
|
||||
* Push all MaxxViT weights to HF hub, and add new ImageNet-12k -> 1k fine-tunes for `rw` base MaxViT and CoAtNet 1/2 models
|
||||
|
||||
|model |top1 |top5 |samples / sec |Params (M) |GMAC |Act (M)|
|
||||
|------------------------------------------------------------------------------------------------------------------------|----:|----:|--------------:|--------------:|-----:|------:|
|
||||
|[maxvit_xlarge_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_xlarge_tf_512.in21k_ft_in1k) |88.53|98.64| 21.76| 475.77|534.14|1413.22|
|
||||
|[maxvit_xlarge_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_xlarge_tf_384.in21k_ft_in1k) |88.32|98.54| 42.53| 475.32|292.78| 668.76|
|
||||
|[maxvit_base_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_base_tf_512.in21k_ft_in1k) |88.20|98.53| 50.87| 119.88|138.02| 703.99|
|
||||
|[maxvit_large_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_large_tf_512.in21k_ft_in1k) |88.04|98.40| 36.42| 212.33|244.75| 942.15|
|
||||
|[maxvit_large_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_large_tf_384.in21k_ft_in1k) |87.98|98.56| 71.75| 212.03|132.55| 445.84|
|
||||
|[maxvit_base_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_base_tf_384.in21k_ft_in1k) |87.92|98.54| 104.71| 119.65| 73.80| 332.90|
|
||||
|[maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k) |87.81|98.37| 106.55| 116.14| 70.97| 318.95|
|
||||
|[maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k) |87.47|98.37| 149.49| 116.09| 72.98| 213.74|
|
||||
|[coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k) |87.39|98.31| 160.80| 73.88| 47.69| 209.43|
|
||||
|[maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k) |86.89|98.02| 375.86| 116.14| 23.15| 92.64|
|
||||
|[maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k) |86.64|98.02| 501.03| 116.09| 24.20| 62.77|
|
||||
|[maxvit_base_tf_512.in1k](https://huggingface.co/timm/maxvit_base_tf_512.in1k) |86.60|97.92| 50.75| 119.88|138.02| 703.99|
|
||||
|[coatnet_2_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_2_rw_224.sw_in12k_ft_in1k) |86.57|97.89| 631.88| 73.87| 15.09| 49.22|
|
||||
|[maxvit_large_tf_512.in1k](https://huggingface.co/timm/maxvit_large_tf_512.in1k) |86.52|97.88| 36.04| 212.33|244.75| 942.15|
|
||||
|[coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k) |86.49|97.90| 620.58| 73.88| 15.18| 54.78|
|
||||
|[maxvit_base_tf_384.in1k](https://huggingface.co/timm/maxvit_base_tf_384.in1k) |86.29|97.80| 101.09| 119.65| 73.80| 332.90|
|
||||
|[maxvit_large_tf_384.in1k](https://huggingface.co/timm/maxvit_large_tf_384.in1k) |86.23|97.69| 70.56| 212.03|132.55| 445.84|
|
||||
|[maxvit_small_tf_512.in1k](https://huggingface.co/timm/maxvit_small_tf_512.in1k) |86.10|97.76| 88.63| 69.13| 67.26| 383.77|
|
||||
|[maxvit_tiny_tf_512.in1k](https://huggingface.co/timm/maxvit_tiny_tf_512.in1k) |85.67|97.58| 144.25| 31.05| 33.49| 257.59|
|
||||
|[maxvit_small_tf_384.in1k](https://huggingface.co/timm/maxvit_small_tf_384.in1k) |85.54|97.46| 188.35| 69.02| 35.87| 183.65|
|
||||
|[maxvit_tiny_tf_384.in1k](https://huggingface.co/timm/maxvit_tiny_tf_384.in1k) |85.11|97.38| 293.46| 30.98| 17.53| 123.42|
|
||||
|[maxvit_large_tf_224.in1k](https://huggingface.co/timm/maxvit_large_tf_224.in1k) |84.93|96.97| 247.71| 211.79| 43.68| 127.35|
|
||||
|[coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k) |84.90|96.96| 1025.45| 41.72| 8.11| 40.13|
|
||||
|[maxvit_base_tf_224.in1k](https://huggingface.co/timm/maxvit_base_tf_224.in1k) |84.85|96.99| 358.25| 119.47| 24.04| 95.01|
|
||||
|[maxxvit_rmlp_small_rw_256.sw_in1k](https://huggingface.co/timm/maxxvit_rmlp_small_rw_256.sw_in1k) |84.63|97.06| 575.53| 66.01| 14.67| 58.38|
|
||||
|[coatnet_rmlp_2_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_224.sw_in1k) |84.61|96.74| 625.81| 73.88| 15.18| 54.78|
|
||||
|[maxvit_rmlp_small_rw_224.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_small_rw_224.sw_in1k) |84.49|96.76| 693.82| 64.90| 10.75| 49.30|
|
||||
|[maxvit_small_tf_224.in1k](https://huggingface.co/timm/maxvit_small_tf_224.in1k) |84.43|96.83| 647.96| 68.93| 11.66| 53.17|
|
||||
|[maxvit_rmlp_tiny_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_tiny_rw_256.sw_in1k) |84.23|96.78| 807.21| 29.15| 6.77| 46.92|
|
||||
|[coatnet_1_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_1_rw_224.sw_in1k) |83.62|96.38| 989.59| 41.72| 8.04| 34.60|
|
||||
|[maxvit_tiny_rw_224.sw_in1k](https://huggingface.co/timm/maxvit_tiny_rw_224.sw_in1k) |83.50|96.50| 1100.53| 29.06| 5.11| 33.11|
|
||||
|[maxvit_tiny_tf_224.in1k](https://huggingface.co/timm/maxvit_tiny_tf_224.in1k) |83.41|96.59| 1004.94| 30.92| 5.60| 35.78|
|
||||
|[coatnet_rmlp_1_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_1_rw_224.sw_in1k) |83.36|96.45| 1093.03| 41.69| 7.85| 35.47|
|
||||
|[maxxvitv2_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxxvitv2_nano_rw_256.sw_in1k) |83.11|96.33| 1276.88| 23.70| 6.26| 23.05|
|
||||
|[maxxvit_rmlp_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxxvit_rmlp_nano_rw_256.sw_in1k) |83.03|96.34| 1341.24| 16.78| 4.37| 26.05|
|
||||
|[maxvit_rmlp_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_nano_rw_256.sw_in1k) |82.96|96.26| 1283.24| 15.50| 4.47| 31.92|
|
||||
|[maxvit_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_nano_rw_256.sw_in1k) |82.93|96.23| 1218.17| 15.45| 4.46| 30.28|
|
||||
|[coatnet_bn_0_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_bn_0_rw_224.sw_in1k) |82.39|96.19| 1600.14| 27.44| 4.67| 22.04|
|
||||
|[coatnet_0_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_0_rw_224.sw_in1k) |82.39|95.84| 1831.21| 27.44| 4.43| 18.73|
|
||||
|[coatnet_rmlp_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_nano_rw_224.sw_in1k) |82.05|95.87| 2109.09| 15.15| 2.62| 20.34|
|
||||
|[coatnext_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnext_nano_rw_224.sw_in1k) |81.95|95.92| 2525.52| 14.70| 2.47| 12.80|
|
||||
|[coatnet_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_nano_rw_224.sw_in1k) |81.70|95.64| 2344.52| 15.14| 2.41| 15.41|
|
||||
|[maxvit_rmlp_pico_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_pico_rw_256.sw_in1k) |80.53|95.21| 1594.71| 7.52| 1.85| 24.86|
|
||||
|
||||
### Jan 11, 2023
|
||||
* Update ConvNeXt ImageNet-12k pretrain series w/ two new fine-tuned weights (and pre FT `.in12k` tags)
|
||||
* `convnext_nano.in12k_ft_in1k` - 82.3 @ 224, 82.9 @ 288 (previously released)
|
||||
|
@ -27,8 +27,9 @@ NON_STD_FILTERS = [
|
||||
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*', 'flexivit*'
|
||||
'eva_*', 'flexivit*'
|
||||
]
|
||||
#'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', '
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
||||
# exclude models that cause specific test failures
|
||||
@ -38,7 +39,7 @@ if 'GITHUB_ACTIONS' in os.environ:
|
||||
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
|
||||
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
|
||||
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
|
||||
'swin*giant*', 'convnextv2_huge*']
|
||||
'swin*giant*', 'convnextv2_huge*', 'maxvit_xlarge*']
|
||||
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
|
||||
else:
|
||||
EXCLUDE_FILTERS = []
|
||||
@ -53,7 +54,7 @@ MAX_JIT_SIZE = 320
|
||||
TARGET_FFEAT_SIZE = 96
|
||||
MAX_FFEAT_SIZE = 256
|
||||
TARGET_FWD_FX_SIZE = 128
|
||||
MAX_FWD_FX_SIZE = 224
|
||||
MAX_FWD_FX_SIZE = 256
|
||||
TARGET_BWD_FX_SIZE = 128
|
||||
MAX_BWD_FX_SIZE = 224
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
||||
rand_augment_transform, auto_augment_transform
|
||||
from .config import resolve_data_config
|
||||
from .config import resolve_data_config, resolve_model_data_config
|
||||
from .constants import *
|
||||
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
||||
from .dataset_factory import create_dataset
|
||||
|
@ -6,16 +6,18 @@ _logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def resolve_data_config(
|
||||
args,
|
||||
default_cfg=None,
|
||||
args=None,
|
||||
pretrained_cfg=None,
|
||||
model=None,
|
||||
use_test_size=False,
|
||||
verbose=False
|
||||
):
|
||||
new_config = {}
|
||||
default_cfg = default_cfg or {}
|
||||
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
|
||||
default_cfg = model.default_cfg
|
||||
assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config."
|
||||
args = args or {}
|
||||
pretrained_cfg = pretrained_cfg or {}
|
||||
if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'):
|
||||
pretrained_cfg = model.pretrained_cfg
|
||||
data_config = {}
|
||||
|
||||
# Resolve input/image size
|
||||
in_chans = 3
|
||||
@ -32,65 +34,94 @@ def resolve_data_config(
|
||||
assert isinstance(args['img_size'], int)
|
||||
input_size = (in_chans, args['img_size'], args['img_size'])
|
||||
else:
|
||||
if use_test_size and default_cfg.get('test_input_size', None) is not None:
|
||||
input_size = default_cfg['test_input_size']
|
||||
elif default_cfg.get('input_size', None) is not None:
|
||||
input_size = default_cfg['input_size']
|
||||
new_config['input_size'] = input_size
|
||||
if use_test_size and pretrained_cfg.get('test_input_size', None) is not None:
|
||||
input_size = pretrained_cfg['test_input_size']
|
||||
elif pretrained_cfg.get('input_size', None) is not None:
|
||||
input_size = pretrained_cfg['input_size']
|
||||
data_config['input_size'] = input_size
|
||||
|
||||
# resolve interpolation method
|
||||
new_config['interpolation'] = 'bicubic'
|
||||
data_config['interpolation'] = 'bicubic'
|
||||
if args.get('interpolation', None):
|
||||
new_config['interpolation'] = args['interpolation']
|
||||
elif default_cfg.get('interpolation', None):
|
||||
new_config['interpolation'] = default_cfg['interpolation']
|
||||
data_config['interpolation'] = args['interpolation']
|
||||
elif pretrained_cfg.get('interpolation', None):
|
||||
data_config['interpolation'] = pretrained_cfg['interpolation']
|
||||
|
||||
# resolve dataset + model mean for normalization
|
||||
new_config['mean'] = IMAGENET_DEFAULT_MEAN
|
||||
data_config['mean'] = IMAGENET_DEFAULT_MEAN
|
||||
if args.get('mean', None) is not None:
|
||||
mean = tuple(args['mean'])
|
||||
if len(mean) == 1:
|
||||
mean = tuple(list(mean) * in_chans)
|
||||
else:
|
||||
assert len(mean) == in_chans
|
||||
new_config['mean'] = mean
|
||||
elif default_cfg.get('mean', None):
|
||||
new_config['mean'] = default_cfg['mean']
|
||||
data_config['mean'] = mean
|
||||
elif pretrained_cfg.get('mean', None):
|
||||
data_config['mean'] = pretrained_cfg['mean']
|
||||
|
||||
# resolve dataset + model std deviation for normalization
|
||||
new_config['std'] = IMAGENET_DEFAULT_STD
|
||||
data_config['std'] = IMAGENET_DEFAULT_STD
|
||||
if args.get('std', None) is not None:
|
||||
std = tuple(args['std'])
|
||||
if len(std) == 1:
|
||||
std = tuple(list(std) * in_chans)
|
||||
else:
|
||||
assert len(std) == in_chans
|
||||
new_config['std'] = std
|
||||
elif default_cfg.get('std', None):
|
||||
new_config['std'] = default_cfg['std']
|
||||
data_config['std'] = std
|
||||
elif pretrained_cfg.get('std', None):
|
||||
data_config['std'] = pretrained_cfg['std']
|
||||
|
||||
# resolve default inference crop
|
||||
crop_pct = DEFAULT_CROP_PCT
|
||||
if args.get('crop_pct', None):
|
||||
crop_pct = args['crop_pct']
|
||||
else:
|
||||
if use_test_size and default_cfg.get('test_crop_pct', None):
|
||||
crop_pct = default_cfg['test_crop_pct']
|
||||
elif default_cfg.get('crop_pct', None):
|
||||
crop_pct = default_cfg['crop_pct']
|
||||
new_config['crop_pct'] = crop_pct
|
||||
if use_test_size and pretrained_cfg.get('test_crop_pct', None):
|
||||
crop_pct = pretrained_cfg['test_crop_pct']
|
||||
elif pretrained_cfg.get('crop_pct', None):
|
||||
crop_pct = pretrained_cfg['crop_pct']
|
||||
data_config['crop_pct'] = crop_pct
|
||||
|
||||
# resolve default crop percentage
|
||||
crop_mode = DEFAULT_CROP_MODE
|
||||
if args.get('crop_mode', None):
|
||||
crop_mode = args['crop_mode']
|
||||
elif default_cfg.get('crop_mode', None):
|
||||
crop_mode = default_cfg['crop_mode']
|
||||
new_config['crop_mode'] = crop_mode
|
||||
elif pretrained_cfg.get('crop_mode', None):
|
||||
crop_mode = pretrained_cfg['crop_mode']
|
||||
data_config['crop_mode'] = crop_mode
|
||||
|
||||
if verbose:
|
||||
_logger.info('Data processing configuration for current model + dataset:')
|
||||
for n, v in new_config.items():
|
||||
for n, v in data_config.items():
|
||||
_logger.info('\t%s: %s' % (n, str(v)))
|
||||
|
||||
return new_config
|
||||
return data_config
|
||||
|
||||
|
||||
def resolve_model_data_config(
|
||||
model,
|
||||
args=None,
|
||||
pretrained_cfg=None,
|
||||
use_test_size=False,
|
||||
verbose=False,
|
||||
):
|
||||
""" Resolve Model Data Config
|
||||
This is equivalent to resolve_data_config() but with arguments re-ordered to put model first.
|
||||
|
||||
Args:
|
||||
model (nn.Module): the model instance
|
||||
args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg)
|
||||
pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model)
|
||||
use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution
|
||||
verbose (bool): enable extra logging of resolved values
|
||||
|
||||
Returns:
|
||||
dictionary of config
|
||||
"""
|
||||
return resolve_data_config(
|
||||
args=args,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
model=model,
|
||||
use_test_size=use_test_size,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
@ -38,13 +38,24 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
|
||||
class ClassifierHead(nn.Module):
|
||||
"""Classifier head w/ configurable global pooling and dropout."""
|
||||
|
||||
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
|
||||
def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
|
||||
super(ClassifierHead, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
|
||||
self.in_features = in_features
|
||||
self.use_conv = use_conv
|
||||
|
||||
self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv)
|
||||
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
||||
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
|
||||
|
||||
def reset(self, num_classes, global_pool=None):
|
||||
if global_pool is not None:
|
||||
if global_pool != self.global_pool.pool_type:
|
||||
self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv)
|
||||
self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity()
|
||||
num_pooled_features = self.in_features * self.global_pool.feat_mult()
|
||||
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv)
|
||||
|
||||
def forward(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate:
|
||||
|
@ -179,11 +179,11 @@ def load_pretrained(
|
||||
return
|
||||
|
||||
if filter_fn is not None:
|
||||
# for backwards compat with filter fn that take one arg, try one first, the two
|
||||
try:
|
||||
state_dict = filter_fn(state_dict)
|
||||
except TypeError:
|
||||
state_dict = filter_fn(state_dict, model)
|
||||
except TypeError as e:
|
||||
# for backwards compat with filter fn that take one arg
|
||||
state_dict = filter_fn(state_dict)
|
||||
|
||||
input_convs = pretrained_cfg.get('first_conv', None)
|
||||
if input_convs is not None and in_chans != 3:
|
||||
|
@ -236,20 +236,7 @@ def push_to_hf_hub(
|
||||
model_card = model_card or {}
|
||||
model_name = repo_id.split('/')[-1]
|
||||
readme_path = Path(tmpdir) / "README.md"
|
||||
readme_text = "---\n"
|
||||
readme_text += "tags:\n- image-classification\n- timm\n"
|
||||
readme_text += "library_tag: timm\n"
|
||||
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
|
||||
readme_text += "---\n"
|
||||
readme_text += f"# Model card for {model_name}\n"
|
||||
if 'description' in model_card:
|
||||
readme_text += f"\n{model_card['description']}\n"
|
||||
if 'details' in model_card:
|
||||
readme_text += f"\n## Model Details\n"
|
||||
for k, v in model_card['details'].items():
|
||||
readme_text += f"- **{k}:** {v}\n"
|
||||
if 'citation' in model_card:
|
||||
readme_text += f"\n## Citation\n```\n{model_card['citation']}```\n"
|
||||
readme_text = generate_readme(model_card, model_name)
|
||||
readme_path.write_text(readme_text)
|
||||
|
||||
# Upload model and return
|
||||
@ -260,3 +247,51 @@ def push_to_hf_hub(
|
||||
create_pr=create_pr,
|
||||
commit_message=commit_message,
|
||||
)
|
||||
|
||||
|
||||
def generate_readme(model_card, model_name):
|
||||
readme_text = "---\n"
|
||||
readme_text += "tags:\n- image-classification\n- timm\n"
|
||||
readme_text += "library_tag: timm\n"
|
||||
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
|
||||
if 'details' in model_card and 'Dataset' in model_card['details']:
|
||||
readme_text += 'datasets:\n'
|
||||
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
|
||||
if 'Pretrain Dataset' in model_card['details']:
|
||||
readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n"
|
||||
readme_text += "---\n"
|
||||
readme_text += f"# Model card for {model_name}\n"
|
||||
if 'description' in model_card:
|
||||
readme_text += f"\n{model_card['description']}\n"
|
||||
if 'details' in model_card:
|
||||
readme_text += f"\n## Model Details\n"
|
||||
for k, v in model_card['details'].items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
readme_text += f"- **{k}:**\n"
|
||||
for vi in v:
|
||||
readme_text += f" - {vi}\n"
|
||||
elif isinstance(v, dict):
|
||||
readme_text += f"- **{k}:**\n"
|
||||
for ki, vi in v.items():
|
||||
readme_text += f" - {ki}: {vi}\n"
|
||||
else:
|
||||
readme_text += f"- **{k}:** {v}\n"
|
||||
if 'usage' in model_card:
|
||||
readme_text += f"\n## Model Usage\n"
|
||||
readme_text += model_card['usage']
|
||||
readme_text += '\n'
|
||||
|
||||
if 'comparison' in model_card:
|
||||
readme_text += f"\n## Model Comparison\n"
|
||||
readme_text += model_card['comparison']
|
||||
readme_text += '\n'
|
||||
|
||||
if 'citation' in model_card:
|
||||
readme_text += f"\n## Citation\n"
|
||||
if not isinstance(model_card['citation'], (list, tuple)):
|
||||
citations = [model_card['citation']]
|
||||
else:
|
||||
citations = model_card['citation']
|
||||
for c in citations:
|
||||
readme_text += f"```bibtex\n{c}\n```\n"
|
||||
return readme_text
|
||||
|
@ -500,6 +500,13 @@ default_cfgs = generate_default_cfgs({
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
|
||||
'convnext_tiny.in12k_ft_in1k_384': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'convnext_small.in12k_ft_in1k_384': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
'convnext_nano.in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, num_classes=11821),
|
||||
@ -706,27 +713,27 @@ default_cfgs = generate_default_cfgs({
|
||||
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640),
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
|
||||
'convnext_base.clip_laion2b_augreg': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640),
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
|
||||
'convnext_base.clip_laiona': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640),
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
|
||||
'convnext_base.clip_laiona_320': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 320, 320), crop_pct=1.0, num_classes=640),
|
||||
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
|
||||
'convnext_base.clip_laiona_augreg_320': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 320, 320), crop_pct=1.0, num_classes=640),
|
||||
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
|
||||
})
|
||||
|
||||
|
||||
|
@ -913,7 +913,7 @@ class CspNet(nn.Module):
|
||||
# Construct the head
|
||||
self.num_features = prev_chs
|
||||
self.head = ClassifierHead(
|
||||
in_chs=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
|
||||
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
|
||||
|
||||
|
@ -12,9 +12,6 @@ These configs work well and appear to be a bit faster / lower resource than the
|
||||
The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to
|
||||
match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match.
|
||||
|
||||
# FIXME / WARNING
|
||||
This impl remains a WIP, some configs and models may vanish or change...
|
||||
|
||||
Papers:
|
||||
|
||||
MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697
|
||||
@ -76,6 +73,8 @@ class MaxxVitTransformerCfg:
|
||||
partition_ratio: int = 32
|
||||
window_size: Optional[Tuple[int, int]] = None
|
||||
grid_size: Optional[Tuple[int, int]] = None
|
||||
no_block_attn: bool = False # disable window block attention for maxvit (ie only grid)
|
||||
use_nchw_attn: bool = False # for MaxViT variants (not used for CoAt), keep tensors in NCHW order
|
||||
init_values: Optional[float] = None
|
||||
act_layer: str = 'gelu'
|
||||
norm_layer: str = 'layernorm2d'
|
||||
@ -889,19 +888,17 @@ class MaxxVitBlock(nn.Module):
|
||||
stride: int = 1,
|
||||
conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
|
||||
transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
|
||||
use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU
|
||||
use_block_attn: bool = True, # FIXME for testing ConvNeXt conv w/o block attention
|
||||
drop_path: float = 0.,
|
||||
):
|
||||
super().__init__()
|
||||
self.nchw_attn = transformer_cfg.use_nchw_attn
|
||||
|
||||
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
|
||||
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
|
||||
|
||||
attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
|
||||
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl
|
||||
self.nchw_attn = use_nchw_attn
|
||||
self.attn_block = partition_layer(**attn_kwargs) if use_block_attn else None
|
||||
partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl
|
||||
self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs)
|
||||
self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
|
||||
|
||||
def init_weights(self, scheme=''):
|
||||
@ -1084,26 +1081,48 @@ class NormMlpHead(nn.Module):
|
||||
hidden_size=None,
|
||||
pool_type='avg',
|
||||
drop_rate=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
act_layer=nn.Tanh,
|
||||
norm_layer='layernorm2d',
|
||||
act_layer='tanh',
|
||||
):
|
||||
super().__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.in_features = in_features
|
||||
self.hidden_size = hidden_size
|
||||
self.num_features = in_features
|
||||
self.use_conv = not pool_type
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
act_layer = get_act_layer(act_layer)
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
|
||||
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
|
||||
self.norm = norm_layer(in_features)
|
||||
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
|
||||
if hidden_size:
|
||||
self.pre_logits = nn.Sequential(OrderedDict([
|
||||
('fc', nn.Linear(in_features, hidden_size)),
|
||||
('fc', linear_layer(in_features, hidden_size)),
|
||||
('act', act_layer()),
|
||||
]))
|
||||
self.num_features = hidden_size
|
||||
else:
|
||||
self.pre_logits = nn.Identity()
|
||||
self.drop = nn.Dropout(self.drop_rate)
|
||||
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def reset(self, num_classes, global_pool=None):
|
||||
if global_pool is not None:
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
self.use_conv = self.global_pool.is_identity()
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
|
||||
if self.hidden_size:
|
||||
if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
|
||||
(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
|
||||
with torch.no_grad():
|
||||
new_fc = linear_layer(self.in_features, self.hidden_size)
|
||||
new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
|
||||
new_fc.bias.copy_(self.pre_logits.fc.bias)
|
||||
self.pre_logits.fc = new_fc
|
||||
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
@ -1163,6 +1182,7 @@ class MaxxVit(nn.Module):
|
||||
self.num_features = self.embed_dim = cfg.embed_dim[-1]
|
||||
self.drop_rate = drop_rate
|
||||
self.grad_checkpointing = False
|
||||
self.feature_info = []
|
||||
|
||||
self.stem = Stem(
|
||||
in_chs=in_chans,
|
||||
@ -1173,8 +1193,8 @@ class MaxxVit(nn.Module):
|
||||
norm_layer=cfg.conv_cfg.norm_layer,
|
||||
norm_eps=cfg.conv_cfg.norm_eps,
|
||||
)
|
||||
|
||||
stride = self.stem.stride
|
||||
self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')]
|
||||
feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))])
|
||||
|
||||
num_stages = len(cfg.embed_dim)
|
||||
@ -1198,15 +1218,17 @@ class MaxxVit(nn.Module):
|
||||
)]
|
||||
stride *= stage_stride
|
||||
in_chs = out_chs
|
||||
self.feature_info += [dict(num_chs=out_chs, reduction=stride, module=f'stages.{i}')]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps)
|
||||
if cfg.head_hidden_size:
|
||||
self.head_hidden_size = cfg.head_hidden_size
|
||||
if self.head_hidden_size:
|
||||
self.norm = nn.Identity()
|
||||
self.head = NormMlpHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
hidden_size=cfg.head_hidden_size,
|
||||
hidden_size=self.head_hidden_size,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
norm_layer=final_norm_layer,
|
||||
@ -1253,9 +1275,7 @@ class MaxxVit(nn.Module):
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is None:
|
||||
global_pool = self.head.global_pool.pool_type
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
@ -1376,6 +1396,7 @@ def _next_cfg(
|
||||
transformer_norm_layer='layernorm2d',
|
||||
transformer_norm_layer_cl='layernorm',
|
||||
window_size=None,
|
||||
no_block_attn=False,
|
||||
init_values=1e-6,
|
||||
rel_pos_type='mlp', # MLP by default for maxxvit
|
||||
rel_pos_dim=512,
|
||||
@ -1396,6 +1417,7 @@ def _next_cfg(
|
||||
expand_first=False,
|
||||
pool_type=pool_type,
|
||||
window_size=window_size,
|
||||
no_block_attn=no_block_attn, # enabled for MaxxViT-V2
|
||||
init_values=init_values[1],
|
||||
norm_layer=transformer_norm_layer,
|
||||
norm_layer_cl=transformer_norm_layer_cl,
|
||||
@ -1422,8 +1444,8 @@ def _tf_cfg():
|
||||
|
||||
|
||||
model_cfgs = dict(
|
||||
# Fiddling with configs / defaults / still pretraining
|
||||
coatnet_pico_rw_224=MaxxVitCfg(
|
||||
# timm specific CoAtNet configs
|
||||
coatnet_pico_rw=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(2, 3, 5, 2),
|
||||
stem_width=(32, 64),
|
||||
@ -1432,7 +1454,7 @@ model_cfgs = dict(
|
||||
conv_attn_ratio=0.25,
|
||||
),
|
||||
),
|
||||
coatnet_nano_rw_224=MaxxVitCfg(
|
||||
coatnet_nano_rw=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(3, 4, 6, 3),
|
||||
stem_width=(32, 64),
|
||||
@ -1442,7 +1464,7 @@ model_cfgs = dict(
|
||||
conv_attn_ratio=0.25,
|
||||
),
|
||||
),
|
||||
coatnet_0_rw_224=MaxxVitCfg(
|
||||
coatnet_0_rw=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 3, 7, 2), # deeper than paper '0' model
|
||||
stem_width=(32, 64),
|
||||
@ -1451,7 +1473,7 @@ model_cfgs = dict(
|
||||
transformer_shortcut_bias=False,
|
||||
),
|
||||
),
|
||||
coatnet_1_rw_224=MaxxVitCfg(
|
||||
coatnet_1_rw=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 6, 14, 2),
|
||||
stem_width=(32, 64),
|
||||
@ -1461,7 +1483,7 @@ model_cfgs = dict(
|
||||
transformer_shortcut_bias=False,
|
||||
)
|
||||
),
|
||||
coatnet_2_rw_224=MaxxVitCfg(
|
||||
coatnet_2_rw=MaxxVitCfg(
|
||||
embed_dim=(128, 256, 512, 1024),
|
||||
depths=(2, 6, 14, 2),
|
||||
stem_width=(64, 128),
|
||||
@ -1471,7 +1493,7 @@ model_cfgs = dict(
|
||||
#init_values=1e-6,
|
||||
),
|
||||
),
|
||||
coatnet_3_rw_224=MaxxVitCfg(
|
||||
coatnet_3_rw=MaxxVitCfg(
|
||||
embed_dim=(192, 384, 768, 1536),
|
||||
depths=(2, 6, 14, 2),
|
||||
stem_width=(96, 192),
|
||||
@ -1482,8 +1504,8 @@ model_cfgs = dict(
|
||||
),
|
||||
),
|
||||
|
||||
# Highly experimental configs
|
||||
coatnet_bn_0_rw_224=MaxxVitCfg(
|
||||
# Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
|
||||
coatnet_bn_0_rw=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 3, 7, 2), # deeper than paper '0' model
|
||||
stem_width=(32, 64),
|
||||
@ -1494,7 +1516,7 @@ model_cfgs = dict(
|
||||
transformer_norm_layer='batchnorm2d',
|
||||
)
|
||||
),
|
||||
coatnet_rmlp_nano_rw_224=MaxxVitCfg(
|
||||
coatnet_rmlp_nano_rw=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(3, 4, 6, 3),
|
||||
stem_width=(32, 64),
|
||||
@ -1505,7 +1527,7 @@ model_cfgs = dict(
|
||||
rel_pos_dim=384,
|
||||
),
|
||||
),
|
||||
coatnet_rmlp_0_rw_224=MaxxVitCfg(
|
||||
coatnet_rmlp_0_rw=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 3, 7, 2), # deeper than paper '0' model
|
||||
stem_width=(32, 64),
|
||||
@ -1514,7 +1536,7 @@ model_cfgs = dict(
|
||||
rel_pos_type='mlp',
|
||||
),
|
||||
),
|
||||
coatnet_rmlp_1_rw_224=MaxxVitCfg(
|
||||
coatnet_rmlp_1_rw=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 6, 14, 2),
|
||||
stem_width=(32, 64),
|
||||
@ -1526,7 +1548,7 @@ model_cfgs = dict(
|
||||
rel_pos_dim=384, # was supposed to be 512, woops
|
||||
),
|
||||
),
|
||||
coatnet_rmlp_1_rw2_224=MaxxVitCfg(
|
||||
coatnet_rmlp_1_rw2=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 6, 14, 2),
|
||||
stem_width=(32, 64),
|
||||
@ -1536,7 +1558,7 @@ model_cfgs = dict(
|
||||
rel_pos_dim=512, # was supposed to be 512, woops
|
||||
),
|
||||
),
|
||||
coatnet_rmlp_2_rw_224=MaxxVitCfg(
|
||||
coatnet_rmlp_2_rw=MaxxVitCfg(
|
||||
embed_dim=(128, 256, 512, 1024),
|
||||
depths=(2, 6, 14, 2),
|
||||
stem_width=(64, 128),
|
||||
@ -1547,7 +1569,7 @@ model_cfgs = dict(
|
||||
rel_pos_type='mlp'
|
||||
),
|
||||
),
|
||||
coatnet_rmlp_3_rw_224=MaxxVitCfg(
|
||||
coatnet_rmlp_3_rw=MaxxVitCfg(
|
||||
embed_dim=(192, 384, 768, 1536),
|
||||
depths=(2, 6, 14, 2),
|
||||
stem_width=(96, 192),
|
||||
@ -1559,14 +1581,14 @@ model_cfgs = dict(
|
||||
),
|
||||
),
|
||||
|
||||
coatnet_nano_cc_224=MaxxVitCfg(
|
||||
coatnet_nano_cc=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(3, 4, 6, 3),
|
||||
stem_width=(32, 64),
|
||||
block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
|
||||
**_rw_coat_cfg(),
|
||||
),
|
||||
coatnext_nano_rw_224=MaxxVitCfg(
|
||||
coatnext_nano_rw=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(3, 4, 6, 3),
|
||||
stem_width=(32, 64),
|
||||
@ -1578,130 +1600,66 @@ model_cfgs = dict(
|
||||
),
|
||||
|
||||
# Trying to be like the CoAtNet paper configs
|
||||
coatnet_0_224=MaxxVitCfg(
|
||||
coatnet_0=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 3, 5, 2),
|
||||
stem_width=64,
|
||||
head_hidden_size=768,
|
||||
),
|
||||
coatnet_1_224=MaxxVitCfg(
|
||||
coatnet_1=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 6, 14, 2),
|
||||
stem_width=64,
|
||||
head_hidden_size=768,
|
||||
),
|
||||
coatnet_2_224=MaxxVitCfg(
|
||||
coatnet_2=MaxxVitCfg(
|
||||
embed_dim=(128, 256, 512, 1024),
|
||||
depths=(2, 6, 14, 2),
|
||||
stem_width=128,
|
||||
head_hidden_size=1024,
|
||||
),
|
||||
coatnet_3_224=MaxxVitCfg(
|
||||
coatnet_3=MaxxVitCfg(
|
||||
embed_dim=(192, 384, 768, 1536),
|
||||
depths=(2, 6, 14, 2),
|
||||
stem_width=192,
|
||||
head_hidden_size=1536,
|
||||
),
|
||||
coatnet_4_224=MaxxVitCfg(
|
||||
coatnet_4=MaxxVitCfg(
|
||||
embed_dim=(192, 384, 768, 1536),
|
||||
depths=(2, 12, 28, 2),
|
||||
stem_width=192,
|
||||
head_hidden_size=1536,
|
||||
),
|
||||
coatnet_5_224=MaxxVitCfg(
|
||||
coatnet_5=MaxxVitCfg(
|
||||
embed_dim=(256, 512, 1280, 2048),
|
||||
depths=(2, 12, 28, 2),
|
||||
stem_width=192,
|
||||
head_hidden_size=2048,
|
||||
),
|
||||
|
||||
# Experimental MaxVit configs
|
||||
maxvit_pico_rw_256=MaxxVitCfg(
|
||||
maxvit_pico_rw=MaxxVitCfg(
|
||||
embed_dim=(32, 64, 128, 256),
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(24, 32),
|
||||
**_rw_max_cfg(),
|
||||
),
|
||||
maxvit_nano_rw_256=MaxxVitCfg(
|
||||
maxvit_nano_rw=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(1, 2, 3, 1),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
**_rw_max_cfg(),
|
||||
),
|
||||
maxvit_tiny_rw_224=MaxxVitCfg(
|
||||
maxvit_tiny_rw=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
**_rw_max_cfg(),
|
||||
),
|
||||
maxvit_tiny_rw_256=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
**_rw_max_cfg(),
|
||||
),
|
||||
|
||||
maxvit_rmlp_pico_rw_256=MaxxVitCfg(
|
||||
embed_dim=(32, 64, 128, 256),
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(24, 32),
|
||||
**_rw_max_cfg(rel_pos_type='mlp'),
|
||||
),
|
||||
maxvit_rmlp_nano_rw_256=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(1, 2, 3, 1),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
**_rw_max_cfg(rel_pos_type='mlp'),
|
||||
),
|
||||
maxvit_rmlp_tiny_rw_256=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
**_rw_max_cfg(rel_pos_type='mlp'),
|
||||
),
|
||||
maxvit_rmlp_small_rw_224=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
**_rw_max_cfg(
|
||||
rel_pos_type='mlp',
|
||||
init_values=1e-6,
|
||||
),
|
||||
),
|
||||
maxvit_rmlp_small_rw_256=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
**_rw_max_cfg(
|
||||
rel_pos_type='mlp',
|
||||
init_values=1e-6,
|
||||
),
|
||||
),
|
||||
maxvit_rmlp_base_rw_224=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 6, 14, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
head_hidden_size=768,
|
||||
**_rw_max_cfg(
|
||||
rel_pos_type='mlp',
|
||||
),
|
||||
),
|
||||
maxvit_rmlp_base_rw_384=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 6, 14, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
head_hidden_size=768,
|
||||
**_rw_max_cfg(
|
||||
rel_pos_type='mlp',
|
||||
),
|
||||
),
|
||||
|
||||
maxvit_tiny_pm_256=MaxxVitCfg(
|
||||
maxvit_tiny_pm=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('PM',) * 4,
|
||||
@ -1709,7 +1667,49 @@ model_cfgs = dict(
|
||||
**_rw_max_cfg(),
|
||||
),
|
||||
|
||||
maxxvit_rmlp_nano_rw_256=MaxxVitCfg(
|
||||
maxvit_rmlp_pico_rw=MaxxVitCfg(
|
||||
embed_dim=(32, 64, 128, 256),
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(24, 32),
|
||||
**_rw_max_cfg(rel_pos_type='mlp'),
|
||||
),
|
||||
maxvit_rmlp_nano_rw=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(1, 2, 3, 1),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
**_rw_max_cfg(rel_pos_type='mlp'),
|
||||
),
|
||||
maxvit_rmlp_tiny_rw=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
**_rw_max_cfg(rel_pos_type='mlp'),
|
||||
),
|
||||
maxvit_rmlp_small_rw=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
**_rw_max_cfg(
|
||||
rel_pos_type='mlp',
|
||||
init_values=1e-6,
|
||||
),
|
||||
),
|
||||
maxvit_rmlp_base_rw=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 6, 14, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
head_hidden_size=768,
|
||||
**_rw_max_cfg(
|
||||
rel_pos_type='mlp',
|
||||
),
|
||||
),
|
||||
|
||||
maxxvit_rmlp_nano_rw=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(1, 2, 3, 1),
|
||||
block_type=('M',) * 4,
|
||||
@ -1717,33 +1717,50 @@ model_cfgs = dict(
|
||||
weight_init='normal',
|
||||
**_next_cfg(),
|
||||
),
|
||||
maxxvit_rmlp_tiny_rw_256=MaxxVitCfg(
|
||||
maxxvit_rmlp_tiny_rw=MaxxVitCfg(
|
||||
embed_dim=(64, 128, 256, 512),
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(32, 64),
|
||||
**_next_cfg(),
|
||||
),
|
||||
maxxvit_rmlp_small_rw_256=MaxxVitCfg(
|
||||
maxxvit_rmlp_small_rw=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 2, 5, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(48, 96),
|
||||
**_next_cfg(),
|
||||
),
|
||||
maxxvit_rmlp_base_rw_224=MaxxVitCfg(
|
||||
|
||||
maxxvitv2_nano_rw=MaxxVitCfg(
|
||||
embed_dim=(96, 192, 384, 768),
|
||||
depths=(2, 6, 14, 2),
|
||||
depths=(1, 2, 3, 1),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(48, 96),
|
||||
**_next_cfg(),
|
||||
weight_init='normal',
|
||||
**_next_cfg(
|
||||
no_block_attn=True,
|
||||
rel_pos_type='bias',
|
||||
),
|
||||
),
|
||||
maxxvit_rmlp_large_rw_224=MaxxVitCfg(
|
||||
maxxvitv2_rmlp_base_rw=MaxxVitCfg(
|
||||
embed_dim=(128, 256, 512, 1024),
|
||||
depths=(2, 6, 12, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(64, 128),
|
||||
**_next_cfg(),
|
||||
**_next_cfg(
|
||||
no_block_attn=True,
|
||||
),
|
||||
),
|
||||
maxxvitv2_rmlp_large_rw=MaxxVitCfg(
|
||||
embed_dim=(160, 320, 640, 1280),
|
||||
depths=(2, 6, 16, 2),
|
||||
block_type=('M',) * 4,
|
||||
stem_width=(80, 160),
|
||||
head_hidden_size=1280,
|
||||
**_next_cfg(
|
||||
no_block_attn=True,
|
||||
),
|
||||
),
|
||||
|
||||
# Trying to be like the MaxViT paper configs
|
||||
@ -1795,11 +1812,29 @@ model_cfgs = dict(
|
||||
)
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model: nn.Module):
|
||||
model_state_dict = model.state_dict()
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel():
|
||||
# adapt between conv2d / linear layers
|
||||
assert v.ndim in (2, 4)
|
||||
v = v.reshape(model_state_dict[k].shape)
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs):
|
||||
if cfg_variant is None:
|
||||
if variant in model_cfgs:
|
||||
cfg_variant = variant
|
||||
else:
|
||||
cfg_variant = '_'.join(variant.split('_')[:-1])
|
||||
return build_model_with_cfg(
|
||||
MaxxVit, variant, pretrained,
|
||||
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
|
||||
model_cfg=model_cfgs[cfg_variant],
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@ -1815,155 +1850,218 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
# Fiddling with configs / defaults / still pretraining
|
||||
'coatnet_pico_rw_224': _cfg(url=''),
|
||||
'coatnet_nano_rw_224': _cfg(
|
||||
# timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos
|
||||
'coatnet_pico_rw_224.untrained': _cfg(url=''),
|
||||
'coatnet_nano_rw_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
|
||||
crop_pct=0.9),
|
||||
'coatnet_0_rw_224': _cfg(
|
||||
'coatnet_0_rw_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
|
||||
'coatnet_1_rw_224': _cfg(
|
||||
'coatnet_1_rw_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
|
||||
),
|
||||
'coatnet_2_rw_224': _cfg(url=''),
|
||||
'coatnet_3_rw_224': _cfg(url=''),
|
||||
|
||||
# Highly experimental configs
|
||||
'coatnet_bn_0_rw_224': _cfg(
|
||||
# timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos
|
||||
'coatnet_2_rw_224.sw_in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
#'coatnet_3_rw_224.untrained': _cfg(url=''),
|
||||
|
||||
# Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos)
|
||||
'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
'coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
'coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
# Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
|
||||
'coatnet_bn_0_rw_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
||||
crop_pct=0.95),
|
||||
'coatnet_rmlp_nano_rw_224': _cfg(
|
||||
'coatnet_rmlp_nano_rw_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
|
||||
crop_pct=0.9),
|
||||
'coatnet_rmlp_0_rw_224': _cfg(url=''),
|
||||
'coatnet_rmlp_1_rw_224': _cfg(
|
||||
'coatnet_rmlp_0_rw_224.untrained': _cfg(url=''),
|
||||
'coatnet_rmlp_1_rw_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
|
||||
'coatnet_rmlp_1_rw2_224': _cfg(url=''),
|
||||
'coatnet_rmlp_2_rw_224': _cfg(
|
||||
'coatnet_rmlp_2_rw_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'),
|
||||
'coatnet_rmlp_3_rw_224': _cfg(url=''),
|
||||
'coatnet_nano_cc_224': _cfg(url=''),
|
||||
'coatnext_nano_rw_224': _cfg(
|
||||
'coatnet_rmlp_3_rw_224.untrained': _cfg(url=''),
|
||||
'coatnet_nano_cc_224.untrained': _cfg(url=''),
|
||||
'coatnext_nano_rw_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth',
|
||||
crop_pct=0.9),
|
||||
|
||||
# Trying to be like the CoAtNet paper configs
|
||||
'coatnet_0_224': _cfg(url=''),
|
||||
'coatnet_1_224': _cfg(url=''),
|
||||
'coatnet_2_224': _cfg(url=''),
|
||||
'coatnet_3_224': _cfg(url=''),
|
||||
'coatnet_4_224': _cfg(url=''),
|
||||
'coatnet_5_224': _cfg(url=''),
|
||||
# ImagenNet-12k pretrain CoAtNet
|
||||
'coatnet_2_rw_224.sw_in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821),
|
||||
'coatnet_3_rw_224.sw_in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821),
|
||||
'coatnet_rmlp_1_rw2_224.sw_in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821),
|
||||
'coatnet_rmlp_2_rw_224.sw_in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821),
|
||||
|
||||
# Experimental configs
|
||||
'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxvit_nano_rw_256': _cfg(
|
||||
# Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released)
|
||||
'coatnet_0_224.untrained': _cfg(url=''),
|
||||
'coatnet_1_224.untrained': _cfg(url=''),
|
||||
'coatnet_2_224.untrained': _cfg(url=''),
|
||||
'coatnet_3_224.untrained': _cfg(url=''),
|
||||
'coatnet_4_224.untrained': _cfg(url=''),
|
||||
'coatnet_5_224.untrained': _cfg(url=''),
|
||||
|
||||
# timm specific MaxVit configs, ImageNet-1k pretrain or untrained
|
||||
'maxvit_pico_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxvit_nano_rw_256.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxvit_tiny_rw_224': _cfg(
|
||||
'maxvit_tiny_rw_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'),
|
||||
'maxvit_tiny_rw_256': _cfg(
|
||||
'maxvit_tiny_rw_256.untrained': _cfg(
|
||||
url='',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxvit_rmlp_pico_rw_256': _cfg(
|
||||
'maxvit_tiny_pm_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
|
||||
# timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain
|
||||
'maxvit_rmlp_pico_rw_256.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxvit_rmlp_nano_rw_256': _cfg(
|
||||
'maxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxvit_rmlp_tiny_rw_256': _cfg(
|
||||
'maxvit_rmlp_tiny_rw_256.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxvit_rmlp_small_rw_224': _cfg(
|
||||
'maxvit_rmlp_small_rw_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth',
|
||||
crop_pct=0.9,
|
||||
),
|
||||
'maxvit_rmlp_small_rw_256': _cfg(
|
||||
'maxvit_rmlp_small_rw_256.untrained': _cfg(
|
||||
url='',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxvit_rmlp_base_rw_224': _cfg(
|
||||
url='',
|
||||
|
||||
# timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune
|
||||
'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
),
|
||||
'maxvit_rmlp_base_rw_384': _cfg(
|
||||
url='',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12)),
|
||||
'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
# timm specific MaxVit w/ ImageNet-12k pretrain
|
||||
'maxvit_rmlp_base_rw_224.sw_in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821,
|
||||
),
|
||||
|
||||
'maxxvit_rmlp_nano_rw_256': _cfg(
|
||||
# timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks)
|
||||
'maxxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxxvit_rmlp_small_rw_256': _cfg(
|
||||
'maxxvit_rmlp_tiny_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxxvit_rmlp_small_rw_256.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxxvit_rmlp_base_rw_224': _cfg(url=''),
|
||||
'maxxvit_rmlp_large_rw_224': _cfg(url=''),
|
||||
|
||||
# timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn)
|
||||
'maxxvitv2_nano_rw_256.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||
'maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
'maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxxvitv2_rmlp_large_rw_224.untrained': _cfg(url=''),
|
||||
|
||||
'maxxvitv2_rmlp_base_rw_224.sw_in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821),
|
||||
|
||||
# MaxViT models ported from official Tensorflow impl
|
||||
'maxvit_tiny_tf_224.in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_tiny_tf_224.in1k',
|
||||
hf_hub_id='timm/',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'maxvit_tiny_tf_384.in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_tiny_tf_384.in1k',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_tiny_tf_512.in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_tiny_tf_512.in1k',
|
||||
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_small_tf_224.in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_small_tf_224.in1k',
|
||||
hf_hub_id='timm/',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'maxvit_small_tf_384.in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_small_tf_384.in1k',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_small_tf_512.in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_small_tf_512.in1k',
|
||||
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_base_tf_224.in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_base_tf_224.in1k',
|
||||
hf_hub_id='timm/',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'maxvit_base_tf_384.in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_base_tf_384.in1k',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_base_tf_512.in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_base_tf_512.in1k',
|
||||
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_large_tf_224.in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_large_tf_224.in1k',
|
||||
hf_hub_id='timm/',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'maxvit_large_tf_384.in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_large_tf_384.in1k',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_large_tf_512.in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_large_tf_512.in1k',
|
||||
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
'maxvit_base_tf_224.in21k': _cfg(
|
||||
url=''),
|
||||
'maxvit_base_tf_384.in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_base_tf_384.in21k_ft_in1k',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_base_tf_512.in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_base_tf_512.in21k_ft_in1k',
|
||||
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_large_tf_224.in21k': _cfg(
|
||||
url=''),
|
||||
'maxvit_large_tf_384.in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_large_tf_384.in21k_ft_in1k',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_large_tf_512.in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_large_tf_512.in21k_ft_in1k',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_xlarge_tf_224.in21k': _cfg(
|
||||
url=''),
|
||||
'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_xlarge_tf_384.in21k_ft_in1k',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_xlarge_tf_512.in21k_ft_in1k',
|
||||
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
|
||||
})
|
||||
|
||||
|
||||
@ -2027,6 +2125,11 @@ def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs):
|
||||
return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def coatnet_rmlp_2_rw_384(pretrained=False, **kwargs):
|
||||
return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs):
|
||||
return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs)
|
||||
@ -2148,13 +2251,23 @@ def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def maxxvit_rmlp_base_rw_224(pretrained=False, **kwargs):
|
||||
return _create_maxxvit('maxxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
|
||||
def maxxvitv2_nano_rw_256(pretrained=False, **kwargs):
|
||||
return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def maxxvit_rmlp_large_rw_224(pretrained=False, **kwargs):
|
||||
return _create_maxxvit('maxxvit_rmlp_large_rw_224', pretrained=pretrained, **kwargs)
|
||||
def maxxvitv2_rmlp_base_rw_224(pretrained=False, **kwargs):
|
||||
return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def maxxvitv2_rmlp_base_rw_384(pretrained=False, **kwargs):
|
||||
return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def maxxvitv2_rmlp_large_rw_224(pretrained=False, **kwargs):
|
||||
return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -496,7 +496,7 @@ class RegNet(nn.Module):
|
||||
self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity()
|
||||
self.num_features = prev_width
|
||||
self.head = ClassifierHead(
|
||||
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
|
||||
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
|
||||
|
||||
|
@ -216,7 +216,7 @@ class XceptionAligned(nn.Module):
|
||||
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
|
||||
self.act = act_layer(inplace=True) if preact else nn.Identity()
|
||||
self.head = ClassifierHead(
|
||||
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
|
@ -1 +1 @@
|
||||
__version__ = '0.8.6dev0'
|
||||
__version__ = '0.8.7dev0'
|
||||
|
Loading…
x
Reference in New Issue
Block a user