Expand cls model zoo (#55)

1. expand cls model zoo
2. uniform load_pretrained
This commit is contained in:
Chen Jiayu 2022-06-01 11:01:29 +08:00 committed by GitHub
parent d76d5d79fc
commit 3ee118f065
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
125 changed files with 903 additions and 736 deletions

View File

@ -2,7 +2,6 @@ _base_ = '../../base.py'
# model settings # model settings
model = dict( model = dict(
type='Classification', type='Classification',
pretrained=None,
backbone=dict( backbone=dict(
type='ResNet', type='ResNet',
depth=50, depth=50,

View File

@ -2,7 +2,6 @@ _base_ = '../../base.py'
# model settings # model settings
model = dict( model = dict(
type='Classification', type='Classification',
pretrained=None,
backbone=dict( backbone=dict(
type='PytorchImageModelWrapper', type='PytorchImageModelWrapper',
model_name='swin_tiny_patch4_window7_224', model_name='swin_tiny_patch4_window7_224',

View File

@ -7,7 +7,6 @@ log_config = dict(
# model settings # model settings
model = dict( model = dict(
type='Classification', type='Classification',
pretrained=None,
backbone=dict(type='HRNet', arch='w18', multi_scale_output=True), backbone=dict(type='HRNet', arch='w18', multi_scale_output=True),
neck=dict(type='HRFuseScales', in_channels=(18, 36, 72, 144)), neck=dict(type='HRFuseScales', in_channels=(18, 36, 72, 144)),
head=dict( head=dict(

View File

@ -1,3 +1,5 @@
_base_ = './hrnetw18_b32x8_100e_jpg.py' _base_ = './hrnetw18_b32x8_100e_jpg.py'
# model settings # model settings
model = dict(backbone=dict(type='HRNet', arch='w30', multi_scale_output=True)) model = dict(
backbone=dict(type='HRNet', arch='w30', multi_scale_output=True),
neck=dict(type='HRFuseScales', in_channels=(30, 60, 120, 240)))

View File

@ -1,3 +1,5 @@
_base_ = './hrnetw18_b32x8_100e_jpg.py' _base_ = './hrnetw18_b32x8_100e_jpg.py'
# model settings # model settings
model = dict(backbone=dict(type='HRNet', arch='w32', multi_scale_output=True)) model = dict(
backbone=dict(type='HRNet', arch='w32', multi_scale_output=True),
neck=dict(type='HRFuseScales', in_channels=(32, 64, 128, 256)))

View File

@ -1,3 +1,5 @@
_base_ = './hrnetw18_b32x8_100e_jpg.py' _base_ = './hrnetw18_b32x8_100e_jpg.py'
# model settings # model settings
model = dict(backbone=dict(type='HRNet', arch='w40', multi_scale_output=True)) model = dict(
backbone=dict(type='HRNet', arch='w40', multi_scale_output=True),
neck=dict(type='HRFuseScales', in_channels=(40, 80, 160, 320)))

View File

@ -1,3 +1,5 @@
_base_ = './hrnetw18_b32x8_100e_jpg.py' _base_ = './hrnetw18_b32x8_100e_jpg.py'
# model settings # model settings
model = dict(backbone=dict(type='HRNet', arch='w44', multi_scale_output=True)) model = dict(
backbone=dict(type='HRNet', arch='w44', multi_scale_output=True),
neck=dict(type='HRFuseScales', in_channels=(44, 88, 176, 352)))

View File

@ -1,4 +1,6 @@
_base_ = './hrnetw18_b32x8_100e_jpg.py' _base_ = './hrnetw18_b32x8_100e_jpg.py'
# model settings # model settings
model = dict(backbone=dict(type='HRNet', arch='w48', multi_scale_output=True)) model = dict(
backbone=dict(type='HRNet', arch='w48', multi_scale_output=True),
neck=dict(type='HRFuseScales', in_channels=(48, 96, 192, 384)))

View File

@ -1,3 +1,5 @@
_base_ = './hrnetw18_b32x8_100e_jpg.py' _base_ = './hrnetw18_b32x8_100e_jpg.py'
# model settings # model settings
model = dict(backbone=dict(type='HRNet', arch='w64', multi_scale_output=True)) model = dict(
backbone=dict(type='HRNet', arch='w64', multi_scale_output=True),
neck=dict(type='HRFuseScales', in_channels=(64, 128, 256, 512)))

View File

@ -2,18 +2,8 @@ _base_ = './resnet50_b32x8_100e_jpg.py'
# model settings # model settings
model = dict( model = dict(
type='Classification', type='Classification',
pretrained=None,
backbone=dict( backbone=dict(
type='ResNet', type='ResNet',
depth=101, depth=101,
out_indices=[4], # 0: conv-1, x: stage-x out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='BN')), norm_cfg=dict(type='BN')))
head=dict(
type='ClsHead',
with_avg_pool=True,
in_channels=2048,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=1000))

View File

@ -1,19 +1,8 @@
_base_ = './resnet50_b32x8_100e_jpg.py' _base_ = './resnet50_b32x8_100e_jpg.py'
# model settings # model settings
model = dict( model = dict(
type='Classification',
pretrained=None,
backbone=dict( backbone=dict(
type='ResNet', type='ResNet',
depth=50, depth=50,
out_indices=[4], # 0: conv-1, x: stage-x out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='BN')), norm_cfg=dict(type='BN')))
head=dict(
type='ClsHead',
with_avg_pool=True,
in_channels=2048,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=1000))

View File

@ -7,7 +7,6 @@ log_config = dict(
# model settings # model settings
model = dict( model = dict(
type='Classification', type='Classification',
pretrained=None,
backbone=dict( backbone=dict(
type='ResNet', type='ResNet',
depth=50, depth=50,

View File

@ -7,7 +7,6 @@ log_config = dict(
# model settings # model settings
model = dict( model = dict(
type='Classification', type='Classification',
pretrained=None,
backbone=dict( backbone=dict(
type='ResNeXt', type='ResNeXt',
depth=50, depth=50,

View File

@ -16,7 +16,6 @@ model = dict(
mode='batch', mode='batch',
label_smoothing=0.1, label_smoothing=0.1,
num_classes=1000), num_classes=1000),
pretrained=None,
backbone=dict( backbone=dict(
type='PytorchImageModelWrapper', type='PytorchImageModelWrapper',
model_name='swin_tiny_patch4_window7_224', model_name='swin_tiny_patch4_window7_224',

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='cait_s24_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='cait_xxs24_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='cait_xxs36_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='coat_mini'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='coat_tiny'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='convit_base'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='convit_small'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='convit_tiny'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='convmixer_1024_20_ks9_p14'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='convmixer_1536_20'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='convmixer_768_32'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='convnext_base'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='convnext_large'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='convnext_small'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='convnext_tiny'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='crossvit_base_240'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='crossvit_small_240'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='crossvit_tiny_240'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='deit_base_distilled_patch16_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='deit_base_patch16_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='gmixer_24_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='gmlp_s16_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='levit_128'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='levit_192'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='levit_256'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='mixer_b16_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='mixer_l16_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='mobilevit_s'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='mobilevit_xs'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='mobilevit_xxs'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='jx_nest_base'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='jx_nest_small'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='jx_nest_tiny'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='pit_b_distilled_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='pit_s_distilled_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='poolformer_m36'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='poolformer_m48'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='poolformer_s12'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='poolformer_s24'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='poolformer_s36'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='resmlp_12_distilled_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='resmlp_24_distilled_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='resmlp_36_distilled_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='resmlp_big_24_distilled_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='sequencer2d_l'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='sequencer2d_m'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='sequencer2d_s'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='shuffletrans_base_p4_w7_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='shuffletrans_small_p4_w7_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='shuffletrans_tiny_p4_w7_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='dynamic_swin_small_p4_w7_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='dynamic_swin_tiny_p4_w7_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='swin_base_patch4_window7_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='swin_large_patch4_window7_224'))

View File

@ -1,45 +1,48 @@
_base_ = 'configs/base.py' _base_ = 'configs/base.py'
log_config = dict( log_config = dict(
interval=100, interval=10,
hooks=[dict(type='TextLoggerHook'), hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')]) dict(type='TensorboardLoggerHook')])
# model settings # model settings
model = dict( model = dict(
type='Classification', type='Classification',
pretrained=None, train_preprocess=['mixUp'],
pretrained=True,
mixup_cfg=dict(
mixup_alpha=0.2,
prob=1.0,
mode='batch',
label_smoothing=0.1,
num_classes=1000),
backbone=dict( backbone=dict(
type='PytorchImageModelWrapper', type='PytorchImageModelWrapper',
# model_name='pit_xs_distilled_224', model_name='vit_base_patch16_224',
# model_name='swin_small_patch4_window7_224',
# model_name='swin_tiny_patch4_window7_224',
# model_name='swin_base_patch4_window7_224_in22k',
model_name='vit_deit_small_distilled_patch16_224',
# model_name = 'vit_deit_small_distilled_patch16_224',
# model_name = 'resnet50',
num_classes=1000, num_classes=1000,
pretrained=True,
), ),
head=dict( head=dict(
# type='ClsHead', with_avg_pool=True, in_channels=384,
# type='ClsHead', with_avg_pool=True, in_channels=768,
# type='ClsHead', with_avg_pool=True, in_channels=1024,
# num_classes=0)
type='ClsHead', type='ClsHead',
loss_config={
'type': 'SoftTargetCrossEntropy',
},
with_fc=False)) with_fc=False))
data_train_list = 'data/imagenet_raw/meta/train_labeled.txt' load_from = None
data_train_root = 'data/imagenet_raw/train/'
data_test_list = 'data/imagenet_raw/meta/val_labeled.txt' data_train_list = './imagenet_raw/meta/train_labeled.txt'
data_test_root = 'data/imagenet_raw/validation/' data_train_root = './imagenet_raw/train/'
data_all_list = 'data/imagenet_raw/meta/all_labeled.txt' data_test_list = './imagenet_raw/meta/val_labeled.txt'
data_root = 'data/imagenet_raw/' data_test_root = './imagenet_raw/val/'
data_all_list = './imagenet_raw/meta/all_labeled.txt'
data_root = './imagenet_raw/'
dataset_type = 'ClsDataset' dataset_type = 'ClsDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [ train_pipeline = [
dict(type='RandomResizedCrop', size=224), dict(type='RandomResizedCrop', size=224),
dict(type='RandomHorizontalFlip'), dict(type='RandomHorizontalFlip'),
dict(type='MMAutoAugment'),
dict(type='ToTensor'), dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img', 'gt_labels']) dict(type='Collect', keys=['img', 'gt_labels'])
@ -53,7 +56,7 @@ test_pipeline = [
] ]
data = dict( data = dict(
imgs_per_gpu=32, # total 256 imgs_per_gpu=64, # total 256
workers_per_gpu=8, workers_per_gpu=8,
train=dict( train=dict(
type=dataset_type, type=dataset_type,
@ -84,11 +87,25 @@ eval_pipelines = [
custom_hooks = [] custom_hooks = []
# optimizer # optimizer
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) optimizer = dict(
type='AdamW',
lr=0.003,
weight_decay=0.3,
paramwise_options={
'cls_token': dict(weight_decay=0.),
'pos_embed': dict(weight_decay=0.),
})
optimizer_config = dict(grad_clip=dict(max_norm=1.0), update_interval=8)
# learning policy # learning policy
lr_config = dict(policy='step', step=[30, 60, 90]) lr_config = dict(
checkpoint_config = dict(interval=10) policy='CosineAnnealing',
min_lr=0,
warmup='linear',
warmup_iters=10000,
warmup_ratio=1e-4,
)
checkpoint_config = dict(interval=30)
# runtime settings # runtime settings
total_epochs = 90 total_epochs = 90

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='tnt_s_patch16_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='twins_svt_base'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='twins_svt_large'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='twins_svt_small'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='vit_base_patch16_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='vit_large_patch16_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='xcit_large_24_p8_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='xcit_large_24_p8_224_dist'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='xcit_medium_24_p8_224'))

View File

@ -0,0 +1,4 @@
_base_ = '../timm_config.py'
# model settings
model = dict(backbone=dict(model_name='xcit_medium_24_p8_224_dist'))

View File

@ -15,7 +15,6 @@ model = dict(
mode='batch', mode='batch',
label_smoothing=0.1, label_smoothing=0.1,
num_classes=1000), num_classes=1000),
pretrained=None,
backbone=dict( backbone=dict(
type='PytorchImageModelWrapper', type='PytorchImageModelWrapper',
model_name='vit_base_patch16_224', model_name='vit_base_patch16_224',

View File

@ -12,7 +12,6 @@ export = dict(export_neck=True)
# model settings # model settings
model = dict( model = dict(
type='Classification', type='Classification',
pretrained=None,
backbone=dict( backbone=dict(
type='ResNet', type='ResNet',
depth=50, depth=50,

View File

@ -19,7 +19,6 @@ export = dict(export_neck=True)
# model settings # model settings
model = dict( model = dict(
type='Classification', type='Classification',
pretrained=None,
backbone=dict( backbone=dict(
type='PytorchImageModelWrapper', type='PytorchImageModelWrapper',
# model_name='pit_xs_distilled_224', # model_name='pit_xs_distilled_224',

View File

@ -19,7 +19,6 @@ export = dict(export_neck=True)
# model settings # model settings
model = dict( model = dict(
type='Classification', type='Classification',
pretrained=None,
backbone=dict( backbone=dict(
type='ResNet', type='ResNet',
depth=50, depth=50,

View File

@ -44,7 +44,6 @@ work_dir = 'oss://path/to/work_dirs/classification/'
# model settings # model settings
model = dict( model = dict(
type='Classification', type='Classification',
# pretrained=None,
backbone=dict( backbone=dict(
type='PytorchImageModelWrapper', type='PytorchImageModelWrapper',

View File

@ -49,7 +49,6 @@ work_dir = 'oss://path/to/work_dirs/classification/'
# model settings # model settings
model = dict( model = dict(
type='Classification', type='Classification',
# pretrained=None,
backbone=dict( backbone=dict(
type='PytorchImageModelWrapper', type='PytorchImageModelWrapper',

View File

@ -7,7 +7,6 @@ log_config = dict(
# model settings # model settings
model = dict( model = dict(
type='Classification', type='Classification',
pretrained=None,
backbone=dict( backbone=dict(
type='ResNet', type='ResNet',
depth=50, depth=50,

View File

@ -7,7 +7,6 @@ log_config = dict(
# model settings # model settings
model = dict( model = dict(
type='Classification', type='Classification',
pretrained=None,
backbone=dict( backbone=dict(
type='ResNet', type='ResNet',
depth=50, depth=50,

View File

@ -13,9 +13,9 @@ load_from = None
# model settings # model settings
model = dict( model = dict(
type='Classification', type='Classification',
train_preprocess=['randomErasing'],
pretrained= pretrained=
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/swin_base_patch4_window7_224_22k_statedict.pth', 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/swin_base_patch4_window7_224_22k_statedict.pth',
train_preprocess=['randomErasing'],
backbone=dict( backbone=dict(
type='PytorchImageModelWrapper', type='PytorchImageModelWrapper',
model_name='swin_base_patch4_window7_224_in22k' model_name='swin_base_patch4_window7_224_in22k'

View File

@ -2,22 +2,77 @@
## Benchmarks ## Benchmarks
| Algorithm | Config | Top-1 (%) | Top-5 (%) | Download | | Algorithm | Config | Top-1 (%) | Top-5 (%) | gpu memory (MB) | inference time (ms/img) | Download |
| --------- | ------------------------------------------------------------ | --------- | --------- | ------------------------------------------------------------ | | --------- | ------------------------------------------------------------ | --------- | --------- | --------- | --------- | ------------------------------------------------------------ |
| resnet50(raw) | [resnet50(raw)](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py) | 76.454 | 93.084 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet50/epoch_100.pth) | | resnet50(raw) | [resnet50(raw)](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py) | 76.454 | 93.084 | 2412 | 8.59 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet50/epoch_100.pth) |
| resnet50(tfrecord) | [resnet50(tfrecord)](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_rn50_tfrecord.py) | 76.266 | 92.972 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet50/epoch_100.pth) | | resnet50(tfrecord) | [resnet50(tfrecord)](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_rn50_tfrecord.py) | 76.266 | 92.972 | 2412 | 8.59 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet50/epoch_100.pth) |
| resnet101 | [resnet101](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet101_jpg.py) | 78.152 | 93.922 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet101/epoch_100.pth) | | resnet101 | [resnet101](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet101_jpg.py) | 78.152 | 93.922 | 2484 | 16.77 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet101/epoch_100.pth) |
| resnet152 | [resnet152](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet152_jpg.py) | 78.544 | 94.206 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet152/epoch_100.pth) | | resnet152 | [resnet152](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet152_jpg.py) | 78.544 | 94.206 | 2544 | 24.69 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet152/epoch_100.pth) |
| resnext50-32x4d | [resnext50-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext50-32x4d_jpg.py) | 77.604 | 93.856 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnet50/epoch_100.pth) | | resnext50-32x4d | [resnext50-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext50-32x4d_jpg.py) | 77.604 | 93.856 | 4718 | 12.88 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnet50/epoch_100.pth) |
| resnext101-32x4d | [resnext101-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext101-32x4d_jpg.py) | 78.568 | 94.344 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext50-32x4d/epoch_100.pth) | | resnext101-32x4d | [resnext101-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext101-32x4d_jpg.py) | 78.568 | 94.344 | 4792 | 26.84 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext50-32x4d/epoch_100.pth) |
| resnext101-32x8d | [resnext101-32x8d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext101-32x8d_jpg.py) | 79.468 | 94.434 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext101-32x8d/epoch_100.pth) | | resnext101-32x8d | [resnext101-32x8d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext101-32x8d_jpg.py) | 79.468 | 94.434 | 9582 | 27.52 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext101-32x8d/epoch_100.pth) |
| resnext152-32x4d | [resnext152-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext152-32x4d_jpg.py) | 78.994 | 94.462 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext152-32x4d/epoch_100.pth) | | resnext152-32x4d | [resnext152-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext152-32x4d_jpg.py) | 78.994 | 94.462 | 4852 | 41.08 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext152-32x4d/epoch_100.pth) |
| hrnetw18 | [hrnetw18](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw18_jpg.py) | 76.258 | 92.976 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw18/epoch_100.pth) | | hrnetw18 | [hrnetw18](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw18_jpg.py) | 76.258 | 92.976 | 4701 | 54.55 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw18/epoch_100.pth) |
| hrnetw30 | [hrnetw30](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw30_jpg.py) | 77.66 | 93.862 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw30/epoch_100.pth) | | hrnetw30 | [hrnetw30](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw30_jpg.py) | 77.66 | 93.862 | 4766 | 54. 30 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw30/epoch_100.pth) |
| hrnetw32 | [hrnetw32](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw32_jpg.py) | 77.994 | 93.976 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw32/epoch_100.pth) | | hrnetw32 | [hrnetw32](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw32_jpg.py) | 77.994 | 93.976 | 4780 | 53.48 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw32/epoch_100.pth) |
| hrnetw40 | [hrnetw40](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw40_jpg.py) | 78.142 | 93.956 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw40/epoch_100.pth) | | hrnetw40 | [hrnetw40](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw40_jpg.py) | 78.142 | 93.956 | 4843 | 54.31 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw40/epoch_100.pth) |
| hrnetw44 | [hrnetw44](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw44_jpg.py) | 79.266 | 94.476 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw44/epoch_100.pth) | | hrnetw44 | [hrnetw44](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw44_jpg.py) | 79.266 | 94.476 | 4884 | 54.83 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw44/epoch_100.pth) |
| hrnetw48 | [hrnetw48](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw48_jpg.py) | 79.636 | 94.802 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw48/epoch_100.pth) | | hrnetw48 | [hrnetw48](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw48_jpg.py) | 79.636 | 94.802 | 4916 | 54.14 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw48/epoch_100.pth) |
| hrnetw64 | [hrnetw64](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw64_jpg.py) | 79.884 | 95.04 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/hrnetw64/epoch_100.pth) | | hrnetw64 | [hrnetw64](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw64_jpg.py) | 79.884 | 95.04 | 5120 | 54.74 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/hrnetw64/epoch_100.pth) |
| vit-base-patch16 | [vit-base-patch16](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/vit/imagenet_vit_base_patch16_224_jpg.py) | 76.082 | 92.026 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/vit/vit-base-patch16/epoch_300.pth) | | vit-base-patch16 | [vit-base-patch16](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/vit/imagenet_vit_base_patch16_224_jpg.py) | 76.082 | 92.026 | 346 | 8.03 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/vit/vit-base-patch16/epoch_300.pth) |
| swin-tiny-patch4-window7 | [swin-tiny-patch4-window7](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/swint/imagenet_swin_tiny_patch4_window7_224_jpg.py) | 80.528 | 94.822 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/swint/swin-tiny-patch4-window7/epoch_300.pth) | | swin-tiny-patch4-window7 | [swin-tiny-patch4-window7](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/swint/imagenet_swin_tiny_patch4_window7_224_jpg.py) | 80.528 | 94.822 | 132 | 12.94 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/swint/swin-tiny-patch4-window7/epoch_300.pth) |
(ps: 通过EasyCV训练得到模型结果推理的输入尺寸默认为224机器默认为V100 16G其中gpu memory记录的是gpu peak memory)
| Algorithm | Config | Top-1 (%) | Top-5 (%) | gpu memory (MB) | inference time (ms/img) | Download |
| --------- | ------------------------------------------------------------ | --------- | --------- | --------- | --------- | ------------------------------------------------------------ |
| vit_base_patch16_224 | [vit_base_patch16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/vit/vit_base_patch16_224.py) | 78.096 | 94.324 | 346 | 8.03 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/vit/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz) |
| vit_large_patch16_224 | [vit_large_patch16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/vit/vit_large_patch16_224.py) | 84.404 | 97.276 | 1171 | 16.30 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/vit/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz) |
| deit_base_patch16_224 | [deit_base_patch16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/deit/deit_base_patch16_224.py) | 81.756 | 95.6 | 346 | 7.98 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_patch16_224-b5f2ef4d.pth) |
| deit_base_distilled_patch16_224 | [deit_base_distilled_patch16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/deit/deit_base_distilled_patch16_224.py) | 83.232 | 96.476 | 349 | 8.07 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_distilled_patch16_224-df68dfff.pth) |
| xcit_medium_24_p8_224 | [xcit_medium_24_p8_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/xcit/xcit_medium_24_p8_224.py) | 83.348 | 96.21 | 884 | 31.77 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_medium_24_p8_224.pth) |
| xcit_medium_24_p8_224_dist | [xcit_medium_24_p8_224_dist](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/xcit/xcit_medium_24_p8_224_dist.py) | 84.876 | 97.164 | 884 | 32.08 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_medium_24_p8_224_dist.pth) |
| xcit_large_24_p8_224 | [xcit_large_24_p8_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/xcit/xcit_large_24_p8_224.py) | 83.986 | 96.47 | 1962 | 37.44 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_large_24_p8_224.pth) |
| xcit_large_24_p8_224_dist | [xcit_large_24_p8_224_dist](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/xcit/xcit_large_24_p8_224_dist.py) | 85.022 | 97.29 | 1962 | 37.44 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_large_24_p8_224_dist.pth) |
| tnt_s_patch16_224 | [tnt_s_patch16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/tnt/tnt_s_patch16_224.py) | 76.934 | 93.388 | 100 | 18.92 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/tnt/tnt_s_patch16_224.pth.tar) |
| convit_tiny | [convit_tiny](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convit/convit_tiny.py) | 72.954 | 91.68 | 31 | 10.79 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convit/convit_tiny.pth) |
| convit_small | [convit_small](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convit/convit_small.py) | 81.342 | 95.784 | 122 | 11.23 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convit/convit_small.pth) |
| convit_base | [convit_base](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convit/convit_base.py) | 82.27 | 95.916 | 358 | 11.26 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convit/convit_base.pth) |
| cait_xxs24_224 | [cait_xxs24_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/cait/cait_xxs24_224.py) | 78.45 | 94.154 | 50 | 22.62 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/cait/XXS24_224.pth) |
| cait_xxs36_224 | [cait_xxs36_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/cait/cait_xxs36_224.py) | 79.788 | 94.87 | 71 | 33.25 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/cait/XXS36_224.pth) |
| cait_s24_224 | [cait_s24_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/cait/cait_s24_224.py) | 83.302 | 96.568 | 190 | 23.74 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/cait/cait_s24_224.pth) |
| levit_128 | [levit_128](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/levit/levit_128.py) | 78.468 | 93.874 | 76 | 15.33 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/levit/LeViT-128-b88c2750.pth) |
| levit_192 | [levit_192](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/levit/levit_192.py) | 79.72 | 94.664 | 128 | 15.17 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/levit/LeViT-192-92712e41.pth) |
| levit_256 | [levit_256](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/levit/levit_256.py) | 81.432 | 95.38 | 222 | 15.27 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/levit/LeViT-256-13b5763e.pth) |
| convnext_tiny | [convnext_tiny](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convnext/convnext_tiny.py) | 81.878 | 95.836 | 128 | 7.17 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_tiny_1k_224_ema.pth) |
| convnext_small | [convnext_small](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convnext/convnext_small.py) | 82.836 | 96.458 | 213 | 12.89 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_small_1k_224_ema.pth) |
| convnext_base | [convnext_base](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convnext/convnext_base.py) | 83.73 | 96.692 | 364 | 13.04 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_base_1k_224_ema.pth) |
| convnext_large | [convnext_large](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convnext/convnext_large.py) | 84.164 | 96.844 | 781 | 13.78 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_large_1k_224_ema.pth) |
| resmlp_12_distilled_224 | [resmlp_12_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/resmlp/resmlp_12_distilled_224.py) | 77.876 | 93.532 | 66 | 4.90 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlp_12_dist.pth) |
| resmlp_24_distilled_224 | [resmlp_24_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/resmlp/resmlp_24_distilled_224.py) | 80.548 | 95.204 | 124 | 9.07 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlp_24_dist.pth) |
| resmlp_36_distilled_224 | [resmlp_36_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/resmlp/resmlp_36_distilled_224.py) | 80.944 | 95.416 | 181 | 13.56 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlp_36_dist.pth) |
| resmlp_big_24_distilled_224 | [resmlp_big_24_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/resmlp/resmlp_big_24_distilled_224.py) | 83.45 | 96.65 | 534 | 20.48 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlpB_24_dist.pth) |
| coat_tiny | [coat_tiny](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/coat/coat_tiny.py) | 78.112 | 93.972 | 127 | 33.09 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/coat/coat_tiny-473c2a20.pth) |
| coat_mini | [coat_mini](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/coat/coat_mini.py) | 80.912 | 95.378 | 247 | 33.29 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/coat/coat_mini-2c6baf49.pth) |
| convmixer_768_32 | [convmixer_768_32](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convmixer/convmixer_768_32.py) | 80.08 | 94.992 | 4995 | 10.23 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convmixer/convmixer_768_32_ks7_p7_relu.pth.tar) |
| convmixer_1024_20_ks9_p14 | [convmixer_1024_20_ks9_p14](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convmixer/convmixer_1024_20_ks9_p14.py) | 81.742 | 95.578 | 2407 | 6.29 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convmixer/convmixer_1024_20_ks9_p14.pth.tar) |
| convmixer_1536_20 | [convmixer_1536_20](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convmixer/convmixer_1536_20.py) | 81.432 | 95.38 | 547 | 14.66 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convmixer/convmixer_1536_20_ks9_p7.pth.tar) |
| gmixer_24_224 | [gmixer_24_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/gmixer/gmixer_24_224.py) | 78.088 | 93.6 | 104 | 11.65 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/gmixer/gmixer_24_224_raa-7daf7ae6.pth) |
| gmlp_s16_224 | [gmlp_s16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/gmlp/gmlp_s16_224.py) | 77.204 | 93.358 | 81 | 11.15 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/gmlp/gmlp_s16_224.pth) |
| mixer_b16_224 | [mixer_b16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/mlp-mixer/mixer_b16_224.py) | 72.558 | 90.068 | 241 | 5.37 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/mlp-mixer/jx_mixer_b16_224-76587d61.pth) |
| mixer_l16_224 | [mixer_l16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/mlp-mixer/mixer_l16_224.py) | 68.34 | 86.11 | 804 | 11.74 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/mlp-mixer/jx_mixer_l16_224-92f9adc4.pth) |
| jx_nest_tiny | [jx_nest_tiny](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/nest/jx_nest_tiny.py) | 81.278 | 95.618 | 90 | 9.05 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/nest/jx_nest_tiny-e3428fb9.pth) |
| jx_nest_small | [jx_nest_tiny](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/nest/jx_nest_small.py) | 83.144 | 96.3 | 174 | 16.92 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/nest/jx_nest_small-422eaded.pth) |
| jx_nest_base | [jx_nest_base](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/nest/jx_nest_base.py) | 83.474 | 96.442 | 300 | 16.88 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/nest/jx_nest_base-8bc41011.pth) |
| pit_s_distilled_224 | [pit_s_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/pit/pit_s_distilled_224.py) | 83.144 | 96.3 | 109 | 7.00 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/pit/pit_s_distill_819.pth) |
| pit_b_distilled_224 | [pit_b_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/pit/pit_b_distilled_224.py) | 83.474 | 96.442 | 330 | 7.66 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/pit/pit_b_distill_840.pth) |
| twins_svt_small | [twins_svt_small](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/twins/twins_svt_small.py) | 81.598 | 95.55 | 657 | 14.07 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/twins/twins_svt_small-42e5f78c.pth) |
| twins_svt_base | [twins_svt_base](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/twins/twins_svt_base.py) | 82.882 | 96.234 | 1447 | 18.99 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/twins/twins_svt_base-c2265010.pth) |
| twins_svt_large | [twins_svt_large](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/twins/twins_svt_large.py) | 83.428 | 96.506 | 2567 | 19.11 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/twins/twins_svt_large-90f6aaa9.pth) |
| swin_base_patch4_window7_224 | [swin_base_patch4_window7_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/swint/swin_base_patch4_window7_224.py) | 84.714 | 97.444 | 375 | 23.47 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_base_patch4_window7_224_22kto1k.pth) |
| swin_large_patch4_window7_224 | [swin_large_patch4_window7_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/swint/swin_large_patch4_window7_224.py) | 85.826 | 97.816 | 788 | 23.29 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_large_patch4_window7_224_22kto1k.pth) |
| dynamic_swin_small_p4_w7_224 | [dynamic_swin_small_p4_w7_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/swint/dynamic_small_base_p4_w7_224.py) | 82.896 | 96.234 | 220 | 28.55 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_small_patch4_window7_224_statedict.pth) |
| dynamic_swin_tiny_p4_w7_224 | [dynamic_swin_tiny_p4_w7_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/swint/dynamic_swin_tiny_p4_w7_224.py) | 80.912 | 95.41 | 136 | 14.58 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_tiny_patch4_window7_224_statedict.pth) |
| shuffletrans_tiny_p4_w7_224 | [shuffletrans_tiny_p4_w7_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/shuffle_transformer/shuffletrans_tiny_p4_w7_224.py) | 82.176 | 96.05 | 5311 | 13.90 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/shuffle_transformer/shuffle_tiny.pth) |
(ps: 通过导入官方模型得到推理结果需要torch.__version__ >= 1.9.0推理的输入尺寸默认为224机器默认为V100 16G其中gpu memory记录的是gpu peak memory)

View File

@ -18,11 +18,8 @@ class BenchMarkMLP(nn.Module):
self.dropout = nn.Dropout(p=0.5) self.dropout = nn.Dropout(p=0.5)
self.pool = nn.AdaptiveAvgPool2d((1, 1)) self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.avg_pool = avg_pool self.avg_pool = avg_pool
# self.fc2 = nn.Linear(feature_num, num_classes)
# self.relu2 = nn.ReLU()
# self._initialize_weights()
def init_weights(self, pretrained=None): def init_weights(self):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
nn.init.kaiming_normal_( nn.init.kaiming_normal_(
@ -33,7 +30,4 @@ class BenchMarkMLP(nn.Module):
x = self.pool(x) x = self.pool(x)
x = self.fc1(x) x = self.fc1(x)
x = self.relu1(x) x = self.relu1(x)
# x = self.dropout(x)
# x = self.fc2(x)
# x = self.relu2(x)
return tuple([x]) return tuple([x])

View File

@ -344,33 +344,17 @@ class BNInception(nn.Module):
1024, 128, kernel_size=(1, 1), stride=(1, 1)) 1024, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, affine=True) self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
self.inception_5b_relu_pool_proj = nn.ReLU(inplace) self.inception_5b_relu_pool_proj = nn.ReLU(inplace)
# self.last_linear = nn.Linear (1024, num_classes)
self.num_classes = num_classes self.num_classes = num_classes
if num_classes > 0: if num_classes > 0:
self.last_linear = nn.Linear(1024, num_classes) self.last_linear = nn.Linear(1024, num_classes)
self.pretrained = model_urls[self.__class__.__name__] def init_weights(self):
def init_weights(self, pretrained=None):
if isinstance(pretrained, str) or isinstance(pretrained, dict):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', nonlinearity='relu') kaiming_init(m, mode='fan_in', nonlinearity='relu')
elif isinstance(m, (_BatchNorm, nn.GroupNorm)): elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1) constant_init(m, 1)
# if self.zero_init_residual:
# for m in self.modules():
# if isinstance(m, Bottleneck):
# constant_init(m.norm3, 0)
# elif isinstance(m, BasicBlock):
# constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
def features(self, input): def features(self, input):
conv1_7x7_s2_out = self.conv1_7x7_s2(input) conv1_7x7_s2_out = self.conv1_7x7_s2(input)
conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out) conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out)

View File

@ -62,7 +62,6 @@ def _fuse_convkx_and_bn_(convkx, bn):
bn.bias[:] = bn.bias - the_bias_shift bn.bias[:] = bn.bias - the_bias_shift
bn.running_var[:] = 1.0 - bn.eps bn.running_var[:] = 1.0 - bn.eps
bn.running_mean[:] = 0.0 bn.running_mean[:] = 0.0
# convkx.register_parameter('bias', bn.bias)
convkx.bias = nn.Parameter(bn.bias) convkx.bias = nn.Parameter(bn.bias)
@ -76,7 +75,6 @@ def remove_bn_in_superblock(super_block):
for block in the_seq_list: for block in the_seq_list:
if isinstance(block, nn.BatchNorm2d): if isinstance(block, nn.BatchNorm2d):
_fuse_convkx_and_bn_(last_block, block) _fuse_convkx_and_bn_(last_block, block)
# print('--debug fuse shortcut bn')
else: else:
new_seq_list.append(block) new_seq_list.append(block)
last_block = block last_block = block
@ -92,7 +90,6 @@ def remove_bn_in_superblock(super_block):
for block in the_seq_list: for block in the_seq_list:
if isinstance(block, nn.BatchNorm2d): if isinstance(block, nn.BatchNorm2d):
_fuse_convkx_and_bn_(last_block, block) _fuse_convkx_and_bn_(last_block, block)
# print('--debug fuse conv bn')
else: else:
new_seq_list.append(block) new_seq_list.append(block)
last_block = block last_block = block
@ -1647,7 +1644,6 @@ class PlainNet(nn.Module):
block_list.pop(-1) block_list.pop(-1)
else: else:
self.adptive_avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.adptive_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
# AdaptiveAvgPool(out_channels=block_list[-1].out_channels, output_size=1)
self.block_list = block_list self.block_list = block_list
if not no_create: if not no_create:
@ -1663,30 +1659,14 @@ class PlainNet(nn.Module):
self.plainnet_struct = str(self) + str(self.adptive_avg_pool) self.plainnet_struct = str(self) + str(self.adptive_avg_pool)
self.zero_init_residual = False self.zero_init_residual = False
self.pretrained = model_urls[self.__class__.__name__ +
plainnet_struct_idx]
def init_weights(self, pretrained=None): def init_weights(self):
if isinstance(pretrained, str) or isinstance(pretrained, dict):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', nonlinearity='relu') kaiming_init(m, mode='fan_in', nonlinearity='relu')
elif isinstance(m, (_BatchNorm, nn.GroupNorm)): elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1) constant_init(m, 1)
# if self.zero_init_residual:
# for m in self.modules():
# if isinstance(m, Bottleneck):
# constant_init(m.norm3, 0)
# elif isinstance(m, BasicBlock):
# constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
return
def forward(self, x): def forward(self, x):
output = x output = x
for the_block in self.block_list: for the_block in self.block_list:
@ -1699,13 +1679,3 @@ class PlainNet(nn.Module):
output = self.fc_linear(output) output = self.fc_linear(output)
return [output] return [output]
if __name__ == '__main__':
from torch import nn
input_image_size = 192
# model = genet_normal(pretrained=True, root='data/models/GENet_params/')
# print(type(model))
# print(model.block_list)

View File

@ -184,7 +184,6 @@ class Bottleneck(nn.Module):
if self.with_cp and x.requires_grad: if self.with_cp and x.requires_grad:
raise NotImplementedError raise NotImplementedError
# out = cp.checkpoint(_inner_forward, x)
else: else:
out = _inner_forward(x) out = _inner_forward(x)
@ -718,17 +717,13 @@ class HRNet(nn.Module):
return nn.Sequential(*hr_modules), in_channels return nn.Sequential(*hr_modules), in_channels
def init_weights(self, pretrained=None): def init_weights(self):
"""Initialize the weights in backbone. """Initialize the weights in backbone.
Args: Args:
pretrained (str, optional): Path to pre-trained weights. pretrained (str, optional): Path to pre-trained weights.
Defaults to None. Defaults to None.
""" """
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001) normal_init(m, std=0.001)
@ -741,8 +736,6 @@ class HRNet(nn.Module):
constant_init(m.norm3, 0) constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock): elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0) constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x): def forward(self, x):
"""Forward function.""" """Forward function."""

View File

@ -60,13 +60,7 @@ class Inception3(nn.Module):
if num_classes > 0: if num_classes > 0:
self.fc = nn.Linear(2048, num_classes) self.fc = nn.Linear(2048, num_classes)
self.pretrained = model_urls[self.__class__.__name__] def init_weights(self):
def init_weights(self, pretrained=None):
if pretrained is not None:
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
else:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
for m in self.modules(): for m in self.modules():

View File

@ -958,24 +958,18 @@ class LiteHRNet(nn.Module):
return nn.Sequential(*modules), in_channels return nn.Sequential(*modules), in_channels
def init_weights(self, pretrained=None): def init_weights(self):
"""Initialize the weights in backbone. """Initialize the weights in backbone.
Args: Args:
pretrained (str, optional): Path to pre-trained weights. pretrained (str, optional): Path to pre-trained weights.
Defaults to None. Defaults to None.
""" """
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001) normal_init(m, std=0.001)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)): elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1) constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x): def forward(self, x):
"""Forward function.""" """Forward function."""

View File

@ -73,9 +73,8 @@ class MaskedAutoencoderViT(nn.Module):
]) ])
self.norm = norm_layer(embed_dim) self.norm = norm_layer(embed_dim)
self.apply(self._init_weights) def init_weights(self):
for m in self.modules():
def _init_weights(self, m):
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT: # we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.xavier_uniform_(m.weight)

View File

@ -143,8 +143,6 @@ class MNASNet(torch.nn.Module):
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.Dropout(p=dropout, inplace=True), nn.Dropout(p=dropout, inplace=True),
nn.Linear(1280, num_classes)) nn.Linear(1280, num_classes))
self.init_weights()
self.pretrained = model_urls[self.__class__.__name__ + str(alpha)]
def forward(self, x): def forward(self, x):
x = self.layers(x) x = self.layers(x)
@ -155,11 +153,7 @@ class MNASNet(torch.nn.Module):
else: else:
return [x] return [x]
def init_weights(self, pretrained=None): def init_weights(self):
if isinstance(pretrained, str) or isinstance(pretrained, dict):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_( nn.init.kaiming_normal_(
@ -172,5 +166,3 @@ class MNASNet(torch.nn.Module):
elif isinstance(m, nn.Linear): elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0.01) nn.init.normal_(m.weight, 0.01)
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
else:
raise TypeError('pretrained must be a str or None')

View File

@ -156,11 +156,7 @@ class MobileNetV2(nn.Module):
self.pretrained = model_urls[self.__class__.__name__ + '_' + self.pretrained = model_urls[self.__class__.__name__ + '_' +
str(width_multi)] str(width_multi)]
def init_weights(self, pretrained=None): def init_weights(self):
if isinstance(pretrained, str) or isinstance(pretrained, dict):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out') nn.init.kaiming_normal_(m.weight, mode='fan_out')
@ -172,8 +168,6 @@ class MobileNetV2(nn.Module):
elif isinstance(m, nn.Linear): elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01) nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x): def forward(self, x):
x = self.features(x) x = self.features(x)

View File

@ -1,12 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import importlib
from distutils.version import LooseVersion from distutils.version import LooseVersion
import timm import timm
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.models.helpers import load_pretrained
from timm.models.hub import download_cached_file
from easycv.utils.checkpoint import load_checkpoint from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger from easycv.utils.logger import get_root_logger, print_log
from ..modelzoo import timm_models as model_urls from ..modelzoo import timm_models as model_urls
from ..registry import BACKBONES from ..registry import BACKBONES
from .shuffle_transformer import (shuffletrans_base_p4_w7_224, from .shuffle_transformer import (shuffletrans_base_p4_w7_224,
@ -66,8 +69,6 @@ class PytorchImageModelWrapper(nn.Module):
def __init__(self, def __init__(self,
model_name='resnet50', model_name='resnet50',
pretrained=False,
checkpoint_path=None,
scriptable=None, scriptable=None,
exportable=None, exportable=None,
no_jit=None, no_jit=None,
@ -76,15 +77,16 @@ class PytorchImageModelWrapper(nn.Module):
Inits PytorchImageModelWrapper by timm.create_models Inits PytorchImageModelWrapper by timm.create_models
Args: Args:
model_name (str): name of model to instantiate model_name (str): name of model to instantiate
pretrained (bool): load pretrained ImageNet-1k weights if true
checkpoint_path (str): path of checkpoint to load after model is initialized
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
""" """
super(PytorchImageModelWrapper, self).__init__() super(PytorchImageModelWrapper, self).__init__()
self.model_name = model_name
timm_model_names = timm.list_models(pretrained=False) timm_model_names = timm.list_models(pretrained=False)
self.timm_model_names = timm_model_names
assert model_name in timm_model_names or model_name in _MODEL_MAP, \ assert model_name in timm_model_names or model_name in _MODEL_MAP, \
f'{model_name} is not in model_list of timm/fair, please check the model_name!' f'{model_name} is not in model_list of timm/fair, please check the model_name!'
@ -94,52 +96,52 @@ class PytorchImageModelWrapper(nn.Module):
# create model by timm # create model by timm
if model_name in timm_model_names: if model_name in timm_model_names:
try: self.model = timm.create_model(model_name, False, '', scriptable,
if pretrained and (model_name in model_urls): exportable, no_jit, **kwargs)
self.model = timm.create_model(model_name, False, '', elif model_name in _MODEL_MAP:
scriptable, exportable,
no_jit, **kwargs)
self.init_weights(model_urls[model_name])
print('Info: Load model from %s' % model_urls[model_name])
if checkpoint_path is not None:
self.init_weights(checkpoint_path)
else:
# load from timm
if pretrained and model_name.startswith('swin_') and (
LooseVersion(
torch.__version__) <= LooseVersion('1.6.0')):
print(
'Warning: Pretrained SwinTransformer from timm may be zipfile extract'
' error while torch<=1.6.0')
self.model = timm.create_model(model_name, pretrained,
checkpoint_path, scriptable,
exportable, no_jit,
**kwargs)
# need fix: delete this except after pytorch 1.7 update in all production
# (dlc, dsw, studio, ev_predict_py3)
except Exception:
print(
f'Error: Fail to create {model_name} with (pretrained={pretrained}, checkpoint_path={checkpoint_path} ...)'
)
print(
f'Try to create {model_name} with pretrained=False, checkpoint_path=None and default params'
)
self.model = timm.create_model(model_name, False, '', None,
None, None, **kwargs)
# facebook model wrapper
if model_name in _MODEL_MAP:
self.model = _MODEL_MAP[model_name](**kwargs) self.model = _MODEL_MAP[model_name](**kwargs)
def init_weights(self, pretrained=None):
"""
Args:
if pretrained == True, load model from default path;
if pretrained == False or None, load from init weights.
if model_name in timm_model_names, load model from timm default path;
if model_name in _MODEL_MAP, load model from easycv default path
"""
logger = get_root_logger()
if pretrained: if pretrained:
if model_name in model_urls.keys(): if self.model_name in self.timm_model_names:
pretrained_path = model_urls[self.model_name]
print_log(
'load model from default path: {}'.format(pretrained_path),
logger)
if pretrained_path.endswith('.npz'):
pretrained_loc = download_cached_file(
pretrained_path, check_hash=False, progress=False)
return self.model.load_pretrained(pretrained_loc)
else:
backbone_module = importlib.import_module(
self.model.__module__)
return load_pretrained(
self.model,
default_cfg={'url': pretrained_path},
filter_fn=backbone_module.checkpoint_filter_fn
if hasattr(backbone_module, 'checkpoint_filter_fn')
else None)
elif self.model_name in _MODEL_MAP:
if self.model_name in model_urls.keys():
pretrained_path = model_urls[self.model_name]
print_log(
'load model from default path: {}'.format(
pretrained_path), logger)
try_max = 3 try_max = 3
try_idx = 0 try_idx = 0
while try_idx < try_max: while try_idx < try_max:
try: try:
state_dict = torch.hub.load_state_dict_from_url( state_dict = torch.hub.load_state_dict_from_url(
url=model_urls[model_name], url=pretrained_path,
map_location='cpu', map_location='cpu',
) )
try_idx += try_max try_idx += try_max
@ -147,31 +149,22 @@ class PytorchImageModelWrapper(nn.Module):
try_idx += 1 try_idx += 1
state_dict = {} state_dict = {}
if try_idx == try_max: if try_idx == try_max:
print( print_log(
'load from url failed ! oh my DLC & OSS, you boys really good! ', f'load from url failed ! oh my DLC & OSS, you boys really good! {model_urls[self.model_name]}',
model_urls[model_name]) logger)
# for some model strict = False still failed when model doesn't exactly match if 'model' in state_dict:
try: state_dict = state_dict['model']
self.model.load_state_dict(state_dict, strict=False) self.model.load_state_dict(state_dict, strict=False)
except Exception:
print('load for model_name not all right')
else: else:
print('%s not in evtorch modelzoo!' % model_name) raise ValueError('{} not in evtorch modelzoo!'.format(
self.model_name))
def init_weights(self, pretrained=None):
# pretrained is the path of pretrained model offered by easycv
if pretrained is not None:
logger = get_root_logger()
load_checkpoint(
self.model,
pretrained,
map_location=torch.device('cpu'),
strict=False,
logger=logger)
else: else:
# init by timm raise ValueError(
pass 'Error: Fail to create {} with (pretrained={}...)'.format(
self.model_name, pretrained))
else:
self.model.init_weights()
def forward(self, x): def forward(self, x):

View File

@ -470,17 +470,12 @@ class ResNeSt(nn.Module):
self.avgpool = GlobalAvgPool2d() self.avgpool = GlobalAvgPool2d()
self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None
self.norm_layer = norm_layer self.norm_layer = norm_layer
# self.fc = nn.Linear(512 * block.expansion, num_classes)
if num_classes > 0: if num_classes > 0:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes) self.fc = nn.Linear(512 * block.expansion, num_classes)
def init_weights(self, pretrained=None): def init_weights(self):
if isinstance(pretrained, str) or isinstance(pretrained, dict):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
@ -488,8 +483,6 @@ class ResNeSt(nn.Module):
elif isinstance(m, self.norm_layer): elif isinstance(m, self.norm_layer):
m.weight.data.fill_(1) m.weight.data.fill_(1)
m.bias.data.zero_() m.bias.data.zero_()
else:
raise TypeError('pretrained must be a str or None')
def _make_layer(self, def _make_layer(self,
block, block,
@ -613,7 +606,6 @@ class ResNeSt(nn.Module):
if hasattr(self, 'fc'): if hasattr(self, 'fc'):
x = self.avgpool(x) x = self.avgpool(x)
# x = x.view(x.size(0), -1)
x = torch.flatten(x, 1) x = torch.flatten(x, 1)
if self.drop: if self.drop:
x = self.drop(x) x = self.drop(x)

Some files were not shown because too many files have changed in this diff Show More