mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
Expand cls model zoo (#55)
1. expand cls model zoo 2. uniform load_pretrained
This commit is contained in:
parent
d76d5d79fc
commit
3ee118f065
@ -2,7 +2,6 @@ _base_ = '../../base.py'
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
@ -2,7 +2,6 @@ _base_ = '../../base.py'
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='PytorchImageModelWrapper',
|
||||
model_name='swin_tiny_patch4_window7_224',
|
||||
|
@ -7,7 +7,6 @@ log_config = dict(
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
pretrained=None,
|
||||
backbone=dict(type='HRNet', arch='w18', multi_scale_output=True),
|
||||
neck=dict(type='HRFuseScales', in_channels=(18, 36, 72, 144)),
|
||||
head=dict(
|
||||
|
@ -1,3 +1,5 @@
|
||||
_base_ = './hrnetw18_b32x8_100e_jpg.py'
|
||||
# model settings
|
||||
model = dict(backbone=dict(type='HRNet', arch='w30', multi_scale_output=True))
|
||||
model = dict(
|
||||
backbone=dict(type='HRNet', arch='w30', multi_scale_output=True),
|
||||
neck=dict(type='HRFuseScales', in_channels=(30, 60, 120, 240)))
|
||||
|
@ -1,3 +1,5 @@
|
||||
_base_ = './hrnetw18_b32x8_100e_jpg.py'
|
||||
# model settings
|
||||
model = dict(backbone=dict(type='HRNet', arch='w32', multi_scale_output=True))
|
||||
model = dict(
|
||||
backbone=dict(type='HRNet', arch='w32', multi_scale_output=True),
|
||||
neck=dict(type='HRFuseScales', in_channels=(32, 64, 128, 256)))
|
||||
|
@ -1,3 +1,5 @@
|
||||
_base_ = './hrnetw18_b32x8_100e_jpg.py'
|
||||
# model settings
|
||||
model = dict(backbone=dict(type='HRNet', arch='w40', multi_scale_output=True))
|
||||
model = dict(
|
||||
backbone=dict(type='HRNet', arch='w40', multi_scale_output=True),
|
||||
neck=dict(type='HRFuseScales', in_channels=(40, 80, 160, 320)))
|
||||
|
@ -1,3 +1,5 @@
|
||||
_base_ = './hrnetw18_b32x8_100e_jpg.py'
|
||||
# model settings
|
||||
model = dict(backbone=dict(type='HRNet', arch='w44', multi_scale_output=True))
|
||||
model = dict(
|
||||
backbone=dict(type='HRNet', arch='w44', multi_scale_output=True),
|
||||
neck=dict(type='HRFuseScales', in_channels=(44, 88, 176, 352)))
|
||||
|
@ -1,4 +1,6 @@
|
||||
_base_ = './hrnetw18_b32x8_100e_jpg.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(type='HRNet', arch='w48', multi_scale_output=True))
|
||||
model = dict(
|
||||
backbone=dict(type='HRNet', arch='w48', multi_scale_output=True),
|
||||
neck=dict(type='HRFuseScales', in_channels=(48, 96, 192, 384)))
|
||||
|
@ -1,3 +1,5 @@
|
||||
_base_ = './hrnetw18_b32x8_100e_jpg.py'
|
||||
# model settings
|
||||
model = dict(backbone=dict(type='HRNet', arch='w64', multi_scale_output=True))
|
||||
model = dict(
|
||||
backbone=dict(type='HRNet', arch='w64', multi_scale_output=True),
|
||||
neck=dict(type='HRFuseScales', in_channels=(64, 128, 256, 512)))
|
||||
|
@ -2,18 +2,8 @@ _base_ = './resnet50_b32x8_100e_jpg.py'
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=101,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN')),
|
||||
head=dict(
|
||||
type='ClsHead',
|
||||
with_avg_pool=True,
|
||||
in_channels=2048,
|
||||
loss_config=dict(
|
||||
type='CrossEntropyLossWithLabelSmooth',
|
||||
label_smooth=0,
|
||||
),
|
||||
num_classes=1000))
|
||||
norm_cfg=dict(type='BN')))
|
||||
|
@ -1,19 +1,8 @@
|
||||
_base_ = './resnet50_b32x8_100e_jpg.py'
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN')),
|
||||
head=dict(
|
||||
type='ClsHead',
|
||||
with_avg_pool=True,
|
||||
in_channels=2048,
|
||||
loss_config=dict(
|
||||
type='CrossEntropyLossWithLabelSmooth',
|
||||
label_smooth=0,
|
||||
),
|
||||
num_classes=1000))
|
||||
norm_cfg=dict(type='BN')))
|
||||
|
@ -7,7 +7,6 @@ log_config = dict(
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
@ -7,7 +7,6 @@ log_config = dict(
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='ResNeXt',
|
||||
depth=50,
|
||||
|
@ -16,7 +16,6 @@ model = dict(
|
||||
mode='batch',
|
||||
label_smoothing=0.1,
|
||||
num_classes=1000),
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='PytorchImageModelWrapper',
|
||||
model_name='swin_tiny_patch4_window7_224',
|
||||
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='cait_s24_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='cait_xxs24_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='cait_xxs36_224'))
|
4
configs/classification/imagenet/timm/coat/coat_mini.py
Normal file
4
configs/classification/imagenet/timm/coat/coat_mini.py
Normal file
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='coat_mini'))
|
4
configs/classification/imagenet/timm/coat/coat_tiny.py
Normal file
4
configs/classification/imagenet/timm/coat/coat_tiny.py
Normal file
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='coat_tiny'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='convit_base'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='convit_small'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='convit_tiny'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='convmixer_1024_20_ks9_p14'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='convmixer_1536_20'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='convmixer_768_32'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='convnext_base'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='convnext_large'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='convnext_small'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='convnext_tiny'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='crossvit_base_240'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='crossvit_small_240'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='crossvit_tiny_240'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='deit_base_distilled_patch16_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='deit_base_patch16_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='gmixer_24_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='gmlp_s16_224'))
|
4
configs/classification/imagenet/timm/levit/levit_128.py
Normal file
4
configs/classification/imagenet/timm/levit/levit_128.py
Normal file
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='levit_128'))
|
4
configs/classification/imagenet/timm/levit/levit_192.py
Normal file
4
configs/classification/imagenet/timm/levit/levit_192.py
Normal file
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='levit_192'))
|
4
configs/classification/imagenet/timm/levit/levit_256.py
Normal file
4
configs/classification/imagenet/timm/levit/levit_256.py
Normal file
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='levit_256'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='mixer_b16_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='mixer_l16_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='mobilevit_s'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='mobilevit_xs'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='mobilevit_xxs'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='jx_nest_base'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='jx_nest_small'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='jx_nest_tiny'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='pit_b_distilled_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='pit_s_distilled_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='poolformer_m36'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='poolformer_m48'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='poolformer_s12'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='poolformer_s24'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='poolformer_s36'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='resmlp_12_distilled_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='resmlp_24_distilled_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='resmlp_36_distilled_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='resmlp_big_24_distilled_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='sequencer2d_l'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='sequencer2d_m'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='sequencer2d_s'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='shuffletrans_base_p4_w7_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='shuffletrans_small_p4_w7_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='shuffletrans_tiny_p4_w7_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='dynamic_swin_small_p4_w7_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='dynamic_swin_tiny_p4_w7_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='swin_base_patch4_window7_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='swin_large_patch4_window7_224'))
|
@ -1,45 +1,48 @@
|
||||
_base_ = 'configs/base.py'
|
||||
|
||||
log_config = dict(
|
||||
interval=100,
|
||||
interval=10,
|
||||
hooks=[dict(type='TextLoggerHook'),
|
||||
dict(type='TensorboardLoggerHook')])
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
pretrained=None,
|
||||
train_preprocess=['mixUp'],
|
||||
pretrained=True,
|
||||
mixup_cfg=dict(
|
||||
mixup_alpha=0.2,
|
||||
prob=1.0,
|
||||
mode='batch',
|
||||
label_smoothing=0.1,
|
||||
num_classes=1000),
|
||||
backbone=dict(
|
||||
type='PytorchImageModelWrapper',
|
||||
# model_name='pit_xs_distilled_224',
|
||||
# model_name='swin_small_patch4_window7_224',
|
||||
# model_name='swin_tiny_patch4_window7_224',
|
||||
# model_name='swin_base_patch4_window7_224_in22k',
|
||||
model_name='vit_deit_small_distilled_patch16_224',
|
||||
# model_name = 'vit_deit_small_distilled_patch16_224',
|
||||
# model_name = 'resnet50',
|
||||
model_name='vit_base_patch16_224',
|
||||
num_classes=1000,
|
||||
pretrained=True,
|
||||
),
|
||||
head=dict(
|
||||
# type='ClsHead', with_avg_pool=True, in_channels=384,
|
||||
# type='ClsHead', with_avg_pool=True, in_channels=768,
|
||||
# type='ClsHead', with_avg_pool=True, in_channels=1024,
|
||||
# num_classes=0)
|
||||
type='ClsHead',
|
||||
loss_config={
|
||||
'type': 'SoftTargetCrossEntropy',
|
||||
},
|
||||
with_fc=False))
|
||||
|
||||
data_train_list = 'data/imagenet_raw/meta/train_labeled.txt'
|
||||
data_train_root = 'data/imagenet_raw/train/'
|
||||
data_test_list = 'data/imagenet_raw/meta/val_labeled.txt'
|
||||
data_test_root = 'data/imagenet_raw/validation/'
|
||||
data_all_list = 'data/imagenet_raw/meta/all_labeled.txt'
|
||||
data_root = 'data/imagenet_raw/'
|
||||
load_from = None
|
||||
|
||||
data_train_list = './imagenet_raw/meta/train_labeled.txt'
|
||||
data_train_root = './imagenet_raw/train/'
|
||||
data_test_list = './imagenet_raw/meta/val_labeled.txt'
|
||||
data_test_root = './imagenet_raw/val/'
|
||||
data_all_list = './imagenet_raw/meta/all_labeled.txt'
|
||||
data_root = './imagenet_raw/'
|
||||
|
||||
dataset_type = 'ClsDataset'
|
||||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_pipeline = [
|
||||
dict(type='RandomResizedCrop', size=224),
|
||||
dict(type='RandomHorizontalFlip'),
|
||||
dict(type='MMAutoAugment'),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Collect', keys=['img', 'gt_labels'])
|
||||
@ -53,7 +56,7 @@ test_pipeline = [
|
||||
]
|
||||
|
||||
data = dict(
|
||||
imgs_per_gpu=32, # total 256
|
||||
imgs_per_gpu=64, # total 256
|
||||
workers_per_gpu=8,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
@ -84,11 +87,25 @@ eval_pipelines = [
|
||||
custom_hooks = []
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
|
||||
optimizer = dict(
|
||||
type='AdamW',
|
||||
lr=0.003,
|
||||
weight_decay=0.3,
|
||||
paramwise_options={
|
||||
'cls_token': dict(weight_decay=0.),
|
||||
'pos_embed': dict(weight_decay=0.),
|
||||
})
|
||||
optimizer_config = dict(grad_clip=dict(max_norm=1.0), update_interval=8)
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[30, 60, 90])
|
||||
checkpoint_config = dict(interval=10)
|
||||
lr_config = dict(
|
||||
policy='CosineAnnealing',
|
||||
min_lr=0,
|
||||
warmup='linear',
|
||||
warmup_iters=10000,
|
||||
warmup_ratio=1e-4,
|
||||
)
|
||||
checkpoint_config = dict(interval=30)
|
||||
|
||||
# runtime settings
|
||||
total_epochs = 90
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='tnt_s_patch16_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='twins_svt_base'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='twins_svt_large'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='twins_svt_small'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='vit_base_patch16_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='vit_large_patch16_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='xcit_large_24_p8_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='xcit_large_24_p8_224_dist'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='xcit_medium_24_p8_224'))
|
@ -0,0 +1,4 @@
|
||||
_base_ = '../timm_config.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(model_name='xcit_medium_24_p8_224_dist'))
|
@ -15,7 +15,6 @@ model = dict(
|
||||
mode='batch',
|
||||
label_smoothing=0.1,
|
||||
num_classes=1000),
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='PytorchImageModelWrapper',
|
||||
model_name='vit_base_patch16_224',
|
||||
|
@ -12,7 +12,6 @@ export = dict(export_neck=True)
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
@ -19,7 +19,6 @@ export = dict(export_neck=True)
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='PytorchImageModelWrapper',
|
||||
# model_name='pit_xs_distilled_224',
|
||||
|
@ -19,7 +19,6 @@ export = dict(export_neck=True)
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
@ -44,7 +44,6 @@ work_dir = 'oss://path/to/work_dirs/classification/'
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
# pretrained=None,
|
||||
backbone=dict(
|
||||
type='PytorchImageModelWrapper',
|
||||
|
||||
|
@ -49,7 +49,6 @@ work_dir = 'oss://path/to/work_dirs/classification/'
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
# pretrained=None,
|
||||
backbone=dict(
|
||||
type='PytorchImageModelWrapper',
|
||||
|
||||
|
@ -7,7 +7,6 @@ log_config = dict(
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
@ -7,7 +7,6 @@ log_config = dict(
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
@ -13,9 +13,9 @@ load_from = None
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Classification',
|
||||
train_preprocess=['randomErasing'],
|
||||
pretrained=
|
||||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/swin_base_patch4_window7_224_22k_statedict.pth',
|
||||
train_preprocess=['randomErasing'],
|
||||
backbone=dict(
|
||||
type='PytorchImageModelWrapper',
|
||||
model_name='swin_base_patch4_window7_224_in22k'
|
||||
|
@ -2,22 +2,77 @@
|
||||
|
||||
## Benchmarks
|
||||
|
||||
| Algorithm | Config | Top-1 (%) | Top-5 (%) | Download |
|
||||
| --------- | ------------------------------------------------------------ | --------- | --------- | ------------------------------------------------------------ |
|
||||
| resnet50(raw) | [resnet50(raw)](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py) | 76.454 | 93.084 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet50/epoch_100.pth) |
|
||||
| resnet50(tfrecord) | [resnet50(tfrecord)](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_rn50_tfrecord.py) | 76.266 | 92.972 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet50/epoch_100.pth) |
|
||||
| resnet101 | [resnet101](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet101_jpg.py) | 78.152 | 93.922 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet101/epoch_100.pth) |
|
||||
| resnet152 | [resnet152](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet152_jpg.py) | 78.544 | 94.206 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet152/epoch_100.pth) |
|
||||
| resnext50-32x4d | [resnext50-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext50-32x4d_jpg.py) | 77.604 | 93.856 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnet50/epoch_100.pth) |
|
||||
| resnext101-32x4d | [resnext101-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext101-32x4d_jpg.py) | 78.568 | 94.344 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext50-32x4d/epoch_100.pth) |
|
||||
| resnext101-32x8d | [resnext101-32x8d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext101-32x8d_jpg.py) | 79.468 | 94.434 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext101-32x8d/epoch_100.pth) |
|
||||
| resnext152-32x4d | [resnext152-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext152-32x4d_jpg.py) | 78.994 | 94.462 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext152-32x4d/epoch_100.pth) |
|
||||
| hrnetw18 | [hrnetw18](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw18_jpg.py) | 76.258 | 92.976 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw18/epoch_100.pth) |
|
||||
| hrnetw30 | [hrnetw30](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw30_jpg.py) | 77.66 | 93.862 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw30/epoch_100.pth) |
|
||||
| hrnetw32 | [hrnetw32](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw32_jpg.py) | 77.994 | 93.976 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw32/epoch_100.pth) |
|
||||
| hrnetw40 | [hrnetw40](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw40_jpg.py) | 78.142 | 93.956 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw40/epoch_100.pth) |
|
||||
| hrnetw44 | [hrnetw44](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw44_jpg.py) | 79.266 | 94.476 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw44/epoch_100.pth) |
|
||||
| hrnetw48 | [hrnetw48](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw48_jpg.py) | 79.636 | 94.802 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw48/epoch_100.pth) |
|
||||
| hrnetw64 | [hrnetw64](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw64_jpg.py) | 79.884 | 95.04 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/hrnetw64/epoch_100.pth) |
|
||||
| vit-base-patch16 | [vit-base-patch16](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/vit/imagenet_vit_base_patch16_224_jpg.py) | 76.082 | 92.026 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/vit/vit-base-patch16/epoch_300.pth) |
|
||||
| swin-tiny-patch4-window7 | [swin-tiny-patch4-window7](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/swint/imagenet_swin_tiny_patch4_window7_224_jpg.py) | 80.528 | 94.822 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/swint/swin-tiny-patch4-window7/epoch_300.pth) |
|
||||
| Algorithm | Config | Top-1 (%) | Top-5 (%) | gpu memory (MB) | inference time (ms/img) | Download |
|
||||
| --------- | ------------------------------------------------------------ | --------- | --------- | --------- | --------- | ------------------------------------------------------------ |
|
||||
| resnet50(raw) | [resnet50(raw)](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py) | 76.454 | 93.084 | 2412 | 8.59 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet50/epoch_100.pth) |
|
||||
| resnet50(tfrecord) | [resnet50(tfrecord)](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_rn50_tfrecord.py) | 76.266 | 92.972 | 2412 | 8.59 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet50/epoch_100.pth) |
|
||||
| resnet101 | [resnet101](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet101_jpg.py) | 78.152 | 93.922 | 2484 | 16.77 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet101/epoch_100.pth) |
|
||||
| resnet152 | [resnet152](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet152_jpg.py) | 78.544 | 94.206 | 2544 | 24.69 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet152/epoch_100.pth) |
|
||||
| resnext50-32x4d | [resnext50-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext50-32x4d_jpg.py) | 77.604 | 93.856 | 4718 | 12.88 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnet50/epoch_100.pth) |
|
||||
| resnext101-32x4d | [resnext101-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext101-32x4d_jpg.py) | 78.568 | 94.344 | 4792 | 26.84 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext50-32x4d/epoch_100.pth) |
|
||||
| resnext101-32x8d | [resnext101-32x8d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext101-32x8d_jpg.py) | 79.468 | 94.434 | 9582 | 27.52 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext101-32x8d/epoch_100.pth) |
|
||||
| resnext152-32x4d | [resnext152-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext152-32x4d_jpg.py) | 78.994 | 94.462 | 4852 | 41.08 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext152-32x4d/epoch_100.pth) |
|
||||
| hrnetw18 | [hrnetw18](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw18_jpg.py) | 76.258 | 92.976 | 4701 | 54.55 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw18/epoch_100.pth) |
|
||||
| hrnetw30 | [hrnetw30](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw30_jpg.py) | 77.66 | 93.862 | 4766 | 54. 30 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw30/epoch_100.pth) |
|
||||
| hrnetw32 | [hrnetw32](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw32_jpg.py) | 77.994 | 93.976 | 4780 | 53.48 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw32/epoch_100.pth) |
|
||||
| hrnetw40 | [hrnetw40](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw40_jpg.py) | 78.142 | 93.956 | 4843 | 54.31 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw40/epoch_100.pth) |
|
||||
| hrnetw44 | [hrnetw44](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw44_jpg.py) | 79.266 | 94.476 | 4884 | 54.83 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw44/epoch_100.pth) |
|
||||
| hrnetw48 | [hrnetw48](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw48_jpg.py) | 79.636 | 94.802 | 4916 | 54.14 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw48/epoch_100.pth) |
|
||||
| hrnetw64 | [hrnetw64](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw64_jpg.py) | 79.884 | 95.04 | 5120 | 54.74 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/hrnetw64/epoch_100.pth) |
|
||||
| vit-base-patch16 | [vit-base-patch16](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/vit/imagenet_vit_base_patch16_224_jpg.py) | 76.082 | 92.026 | 346 | 8.03 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/vit/vit-base-patch16/epoch_300.pth) |
|
||||
| swin-tiny-patch4-window7 | [swin-tiny-patch4-window7](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/swint/imagenet_swin_tiny_patch4_window7_224_jpg.py) | 80.528 | 94.822 | 132 | 12.94 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/swint/swin-tiny-patch4-window7/epoch_300.pth) |
|
||||
|
||||
(ps: 通过EasyCV训练得到模型结果,推理的输入尺寸默认为224,机器默认为V100 16G,其中gpu memory记录的是gpu peak memory)
|
||||
|
||||
| Algorithm | Config | Top-1 (%) | Top-5 (%) | gpu memory (MB) | inference time (ms/img) | Download |
|
||||
| --------- | ------------------------------------------------------------ | --------- | --------- | --------- | --------- | ------------------------------------------------------------ |
|
||||
| vit_base_patch16_224 | [vit_base_patch16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/vit/vit_base_patch16_224.py) | 78.096 | 94.324 | 346 | 8.03 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/vit/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz) |
|
||||
| vit_large_patch16_224 | [vit_large_patch16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/vit/vit_large_patch16_224.py) | 84.404 | 97.276 | 1171 | 16.30 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/vit/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz) |
|
||||
| deit_base_patch16_224 | [deit_base_patch16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/deit/deit_base_patch16_224.py) | 81.756 | 95.6 | 346 | 7.98 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_patch16_224-b5f2ef4d.pth) |
|
||||
| deit_base_distilled_patch16_224 | [deit_base_distilled_patch16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/deit/deit_base_distilled_patch16_224.py) | 83.232 | 96.476 | 349 | 8.07 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_distilled_patch16_224-df68dfff.pth) |
|
||||
| xcit_medium_24_p8_224 | [xcit_medium_24_p8_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/xcit/xcit_medium_24_p8_224.py) | 83.348 | 96.21 | 884 | 31.77 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_medium_24_p8_224.pth) |
|
||||
| xcit_medium_24_p8_224_dist | [xcit_medium_24_p8_224_dist](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/xcit/xcit_medium_24_p8_224_dist.py) | 84.876 | 97.164 | 884 | 32.08 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_medium_24_p8_224_dist.pth) |
|
||||
| xcit_large_24_p8_224 | [xcit_large_24_p8_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/xcit/xcit_large_24_p8_224.py) | 83.986 | 96.47 | 1962 | 37.44 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_large_24_p8_224.pth) |
|
||||
| xcit_large_24_p8_224_dist | [xcit_large_24_p8_224_dist](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/xcit/xcit_large_24_p8_224_dist.py) | 85.022 | 97.29 | 1962 | 37.44 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_large_24_p8_224_dist.pth) |
|
||||
| tnt_s_patch16_224 | [tnt_s_patch16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/tnt/tnt_s_patch16_224.py) | 76.934 | 93.388 | 100 | 18.92 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/tnt/tnt_s_patch16_224.pth.tar) |
|
||||
| convit_tiny | [convit_tiny](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convit/convit_tiny.py) | 72.954 | 91.68 | 31 | 10.79 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convit/convit_tiny.pth) |
|
||||
| convit_small | [convit_small](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convit/convit_small.py) | 81.342 | 95.784 | 122 | 11.23 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convit/convit_small.pth) |
|
||||
| convit_base | [convit_base](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convit/convit_base.py) | 82.27 | 95.916 | 358 | 11.26 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convit/convit_base.pth) |
|
||||
| cait_xxs24_224 | [cait_xxs24_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/cait/cait_xxs24_224.py) | 78.45 | 94.154 | 50 | 22.62 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/cait/XXS24_224.pth) |
|
||||
| cait_xxs36_224 | [cait_xxs36_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/cait/cait_xxs36_224.py) | 79.788 | 94.87 | 71 | 33.25 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/cait/XXS36_224.pth) |
|
||||
| cait_s24_224 | [cait_s24_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/cait/cait_s24_224.py) | 83.302 | 96.568 | 190 | 23.74 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/cait/cait_s24_224.pth) |
|
||||
| levit_128 | [levit_128](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/levit/levit_128.py) | 78.468 | 93.874 | 76 | 15.33 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/levit/LeViT-128-b88c2750.pth) |
|
||||
| levit_192 | [levit_192](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/levit/levit_192.py) | 79.72 | 94.664 | 128 | 15.17 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/levit/LeViT-192-92712e41.pth) |
|
||||
| levit_256 | [levit_256](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/levit/levit_256.py) | 81.432 | 95.38 | 222 | 15.27 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/levit/LeViT-256-13b5763e.pth) |
|
||||
| convnext_tiny | [convnext_tiny](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convnext/convnext_tiny.py) | 81.878 | 95.836 | 128 | 7.17 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_tiny_1k_224_ema.pth) |
|
||||
| convnext_small | [convnext_small](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convnext/convnext_small.py) | 82.836 | 96.458 | 213 | 12.89 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_small_1k_224_ema.pth) |
|
||||
| convnext_base | [convnext_base](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convnext/convnext_base.py) | 83.73 | 96.692 | 364 | 13.04 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_base_1k_224_ema.pth) |
|
||||
| convnext_large | [convnext_large](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convnext/convnext_large.py) | 84.164 | 96.844 | 781 | 13.78 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_large_1k_224_ema.pth) |
|
||||
| resmlp_12_distilled_224 | [resmlp_12_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/resmlp/resmlp_12_distilled_224.py) | 77.876 | 93.532 | 66 | 4.90 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlp_12_dist.pth) |
|
||||
| resmlp_24_distilled_224 | [resmlp_24_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/resmlp/resmlp_24_distilled_224.py) | 80.548 | 95.204 | 124 | 9.07 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlp_24_dist.pth) |
|
||||
| resmlp_36_distilled_224 | [resmlp_36_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/resmlp/resmlp_36_distilled_224.py) | 80.944 | 95.416 | 181 | 13.56 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlp_36_dist.pth) |
|
||||
| resmlp_big_24_distilled_224 | [resmlp_big_24_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/resmlp/resmlp_big_24_distilled_224.py) | 83.45 | 96.65 | 534 | 20.48 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlpB_24_dist.pth) |
|
||||
| coat_tiny | [coat_tiny](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/coat/coat_tiny.py) | 78.112 | 93.972 | 127 | 33.09 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/coat/coat_tiny-473c2a20.pth) |
|
||||
| coat_mini | [coat_mini](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/coat/coat_mini.py) | 80.912 | 95.378 | 247 | 33.29 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/coat/coat_mini-2c6baf49.pth) |
|
||||
| convmixer_768_32 | [convmixer_768_32](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convmixer/convmixer_768_32.py) | 80.08 | 94.992 | 4995 | 10.23 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convmixer/convmixer_768_32_ks7_p7_relu.pth.tar) |
|
||||
| convmixer_1024_20_ks9_p14 | [convmixer_1024_20_ks9_p14](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convmixer/convmixer_1024_20_ks9_p14.py) | 81.742 | 95.578 | 2407 | 6.29 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convmixer/convmixer_1024_20_ks9_p14.pth.tar) |
|
||||
| convmixer_1536_20 | [convmixer_1536_20](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convmixer/convmixer_1536_20.py) | 81.432 | 95.38 | 547 | 14.66 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convmixer/convmixer_1536_20_ks9_p7.pth.tar) |
|
||||
| gmixer_24_224 | [gmixer_24_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/gmixer/gmixer_24_224.py) | 78.088 | 93.6 | 104 | 11.65 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/gmixer/gmixer_24_224_raa-7daf7ae6.pth) |
|
||||
| gmlp_s16_224 | [gmlp_s16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/gmlp/gmlp_s16_224.py) | 77.204 | 93.358 | 81 | 11.15 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/gmlp/gmlp_s16_224.pth) |
|
||||
| mixer_b16_224 | [mixer_b16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/mlp-mixer/mixer_b16_224.py) | 72.558 | 90.068 | 241 | 5.37 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/mlp-mixer/jx_mixer_b16_224-76587d61.pth) |
|
||||
| mixer_l16_224 | [mixer_l16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/mlp-mixer/mixer_l16_224.py) | 68.34 | 86.11 | 804 | 11.74 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/mlp-mixer/jx_mixer_l16_224-92f9adc4.pth) |
|
||||
| jx_nest_tiny | [jx_nest_tiny](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/nest/jx_nest_tiny.py) | 81.278 | 95.618 | 90 | 9.05 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/nest/jx_nest_tiny-e3428fb9.pth) |
|
||||
| jx_nest_small | [jx_nest_tiny](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/nest/jx_nest_small.py) | 83.144 | 96.3 | 174 | 16.92 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/nest/jx_nest_small-422eaded.pth) |
|
||||
| jx_nest_base | [jx_nest_base](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/nest/jx_nest_base.py) | 83.474 | 96.442 | 300 | 16.88 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/nest/jx_nest_base-8bc41011.pth) |
|
||||
| pit_s_distilled_224 | [pit_s_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/pit/pit_s_distilled_224.py) | 83.144 | 96.3 | 109 | 7.00 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/pit/pit_s_distill_819.pth) |
|
||||
| pit_b_distilled_224 | [pit_b_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/pit/pit_b_distilled_224.py) | 83.474 | 96.442 | 330 | 7.66 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/pit/pit_b_distill_840.pth) |
|
||||
| twins_svt_small | [twins_svt_small](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/twins/twins_svt_small.py) | 81.598 | 95.55 | 657 | 14.07 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/twins/twins_svt_small-42e5f78c.pth) |
|
||||
| twins_svt_base | [twins_svt_base](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/twins/twins_svt_base.py) | 82.882 | 96.234 | 1447 | 18.99 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/twins/twins_svt_base-c2265010.pth) |
|
||||
| twins_svt_large | [twins_svt_large](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/twins/twins_svt_large.py) | 83.428 | 96.506 | 2567 | 19.11 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/twins/twins_svt_large-90f6aaa9.pth) |
|
||||
| swin_base_patch4_window7_224 | [swin_base_patch4_window7_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/swint/swin_base_patch4_window7_224.py) | 84.714 | 97.444 | 375 | 23.47 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_base_patch4_window7_224_22kto1k.pth) |
|
||||
| swin_large_patch4_window7_224 | [swin_large_patch4_window7_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/swint/swin_large_patch4_window7_224.py) | 85.826 | 97.816 | 788 | 23.29 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_large_patch4_window7_224_22kto1k.pth) |
|
||||
| dynamic_swin_small_p4_w7_224 | [dynamic_swin_small_p4_w7_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/swint/dynamic_small_base_p4_w7_224.py) | 82.896 | 96.234 | 220 | 28.55 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_small_patch4_window7_224_statedict.pth) |
|
||||
| dynamic_swin_tiny_p4_w7_224 | [dynamic_swin_tiny_p4_w7_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/swint/dynamic_swin_tiny_p4_w7_224.py) | 80.912 | 95.41 | 136 | 14.58 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_tiny_patch4_window7_224_statedict.pth) |
|
||||
| shuffletrans_tiny_p4_w7_224 | [shuffletrans_tiny_p4_w7_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/shuffle_transformer/shuffletrans_tiny_p4_w7_224.py) | 82.176 | 96.05 | 5311 | 13.90 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/shuffle_transformer/shuffle_tiny.pth) |
|
||||
|
||||
(ps: 通过导入官方模型得到推理结果,需要torch.__version__ >= 1.9.0,推理的输入尺寸默认为224,机器默认为V100 16G,其中gpu memory记录的是gpu peak memory)
|
||||
|
@ -18,11 +18,8 @@ class BenchMarkMLP(nn.Module):
|
||||
self.dropout = nn.Dropout(p=0.5)
|
||||
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.avg_pool = avg_pool
|
||||
# self.fc2 = nn.Linear(feature_num, num_classes)
|
||||
# self.relu2 = nn.ReLU()
|
||||
# self._initialize_weights()
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(
|
||||
@ -33,7 +30,4 @@ class BenchMarkMLP(nn.Module):
|
||||
x = self.pool(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu1(x)
|
||||
# x = self.dropout(x)
|
||||
# x = self.fc2(x)
|
||||
# x = self.relu2(x)
|
||||
return tuple([x])
|
||||
|
@ -344,32 +344,16 @@ class BNInception(nn.Module):
|
||||
1024, 128, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_5b_relu_pool_proj = nn.ReLU(inplace)
|
||||
# self.last_linear = nn.Linear (1024, num_classes)
|
||||
self.num_classes = num_classes
|
||||
if num_classes > 0:
|
||||
self.last_linear = nn.Linear(1024, num_classes)
|
||||
|
||||
self.pretrained = model_urls[self.__class__.__name__]
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str) or isinstance(pretrained, dict):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', nonlinearity='relu')
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
|
||||
# if self.zero_init_residual:
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, Bottleneck):
|
||||
# constant_init(m.norm3, 0)
|
||||
# elif isinstance(m, BasicBlock):
|
||||
# constant_init(m.norm2, 0)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', nonlinearity='relu')
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
|
||||
def features(self, input):
|
||||
conv1_7x7_s2_out = self.conv1_7x7_s2(input)
|
||||
|
@ -62,7 +62,6 @@ def _fuse_convkx_and_bn_(convkx, bn):
|
||||
bn.bias[:] = bn.bias - the_bias_shift
|
||||
bn.running_var[:] = 1.0 - bn.eps
|
||||
bn.running_mean[:] = 0.0
|
||||
# convkx.register_parameter('bias', bn.bias)
|
||||
convkx.bias = nn.Parameter(bn.bias)
|
||||
|
||||
|
||||
@ -76,7 +75,6 @@ def remove_bn_in_superblock(super_block):
|
||||
for block in the_seq_list:
|
||||
if isinstance(block, nn.BatchNorm2d):
|
||||
_fuse_convkx_and_bn_(last_block, block)
|
||||
# print('--debug fuse shortcut bn')
|
||||
else:
|
||||
new_seq_list.append(block)
|
||||
last_block = block
|
||||
@ -92,7 +90,6 @@ def remove_bn_in_superblock(super_block):
|
||||
for block in the_seq_list:
|
||||
if isinstance(block, nn.BatchNorm2d):
|
||||
_fuse_convkx_and_bn_(last_block, block)
|
||||
# print('--debug fuse conv bn')
|
||||
else:
|
||||
new_seq_list.append(block)
|
||||
last_block = block
|
||||
@ -1647,7 +1644,6 @@ class PlainNet(nn.Module):
|
||||
block_list.pop(-1)
|
||||
else:
|
||||
self.adptive_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
# AdaptiveAvgPool(out_channels=block_list[-1].out_channels, output_size=1)
|
||||
|
||||
self.block_list = block_list
|
||||
if not no_create:
|
||||
@ -1663,29 +1659,13 @@ class PlainNet(nn.Module):
|
||||
|
||||
self.plainnet_struct = str(self) + str(self.adptive_avg_pool)
|
||||
self.zero_init_residual = False
|
||||
self.pretrained = model_urls[self.__class__.__name__ +
|
||||
plainnet_struct_idx]
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str) or isinstance(pretrained, dict):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', nonlinearity='relu')
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
|
||||
# if self.zero_init_residual:
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, Bottleneck):
|
||||
# constant_init(m.norm3, 0)
|
||||
# elif isinstance(m, BasicBlock):
|
||||
# constant_init(m.norm2, 0)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
return
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', nonlinearity='relu')
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
|
||||
def forward(self, x):
|
||||
output = x
|
||||
@ -1699,13 +1679,3 @@ class PlainNet(nn.Module):
|
||||
output = self.fc_linear(output)
|
||||
|
||||
return [output]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from torch import nn
|
||||
|
||||
input_image_size = 192
|
||||
|
||||
# model = genet_normal(pretrained=True, root='data/models/GENet_params/')
|
||||
# print(type(model))
|
||||
# print(model.block_list)
|
||||
|
@ -184,7 +184,6 @@ class Bottleneck(nn.Module):
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
raise NotImplementedError
|
||||
# out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
@ -718,31 +717,25 @@ class HRNet(nn.Module):
|
||||
|
||||
return nn.Sequential(*hr_modules), in_channels
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
def init_weights(self):
|
||||
"""Initialize the weights in backbone.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
if isinstance(pretrained, str):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
normal_init(m, std=0.001)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
normal_init(m, std=0.001)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
|
||||
if self.zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
constant_init(m.norm3, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
constant_init(m.norm2, 0)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
if self.zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
constant_init(m.norm3, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
constant_init(m.norm2, 0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
@ -60,23 +60,17 @@ class Inception3(nn.Module):
|
||||
if num_classes > 0:
|
||||
self.fc = nn.Linear(2048, num_classes)
|
||||
|
||||
self.pretrained = model_urls[self.__class__.__name__]
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if pretrained is not None:
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
else:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', nonlinearity='relu')
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', nonlinearity='relu')
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.transform_input:
|
||||
|
@ -958,24 +958,18 @@ class LiteHRNet(nn.Module):
|
||||
|
||||
return nn.Sequential(*modules), in_channels
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
def init_weights(self):
|
||||
"""Initialize the weights in backbone.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
if isinstance(pretrained, str):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
normal_init(m, std=0.001)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
normal_init(m, std=0.001)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
@ -73,17 +73,16 @@ class MaskedAutoencoderViT(nn.Module):
|
||||
])
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
# we use xavier_uniform following official JAX ViT:
|
||||
torch.nn.init.xavier_uniform_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
# we use xavier_uniform following official JAX ViT:
|
||||
torch.nn.init.xavier_uniform_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def random_masking(self, x, mask_ratio):
|
||||
"""
|
||||
|
@ -143,8 +143,6 @@ class MNASNet(torch.nn.Module):
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(p=dropout, inplace=True),
|
||||
nn.Linear(1280, num_classes))
|
||||
self.init_weights()
|
||||
self.pretrained = model_urls[self.__class__.__name__ + str(alpha)]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layers(x)
|
||||
@ -155,22 +153,16 @@ class MNASNet(torch.nn.Module):
|
||||
else:
|
||||
return [x]
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str) or isinstance(pretrained, dict):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
@ -156,24 +156,18 @@ class MobileNetV2(nn.Module):
|
||||
self.pretrained = model_urls[self.__class__.__name__ + '_' +
|
||||
str(width_multi)]
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str) or isinstance(pretrained, dict):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
|
@ -1,12 +1,15 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import importlib
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.models.helpers import load_pretrained
|
||||
from timm.models.hub import download_cached_file
|
||||
|
||||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.logger import get_root_logger
|
||||
from easycv.utils.logger import get_root_logger, print_log
|
||||
from ..modelzoo import timm_models as model_urls
|
||||
from ..registry import BACKBONES
|
||||
from .shuffle_transformer import (shuffletrans_base_p4_w7_224,
|
||||
@ -66,8 +69,6 @@ class PytorchImageModelWrapper(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
model_name='resnet50',
|
||||
pretrained=False,
|
||||
checkpoint_path=None,
|
||||
scriptable=None,
|
||||
exportable=None,
|
||||
no_jit=None,
|
||||
@ -76,15 +77,16 @@ class PytorchImageModelWrapper(nn.Module):
|
||||
Inits PytorchImageModelWrapper by timm.create_models
|
||||
Args:
|
||||
model_name (str): name of model to instantiate
|
||||
pretrained (bool): load pretrained ImageNet-1k weights if true
|
||||
checkpoint_path (str): path of checkpoint to load after model is initialized
|
||||
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
|
||||
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
|
||||
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
|
||||
"""
|
||||
super(PytorchImageModelWrapper, self).__init__()
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
timm_model_names = timm.list_models(pretrained=False)
|
||||
self.timm_model_names = timm_model_names
|
||||
assert model_name in timm_model_names or model_name in _MODEL_MAP, \
|
||||
f'{model_name} is not in model_list of timm/fair, please check the model_name!'
|
||||
|
||||
@ -94,52 +96,52 @@ class PytorchImageModelWrapper(nn.Module):
|
||||
|
||||
# create model by timm
|
||||
if model_name in timm_model_names:
|
||||
try:
|
||||
if pretrained and (model_name in model_urls):
|
||||
self.model = timm.create_model(model_name, False, '',
|
||||
scriptable, exportable,
|
||||
no_jit, **kwargs)
|
||||
self.init_weights(model_urls[model_name])
|
||||
print('Info: Load model from %s' % model_urls[model_name])
|
||||
|
||||
if checkpoint_path is not None:
|
||||
self.init_weights(checkpoint_path)
|
||||
else:
|
||||
# load from timm
|
||||
if pretrained and model_name.startswith('swin_') and (
|
||||
LooseVersion(
|
||||
torch.__version__) <= LooseVersion('1.6.0')):
|
||||
print(
|
||||
'Warning: Pretrained SwinTransformer from timm may be zipfile extract'
|
||||
' error while torch<=1.6.0')
|
||||
self.model = timm.create_model(model_name, pretrained,
|
||||
checkpoint_path, scriptable,
|
||||
exportable, no_jit,
|
||||
**kwargs)
|
||||
|
||||
# need fix: delete this except after pytorch 1.7 update in all production
|
||||
# (dlc, dsw, studio, ev_predict_py3)
|
||||
except Exception:
|
||||
print(
|
||||
f'Error: Fail to create {model_name} with (pretrained={pretrained}, checkpoint_path={checkpoint_path} ...)'
|
||||
)
|
||||
print(
|
||||
f'Try to create {model_name} with pretrained=False, checkpoint_path=None and default params'
|
||||
)
|
||||
self.model = timm.create_model(model_name, False, '', None,
|
||||
None, None, **kwargs)
|
||||
|
||||
# facebook model wrapper
|
||||
if model_name in _MODEL_MAP:
|
||||
self.model = timm.create_model(model_name, False, '', scriptable,
|
||||
exportable, no_jit, **kwargs)
|
||||
elif model_name in _MODEL_MAP:
|
||||
self.model = _MODEL_MAP[model_name](**kwargs)
|
||||
if pretrained:
|
||||
if model_name in model_urls.keys():
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""
|
||||
Args:
|
||||
if pretrained == True, load model from default path;
|
||||
if pretrained == False or None, load from init weights.
|
||||
|
||||
if model_name in timm_model_names, load model from timm default path;
|
||||
if model_name in _MODEL_MAP, load model from easycv default path
|
||||
"""
|
||||
logger = get_root_logger()
|
||||
if pretrained:
|
||||
if self.model_name in self.timm_model_names:
|
||||
pretrained_path = model_urls[self.model_name]
|
||||
print_log(
|
||||
'load model from default path: {}'.format(pretrained_path),
|
||||
logger)
|
||||
if pretrained_path.endswith('.npz'):
|
||||
pretrained_loc = download_cached_file(
|
||||
pretrained_path, check_hash=False, progress=False)
|
||||
return self.model.load_pretrained(pretrained_loc)
|
||||
else:
|
||||
backbone_module = importlib.import_module(
|
||||
self.model.__module__)
|
||||
return load_pretrained(
|
||||
self.model,
|
||||
default_cfg={'url': pretrained_path},
|
||||
filter_fn=backbone_module.checkpoint_filter_fn
|
||||
if hasattr(backbone_module, 'checkpoint_filter_fn')
|
||||
else None)
|
||||
elif self.model_name in _MODEL_MAP:
|
||||
if self.model_name in model_urls.keys():
|
||||
pretrained_path = model_urls[self.model_name]
|
||||
print_log(
|
||||
'load model from default path: {}'.format(
|
||||
pretrained_path), logger)
|
||||
try_max = 3
|
||||
try_idx = 0
|
||||
while try_idx < try_max:
|
||||
try:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
url=model_urls[model_name],
|
||||
url=pretrained_path,
|
||||
map_location='cpu',
|
||||
)
|
||||
try_idx += try_max
|
||||
@ -147,31 +149,22 @@ class PytorchImageModelWrapper(nn.Module):
|
||||
try_idx += 1
|
||||
state_dict = {}
|
||||
if try_idx == try_max:
|
||||
print(
|
||||
'load from url failed ! oh my DLC & OSS, you boys really good! ',
|
||||
model_urls[model_name])
|
||||
print_log(
|
||||
f'load from url failed ! oh my DLC & OSS, you boys really good! {model_urls[self.model_name]}',
|
||||
logger)
|
||||
|
||||
# for some model strict = False still failed when model doesn't exactly match
|
||||
try:
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
except Exception:
|
||||
print('load for model_name not all right')
|
||||
if 'model' in state_dict:
|
||||
state_dict = state_dict['model']
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
else:
|
||||
print('%s not in evtorch modelzoo!' % model_name)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
# pretrained is the path of pretrained model offered by easycv
|
||||
if pretrained is not None:
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(
|
||||
self.model,
|
||||
pretrained,
|
||||
map_location=torch.device('cpu'),
|
||||
strict=False,
|
||||
logger=logger)
|
||||
raise ValueError('{} not in evtorch modelzoo!'.format(
|
||||
self.model_name))
|
||||
else:
|
||||
raise ValueError(
|
||||
'Error: Fail to create {} with (pretrained={}...)'.format(
|
||||
self.model_name, pretrained))
|
||||
else:
|
||||
# init by timm
|
||||
pass
|
||||
self.model.init_weights()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
|
@ -470,26 +470,19 @@ class ResNeSt(nn.Module):
|
||||
self.avgpool = GlobalAvgPool2d()
|
||||
self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None
|
||||
self.norm_layer = norm_layer
|
||||
# self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
if num_classes > 0:
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str) or isinstance(pretrained, dict):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, self.norm_layer):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, self.norm_layer):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self,
|
||||
block,
|
||||
@ -613,7 +606,6 @@ class ResNeSt(nn.Module):
|
||||
|
||||
if hasattr(self, 'fc'):
|
||||
x = self.avgpool(x)
|
||||
# x = x.view(x.size(0), -1)
|
||||
x = torch.flatten(x, 1)
|
||||
if self.drop:
|
||||
x = self.drop(x)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user