diff --git a/configs/classification/cifar10/r50_b128_300e_jpg.py b/configs/classification/cifar10/r50_b128_300e_jpg.py index fa628304..3099c312 100644 --- a/configs/classification/cifar10/r50_b128_300e_jpg.py +++ b/configs/classification/cifar10/r50_b128_300e_jpg.py @@ -2,7 +2,6 @@ _base_ = '../../base.py' # model settings model = dict( type='Classification', - pretrained=None, backbone=dict( type='ResNet', depth=50, diff --git a/configs/classification/cifar10/swintiny_b64_5e_jpg.py b/configs/classification/cifar10/swintiny_b64_5e_jpg.py index 695a9b31..dffb32a7 100644 --- a/configs/classification/cifar10/swintiny_b64_5e_jpg.py +++ b/configs/classification/cifar10/swintiny_b64_5e_jpg.py @@ -2,7 +2,6 @@ _base_ = '../../base.py' # model settings model = dict( type='Classification', - pretrained=None, backbone=dict( type='PytorchImageModelWrapper', model_name='swin_tiny_patch4_window7_224', diff --git a/configs/classification/imagenet/hrnet/hrnetw18_b32x8_100e_jpg.py b/configs/classification/imagenet/hrnet/hrnetw18_b32x8_100e_jpg.py index bc14d535..104db80e 100644 --- a/configs/classification/imagenet/hrnet/hrnetw18_b32x8_100e_jpg.py +++ b/configs/classification/imagenet/hrnet/hrnetw18_b32x8_100e_jpg.py @@ -7,7 +7,6 @@ log_config = dict( # model settings model = dict( type='Classification', - pretrained=None, backbone=dict(type='HRNet', arch='w18', multi_scale_output=True), neck=dict(type='HRFuseScales', in_channels=(18, 36, 72, 144)), head=dict( diff --git a/configs/classification/imagenet/hrnet/imagenet_hrnetw30_jpg.py b/configs/classification/imagenet/hrnet/imagenet_hrnetw30_jpg.py index e308be7e..2135bb06 100644 --- a/configs/classification/imagenet/hrnet/imagenet_hrnetw30_jpg.py +++ b/configs/classification/imagenet/hrnet/imagenet_hrnetw30_jpg.py @@ -1,3 +1,5 @@ _base_ = './hrnetw18_b32x8_100e_jpg.py' # model settings -model = dict(backbone=dict(type='HRNet', arch='w30', multi_scale_output=True)) +model = dict( + backbone=dict(type='HRNet', arch='w30', multi_scale_output=True), + neck=dict(type='HRFuseScales', in_channels=(30, 60, 120, 240))) diff --git a/configs/classification/imagenet/hrnet/imagenet_hrnetw32_jpg.py b/configs/classification/imagenet/hrnet/imagenet_hrnetw32_jpg.py index f683df9e..1c1192bc 100644 --- a/configs/classification/imagenet/hrnet/imagenet_hrnetw32_jpg.py +++ b/configs/classification/imagenet/hrnet/imagenet_hrnetw32_jpg.py @@ -1,3 +1,5 @@ _base_ = './hrnetw18_b32x8_100e_jpg.py' # model settings -model = dict(backbone=dict(type='HRNet', arch='w32', multi_scale_output=True)) +model = dict( + backbone=dict(type='HRNet', arch='w32', multi_scale_output=True), + neck=dict(type='HRFuseScales', in_channels=(32, 64, 128, 256))) diff --git a/configs/classification/imagenet/hrnet/imagenet_hrnetw40_jpg.py b/configs/classification/imagenet/hrnet/imagenet_hrnetw40_jpg.py index b940fe33..d77c4c30 100644 --- a/configs/classification/imagenet/hrnet/imagenet_hrnetw40_jpg.py +++ b/configs/classification/imagenet/hrnet/imagenet_hrnetw40_jpg.py @@ -1,3 +1,5 @@ _base_ = './hrnetw18_b32x8_100e_jpg.py' # model settings -model = dict(backbone=dict(type='HRNet', arch='w40', multi_scale_output=True)) +model = dict( + backbone=dict(type='HRNet', arch='w40', multi_scale_output=True), + neck=dict(type='HRFuseScales', in_channels=(40, 80, 160, 320))) diff --git a/configs/classification/imagenet/hrnet/imagenet_hrnetw44_jpg.py b/configs/classification/imagenet/hrnet/imagenet_hrnetw44_jpg.py index 36b106da..50fc9bd8 100644 --- a/configs/classification/imagenet/hrnet/imagenet_hrnetw44_jpg.py +++ b/configs/classification/imagenet/hrnet/imagenet_hrnetw44_jpg.py @@ -1,3 +1,5 @@ _base_ = './hrnetw18_b32x8_100e_jpg.py' # model settings -model = dict(backbone=dict(type='HRNet', arch='w44', multi_scale_output=True)) +model = dict( + backbone=dict(type='HRNet', arch='w44', multi_scale_output=True), + neck=dict(type='HRFuseScales', in_channels=(44, 88, 176, 352))) diff --git a/configs/classification/imagenet/hrnet/imagenet_hrnetw48_jpg.py b/configs/classification/imagenet/hrnet/imagenet_hrnetw48_jpg.py index 16cda45e..c4d60d3f 100644 --- a/configs/classification/imagenet/hrnet/imagenet_hrnetw48_jpg.py +++ b/configs/classification/imagenet/hrnet/imagenet_hrnetw48_jpg.py @@ -1,4 +1,6 @@ _base_ = './hrnetw18_b32x8_100e_jpg.py' # model settings -model = dict(backbone=dict(type='HRNet', arch='w48', multi_scale_output=True)) +model = dict( + backbone=dict(type='HRNet', arch='w48', multi_scale_output=True), + neck=dict(type='HRFuseScales', in_channels=(48, 96, 192, 384))) diff --git a/configs/classification/imagenet/hrnet/imagenet_hrnetw64_jpg.py b/configs/classification/imagenet/hrnet/imagenet_hrnetw64_jpg.py index 36ea6ae1..133e2a1e 100644 --- a/configs/classification/imagenet/hrnet/imagenet_hrnetw64_jpg.py +++ b/configs/classification/imagenet/hrnet/imagenet_hrnetw64_jpg.py @@ -1,3 +1,5 @@ _base_ = './hrnetw18_b32x8_100e_jpg.py' # model settings -model = dict(backbone=dict(type='HRNet', arch='w64', multi_scale_output=True)) +model = dict( + backbone=dict(type='HRNet', arch='w64', multi_scale_output=True), + neck=dict(type='HRFuseScales', in_channels=(64, 128, 256, 512))) diff --git a/configs/classification/imagenet/resnet/imagenet_resnet101_jpg.py b/configs/classification/imagenet/resnet/imagenet_resnet101_jpg.py index 32ad3512..cc16641b 100644 --- a/configs/classification/imagenet/resnet/imagenet_resnet101_jpg.py +++ b/configs/classification/imagenet/resnet/imagenet_resnet101_jpg.py @@ -2,18 +2,8 @@ _base_ = './resnet50_b32x8_100e_jpg.py' # model settings model = dict( type='Classification', - pretrained=None, backbone=dict( type='ResNet', depth=101, out_indices=[4], # 0: conv-1, x: stage-x - norm_cfg=dict(type='BN')), - head=dict( - type='ClsHead', - with_avg_pool=True, - in_channels=2048, - loss_config=dict( - type='CrossEntropyLossWithLabelSmooth', - label_smooth=0, - ), - num_classes=1000)) + norm_cfg=dict(type='BN'))) diff --git a/configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py b/configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py index 11f727d5..99f2d169 100644 --- a/configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py +++ b/configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py @@ -1,19 +1,8 @@ _base_ = './resnet50_b32x8_100e_jpg.py' # model settings model = dict( - type='Classification', - pretrained=None, backbone=dict( type='ResNet', depth=50, out_indices=[4], # 0: conv-1, x: stage-x - norm_cfg=dict(type='BN')), - head=dict( - type='ClsHead', - with_avg_pool=True, - in_channels=2048, - loss_config=dict( - type='CrossEntropyLossWithLabelSmooth', - label_smooth=0, - ), - num_classes=1000)) + norm_cfg=dict(type='BN'))) diff --git a/configs/classification/imagenet/resnet/resnet50_b32x8_100e_jpg.py b/configs/classification/imagenet/resnet/resnet50_b32x8_100e_jpg.py index d788bcaf..03124f20 100644 --- a/configs/classification/imagenet/resnet/resnet50_b32x8_100e_jpg.py +++ b/configs/classification/imagenet/resnet/resnet50_b32x8_100e_jpg.py @@ -7,7 +7,6 @@ log_config = dict( # model settings model = dict( type='Classification', - pretrained=None, backbone=dict( type='ResNet', depth=50, diff --git a/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py b/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py index b0ad9c51..22f0f7c1 100644 --- a/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py +++ b/configs/classification/imagenet/resnext/resnext50-32x4d_b32x8_100e_jpg.py @@ -7,7 +7,6 @@ log_config = dict( # model settings model = dict( type='Classification', - pretrained=None, backbone=dict( type='ResNeXt', depth=50, diff --git a/configs/classification/imagenet/swint/swin_tiny_patch4_window7_224_b64x16_300e_jpg.py b/configs/classification/imagenet/swint/swin_tiny_patch4_window7_224_b64x16_300e_jpg.py index 4fcb7a21..cbb472f9 100644 --- a/configs/classification/imagenet/swint/swin_tiny_patch4_window7_224_b64x16_300e_jpg.py +++ b/configs/classification/imagenet/swint/swin_tiny_patch4_window7_224_b64x16_300e_jpg.py @@ -16,7 +16,6 @@ model = dict( mode='batch', label_smoothing=0.1, num_classes=1000), - pretrained=None, backbone=dict( type='PytorchImageModelWrapper', model_name='swin_tiny_patch4_window7_224', diff --git a/configs/classification/imagenet/timm/cait/cait_s24_224.py b/configs/classification/imagenet/timm/cait/cait_s24_224.py new file mode 100644 index 00000000..763fa3b4 --- /dev/null +++ b/configs/classification/imagenet/timm/cait/cait_s24_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='cait_s24_224')) diff --git a/configs/classification/imagenet/timm/cait/cait_xxs24_224.py b/configs/classification/imagenet/timm/cait/cait_xxs24_224.py new file mode 100644 index 00000000..84206b8e --- /dev/null +++ b/configs/classification/imagenet/timm/cait/cait_xxs24_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='cait_xxs24_224')) diff --git a/configs/classification/imagenet/timm/cait/cait_xxs36_224.py b/configs/classification/imagenet/timm/cait/cait_xxs36_224.py new file mode 100644 index 00000000..d509f4a8 --- /dev/null +++ b/configs/classification/imagenet/timm/cait/cait_xxs36_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='cait_xxs36_224')) diff --git a/configs/classification/imagenet/timm/coat/coat_mini.py b/configs/classification/imagenet/timm/coat/coat_mini.py new file mode 100644 index 00000000..af77e823 --- /dev/null +++ b/configs/classification/imagenet/timm/coat/coat_mini.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='coat_mini')) diff --git a/configs/classification/imagenet/timm/coat/coat_tiny.py b/configs/classification/imagenet/timm/coat/coat_tiny.py new file mode 100644 index 00000000..0a9bca68 --- /dev/null +++ b/configs/classification/imagenet/timm/coat/coat_tiny.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='coat_tiny')) diff --git a/configs/classification/imagenet/timm/convit/convit_base.py b/configs/classification/imagenet/timm/convit/convit_base.py new file mode 100644 index 00000000..c27b8768 --- /dev/null +++ b/configs/classification/imagenet/timm/convit/convit_base.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='convit_base')) diff --git a/configs/classification/imagenet/timm/convit/convit_small.py b/configs/classification/imagenet/timm/convit/convit_small.py new file mode 100644 index 00000000..f63e8411 --- /dev/null +++ b/configs/classification/imagenet/timm/convit/convit_small.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='convit_small')) diff --git a/configs/classification/imagenet/timm/convit/convit_tiny.py b/configs/classification/imagenet/timm/convit/convit_tiny.py new file mode 100644 index 00000000..3ee9525f --- /dev/null +++ b/configs/classification/imagenet/timm/convit/convit_tiny.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='convit_tiny')) diff --git a/configs/classification/imagenet/timm/convmixer/convmixer_1024_20_ks9_p14.py b/configs/classification/imagenet/timm/convmixer/convmixer_1024_20_ks9_p14.py new file mode 100644 index 00000000..4eef59e4 --- /dev/null +++ b/configs/classification/imagenet/timm/convmixer/convmixer_1024_20_ks9_p14.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='convmixer_1024_20_ks9_p14')) diff --git a/configs/classification/imagenet/timm/convmixer/convmixer_1536_20.py b/configs/classification/imagenet/timm/convmixer/convmixer_1536_20.py new file mode 100644 index 00000000..9dfa25dc --- /dev/null +++ b/configs/classification/imagenet/timm/convmixer/convmixer_1536_20.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='convmixer_1536_20')) diff --git a/configs/classification/imagenet/timm/convmixer/convmixer_768_32.py b/configs/classification/imagenet/timm/convmixer/convmixer_768_32.py new file mode 100644 index 00000000..4c14a249 --- /dev/null +++ b/configs/classification/imagenet/timm/convmixer/convmixer_768_32.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='convmixer_768_32')) diff --git a/configs/classification/imagenet/timm/convnext/convnext_base.py b/configs/classification/imagenet/timm/convnext/convnext_base.py new file mode 100644 index 00000000..39c941ce --- /dev/null +++ b/configs/classification/imagenet/timm/convnext/convnext_base.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='convnext_base')) diff --git a/configs/classification/imagenet/timm/convnext/convnext_large.py b/configs/classification/imagenet/timm/convnext/convnext_large.py new file mode 100644 index 00000000..2cc138f2 --- /dev/null +++ b/configs/classification/imagenet/timm/convnext/convnext_large.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='convnext_large')) diff --git a/configs/classification/imagenet/timm/convnext/convnext_small.py b/configs/classification/imagenet/timm/convnext/convnext_small.py new file mode 100644 index 00000000..e003323f --- /dev/null +++ b/configs/classification/imagenet/timm/convnext/convnext_small.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='convnext_small')) diff --git a/configs/classification/imagenet/timm/convnext/convnext_tiny.py b/configs/classification/imagenet/timm/convnext/convnext_tiny.py new file mode 100644 index 00000000..c4eb09de --- /dev/null +++ b/configs/classification/imagenet/timm/convnext/convnext_tiny.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='convnext_tiny')) diff --git a/configs/classification/imagenet/timm/crossvit/crossvit_base_240.py b/configs/classification/imagenet/timm/crossvit/crossvit_base_240.py new file mode 100644 index 00000000..223476a2 --- /dev/null +++ b/configs/classification/imagenet/timm/crossvit/crossvit_base_240.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='crossvit_base_240')) diff --git a/configs/classification/imagenet/timm/crossvit/crossvit_small_240.py b/configs/classification/imagenet/timm/crossvit/crossvit_small_240.py new file mode 100644 index 00000000..d7202787 --- /dev/null +++ b/configs/classification/imagenet/timm/crossvit/crossvit_small_240.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='crossvit_small_240')) diff --git a/configs/classification/imagenet/timm/crossvit/crossvit_tiny_240.py b/configs/classification/imagenet/timm/crossvit/crossvit_tiny_240.py new file mode 100644 index 00000000..d385a7c9 --- /dev/null +++ b/configs/classification/imagenet/timm/crossvit/crossvit_tiny_240.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='crossvit_tiny_240')) diff --git a/configs/classification/imagenet/timm/deit/deit_base_distilled_patch16_224.py b/configs/classification/imagenet/timm/deit/deit_base_distilled_patch16_224.py new file mode 100644 index 00000000..cf5b7558 --- /dev/null +++ b/configs/classification/imagenet/timm/deit/deit_base_distilled_patch16_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='deit_base_distilled_patch16_224')) diff --git a/configs/classification/imagenet/timm/deit/deit_base_patch16_224.py b/configs/classification/imagenet/timm/deit/deit_base_patch16_224.py new file mode 100644 index 00000000..4f751769 --- /dev/null +++ b/configs/classification/imagenet/timm/deit/deit_base_patch16_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='deit_base_patch16_224')) diff --git a/configs/classification/imagenet/timm/gmixer/gmixer_24_224.py b/configs/classification/imagenet/timm/gmixer/gmixer_24_224.py new file mode 100644 index 00000000..c2a6e487 --- /dev/null +++ b/configs/classification/imagenet/timm/gmixer/gmixer_24_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='gmixer_24_224')) diff --git a/configs/classification/imagenet/timm/gmlp/gmlp_s16_224.py b/configs/classification/imagenet/timm/gmlp/gmlp_s16_224.py new file mode 100644 index 00000000..852b2c0e --- /dev/null +++ b/configs/classification/imagenet/timm/gmlp/gmlp_s16_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='gmlp_s16_224')) diff --git a/configs/classification/imagenet/timm/levit/levit_128.py b/configs/classification/imagenet/timm/levit/levit_128.py new file mode 100644 index 00000000..405e1004 --- /dev/null +++ b/configs/classification/imagenet/timm/levit/levit_128.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='levit_128')) diff --git a/configs/classification/imagenet/timm/levit/levit_192.py b/configs/classification/imagenet/timm/levit/levit_192.py new file mode 100644 index 00000000..affb90eb --- /dev/null +++ b/configs/classification/imagenet/timm/levit/levit_192.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='levit_192')) diff --git a/configs/classification/imagenet/timm/levit/levit_256.py b/configs/classification/imagenet/timm/levit/levit_256.py new file mode 100644 index 00000000..9b5684e1 --- /dev/null +++ b/configs/classification/imagenet/timm/levit/levit_256.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='levit_256')) diff --git a/configs/classification/imagenet/timm/mlp-mixer/mixer_b16_224.py b/configs/classification/imagenet/timm/mlp-mixer/mixer_b16_224.py new file mode 100644 index 00000000..89125a16 --- /dev/null +++ b/configs/classification/imagenet/timm/mlp-mixer/mixer_b16_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='mixer_b16_224')) diff --git a/configs/classification/imagenet/timm/mlp-mixer/mixer_l16_224.py b/configs/classification/imagenet/timm/mlp-mixer/mixer_l16_224.py new file mode 100644 index 00000000..40912a23 --- /dev/null +++ b/configs/classification/imagenet/timm/mlp-mixer/mixer_l16_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='mixer_l16_224')) diff --git a/configs/classification/imagenet/timm/mobilevit/mobilevit_s.py b/configs/classification/imagenet/timm/mobilevit/mobilevit_s.py new file mode 100644 index 00000000..98a6805a --- /dev/null +++ b/configs/classification/imagenet/timm/mobilevit/mobilevit_s.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='mobilevit_s')) diff --git a/configs/classification/imagenet/timm/mobilevit/mobilevit_xs.py b/configs/classification/imagenet/timm/mobilevit/mobilevit_xs.py new file mode 100644 index 00000000..827f2401 --- /dev/null +++ b/configs/classification/imagenet/timm/mobilevit/mobilevit_xs.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='mobilevit_xs')) diff --git a/configs/classification/imagenet/timm/mobilevit/mobilevit_xxs.py b/configs/classification/imagenet/timm/mobilevit/mobilevit_xxs.py new file mode 100644 index 00000000..65b200a3 --- /dev/null +++ b/configs/classification/imagenet/timm/mobilevit/mobilevit_xxs.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='mobilevit_xxs')) diff --git a/configs/classification/imagenet/timm/nest/jx_nest_base.py b/configs/classification/imagenet/timm/nest/jx_nest_base.py new file mode 100644 index 00000000..261ff3ca --- /dev/null +++ b/configs/classification/imagenet/timm/nest/jx_nest_base.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='jx_nest_base')) diff --git a/configs/classification/imagenet/timm/nest/jx_nest_small.py b/configs/classification/imagenet/timm/nest/jx_nest_small.py new file mode 100644 index 00000000..dc4a2e71 --- /dev/null +++ b/configs/classification/imagenet/timm/nest/jx_nest_small.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='jx_nest_small')) diff --git a/configs/classification/imagenet/timm/nest/jx_nest_tiny.py b/configs/classification/imagenet/timm/nest/jx_nest_tiny.py new file mode 100644 index 00000000..49a90599 --- /dev/null +++ b/configs/classification/imagenet/timm/nest/jx_nest_tiny.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='jx_nest_tiny')) diff --git a/configs/classification/imagenet/timm/pit/pit_b_distilled_224.py b/configs/classification/imagenet/timm/pit/pit_b_distilled_224.py new file mode 100644 index 00000000..c726aa50 --- /dev/null +++ b/configs/classification/imagenet/timm/pit/pit_b_distilled_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='pit_b_distilled_224')) diff --git a/configs/classification/imagenet/timm/pit/pit_s_distilled_224.py b/configs/classification/imagenet/timm/pit/pit_s_distilled_224.py new file mode 100644 index 00000000..1f0b72e0 --- /dev/null +++ b/configs/classification/imagenet/timm/pit/pit_s_distilled_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='pit_s_distilled_224')) diff --git a/configs/classification/imagenet/timm/poolformer/poolformer_m36.py b/configs/classification/imagenet/timm/poolformer/poolformer_m36.py new file mode 100644 index 00000000..59842d12 --- /dev/null +++ b/configs/classification/imagenet/timm/poolformer/poolformer_m36.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='poolformer_m36')) diff --git a/configs/classification/imagenet/timm/poolformer/poolformer_m48.py b/configs/classification/imagenet/timm/poolformer/poolformer_m48.py new file mode 100644 index 00000000..162c1b6a --- /dev/null +++ b/configs/classification/imagenet/timm/poolformer/poolformer_m48.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='poolformer_m48')) diff --git a/configs/classification/imagenet/timm/poolformer/poolformer_s12.py b/configs/classification/imagenet/timm/poolformer/poolformer_s12.py new file mode 100644 index 00000000..2ec5c74c --- /dev/null +++ b/configs/classification/imagenet/timm/poolformer/poolformer_s12.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='poolformer_s12')) diff --git a/configs/classification/imagenet/timm/poolformer/poolformer_s24.py b/configs/classification/imagenet/timm/poolformer/poolformer_s24.py new file mode 100644 index 00000000..153b0ad9 --- /dev/null +++ b/configs/classification/imagenet/timm/poolformer/poolformer_s24.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='poolformer_s24')) diff --git a/configs/classification/imagenet/timm/poolformer/poolformer_s36.py b/configs/classification/imagenet/timm/poolformer/poolformer_s36.py new file mode 100644 index 00000000..5a175875 --- /dev/null +++ b/configs/classification/imagenet/timm/poolformer/poolformer_s36.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='poolformer_s36')) diff --git a/configs/classification/imagenet/timm/resmlp/resmlp_12_distilled_224.py b/configs/classification/imagenet/timm/resmlp/resmlp_12_distilled_224.py new file mode 100644 index 00000000..46756407 --- /dev/null +++ b/configs/classification/imagenet/timm/resmlp/resmlp_12_distilled_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='resmlp_12_distilled_224')) diff --git a/configs/classification/imagenet/timm/resmlp/resmlp_24_distilled_224.py b/configs/classification/imagenet/timm/resmlp/resmlp_24_distilled_224.py new file mode 100644 index 00000000..2a34879d --- /dev/null +++ b/configs/classification/imagenet/timm/resmlp/resmlp_24_distilled_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='resmlp_24_distilled_224')) diff --git a/configs/classification/imagenet/timm/resmlp/resmlp_36_distilled_224.py b/configs/classification/imagenet/timm/resmlp/resmlp_36_distilled_224.py new file mode 100644 index 00000000..6cddde73 --- /dev/null +++ b/configs/classification/imagenet/timm/resmlp/resmlp_36_distilled_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='resmlp_36_distilled_224')) diff --git a/configs/classification/imagenet/timm/resmlp/resmlp_big_24_distilled_224.py b/configs/classification/imagenet/timm/resmlp/resmlp_big_24_distilled_224.py new file mode 100644 index 00000000..e604c5f4 --- /dev/null +++ b/configs/classification/imagenet/timm/resmlp/resmlp_big_24_distilled_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='resmlp_big_24_distilled_224')) diff --git a/configs/classification/imagenet/timm/sequencer/sequencer2d_l.py b/configs/classification/imagenet/timm/sequencer/sequencer2d_l.py new file mode 100644 index 00000000..90dac68c --- /dev/null +++ b/configs/classification/imagenet/timm/sequencer/sequencer2d_l.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='sequencer2d_l')) diff --git a/configs/classification/imagenet/timm/sequencer/sequencer2d_m.py b/configs/classification/imagenet/timm/sequencer/sequencer2d_m.py new file mode 100644 index 00000000..c2b745c2 --- /dev/null +++ b/configs/classification/imagenet/timm/sequencer/sequencer2d_m.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='sequencer2d_m')) diff --git a/configs/classification/imagenet/timm/sequencer/sequencer2d_s.py b/configs/classification/imagenet/timm/sequencer/sequencer2d_s.py new file mode 100644 index 00000000..d064c395 --- /dev/null +++ b/configs/classification/imagenet/timm/sequencer/sequencer2d_s.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='sequencer2d_s')) diff --git a/configs/classification/imagenet/timm/shuffle_transformer/shuffletrans_base_p4_w7_224.py b/configs/classification/imagenet/timm/shuffle_transformer/shuffletrans_base_p4_w7_224.py new file mode 100644 index 00000000..17a3522e --- /dev/null +++ b/configs/classification/imagenet/timm/shuffle_transformer/shuffletrans_base_p4_w7_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='shuffletrans_base_p4_w7_224')) diff --git a/configs/classification/imagenet/timm/shuffle_transformer/shuffletrans_small_p4_w7_224.py b/configs/classification/imagenet/timm/shuffle_transformer/shuffletrans_small_p4_w7_224.py new file mode 100644 index 00000000..c467f700 --- /dev/null +++ b/configs/classification/imagenet/timm/shuffle_transformer/shuffletrans_small_p4_w7_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='shuffletrans_small_p4_w7_224')) diff --git a/configs/classification/imagenet/timm/shuffle_transformer/shuffletrans_tiny_p4_w7_224.py b/configs/classification/imagenet/timm/shuffle_transformer/shuffletrans_tiny_p4_w7_224.py new file mode 100644 index 00000000..30b1db3c --- /dev/null +++ b/configs/classification/imagenet/timm/shuffle_transformer/shuffletrans_tiny_p4_w7_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='shuffletrans_tiny_p4_w7_224')) diff --git a/configs/classification/imagenet/timm/swint/dynamic_swin_small_p4_w7_224.py b/configs/classification/imagenet/timm/swint/dynamic_swin_small_p4_w7_224.py new file mode 100644 index 00000000..bc99874c --- /dev/null +++ b/configs/classification/imagenet/timm/swint/dynamic_swin_small_p4_w7_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='dynamic_swin_small_p4_w7_224')) diff --git a/configs/classification/imagenet/timm/swint/dynamic_swin_tiny_p4_w7_224.py b/configs/classification/imagenet/timm/swint/dynamic_swin_tiny_p4_w7_224.py new file mode 100644 index 00000000..a4c8d187 --- /dev/null +++ b/configs/classification/imagenet/timm/swint/dynamic_swin_tiny_p4_w7_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='dynamic_swin_tiny_p4_w7_224')) diff --git a/configs/classification/imagenet/timm/swint/swin_base_patch4_window7_224.py b/configs/classification/imagenet/timm/swint/swin_base_patch4_window7_224.py new file mode 100644 index 00000000..1bae92d1 --- /dev/null +++ b/configs/classification/imagenet/timm/swint/swin_base_patch4_window7_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='swin_base_patch4_window7_224')) diff --git a/configs/classification/imagenet/timm/swint/swin_large_patch4_window7_224.py b/configs/classification/imagenet/timm/swint/swin_large_patch4_window7_224.py new file mode 100644 index 00000000..50441669 --- /dev/null +++ b/configs/classification/imagenet/timm/swint/swin_large_patch4_window7_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='swin_large_patch4_window7_224')) diff --git a/configs/classification/imagenet/timm/imagenet_transformer_jpg.py b/configs/classification/imagenet/timm/timm_config.py similarity index 59% rename from configs/classification/imagenet/timm/imagenet_transformer_jpg.py rename to configs/classification/imagenet/timm/timm_config.py index b8a84f7a..562a77ad 100644 --- a/configs/classification/imagenet/timm/imagenet_transformer_jpg.py +++ b/configs/classification/imagenet/timm/timm_config.py @@ -1,45 +1,48 @@ _base_ = 'configs/base.py' + log_config = dict( - interval=100, + interval=10, hooks=[dict(type='TextLoggerHook'), dict(type='TensorboardLoggerHook')]) # model settings model = dict( type='Classification', - pretrained=None, + train_preprocess=['mixUp'], + pretrained=True, + mixup_cfg=dict( + mixup_alpha=0.2, + prob=1.0, + mode='batch', + label_smoothing=0.1, + num_classes=1000), backbone=dict( type='PytorchImageModelWrapper', - # model_name='pit_xs_distilled_224', - # model_name='swin_small_patch4_window7_224', - # model_name='swin_tiny_patch4_window7_224', - # model_name='swin_base_patch4_window7_224_in22k', - model_name='vit_deit_small_distilled_patch16_224', - # model_name = 'vit_deit_small_distilled_patch16_224', - # model_name = 'resnet50', + model_name='vit_base_patch16_224', num_classes=1000, - pretrained=True, ), head=dict( - # type='ClsHead', with_avg_pool=True, in_channels=384, - # type='ClsHead', with_avg_pool=True, in_channels=768, - # type='ClsHead', with_avg_pool=True, in_channels=1024, - # num_classes=0) type='ClsHead', + loss_config={ + 'type': 'SoftTargetCrossEntropy', + }, with_fc=False)) -data_train_list = 'data/imagenet_raw/meta/train_labeled.txt' -data_train_root = 'data/imagenet_raw/train/' -data_test_list = 'data/imagenet_raw/meta/val_labeled.txt' -data_test_root = 'data/imagenet_raw/validation/' -data_all_list = 'data/imagenet_raw/meta/all_labeled.txt' -data_root = 'data/imagenet_raw/' +load_from = None + +data_train_list = './imagenet_raw/meta/train_labeled.txt' +data_train_root = './imagenet_raw/train/' +data_test_list = './imagenet_raw/meta/val_labeled.txt' +data_test_root = './imagenet_raw/val/' +data_all_list = './imagenet_raw/meta/all_labeled.txt' +data_root = './imagenet_raw/' dataset_type = 'ClsDataset' img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_pipeline = [ dict(type='RandomResizedCrop', size=224), dict(type='RandomHorizontalFlip'), + dict(type='MMAutoAugment'), dict(type='ToTensor'), dict(type='Normalize', **img_norm_cfg), dict(type='Collect', keys=['img', 'gt_labels']) @@ -53,7 +56,7 @@ test_pipeline = [ ] data = dict( - imgs_per_gpu=32, # total 256 + imgs_per_gpu=64, # total 256 workers_per_gpu=8, train=dict( type=dataset_type, @@ -84,11 +87,25 @@ eval_pipelines = [ custom_hooks = [] # optimizer -optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) +optimizer = dict( + type='AdamW', + lr=0.003, + weight_decay=0.3, + paramwise_options={ + 'cls_token': dict(weight_decay=0.), + 'pos_embed': dict(weight_decay=0.), + }) +optimizer_config = dict(grad_clip=dict(max_norm=1.0), update_interval=8) # learning policy -lr_config = dict(policy='step', step=[30, 60, 90]) -checkpoint_config = dict(interval=10) +lr_config = dict( + policy='CosineAnnealing', + min_lr=0, + warmup='linear', + warmup_iters=10000, + warmup_ratio=1e-4, +) +checkpoint_config = dict(interval=30) # runtime settings total_epochs = 90 diff --git a/configs/classification/imagenet/timm/tnt/tnt_s_patch16_224.py b/configs/classification/imagenet/timm/tnt/tnt_s_patch16_224.py new file mode 100644 index 00000000..47b55cc7 --- /dev/null +++ b/configs/classification/imagenet/timm/tnt/tnt_s_patch16_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='tnt_s_patch16_224')) diff --git a/configs/classification/imagenet/timm/twins/twins_svt_base.py b/configs/classification/imagenet/timm/twins/twins_svt_base.py new file mode 100644 index 00000000..a388b0d0 --- /dev/null +++ b/configs/classification/imagenet/timm/twins/twins_svt_base.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='twins_svt_base')) diff --git a/configs/classification/imagenet/timm/twins/twins_svt_large.py b/configs/classification/imagenet/timm/twins/twins_svt_large.py new file mode 100644 index 00000000..eaf1fcb3 --- /dev/null +++ b/configs/classification/imagenet/timm/twins/twins_svt_large.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='twins_svt_large')) diff --git a/configs/classification/imagenet/timm/twins/twins_svt_small.py b/configs/classification/imagenet/timm/twins/twins_svt_small.py new file mode 100644 index 00000000..c5c1c118 --- /dev/null +++ b/configs/classification/imagenet/timm/twins/twins_svt_small.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='twins_svt_small')) diff --git a/configs/classification/imagenet/timm/vit/vit_base_patch16_224.py b/configs/classification/imagenet/timm/vit/vit_base_patch16_224.py new file mode 100644 index 00000000..b4d60077 --- /dev/null +++ b/configs/classification/imagenet/timm/vit/vit_base_patch16_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='vit_base_patch16_224')) diff --git a/configs/classification/imagenet/timm/vit/vit_large_patch16_224.py b/configs/classification/imagenet/timm/vit/vit_large_patch16_224.py new file mode 100644 index 00000000..aa2dfeaa --- /dev/null +++ b/configs/classification/imagenet/timm/vit/vit_large_patch16_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='vit_large_patch16_224')) diff --git a/configs/classification/imagenet/timm/xcit/xcit_large_24_p8_224.py b/configs/classification/imagenet/timm/xcit/xcit_large_24_p8_224.py new file mode 100644 index 00000000..fb026548 --- /dev/null +++ b/configs/classification/imagenet/timm/xcit/xcit_large_24_p8_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='xcit_large_24_p8_224')) diff --git a/configs/classification/imagenet/timm/xcit/xcit_large_24_p8_224_dist.py b/configs/classification/imagenet/timm/xcit/xcit_large_24_p8_224_dist.py new file mode 100644 index 00000000..83ee4fca --- /dev/null +++ b/configs/classification/imagenet/timm/xcit/xcit_large_24_p8_224_dist.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='xcit_large_24_p8_224_dist')) diff --git a/configs/classification/imagenet/timm/xcit/xcit_medium_24_p8_224.py b/configs/classification/imagenet/timm/xcit/xcit_medium_24_p8_224.py new file mode 100644 index 00000000..c56a5713 --- /dev/null +++ b/configs/classification/imagenet/timm/xcit/xcit_medium_24_p8_224.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='xcit_medium_24_p8_224')) diff --git a/configs/classification/imagenet/timm/xcit/xcit_medium_24_p8_224_dist.py b/configs/classification/imagenet/timm/xcit/xcit_medium_24_p8_224_dist.py new file mode 100644 index 00000000..0db87288 --- /dev/null +++ b/configs/classification/imagenet/timm/xcit/xcit_medium_24_p8_224_dist.py @@ -0,0 +1,4 @@ +_base_ = '../timm_config.py' + +# model settings +model = dict(backbone=dict(model_name='xcit_medium_24_p8_224_dist')) diff --git a/configs/classification/imagenet/vit/vit_base_patch16_224_b64x64_300e_jpg.py b/configs/classification/imagenet/vit/vit_base_patch16_224_b64x64_300e_jpg.py index e041198e..b4fe5e6b 100644 --- a/configs/classification/imagenet/vit/vit_base_patch16_224_b64x64_300e_jpg.py +++ b/configs/classification/imagenet/vit/vit_base_patch16_224_b64x64_300e_jpg.py @@ -15,7 +15,6 @@ model = dict( mode='batch', label_smoothing=0.1, num_classes=1000), - pretrained=None, backbone=dict( type='PytorchImageModelWrapper', model_name='vit_base_patch16_224', diff --git a/configs/config_templates/classification.py b/configs/config_templates/classification.py index 67fcd306..e3c14aae 100644 --- a/configs/config_templates/classification.py +++ b/configs/config_templates/classification.py @@ -12,7 +12,6 @@ export = dict(export_neck=True) # model settings model = dict( type='Classification', - pretrained=None, backbone=dict( type='ResNet', depth=50, diff --git a/configs/config_templates/classification_oss.py b/configs/config_templates/classification_oss.py index 5519d94b..76d52cc7 100644 --- a/configs/config_templates/classification_oss.py +++ b/configs/config_templates/classification_oss.py @@ -19,7 +19,6 @@ export = dict(export_neck=True) # model settings model = dict( type='Classification', - pretrained=None, backbone=dict( type='PytorchImageModelWrapper', # model_name='pit_xs_distilled_224', diff --git a/configs/config_templates/classification_tfrecord_oss.py b/configs/config_templates/classification_tfrecord_oss.py index 7ebd3e1b..bd8b4d4d 100644 --- a/configs/config_templates/classification_tfrecord_oss.py +++ b/configs/config_templates/classification_tfrecord_oss.py @@ -19,7 +19,6 @@ export = dict(export_neck=True) # model settings model = dict( type='Classification', - pretrained=None, backbone=dict( type='ResNet', depth=50, diff --git a/configs/config_templates/metric_learning/modelparallel_softmaxbased_tfrecord_oss.py b/configs/config_templates/metric_learning/modelparallel_softmaxbased_tfrecord_oss.py index 486f2f33..0bcf9c0f 100644 --- a/configs/config_templates/metric_learning/modelparallel_softmaxbased_tfrecord_oss.py +++ b/configs/config_templates/metric_learning/modelparallel_softmaxbased_tfrecord_oss.py @@ -44,7 +44,6 @@ work_dir = 'oss://path/to/work_dirs/classification/' # model settings model = dict( type='Classification', - # pretrained=None, backbone=dict( type='PytorchImageModelWrapper', diff --git a/configs/config_templates/metric_learning/softmaxbased_tfrecord_oss.py b/configs/config_templates/metric_learning/softmaxbased_tfrecord_oss.py index 36fa4a3f..7ce3b7cf 100644 --- a/configs/config_templates/metric_learning/softmaxbased_tfrecord_oss.py +++ b/configs/config_templates/metric_learning/softmaxbased_tfrecord_oss.py @@ -49,7 +49,6 @@ work_dir = 'oss://path/to/work_dirs/classification/' # model settings model = dict( type='Classification', - # pretrained=None, backbone=dict( type='PytorchImageModelWrapper', diff --git a/configs/metric_learning/cub_resnet50_jpg.py b/configs/metric_learning/cub_resnet50_jpg.py index 64f5d125..82a774d3 100644 --- a/configs/metric_learning/cub_resnet50_jpg.py +++ b/configs/metric_learning/cub_resnet50_jpg.py @@ -7,7 +7,6 @@ log_config = dict( # model settings model = dict( type='Classification', - pretrained=None, backbone=dict( type='ResNet', depth=50, diff --git a/configs/metric_learning/imagenet_resnet50_1000kid_jpg.py b/configs/metric_learning/imagenet_resnet50_1000kid_jpg.py index cdde110a..e6af3303 100644 --- a/configs/metric_learning/imagenet_resnet50_1000kid_jpg.py +++ b/configs/metric_learning/imagenet_resnet50_1000kid_jpg.py @@ -7,7 +7,6 @@ log_config = dict( # model settings model = dict( type='Classification', - pretrained=None, backbone=dict( type='ResNet', depth=50, diff --git a/configs/metric_learning/sop_timm_swinb_local.py b/configs/metric_learning/sop_timm_swinb_local.py index 7cb04097..930976b5 100644 --- a/configs/metric_learning/sop_timm_swinb_local.py +++ b/configs/metric_learning/sop_timm_swinb_local.py @@ -13,9 +13,9 @@ load_from = None # model settings model = dict( type='Classification', + train_preprocess=['randomErasing'], pretrained= 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/swin_base_patch4_window7_224_22k_statedict.pth', - train_preprocess=['randomErasing'], backbone=dict( type='PytorchImageModelWrapper', model_name='swin_base_patch4_window7_224_in22k' diff --git a/docs/source/model_zoo_cls.md b/docs/source/model_zoo_cls.md index eb33217f..b988d528 100644 --- a/docs/source/model_zoo_cls.md +++ b/docs/source/model_zoo_cls.md @@ -2,22 +2,77 @@ ## Benchmarks -| Algorithm | Config | Top-1 (%) | Top-5 (%) | Download | -| --------- | ------------------------------------------------------------ | --------- | --------- | ------------------------------------------------------------ | -| resnet50(raw) | [resnet50(raw)](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py) | 76.454 | 93.084 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet50/epoch_100.pth) | -| resnet50(tfrecord) | [resnet50(tfrecord)](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_rn50_tfrecord.py) | 76.266 | 92.972 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet50/epoch_100.pth) | -| resnet101 | [resnet101](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet101_jpg.py) | 78.152 | 93.922 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet101/epoch_100.pth) | -| resnet152 | [resnet152](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet152_jpg.py) | 78.544 | 94.206 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet152/epoch_100.pth) | -| resnext50-32x4d | [resnext50-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext50-32x4d_jpg.py) | 77.604 | 93.856 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnet50/epoch_100.pth) | -| resnext101-32x4d | [resnext101-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext101-32x4d_jpg.py) | 78.568 | 94.344 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext50-32x4d/epoch_100.pth) | -| resnext101-32x8d | [resnext101-32x8d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext101-32x8d_jpg.py) | 79.468 | 94.434 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext101-32x8d/epoch_100.pth) | -| resnext152-32x4d | [resnext152-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext152-32x4d_jpg.py) | 78.994 | 94.462 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext152-32x4d/epoch_100.pth) | -| hrnetw18 | [hrnetw18](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw18_jpg.py) | 76.258 | 92.976 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw18/epoch_100.pth) | -| hrnetw30 | [hrnetw30](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw30_jpg.py) | 77.66 | 93.862 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw30/epoch_100.pth) | -| hrnetw32 | [hrnetw32](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw32_jpg.py) | 77.994 | 93.976 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw32/epoch_100.pth) | -| hrnetw40 | [hrnetw40](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw40_jpg.py) | 78.142 | 93.956 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw40/epoch_100.pth) | -| hrnetw44 | [hrnetw44](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw44_jpg.py) | 79.266 | 94.476 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw44/epoch_100.pth) | -| hrnetw48 | [hrnetw48](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw48_jpg.py) | 79.636 | 94.802 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw48/epoch_100.pth) | -| hrnetw64 | [hrnetw64](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw64_jpg.py) | 79.884 | 95.04 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/hrnetw64/epoch_100.pth) | -| vit-base-patch16 | [vit-base-patch16](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/vit/imagenet_vit_base_patch16_224_jpg.py) | 76.082 | 92.026 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/vit/vit-base-patch16/epoch_300.pth) | -| swin-tiny-patch4-window7 | [swin-tiny-patch4-window7](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/swint/imagenet_swin_tiny_patch4_window7_224_jpg.py) | 80.528 | 94.822 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/swint/swin-tiny-patch4-window7/epoch_300.pth) | +| Algorithm | Config | Top-1 (%) | Top-5 (%) | gpu memory (MB) | inference time (ms/img) | Download | +| --------- | ------------------------------------------------------------ | --------- | --------- | --------- | --------- | ------------------------------------------------------------ | +| resnet50(raw) | [resnet50(raw)](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py) | 76.454 | 93.084 | 2412 | 8.59 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet50/epoch_100.pth) | +| resnet50(tfrecord) | [resnet50(tfrecord)](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_rn50_tfrecord.py) | 76.266 | 92.972 | 2412 | 8.59 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet50/epoch_100.pth) | +| resnet101 | [resnet101](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet101_jpg.py) | 78.152 | 93.922 | 2484 | 16.77 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet101/epoch_100.pth) | +| resnet152 | [resnet152](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnet/imagenet_resnet152_jpg.py) | 78.544 | 94.206 | 2544 | 24.69 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/resnet152/epoch_100.pth) | +| resnext50-32x4d | [resnext50-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext50-32x4d_jpg.py) | 77.604 | 93.856 | 4718 | 12.88 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnet50/epoch_100.pth) | +| resnext101-32x4d | [resnext101-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext101-32x4d_jpg.py) | 78.568 | 94.344 | 4792 | 26.84 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext50-32x4d/epoch_100.pth) | +| resnext101-32x8d | [resnext101-32x8d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext101-32x8d_jpg.py) | 79.468 | 94.434 | 9582 | 27.52 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext101-32x8d/epoch_100.pth) | +| resnext152-32x4d | [resnext152-32x4d](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/resnext/imagenet_resnext152-32x4d_jpg.py) | 78.994 | 94.462 | 4852 | 41.08 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnext/resnext152-32x4d/epoch_100.pth) | +| hrnetw18 | [hrnetw18](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw18_jpg.py) | 76.258 | 92.976 | 4701 | 54.55 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw18/epoch_100.pth) | +| hrnetw30 | [hrnetw30](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw30_jpg.py) | 77.66 | 93.862 | 4766 | 54. 30 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw30/epoch_100.pth) | +| hrnetw32 | [hrnetw32](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw32_jpg.py) | 77.994 | 93.976 | 4780 | 53.48 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw32/epoch_100.pth) | +| hrnetw40 | [hrnetw40](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw40_jpg.py) | 78.142 | 93.956 | 4843 | 54.31 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw40/epoch_100.pth) | +| hrnetw44 | [hrnetw44](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw44_jpg.py) | 79.266 | 94.476 | 4884 | 54.83 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw44/epoch_100.pth) | +| hrnetw48 | [hrnetw48](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw48_jpg.py) | 79.636 | 94.802 | 4916 | 54.14 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/hrnet/hrnetw48/epoch_100.pth) | +| hrnetw64 | [hrnetw64](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/hrnet/imagenet_hrnetw64_jpg.py) | 79.884 | 95.04 | 5120 | 54.74 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/resnet/hrnetw64/epoch_100.pth) | +| vit-base-patch16 | [vit-base-patch16](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/vit/imagenet_vit_base_patch16_224_jpg.py) | 76.082 | 92.026 | 346 | 8.03 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/vit/vit-base-patch16/epoch_300.pth) | +| swin-tiny-patch4-window7 | [swin-tiny-patch4-window7](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/swint/imagenet_swin_tiny_patch4_window7_224_jpg.py) | 80.528 | 94.822 | 132 | 12.94 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/swint/swin-tiny-patch4-window7/epoch_300.pth) | + +(ps: 通过EasyCV训练得到模型结果,推理的输入尺寸默认为224,机器默认为V100 16G,其中gpu memory记录的是gpu peak memory) + +| Algorithm | Config | Top-1 (%) | Top-5 (%) | gpu memory (MB) | inference time (ms/img) | Download | +| --------- | ------------------------------------------------------------ | --------- | --------- | --------- | --------- | ------------------------------------------------------------ | +| vit_base_patch16_224 | [vit_base_patch16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/vit/vit_base_patch16_224.py) | 78.096 | 94.324 | 346 | 8.03 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/vit/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz) | +| vit_large_patch16_224 | [vit_large_patch16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/vit/vit_large_patch16_224.py) | 84.404 | 97.276 | 1171 | 16.30 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/vit/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz) | +| deit_base_patch16_224 | [deit_base_patch16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/deit/deit_base_patch16_224.py) | 81.756 | 95.6 | 346 | 7.98 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_patch16_224-b5f2ef4d.pth) | +| deit_base_distilled_patch16_224 | [deit_base_distilled_patch16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/deit/deit_base_distilled_patch16_224.py) | 83.232 | 96.476 | 349 | 8.07 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_distilled_patch16_224-df68dfff.pth) | +| xcit_medium_24_p8_224 | [xcit_medium_24_p8_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/xcit/xcit_medium_24_p8_224.py) | 83.348 | 96.21 | 884 | 31.77 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_medium_24_p8_224.pth) | +| xcit_medium_24_p8_224_dist | [xcit_medium_24_p8_224_dist](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/xcit/xcit_medium_24_p8_224_dist.py) | 84.876 | 97.164 | 884 | 32.08 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_medium_24_p8_224_dist.pth) | +| xcit_large_24_p8_224 | [xcit_large_24_p8_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/xcit/xcit_large_24_p8_224.py) | 83.986 | 96.47 | 1962 | 37.44 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_large_24_p8_224.pth) | +| xcit_large_24_p8_224_dist | [xcit_large_24_p8_224_dist](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/xcit/xcit_large_24_p8_224_dist.py) | 85.022 | 97.29 | 1962 | 37.44 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_large_24_p8_224_dist.pth) | +| tnt_s_patch16_224 | [tnt_s_patch16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/tnt/tnt_s_patch16_224.py) | 76.934 | 93.388 | 100 | 18.92 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/tnt/tnt_s_patch16_224.pth.tar) | +| convit_tiny | [convit_tiny](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convit/convit_tiny.py) | 72.954 | 91.68 | 31 | 10.79 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convit/convit_tiny.pth) | +| convit_small | [convit_small](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convit/convit_small.py) | 81.342 | 95.784 | 122 | 11.23 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convit/convit_small.pth) | +| convit_base | [convit_base](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convit/convit_base.py) | 82.27 | 95.916 | 358 | 11.26 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convit/convit_base.pth) | +| cait_xxs24_224 | [cait_xxs24_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/cait/cait_xxs24_224.py) | 78.45 | 94.154 | 50 | 22.62 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/cait/XXS24_224.pth) | +| cait_xxs36_224 | [cait_xxs36_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/cait/cait_xxs36_224.py) | 79.788 | 94.87 | 71 | 33.25 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/cait/XXS36_224.pth) | +| cait_s24_224 | [cait_s24_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/cait/cait_s24_224.py) | 83.302 | 96.568 | 190 | 23.74 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/cait/cait_s24_224.pth) | +| levit_128 | [levit_128](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/levit/levit_128.py) | 78.468 | 93.874 | 76 | 15.33 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/levit/LeViT-128-b88c2750.pth) | +| levit_192 | [levit_192](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/levit/levit_192.py) | 79.72 | 94.664 | 128 | 15.17 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/levit/LeViT-192-92712e41.pth) | +| levit_256 | [levit_256](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/levit/levit_256.py) | 81.432 | 95.38 | 222 | 15.27 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/levit/LeViT-256-13b5763e.pth) | +| convnext_tiny | [convnext_tiny](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convnext/convnext_tiny.py) | 81.878 | 95.836 | 128 | 7.17 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_tiny_1k_224_ema.pth) | +| convnext_small | [convnext_small](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convnext/convnext_small.py) | 82.836 | 96.458 | 213 | 12.89 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_small_1k_224_ema.pth) | +| convnext_base | [convnext_base](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convnext/convnext_base.py) | 83.73 | 96.692 | 364 | 13.04 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_base_1k_224_ema.pth) | +| convnext_large | [convnext_large](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convnext/convnext_large.py) | 84.164 | 96.844 | 781 | 13.78 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_large_1k_224_ema.pth) | +| resmlp_12_distilled_224 | [resmlp_12_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/resmlp/resmlp_12_distilled_224.py) | 77.876 | 93.532 | 66 | 4.90 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlp_12_dist.pth) | +| resmlp_24_distilled_224 | [resmlp_24_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/resmlp/resmlp_24_distilled_224.py) | 80.548 | 95.204 | 124 | 9.07 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlp_24_dist.pth) | +| resmlp_36_distilled_224 | [resmlp_36_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/resmlp/resmlp_36_distilled_224.py) | 80.944 | 95.416 | 181 | 13.56 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlp_36_dist.pth) | +| resmlp_big_24_distilled_224 | [resmlp_big_24_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/resmlp/resmlp_big_24_distilled_224.py) | 83.45 | 96.65 | 534 | 20.48 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlpB_24_dist.pth) | +| coat_tiny | [coat_tiny](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/coat/coat_tiny.py) | 78.112 | 93.972 | 127 | 33.09 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/coat/coat_tiny-473c2a20.pth) | +| coat_mini | [coat_mini](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/coat/coat_mini.py) | 80.912 | 95.378 | 247 | 33.29 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/coat/coat_mini-2c6baf49.pth) | +| convmixer_768_32 | [convmixer_768_32](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convmixer/convmixer_768_32.py) | 80.08 | 94.992 | 4995 | 10.23 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convmixer/convmixer_768_32_ks7_p7_relu.pth.tar) | +| convmixer_1024_20_ks9_p14 | [convmixer_1024_20_ks9_p14](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convmixer/convmixer_1024_20_ks9_p14.py) | 81.742 | 95.578 | 2407 | 6.29 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convmixer/convmixer_1024_20_ks9_p14.pth.tar) | +| convmixer_1536_20 | [convmixer_1536_20](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/convmixer/convmixer_1536_20.py) | 81.432 | 95.38 | 547 | 14.66 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convmixer/convmixer_1536_20_ks9_p7.pth.tar) | +| gmixer_24_224 | [gmixer_24_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/gmixer/gmixer_24_224.py) | 78.088 | 93.6 | 104 | 11.65 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/gmixer/gmixer_24_224_raa-7daf7ae6.pth) | +| gmlp_s16_224 | [gmlp_s16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/gmlp/gmlp_s16_224.py) | 77.204 | 93.358 | 81 | 11.15 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/gmlp/gmlp_s16_224.pth) | +| mixer_b16_224 | [mixer_b16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/mlp-mixer/mixer_b16_224.py) | 72.558 | 90.068 | 241 | 5.37 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/mlp-mixer/jx_mixer_b16_224-76587d61.pth) | +| mixer_l16_224 | [mixer_l16_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/mlp-mixer/mixer_l16_224.py) | 68.34 | 86.11 | 804 | 11.74 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/mlp-mixer/jx_mixer_l16_224-92f9adc4.pth) | +| jx_nest_tiny | [jx_nest_tiny](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/nest/jx_nest_tiny.py) | 81.278 | 95.618 | 90 | 9.05 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/nest/jx_nest_tiny-e3428fb9.pth) | +| jx_nest_small | [jx_nest_tiny](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/nest/jx_nest_small.py) | 83.144 | 96.3 | 174 | 16.92 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/nest/jx_nest_small-422eaded.pth) | +| jx_nest_base | [jx_nest_base](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/nest/jx_nest_base.py) | 83.474 | 96.442 | 300 | 16.88 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/nest/jx_nest_base-8bc41011.pth) | +| pit_s_distilled_224 | [pit_s_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/pit/pit_s_distilled_224.py) | 83.144 | 96.3 | 109 | 7.00 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/pit/pit_s_distill_819.pth) | +| pit_b_distilled_224 | [pit_b_distilled_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/pit/pit_b_distilled_224.py) | 83.474 | 96.442 | 330 | 7.66 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/pit/pit_b_distill_840.pth) | +| twins_svt_small | [twins_svt_small](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/twins/twins_svt_small.py) | 81.598 | 95.55 | 657 | 14.07 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/twins/twins_svt_small-42e5f78c.pth) | +| twins_svt_base | [twins_svt_base](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/twins/twins_svt_base.py) | 82.882 | 96.234 | 1447 | 18.99 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/twins/twins_svt_base-c2265010.pth) | +| twins_svt_large | [twins_svt_large](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/twins/twins_svt_large.py) | 83.428 | 96.506 | 2567 | 19.11 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/twins/twins_svt_large-90f6aaa9.pth) | +| swin_base_patch4_window7_224 | [swin_base_patch4_window7_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/swint/swin_base_patch4_window7_224.py) | 84.714 | 97.444 | 375 | 23.47 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_base_patch4_window7_224_22kto1k.pth) | +| swin_large_patch4_window7_224 | [swin_large_patch4_window7_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/swint/swin_large_patch4_window7_224.py) | 85.826 | 97.816 | 788 | 23.29 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_large_patch4_window7_224_22kto1k.pth) | +| dynamic_swin_small_p4_w7_224 | [dynamic_swin_small_p4_w7_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/swint/dynamic_small_base_p4_w7_224.py) | 82.896 | 96.234 | 220 | 28.55 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_small_patch4_window7_224_statedict.pth) | +| dynamic_swin_tiny_p4_w7_224 | [dynamic_swin_tiny_p4_w7_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/swint/dynamic_swin_tiny_p4_w7_224.py) | 80.912 | 95.41 | 136 | 14.58 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_tiny_patch4_window7_224_statedict.pth) | +| shuffletrans_tiny_p4_w7_224 | [shuffletrans_tiny_p4_w7_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/timm/shuffle_transformer/shuffletrans_tiny_p4_w7_224.py) | 82.176 | 96.05 | 5311 | 13.90 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/shuffle_transformer/shuffle_tiny.pth) | + +(ps: 通过导入官方模型得到推理结果,需要torch.__version__ >= 1.9.0,推理的输入尺寸默认为224,机器默认为V100 16G,其中gpu memory记录的是gpu peak memory) diff --git a/easycv/models/backbones/benchmark_mlp.py b/easycv/models/backbones/benchmark_mlp.py index 69aa0fff..5c20ee1b 100644 --- a/easycv/models/backbones/benchmark_mlp.py +++ b/easycv/models/backbones/benchmark_mlp.py @@ -18,11 +18,8 @@ class BenchMarkMLP(nn.Module): self.dropout = nn.Dropout(p=0.5) self.pool = nn.AdaptiveAvgPool2d((1, 1)) self.avg_pool = avg_pool - # self.fc2 = nn.Linear(feature_num, num_classes) - # self.relu2 = nn.ReLU() - # self._initialize_weights() - def init_weights(self, pretrained=None): + def init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.kaiming_normal_( @@ -33,7 +30,4 @@ class BenchMarkMLP(nn.Module): x = self.pool(x) x = self.fc1(x) x = self.relu1(x) - # x = self.dropout(x) - # x = self.fc2(x) - # x = self.relu2(x) return tuple([x]) diff --git a/easycv/models/backbones/bninception.py b/easycv/models/backbones/bninception.py index 8b782258..5c4ad53a 100644 --- a/easycv/models/backbones/bninception.py +++ b/easycv/models/backbones/bninception.py @@ -344,32 +344,16 @@ class BNInception(nn.Module): 1024, 128, kernel_size=(1, 1), stride=(1, 1)) self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, affine=True) self.inception_5b_relu_pool_proj = nn.ReLU(inplace) - # self.last_linear = nn.Linear (1024, num_classes) self.num_classes = num_classes if num_classes > 0: self.last_linear = nn.Linear(1024, num_classes) - self.pretrained = model_urls[self.__class__.__name__] - - def init_weights(self, pretrained=None): - if isinstance(pretrained, str) or isinstance(pretrained, dict): - logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - kaiming_init(m, mode='fan_in', nonlinearity='relu') - elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) - - # if self.zero_init_residual: - # for m in self.modules(): - # if isinstance(m, Bottleneck): - # constant_init(m.norm3, 0) - # elif isinstance(m, BasicBlock): - # constant_init(m.norm2, 0) - else: - raise TypeError('pretrained must be a str or None') + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', nonlinearity='relu') + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) def features(self, input): conv1_7x7_s2_out = self.conv1_7x7_s2(input) diff --git a/easycv/models/backbones/genet.py b/easycv/models/backbones/genet.py index 8d432ffa..f579c8ec 100644 --- a/easycv/models/backbones/genet.py +++ b/easycv/models/backbones/genet.py @@ -62,7 +62,6 @@ def _fuse_convkx_and_bn_(convkx, bn): bn.bias[:] = bn.bias - the_bias_shift bn.running_var[:] = 1.0 - bn.eps bn.running_mean[:] = 0.0 - # convkx.register_parameter('bias', bn.bias) convkx.bias = nn.Parameter(bn.bias) @@ -76,7 +75,6 @@ def remove_bn_in_superblock(super_block): for block in the_seq_list: if isinstance(block, nn.BatchNorm2d): _fuse_convkx_and_bn_(last_block, block) - # print('--debug fuse shortcut bn') else: new_seq_list.append(block) last_block = block @@ -92,7 +90,6 @@ def remove_bn_in_superblock(super_block): for block in the_seq_list: if isinstance(block, nn.BatchNorm2d): _fuse_convkx_and_bn_(last_block, block) - # print('--debug fuse conv bn') else: new_seq_list.append(block) last_block = block @@ -1647,7 +1644,6 @@ class PlainNet(nn.Module): block_list.pop(-1) else: self.adptive_avg_pool = nn.AdaptiveAvgPool2d((1, 1)) - # AdaptiveAvgPool(out_channels=block_list[-1].out_channels, output_size=1) self.block_list = block_list if not no_create: @@ -1663,29 +1659,13 @@ class PlainNet(nn.Module): self.plainnet_struct = str(self) + str(self.adptive_avg_pool) self.zero_init_residual = False - self.pretrained = model_urls[self.__class__.__name__ + - plainnet_struct_idx] - def init_weights(self, pretrained=None): - if isinstance(pretrained, str) or isinstance(pretrained, dict): - logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - kaiming_init(m, mode='fan_in', nonlinearity='relu') - elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) - - # if self.zero_init_residual: - # for m in self.modules(): - # if isinstance(m, Bottleneck): - # constant_init(m.norm3, 0) - # elif isinstance(m, BasicBlock): - # constant_init(m.norm2, 0) - else: - raise TypeError('pretrained must be a str or None') - return + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', nonlinearity='relu') + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) def forward(self, x): output = x @@ -1699,13 +1679,3 @@ class PlainNet(nn.Module): output = self.fc_linear(output) return [output] - - -if __name__ == '__main__': - from torch import nn - - input_image_size = 192 - - # model = genet_normal(pretrained=True, root='data/models/GENet_params/') - # print(type(model)) - # print(model.block_list) diff --git a/easycv/models/backbones/hrnet.py b/easycv/models/backbones/hrnet.py index f18b57b0..f8daeec7 100644 --- a/easycv/models/backbones/hrnet.py +++ b/easycv/models/backbones/hrnet.py @@ -184,7 +184,6 @@ class Bottleneck(nn.Module): if self.with_cp and x.requires_grad: raise NotImplementedError - # out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) @@ -718,31 +717,25 @@ class HRNet(nn.Module): return nn.Sequential(*hr_modules), in_channels - def init_weights(self, pretrained=None): + def init_weights(self): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ - if isinstance(pretrained, str): - logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - normal_init(m, std=0.001) - elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) - if self.zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck): - constant_init(m.norm3, 0) - elif isinstance(m, BasicBlock): - constant_init(m.norm2, 0) - else: - raise TypeError('pretrained must be a str or None') + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) def forward(self, x): """Forward function.""" diff --git a/easycv/models/backbones/inceptionv3.py b/easycv/models/backbones/inceptionv3.py index e4c165d2..d6f1b5f1 100644 --- a/easycv/models/backbones/inceptionv3.py +++ b/easycv/models/backbones/inceptionv3.py @@ -60,23 +60,17 @@ class Inception3(nn.Module): if num_classes > 0: self.fc = nn.Linear(2048, num_classes) - self.pretrained = model_urls[self.__class__.__name__] - - def init_weights(self, pretrained=None): - if pretrained is not None: - logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - else: - for m in self.modules(): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - kaiming_init(m, mode='fan_in', nonlinearity='relu') - elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', nonlinearity='relu') + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) def forward(self, x): if self.transform_input: diff --git a/easycv/models/backbones/lighthrnet.py b/easycv/models/backbones/lighthrnet.py index 7777c4c3..24db40e0 100644 --- a/easycv/models/backbones/lighthrnet.py +++ b/easycv/models/backbones/lighthrnet.py @@ -958,24 +958,18 @@ class LiteHRNet(nn.Module): return nn.Sequential(*modules), in_channels - def init_weights(self, pretrained=None): + def init_weights(self): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ - if isinstance(pretrained, str): - logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - normal_init(m, std=0.001) - elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) - else: - raise TypeError('pretrained must be a str or None') + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) def forward(self, x): """Forward function.""" diff --git a/easycv/models/backbones/mae_vit_transformer.py b/easycv/models/backbones/mae_vit_transformer.py index 56b1b13e..49975ebf 100644 --- a/easycv/models/backbones/mae_vit_transformer.py +++ b/easycv/models/backbones/mae_vit_transformer.py @@ -73,17 +73,16 @@ class MaskedAutoencoderViT(nn.Module): ]) self.norm = norm_layer(embed_dim) - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - # we use xavier_uniform following official JAX ViT: - torch.nn.init.xavier_uniform_(m.weight) - if isinstance(m, nn.Linear) and m.bias is not None: + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.weight, 1.0) def random_masking(self, x, mask_ratio): """ diff --git a/easycv/models/backbones/mnasnet.py b/easycv/models/backbones/mnasnet.py index f85f2066..e1bdfcce 100644 --- a/easycv/models/backbones/mnasnet.py +++ b/easycv/models/backbones/mnasnet.py @@ -143,8 +143,6 @@ class MNASNet(torch.nn.Module): self.classifier = nn.Sequential( nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes)) - self.init_weights() - self.pretrained = model_urls[self.__class__.__name__ + str(alpha)] def forward(self, x): x = self.layers(x) @@ -155,22 +153,16 @@ class MNASNet(torch.nn.Module): else: return [x] - def init_weights(self, pretrained=None): - if isinstance(pretrained, str) or isinstance(pretrained, dict): - logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_( - m.weight, mode='fan_out', nonlinearity='relu') - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.BatchNorm2d): - nn.init.ones_(m.weight) + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: nn.init.zeros_(m.bias) - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, 0.01) - nn.init.zeros_(m.bias) - else: - raise TypeError('pretrained must be a str or None') + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0.01) + nn.init.zeros_(m.bias) diff --git a/easycv/models/backbones/mobilenetv2.py b/easycv/models/backbones/mobilenetv2.py index f7b14906..0f1d8eb2 100644 --- a/easycv/models/backbones/mobilenetv2.py +++ b/easycv/models/backbones/mobilenetv2.py @@ -156,24 +156,18 @@ class MobileNetV2(nn.Module): self.pretrained = model_urls[self.__class__.__name__ + '_' + str(width_multi)] - def init_weights(self, pretrained=None): - if isinstance(pretrained, str) or isinstance(pretrained, dict): - logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out') - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.BatchNorm2d): - nn.init.ones_(m.weight) + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: nn.init.zeros_(m.bias) - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, 0, 0.01) - nn.init.zeros_(m.bias) - else: - raise TypeError('pretrained must be a str or None') + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) def forward(self, x): x = self.features(x) diff --git a/easycv/models/backbones/pytorch_image_models_wrapper.py b/easycv/models/backbones/pytorch_image_models_wrapper.py index 61908c0b..8ab85450 100644 --- a/easycv/models/backbones/pytorch_image_models_wrapper.py +++ b/easycv/models/backbones/pytorch_image_models_wrapper.py @@ -1,12 +1,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import importlib from distutils.version import LooseVersion import timm import torch import torch.nn as nn +from timm.models.helpers import load_pretrained +from timm.models.hub import download_cached_file from easycv.utils.checkpoint import load_checkpoint -from easycv.utils.logger import get_root_logger +from easycv.utils.logger import get_root_logger, print_log from ..modelzoo import timm_models as model_urls from ..registry import BACKBONES from .shuffle_transformer import (shuffletrans_base_p4_w7_224, @@ -66,8 +69,6 @@ class PytorchImageModelWrapper(nn.Module): def __init__(self, model_name='resnet50', - pretrained=False, - checkpoint_path=None, scriptable=None, exportable=None, no_jit=None, @@ -76,15 +77,16 @@ class PytorchImageModelWrapper(nn.Module): Inits PytorchImageModelWrapper by timm.create_models Args: model_name (str): name of model to instantiate - pretrained (bool): load pretrained ImageNet-1k weights if true - checkpoint_path (str): path of checkpoint to load after model is initialized scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) """ super(PytorchImageModelWrapper, self).__init__() + self.model_name = model_name + timm_model_names = timm.list_models(pretrained=False) + self.timm_model_names = timm_model_names assert model_name in timm_model_names or model_name in _MODEL_MAP, \ f'{model_name} is not in model_list of timm/fair, please check the model_name!' @@ -94,52 +96,52 @@ class PytorchImageModelWrapper(nn.Module): # create model by timm if model_name in timm_model_names: - try: - if pretrained and (model_name in model_urls): - self.model = timm.create_model(model_name, False, '', - scriptable, exportable, - no_jit, **kwargs) - self.init_weights(model_urls[model_name]) - print('Info: Load model from %s' % model_urls[model_name]) - - if checkpoint_path is not None: - self.init_weights(checkpoint_path) - else: - # load from timm - if pretrained and model_name.startswith('swin_') and ( - LooseVersion( - torch.__version__) <= LooseVersion('1.6.0')): - print( - 'Warning: Pretrained SwinTransformer from timm may be zipfile extract' - ' error while torch<=1.6.0') - self.model = timm.create_model(model_name, pretrained, - checkpoint_path, scriptable, - exportable, no_jit, - **kwargs) - - # need fix: delete this except after pytorch 1.7 update in all production - # (dlc, dsw, studio, ev_predict_py3) - except Exception: - print( - f'Error: Fail to create {model_name} with (pretrained={pretrained}, checkpoint_path={checkpoint_path} ...)' - ) - print( - f'Try to create {model_name} with pretrained=False, checkpoint_path=None and default params' - ) - self.model = timm.create_model(model_name, False, '', None, - None, None, **kwargs) - - # facebook model wrapper - if model_name in _MODEL_MAP: + self.model = timm.create_model(model_name, False, '', scriptable, + exportable, no_jit, **kwargs) + elif model_name in _MODEL_MAP: self.model = _MODEL_MAP[model_name](**kwargs) - if pretrained: - if model_name in model_urls.keys(): + + def init_weights(self, pretrained=None): + """ + Args: + if pretrained == True, load model from default path; + if pretrained == False or None, load from init weights. + + if model_name in timm_model_names, load model from timm default path; + if model_name in _MODEL_MAP, load model from easycv default path + """ + logger = get_root_logger() + if pretrained: + if self.model_name in self.timm_model_names: + pretrained_path = model_urls[self.model_name] + print_log( + 'load model from default path: {}'.format(pretrained_path), + logger) + if pretrained_path.endswith('.npz'): + pretrained_loc = download_cached_file( + pretrained_path, check_hash=False, progress=False) + return self.model.load_pretrained(pretrained_loc) + else: + backbone_module = importlib.import_module( + self.model.__module__) + return load_pretrained( + self.model, + default_cfg={'url': pretrained_path}, + filter_fn=backbone_module.checkpoint_filter_fn + if hasattr(backbone_module, 'checkpoint_filter_fn') + else None) + elif self.model_name in _MODEL_MAP: + if self.model_name in model_urls.keys(): + pretrained_path = model_urls[self.model_name] + print_log( + 'load model from default path: {}'.format( + pretrained_path), logger) try_max = 3 try_idx = 0 while try_idx < try_max: try: state_dict = torch.hub.load_state_dict_from_url( - url=model_urls[model_name], + url=pretrained_path, map_location='cpu', ) try_idx += try_max @@ -147,31 +149,22 @@ class PytorchImageModelWrapper(nn.Module): try_idx += 1 state_dict = {} if try_idx == try_max: - print( - 'load from url failed ! oh my DLC & OSS, you boys really good! ', - model_urls[model_name]) + print_log( + f'load from url failed ! oh my DLC & OSS, you boys really good! {model_urls[self.model_name]}', + logger) - # for some model strict = False still failed when model doesn't exactly match - try: - self.model.load_state_dict(state_dict, strict=False) - except Exception: - print('load for model_name not all right') + if 'model' in state_dict: + state_dict = state_dict['model'] + self.model.load_state_dict(state_dict, strict=False) else: - print('%s not in evtorch modelzoo!' % model_name) - - def init_weights(self, pretrained=None): - # pretrained is the path of pretrained model offered by easycv - if pretrained is not None: - logger = get_root_logger() - load_checkpoint( - self.model, - pretrained, - map_location=torch.device('cpu'), - strict=False, - logger=logger) + raise ValueError('{} not in evtorch modelzoo!'.format( + self.model_name)) + else: + raise ValueError( + 'Error: Fail to create {} with (pretrained={}...)'.format( + self.model_name, pretrained)) else: - # init by timm - pass + self.model.init_weights() def forward(self, x): diff --git a/easycv/models/backbones/resnest.py b/easycv/models/backbones/resnest.py index 5c5714f8..eb16d4e1 100644 --- a/easycv/models/backbones/resnest.py +++ b/easycv/models/backbones/resnest.py @@ -470,26 +470,19 @@ class ResNeSt(nn.Module): self.avgpool = GlobalAvgPool2d() self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None self.norm_layer = norm_layer - # self.fc = nn.Linear(512 * block.expansion, num_classes) if num_classes > 0: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) - def init_weights(self, pretrained=None): - if isinstance(pretrained, str) or isinstance(pretrained, dict): - logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - elif isinstance(m, self.norm_layer): - m.weight.data.fill_(1) - m.bias.data.zero_() - else: - raise TypeError('pretrained must be a str or None') + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, self.norm_layer): + m.weight.data.fill_(1) + m.bias.data.zero_() def _make_layer(self, block, @@ -613,7 +606,6 @@ class ResNeSt(nn.Module): if hasattr(self, 'fc'): x = self.avgpool(x) - # x = x.view(x.size(0), -1) x = torch.flatten(x, 1) if self.drop: x = self.drop(x) diff --git a/easycv/models/backbones/resnet.py b/easycv/models/backbones/resnet.py index 3ab712ae..1a0b316c 100644 --- a/easycv/models/backbones/resnet.py +++ b/easycv/models/backbones/resnet.py @@ -403,8 +403,6 @@ class ResNet(nn.Module): self.frelu = frelu self.multi_grid = multi_grid self.contract_dilation = contract_dilation - self.pretrained = model_urls.get(self.__class__.__name__ + str(depth), - None) self._make_stem_layer(in_channels, stem_channels) @@ -518,25 +516,19 @@ class ResNet(nn.Module): for param in m.parameters(): param.requires_grad = False - def init_weights(self, pretrained=None): - if isinstance(pretrained, str) or isinstance(pretrained, dict): - logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - kaiming_init(m, mode='fan_in', nonlinearity='relu') - elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', nonlinearity='relu') + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) - if self.zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck): - constant_init(m.norm3, 0) - elif isinstance(m, BasicBlock): - constant_init(m.norm2, 0) - else: - raise TypeError('pretrained must be a str or None') + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) def forward(self, x): outs = [] diff --git a/easycv/models/backbones/resnet_jit.py b/easycv/models/backbones/resnet_jit.py index e27a2e97..c6f3ceca 100644 --- a/easycv/models/backbones/resnet_jit.py +++ b/easycv/models/backbones/resnet_jit.py @@ -4,7 +4,6 @@ from typing import List import torch import torch.nn as nn from mmcv.cnn import constant_init, kaiming_init -# from mmcv.runner import load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm from easycv.utils.checkpoint import load_checkpoint @@ -183,13 +182,6 @@ class Bottleneck(nn.Module): def forward(self, x): - # if self.with_cp and x.requires_grad: - # out = cp.checkpoint(self._inner_forward, x) - # else: - # out = self._inner_forward(x) - - # out = self.relu(out) - out = self._inner_forward(x) out = self.relu(out) return out @@ -350,8 +342,6 @@ class ResNetJIT(nn.Module): norm_cfg=norm_cfg) self.inplanes = planes * self.block.expansion layer_name = 'layer{}'.format(i + 1) - # self.add_module(layer_name, res_layer) - # self.res_layers.append(layer_name) self.res_layers.add_module(layer_name, res_layer) self._freeze_stages() @@ -361,7 +351,6 @@ class ResNetJIT(nn.Module): @property def norm1(self): - # return getattr(self, self.norm1_name) return getattr(self, 'bn1') def _make_stem_layer(self, in_channels): @@ -375,7 +364,6 @@ class ResNetJIT(nn.Module): bias=False) self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) self.bn1 = norm1 - # self.add_module(self.norm1_name, norm1) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -392,30 +380,23 @@ class ResNetJIT(nn.Module): for param in m.parameters(): param.requires_grad = False - def init_weights(self, pretrained=None): - if isinstance(pretrained, str) or isinstance(pretrained, dict): - logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - kaiming_init(m, mode='fan_in', nonlinearity='relu') - elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', nonlinearity='relu') + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) - if self.zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck): - constant_init(m.norm3, 0) - elif isinstance(m, BasicBlock): - constant_init(m.norm2, 0) - else: - raise TypeError('pretrained must be a str or None') + if self.zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: outs = [] x = self.conv1(x) - # x = self.norm1(x) x = self.bn1(x) x = self.relu(x) # r50: 64x128x128 if 0 in self.out_indices: diff --git a/easycv/models/backbones/shuffle_transformer.py b/easycv/models/backbones/shuffle_transformer.py index a7e30377..98067988 100644 --- a/easycv/models/backbones/shuffle_transformer.py +++ b/easycv/models/backbones/shuffle_transformer.py @@ -428,30 +428,17 @@ class ShuffleTransformer(nn.Module): # Classifier head self.head = nn.Linear( dims[3], num_classes) if num_classes > 0 else nn.Identity() - self.apply(self._init_weights) - def _init_weights(self, m): - if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): - nn.init.constant_(m.weight, 1.0) - nn.init.constant_(m.bias, 0) - elif isinstance(m, (nn.Linear, nn.Conv2d)): - trunc_normal_(m.weight, std=.02) - if isinstance(m, (nn.Linear, nn.Conv2d)) and m.bias is not None: + def init_weights(self): + for m in self.modules(): + if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): + nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.bias, 0) - - def init_weights(self, pretrained=None): - if isinstance(pretrained, str) or isinstance(pretrained, dict): - logger = get_root_logger() - load_checkpoint( - self, - pretrained, - map_location='cpu', - strict=False, - logger=logger) - elif pretrained is None: - self.apply(self._init_weights) - else: - raise TypeError('pretrained must be a str or None') + elif isinstance(m, (nn.Linear, nn.Conv2d)): + trunc_normal_(m.weight, std=.02) + if isinstance(m, + (nn.Linear, nn.Conv2d)) and m.bias is not None: + nn.init.constant_(m.bias, 0) @torch.jit.ignore def no_weight_decay(self): diff --git a/easycv/models/backbones/swin_transformer_dynamic.py b/easycv/models/backbones/swin_transformer_dynamic.py index 20f6967b..d7857dbf 100644 --- a/easycv/models/backbones/swin_transformer_dynamic.py +++ b/easycv/models/backbones/swin_transformer_dynamic.py @@ -770,30 +770,15 @@ class SwinTransformer(nn.Module): if self.use_dense_prediction: self.head_dense = None - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def init_weights_unused(self, pretrained=None): - if isinstance(pretrained, str) or isinstance(pretrained, dict): - logger = get_root_logger() - load_checkpoint( - self, - pretrained, - map_location='cpu', - strict=False, - logger=logger) - elif pretrained is None: - self.apply(self._init_weights) - else: - raise TypeError('pretrained must be a str or None') + nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): @@ -982,80 +967,6 @@ class SwinTransformer(nn.Module): flops += self.num_features * self.num_classes return flops - def init_weights(self, pretrained='', pretrained_layers=[], verbose=True): - if os.path.isfile(pretrained): - pretrained_dict = torch.load(pretrained, map_location='cpu') - logging.info(f'=> loading pretrained model {pretrained}') - model_dict = self.state_dict() - pretrained_dict = { - k: v - for k, v in pretrained_dict.items() if k in model_dict.keys() - } - need_init_state_dict = {} - for k, v in pretrained_dict.items(): - need_init = ( - k.split('.')[0] in pretrained_layers - or pretrained_layers[0] == '*' - or 'relative_position_index' not in k - or 'attn_mask' not in k) - - if need_init: - if verbose: - logging.info(f'=> init {k} from {pretrained}') - - if 'relative_position_bias_table' in k and v.size( - ) != model_dict[k].size(): - relative_position_bias_table_pretrained = v - relative_position_bias_table_current = model_dict[k] - L1, nH1 = relative_position_bias_table_pretrained.size( - ) - L2, nH2 = relative_position_bias_table_current.size() - if nH1 != nH2: - logging.info(f'Error in loading {k}, passing') - else: - if L1 != L2: - logging.info( - '=> load_pretrained: resized variant: {} to {}' - .format((L1, nH1), (L2, nH2))) - S1 = int(L1**0.5) - S2 = int(L2**0.5) - relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( - relative_position_bias_table_pretrained. - permute(1, 0).view(1, nH1, S1, S1), - size=(S2, S2), - mode='bicubic') - v = relative_position_bias_table_pretrained_resized.view( - nH2, L2).permute(1, 0) - - if 'absolute_pos_embed' in k and v.size( - ) != model_dict[k].size(): - absolute_pos_embed_pretrained = v - absolute_pos_embed_current = model_dict[k] - _, L1, C1 = absolute_pos_embed_pretrained.size() - _, L2, C2 = absolute_pos_embed_current.size() - if C1 != C1: - logging.info(f'Error in loading {k}, passing') - else: - if L1 != L2: - logging.info( - '=> load_pretrained: resized variant: {} to {}' - .format((1, L1, C1), (1, L2, C2))) - S1 = int(L1**0.5) - S2 = int(L2**0.5) - absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape( - -1, S1, S1, C1) - absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute( - 0, 3, 1, 2) - absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( - absolute_pos_embed_pretrained, - size=(S2, S2), - mode='bicubic') - v = absolute_pos_embed_pretrained_resized.permute( - 0, 2, 3, 1).flatten(1, 2) - - need_init_state_dict[k] = v - self.load_state_dict(need_init_state_dict, strict=False) - def freeze_pretrained_layers(self, frozen_layers=[]): for name, module in self.named_modules(): if (name.split('.')[0] in frozen_layers @@ -1075,43 +986,6 @@ class SwinTransformer(nn.Module): return self -# @register_model -# def get_cls_model(config, is_teacher=False, use_dense_prediction=False, **kwargs): -# swin_spec = config.MODEL.SPEC -# swin = SwinTransformer( -# img_size=config.TRAIN.IMAGE_SIZE[0], -# in_chans=3, -# num_classes=config.MODEL.NUM_CLASSES, -# patch_size=swin_spec['PATCH_SIZE'], -# embed_dim=swin_spec['DIM_EMBED'], -# depths=swin_spec['DEPTHS'], -# num_heads=swin_spec['NUM_HEADS'], -# window_size=swin_spec['WINDOW_SIZE'], -# mlp_ratio=swin_spec['MLP_RATIO'], -# qkv_bias=swin_spec['QKV_BIAS'], -# drop_rate=swin_spec['DROP_RATE'], -# attn_drop_rate=swin_spec['ATTN_DROP_RATE'], -# drop_path_rate= 0.0 if is_teacher else swin_spec['DROP_PATH_RATE'], -# norm_layer=partial(nn.LayerNorm, eps=1e-6), -# ape=swin_spec['USE_APE'], -# patch_norm=swin_spec['PATCH_NORM'], -# use_dense_prediction=use_dense_prediction, -# ) - -# if config.MODEL.INIT_WEIGHTS: -# swin.init_weights( -# config.MODEL.PRETRAINED, -# config.MODEL.PRETRAINED_LAYERS, -# config.VERBOSE -# ) - -# # freeze the specified pre-trained layers (if any) -# if config.FINETUNE.FINETUNE: -# swin.freeze_pretrained_layers(config.FINETUNE.FROZEN_LAYERS) - -# return swin - - def dynamic_swin_tiny_p4_w7_224(pretrained=False, **kwargs): model = SwinTransformer( img_size=224, diff --git a/easycv/models/backbones/vit_transfomer_dynamic.py b/easycv/models/backbones/vit_transfomer_dynamic.py index 0dda79ac..0e2badca 100644 --- a/easycv/models/backbones/vit_transfomer_dynamic.py +++ b/easycv/models/backbones/vit_transfomer_dynamic.py @@ -241,30 +241,16 @@ class VisionTransformer(nn.Module): trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) - self.apply(self._init_weights) - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def init_weights(self, pretrained=None): - if isinstance(pretrained, str) or isinstance(pretrained, dict): - logger = get_root_logger() - load_checkpoint( - self, - pretrained, - map_location='cpu', - strict=False, - logger=logger) - elif pretrained is None: - self.apply(self._init_weights) - else: - raise TypeError('pretrained must be a str or None') + nn.init.constant_(m.weight, 1.0) def forward(self, x): # convert to list diff --git a/easycv/models/backbones/xcit_transformer.py b/easycv/models/backbones/xcit_transformer.py index 0588b150..4c06c5b7 100644 --- a/easycv/models/backbones/xcit_transformer.py +++ b/easycv/models/backbones/xcit_transformer.py @@ -476,30 +476,16 @@ class XCiT(nn.Module): # Classifier head trunc_normal_(self.cls_token, std=.02) - self.apply(self._init_weights) - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def init_weights(self, pretrained=None): - if isinstance(pretrained, str) or isinstance(pretrained, dict): - logger = get_root_logger() - load_checkpoint( - self, - pretrained, - map_location='cpu', - strict=False, - logger=logger) - elif pretrained is None: - self.apply(self._init_weights) - else: - raise TypeError('pretrained must be a str or None') + nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): diff --git a/easycv/models/classification/classification.py b/easycv/models/classification/classification.py index 4d7aaabb..a339e560 100644 --- a/easycv/models/classification/classification.py +++ b/easycv/models/classification/classification.py @@ -7,7 +7,8 @@ import torch.nn as nn from mmcv.runner import get_dist_info from timm.data.mixup import Mixup -from easycv.utils.logger import print_log +from easycv.utils.checkpoint import load_checkpoint +from easycv.utils.logger import get_root_logger, print_log from easycv.utils.preprocess_function import (bninceptionPre, gaussianBlur, mixUpCls, randomErasing) from .. import builder @@ -16,17 +17,15 @@ from ..registry import MODELS from ..utils import Sobel -def distill_loss(cls_score, teacher_score, tempreature=1.0): - """ Soft cross entropy loss - """ - log_prob = torch.nn.functional.log_softmax(cls_score / tempreature, dim=-1) - targets_prob = torch.nn.functional.softmax( - teacher_score / tempreature, dim=-1) - return (torch.sum(-targets_prob * log_prob, dim=1)).mean() - - @MODELS.register_module class Classification(BaseModel): + """ + Args: + pretrained: Select one {str or True or False/None}. + if pretrained == str, load model from specified path; + if pretrained == True, load model from default path(currently only supports timm); + if pretrained == False or None, load from init weights. + """ def __init__(self, backbone, @@ -34,11 +33,11 @@ class Classification(BaseModel): with_sobel=False, head=None, neck=None, - teacher=None, pretrained=None, mixup_cfg=None): super(Classification, self).__init__() self.with_sobel = with_sobel + self.pretrained = pretrained if with_sobel: self.sobel_layer = Sobel() else: @@ -74,6 +73,7 @@ class Classification(BaseModel): self.train_preprocess = [ self.preprocess_key_map[i] for i in train_preprocess ] + self.backbone = builder.build_backbone(backbone) assert head is not None, 'Classification head should be configed' @@ -103,29 +103,23 @@ class Classification(BaseModel): for idx, n in enumerate(tmp_neck_list): setattr(self, 'neck_%d' % idx, n) - if teacher is not None: - self.temperature = teacher.pop('temperature', 1) - self.teacher_loss_weight = teacher.pop('loss_weight', 1.0) - teacher_pretrained = teacher.pop('pretrained', None) - - self.teacher = builder.build_backbone(teacher) - if teacher_pretrained is None: - self.teacher.init_weights(pretrained=self.teacher.pretrained) - else: - self.teacher.init_weights(pretrained=teacher_pretrained) - self.teacher.eval() - else: - self.teacher = None - - self.init_weights(pretrained=pretrained) self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.activate_fn = nn.Softmax(dim=1) self.extract_list = ['neck'] - def init_weights(self, pretrained=None): - if pretrained is not None: - print_log('load model from: {}'.format(pretrained), logger='root') - self.backbone.init_weights(pretrained=pretrained) + self.init_weights() + + def init_weights(self): + if isinstance(self.pretrained, str): + logger = get_root_logger() + load_checkpoint( + self.backbone, self.pretrained, strict=False, logger=logger) + else: + print_log('load model from init weights') + if self.backbone.__class__.__name__ == 'PytorchImageModelWrapper': + self.backbone.init_weights(pretrained=self.pretrained) + else: + self.backbone.init_weights() for idx in range(self.head_num): h = getattr(self, 'head_%d' % idx) @@ -190,13 +184,6 @@ class Classification(BaseModel): else: losses['loss'] = hlosses['loss'] - # need to check this head can be teacher - if self.teacher is not None: - with torch.no_grad(): - teacher_outs = self.teacher(img)[0].detach() - losses['loss'] += self.teacher_loss_weight * distill_loss( - outs[0], teacher_outs, self.temperature) - return losses # @torch.jit.unused diff --git a/easycv/models/modelzoo.py b/easycv/models/modelzoo.py index 51c0e0a2..9fdd8590 100644 --- a/easycv/models/modelzoo.py +++ b/easycv/models/modelzoo.py @@ -10,15 +10,6 @@ resnet = { 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet101.pth', 'ResNet152': 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet152.pth', - # 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', - # 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', - # 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', - # 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', - # 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', - # 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', - # 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', - # 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', - # 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', } mobilenetv2 = { @@ -67,26 +58,134 @@ resnest = { } timm_models = { - 'resnet50': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/resnet50_ram-a26f946b.pth', - 'resnet18': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet18.pth', - 'resnet34': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet34.pth', - 'resnet101': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet101.pth', - 'resnet152': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet152.pth', - 'swin_base_patch4_window7_224_in22k': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/swin_base_patch4_window7_224_22k_statedict.pth', - 'swin_small_patch4_window7_224': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/swin_small_patch4_window7_224_statedict.pth', - 'swin_tiny_patch4_window7_224': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/swin_tiny_patch4_window7_224_statedict.pth', - 'vit_deit_tiny_distilled_patch16_224': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/vit_deit_tiny_distilled_patch16_224_pytorch151.pth', - 'vit_deit_small_distilled_patch16_224': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/vit_deit_small_distilled_patch16_224_pytorch151.pth', + 'vit_base_patch16_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/vit/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', + 'vit_large_patch16_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/vit/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + 'deit_base_patch16_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_patch16_224-b5f2ef4d.pth', + 'deit_base_distilled_patch16_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/deit/deit_base_distilled_patch16_224-df68dfff.pth', + 'swin_base_patch4_window7_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_base_patch4_window7_224_22kto1k.pth', + 'swin_large_patch4_window7_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_large_patch4_window7_224_22kto1k.pth', + 'swin_v2_cr_small_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_v2_cr_small_224-0813c165.pth', + 'swin_v2_cr_small_ns_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth', + 'swin_v2_cr_tiny_ns_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/swint/swin_v2_cr_tiny_ns_224-ba8166c6.pth', + 'xcit_medium_24_p8_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_medium_24_p8_224.pth', + 'xcit_medium_24_p8_224_dist': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_medium_24_p8_224_dist.pth', + 'xcit_large_24_p8_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_large_24_p8_224.pth', + 'xcit_large_24_p8_224_dist': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/xcit/xcit_large_24_p8_224_dist.pth', + 'twins_svt_small': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/twins/twins_svt_small-42e5f78c.pth', + 'twins_svt_base': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/twins/twins_svt_base-c2265010.pth', + 'twins_svt_large': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/twins/twins_svt_large-90f6aaa9.pth', + 'tnt_s_patch16_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/tnt/tnt_s_patch16_224.pth.tar', + 'pit_s_distilled_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/pit/pit_s_distill_819.pth', + 'pit_b_distilled_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/pit/pit_b_distill_840.pth', + 'jx_nest_tiny': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/nest/jx_nest_tiny-e3428fb9.pth', + 'jx_nest_small': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/nest/jx_nest_small-422eaded.pth', + 'jx_nest_base': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/nest/jx_nest_base-8bc41011.pth', + 'crossvit_tiny_240': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/crossvit/crossvit_tiny_224.pth', + 'crossvit_small_240': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/crossvit/crossvit_small_224.pth', + 'crossvit_base_240': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/crossvit/crossvit_base_224.pth', + 'convit_tiny': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convit/convit_tiny.pth', + 'convit_small': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convit/convit_small.pth', + 'convit_base': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convit/convit_base.pth', + 'coat_tiny': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/coat/coat_tiny-473c2a20.pth', + 'coat_mini': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/coat/coat_mini-2c6baf49.pth', + 'cait_xxs24_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/cait/XXS24_224.pth', + 'cait_xxs36_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/cait/XXS36_224.pth', + 'cait_s24_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/cait/S24_224.pth', + 'levit_128': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/levit/LeViT-128-b88c2750.pth', + 'levit_192': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/levit/LeViT-192-92712e41.pth', + 'levit_256': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/levit/LeViT-256-13b5763e.pth', + 'convmixer_1536_20': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convmixer/convmixer_1536_20_ks9_p7.pth.tar', + 'convmixer_768_32': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convmixer/convmixer_768_32_ks7_p7_relu.pth.tar', + 'convmixer_1024_20_ks9_p14': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convmixer/convmixer_1024_20_ks9_p14.pth.tar', + 'convnext_tiny': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_tiny_1k_224_ema.pth', + 'convnext_small': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_small_1k_224_ema.pth', + 'convnext_base': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_base_1k_224_ema.pth', + 'convnext_large': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/convnext/convnext_large_1k_224_ema.pth', + 'mixer_b16_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/mlp-mixer/jx_mixer_b16_224-76587d61.pth', + 'mixer_l16_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/mlp-mixer/jx_mixer_l16_224-92f9adc4.pth', + 'gmixer_24_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/gmixer/gmixer_24_224_raa-7daf7ae6.pth', + 'resmlp_12_distilled_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlp_12_dist.pth', + 'resmlp_24_distilled_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlp_24_dist.pth', + 'resmlp_36_distilled_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlp_36_dist.pth', + 'resmlp_big_24_distilled_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/resmlp/resmlpB_24_dist.pth', + 'gmlp_s16_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/gmlp/gmlp_s16_224_raa-10536d42.pth', + 'mobilevit_xxs': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/mobilevit/mobilevit_xxs-ad385b40.pth', + 'mobilevit_xs': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/mobilevit/mobilevit_xs-8fbd6366.pth', + 'mobilevit_s': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/mobilevit/mobilevit_s-38a5a959.pth', + 'poolformer_s12': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/poolformer/poolformer_s12.pth.tar', + 'poolformer_s24': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/poolformer/poolformer_s24.pth.tar', + 'poolformer_s36': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/poolformer/poolformer_s36.pth.tar', + 'poolformer_m36': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/poolformer/poolformer_m36.pth.tar', + 'poolformer_m48': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/poolformer/poolformer_m48.pth.tar', + 'volo_d1_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/volo/d1_224_84.2.pth.tar', + 'volo_d2_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/volo/d2_224_85.2.pth.tar', + 'volo_d3_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/volo/d3_224_85.4.pth.tar', + 'volo_d4_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/volo/d4_224_85.7.pth.tar', + 'volo_d5_224': + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/timm/volo/d5_224_86.10.pth.tar', # facebook xcit 'xcit_small_12_p16': @@ -97,16 +196,14 @@ timm_models = { 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/xcit/dino_xcit_medium_24_p16_pretrain.pth', # 512 'xcit_medium_24_p8': 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/xcit/dino_xcit_medium_24_p8_pretrain.pth', # 512 - 'xcit_large_24_p8': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/xcit/xcit_large_24_p8_224.pth', # shuffle_trans 'shuffletrans_base_p4_w7_224': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/imagenet/shuffle_transformer/shuffle_base.pth', + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/shuffle_transformer/shuffle_base.pth', 'shuffletrans_small_p4_w7_224': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/imagenet/shuffle_transformer/shuffle_small.pth', + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/shuffle_transformer/shuffle_small.pth', 'shuffletrans_tiny_p4_w7_224': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/imagenet/shuffle_transformer/shuffle_tiny.pth', + 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/shuffle_transformer/shuffle_tiny.pth', # dynamic swint: 'dynamic_swin_base_p4_w7_224': @@ -115,20 +212,4 @@ timm_models = { 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/swin_small_patch4_window7_224_statedict.pth', 'dynamic_swin_tiny_p4_w7_224': 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/swin_tiny_patch4_window7_224_statedict.pth', - - # dynamic vit: - 'dynamic_deit_tiny_p16': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/vit_deit_tiny_distilled_patch16_224_pytorch151.pth', - 'dynamic_deit_small_p16': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/vit_deit_small_distilled_patch16_224_pytorch151.pth', - 'resnet50': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/timm/resnet50_ram-a26f946b.pth', - 'resnet18': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet18.pth', - 'resnet34': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet34.pth', - 'resnet101': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet101.pth', - 'resnet152': - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet152.pth', } diff --git a/easycv/models/pose/top_down.py b/easycv/models/pose/top_down.py index 91cb89db..ad6647ad 100644 --- a/easycv/models/pose/top_down.py +++ b/easycv/models/pose/top_down.py @@ -13,6 +13,8 @@ from easycv.core.visualization import imshow_bboxes, imshow_keypoints from easycv.models import builder from easycv.models.base import BaseModel from easycv.models.builder import MODELS +from easycv.utils.checkpoint import load_checkpoint +from easycv.utils.logger import get_root_logger @MODELS.register_module() @@ -39,6 +41,7 @@ class TopDown(BaseModel): loss_pose=None): super().__init__() self.fp16_enabled = False + self.pretrained = pretrained self.backbone = builder.build_backbone(backbone) @@ -62,7 +65,7 @@ class TopDown(BaseModel): self.keypoint_head = builder.build_head(keypoint_head) - self.init_weights(pretrained=pretrained) + self.init_weights() @property def with_neck(self): @@ -74,9 +77,14 @@ class TopDown(BaseModel): """Check if has keypoint_head.""" return hasattr(self, 'keypoint_head') - def init_weights(self, pretrained=None): + def init_weights(self): """Weight initialization for model.""" - self.backbone.init_weights(pretrained) + if isinstance(self.pretrained, str): + logger = get_root_logger() + load_checkpoint( + self.backbone, self.pretrained, strict=False, logger=logger) + else: + self.backbone.init_weights() if self.with_neck: self.neck.init_weights() if self.with_keypoint: diff --git a/easycv/models/selfsup/byol.py b/easycv/models/selfsup/byol.py index 7d19895b..240e8d92 100644 --- a/easycv/models/selfsup/byol.py +++ b/easycv/models/selfsup/byol.py @@ -2,7 +2,8 @@ import torch import torch.nn as nn -from easycv.utils.logger import print_log +from easycv.utils.checkpoint import load_checkpoint +from easycv.utils.logger import get_root_logger, print_log from .. import builder from ..base import BaseModel from ..registry import MODELS @@ -21,6 +22,8 @@ class BYOL(BaseModel): base_momentum=0.996, **kwargs): super(BYOL, self).__init__() + + self.pretrained = pretrained self.online_net = nn.Sequential( builder.build_backbone(backbone), builder.build_neck(neck)) self.target_net = nn.Sequential( @@ -29,15 +32,21 @@ class BYOL(BaseModel): for param in self.target_net.parameters(): param.requires_grad = False self.head = builder.build_head(head) - self.init_weights(pretrained=pretrained) + self.init_weights() self.base_momentum = base_momentum self.momentum = base_momentum - def init_weights(self, pretrained=None): - if pretrained is not None: - print_log('load model from: {}'.format(pretrained), logger='root') - self.online_net[0].init_weights(pretrained=pretrained) # backbone + def init_weights(self): + if isinstance(self.pretrained, str): + logger = get_root_logger() + load_checkpoint( + self.online_net[0], + self.pretrained, + strict=False, + logger=logger) + else: + self.online_net[0].init_weights() self.online_net[1].init_weights(init_linear='kaiming') # projection for param_ol, param_tgt in zip(self.online_net.parameters(), self.target_net.parameters()): diff --git a/easycv/models/selfsup/mixco.py b/easycv/models/selfsup/mixco.py index 958e1315..9a970bc8 100644 --- a/easycv/models/selfsup/mixco.py +++ b/easycv/models/selfsup/mixco.py @@ -4,6 +4,8 @@ import torch import torch.nn as nn from mmcv.runner import get_dist_info +from easycv.utils.checkpoint import load_checkpoint +from easycv.utils.logger import get_root_logger from easycv.utils.preprocess_function import mixUp from .. import builder from ..registry import MODELS diff --git a/easycv/models/selfsup/moby.py b/easycv/models/selfsup/moby.py index 32ebb4a5..5ff9307e 100644 --- a/easycv/models/selfsup/moby.py +++ b/easycv/models/selfsup/moby.py @@ -3,7 +3,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from easycv.utils.logger import print_log +from easycv.utils.checkpoint import load_checkpoint +from easycv.utils.logger import get_root_logger, print_log from easycv.utils.preprocess_function import gaussianBlur, randomGrayScale from .. import builder from ..base import BaseModel @@ -46,6 +47,8 @@ class MoBY(BaseModel): super(MoBY, self).__init__() + self.pretrained = pretrained + self.preprocess_key_map = { 'randomGrayScale': randomGrayScale, 'gaussianBlur': gaussianBlur @@ -92,7 +95,7 @@ class MoBY(BaseModel): nn.SyncBatchNorm.convert_sync_batchnorm(self.predictor) # set parameters - self.init_weights(pretrained=pretrained) + self.init_weights() self.queue_len = queue_len self.momentum = momentum self.contrast_temperature = contrast_temperature @@ -112,10 +115,13 @@ class MoBY(BaseModel): self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) - def init_weights(self, pretrained=None): - if pretrained is not None: - print_log('load model from: {}'.format(pretrained), logger='root') - self.encoder_q.init_weights(pretrained=pretrained) + def init_weights(self): + if isinstance(self.pretrained, str): + logger = get_root_logger() + load_checkpoint( + self.encoder_q, self.pretrained, strict=False, logger=logger) + else: + self.encoder_q.init_weights() for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): param_k.data.copy_(param_q.data) diff --git a/easycv/models/selfsup/moco.py b/easycv/models/selfsup/moco.py index d380b5ba..9b9b5864 100644 --- a/easycv/models/selfsup/moco.py +++ b/easycv/models/selfsup/moco.py @@ -2,7 +2,8 @@ import torch import torch.nn as nn -from easycv.utils.logger import print_log +from easycv.utils.checkpoint import load_checkpoint +from easycv.utils.logger import get_root_logger, print_log from easycv.utils.preprocess_function import gaussianBlur, randomGrayScale from .. import builder from ..base import BaseModel @@ -28,6 +29,8 @@ class MOCO(BaseModel): **kwargs): super(MOCO, self).__init__() + self.pretrained = pretrained + self.preprocess_key_map = { 'randomGrayScale': randomGrayScale, 'gaussianBlur': gaussianBlur @@ -43,7 +46,7 @@ class MOCO(BaseModel): for param in self.encoder_k.parameters(): param.requires_grad = False self.head = builder.build_head(head) - self.init_weights(pretrained=pretrained) + self.init_weights() self.queue_len = queue_len self.momentum = momentum @@ -54,10 +57,16 @@ class MOCO(BaseModel): self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) - def init_weights(self, pretrained=None): - if pretrained is not None: - print_log('load model from: {}'.format(pretrained), logger='root') - self.encoder_q[0].init_weights(pretrained=pretrained) + def init_weights(self): + if isinstance(self.pretrained, str): + logger = get_root_logger() + load_checkpoint( + self.encoder_q[0], + self.pretrained, + strict=False, + logger=logger) + else: + self.encoder_q[0].init_weights() self.encoder_q[1].init_weights(init_linear='kaiming') for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): diff --git a/easycv/models/selfsup/simclr.py b/easycv/models/selfsup/simclr.py index 348ff579..ac824e9d 100644 --- a/easycv/models/selfsup/simclr.py +++ b/easycv/models/selfsup/simclr.py @@ -1,7 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import torch -from easycv.utils.logger import print_log +from easycv.utils.checkpoint import load_checkpoint +from easycv.utils.logger import get_root_logger, print_log from easycv.utils.preprocess_function import gaussianBlur, randomGrayScale from .. import builder from ..base import BaseModel @@ -19,6 +20,7 @@ class SimCLR(BaseModel): head=None, pretrained=None): super(SimCLR, self).__init__() + self.pretrained = pretrained self.backbone = builder.build_backbone(backbone) self.preprocess_key_map = { @@ -30,7 +32,7 @@ class SimCLR(BaseModel): ] self.neck = builder.build_neck(neck) self.head = builder.build_head(head) - self.init_weights(pretrained=pretrained) + self.init_weights() @staticmethod def _create_buffer(N): @@ -42,10 +44,13 @@ class SimCLR(BaseModel): neg_mask[pos_ind] = 0 return mask, pos_ind, neg_mask - def init_weights(self, pretrained=None): - if pretrained is not None: - print_log('load model from: {}'.format(pretrained), logger='root') - self.backbone.init_weights(pretrained=pretrained) + def init_weights(self): + if isinstance(self.pretrained, str): + logger = get_root_logger() + load_checkpoint( + self.backbone, self.pretrained, strict=False, logger=logger) + else: + self.backbone.init_weights() self.neck.init_weights(init_linear='kaiming') def forward_backbone(self, img): diff --git a/easycv/models/selfsup/swav.py b/easycv/models/selfsup/swav.py index d68ba760..2e3e9f56 100644 --- a/easycv/models/selfsup/swav.py +++ b/easycv/models/selfsup/swav.py @@ -5,7 +5,8 @@ import torch.distributed as dist import torch.nn as nn from mmcv.runner import get_dist_info -from easycv.utils.logger import print_log +from easycv.utils.checkpoint import load_checkpoint +from easycv.utils.logger import get_root_logger, print_log from easycv.utils.preprocess_function import gaussianBlur, randomGrayScale from .. import builder from ..base import BaseModel @@ -22,6 +23,7 @@ class SWAV(BaseModel): config=None, pretrained=None): super(SWAV, self).__init__() + self.pretrained = pretrained self.backbone = builder.build_backbone(backbone) self.preprocess_key_map = { @@ -49,13 +51,16 @@ class SWAV(BaseModel): self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.softmax = nn.Softmax(dim=1).cuda() self.use_the_queue = False - self.init_weights(pretrained=pretrained) + self.init_weights() self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) - def init_weights(self, pretrained=None): - if pretrained is not None: - print_log('load model from: {}'.format(pretrained), logger='root') - self.backbone.init_weights(pretrained=pretrained) + def init_weights(self): + if isinstance(self.pretrained, str): + logger = get_root_logger() + load_checkpoint( + self.backbone, self.pretrained, strict=False, logger=logger) + else: + self.backbone.init_weights() self.neck.init_weights(init_linear='kaiming') # if torch.load(pretrained).get("prototypes.weight", None) is not None: diff --git a/easycv/utils/checkpoint.py b/easycv/utils/checkpoint.py index 5a60f713..4bf0af60 100644 --- a/easycv/utils/checkpoint.py +++ b/easycv/utils/checkpoint.py @@ -1,12 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os -import time import torch from mmcv.parallel import is_module_wrapper from mmcv.runner import load_checkpoint as mmcv_load_checkpoint -from mmcv.runner.checkpoint import (_save_to_state_dict, get_state_dict, - weights_to_cpu) +from mmcv.runner.checkpoint import get_state_dict, weights_to_cpu from torch.optim import Optimizer from easycv.file import io diff --git a/tests/models/backbones/test_bninception.py b/tests/models/backbones/test_bninception.py index 2c587b00..c923b2c2 100644 --- a/tests/models/backbones/test_bninception.py +++ b/tests/models/backbones/test_bninception.py @@ -49,12 +49,10 @@ class BNInceptionTest(unittest.TestCase): original_weight = net.conv2_3x3.weight original_weight = copy.deepcopy(original_weight.cpu().data.numpy()) - net.init_weights(net.pretrained) + net.init_weights() load_weight = net.conv2_3x3.weight.cpu().data.numpy() self.assertFalse(np.allclose(original_weight, load_weight)) - self.assertTrue( - net.pretrained == modelzoo.bninception['BNInception']) if __name__ == '__main__': diff --git a/tests/models/backbones/test_genet.py b/tests/models/backbones/test_genet.py index c01f9d79..6880109b 100644 --- a/tests/models/backbones/test_genet.py +++ b/tests/models/backbones/test_genet.py @@ -56,15 +56,12 @@ class PlainNetTest(unittest.TestCase): original_weight = copy.deepcopy( original_weight.cpu().data.numpy()) - net.init_weights(net.pretrained) + net.init_weights() load_weight = net.module_list[0].netblock.weight.cpu( ).data.numpy() self.assertFalse(np.allclose(original_weight, load_weight)) - self.assertTrue(net.pretrained == modelzoo.genet['PlainNet' + - struct]) - if __name__ == '__main__': unittest.main() diff --git a/tests/models/backbones/test_inceptionv3.py b/tests/models/backbones/test_inceptionv3.py index 3a332e99..e1ac422b 100644 --- a/tests/models/backbones/test_inceptionv3.py +++ b/tests/models/backbones/test_inceptionv3.py @@ -53,14 +53,11 @@ class InceptionV3Test(unittest.TestCase): original_weight = net.Conv2d_1a_3x3.conv.weight original_weight = copy.deepcopy(original_weight.cpu().data.numpy()) - net.init_weights(net.pretrained) + net.init_weights() load_weight = net.Conv2d_1a_3x3.conv.weight.cpu().data.numpy() self.assertFalse(np.allclose(original_weight, load_weight)) - self.assertTrue( - net.pretrained == modelzoo.inceptionv3['Inception3']) - if __name__ == '__main__': unittest.main() diff --git a/tests/models/backbones/test_mnasnet.py b/tests/models/backbones/test_mnasnet.py index a061df30..2308896c 100644 --- a/tests/models/backbones/test_mnasnet.py +++ b/tests/models/backbones/test_mnasnet.py @@ -56,14 +56,11 @@ class MnasnetTest(unittest.TestCase): original_weight = copy.deepcopy( original_weight.cpu().data.numpy()) - net.init_weights(net.pretrained) + net.init_weights() load_weight = net.layers[0].weight.cpu().data.numpy() self.assertFalse(np.allclose(original_weight, load_weight)) - self.assertTrue(net.pretrained == modelzoo.mnasnet['MNASNet' + - '1.0']) - if __name__ == '__main__': unittest.main() diff --git a/tests/models/backbones/test_mobilenetv2.py b/tests/models/backbones/test_mobilenetv2.py index e6d7ff4c..439d6a91 100644 --- a/tests/models/backbones/test_mobilenetv2.py +++ b/tests/models/backbones/test_mobilenetv2.py @@ -61,15 +61,11 @@ class MobileNetTest(unittest.TestCase): original_weight = copy.deepcopy( original_weight.cpu().data.numpy()) - net.init_weights(net.pretrained) + net.init_weights() load_weight = net.features[0][0].weight.cpu().data.numpy() self.assertFalse(np.allclose(original_weight, load_weight)) - self.assertTrue( - net.pretrained == modelzoo.mobilenetv2['MobileNetV2_' + - str(width_multi)]) - if __name__ == '__main__': unittest.main() diff --git a/tests/models/backbones/test_pytorch_image_models_wrapper.py b/tests/models/backbones/test_pytorch_image_models_wrapper.py index 51dcf7d7..c26790a5 100644 --- a/tests/models/backbones/test_pytorch_image_models_wrapper.py +++ b/tests/models/backbones/test_pytorch_image_models_wrapper.py @@ -20,7 +20,6 @@ class PytorchImageModelWrapperTest(unittest.TestCase): net = PytorchImageModelWrapper( model_name='swin_tiny_patch4_window7_224', - pretrained=False, num_classes=0, global_pool='').to('cuda') net.eval() @@ -39,9 +38,7 @@ class PytorchImageModelWrapperTest(unittest.TestCase): a = torch.rand(batch_size, 3, 224, 224).to('cuda') net = PytorchImageModelWrapper( - model_name='efficientnet_b0', - pretrained=False, - num_classes=0, + model_name='efficientnet_b0', num_classes=0, global_pool='').to('cuda') net.eval() @@ -66,7 +63,6 @@ class PytorchImageModelWrapperTest(unittest.TestCase): # swin_tiny_patch4_window7_224 net = PytorchImageModelWrapper( model_name='swin_tiny_patch4_window7_224', - pretrained=True, num_classes=0, global_pool='').to('cuda') net.eval() @@ -74,7 +70,6 @@ class PytorchImageModelWrapperTest(unittest.TestCase): net_random_init = PytorchImageModelWrapper( model_name='swin_tiny_patch4_window7_224', - pretrained=False, num_classes=0, global_pool='').to('cuda') net_random_init.eval() diff --git a/tests/models/selfsup/test_moby.py b/tests/models/selfsup/test_moby.py index ade5932e..313a10e1 100644 --- a/tests/models/selfsup/test_moby.py +++ b/tests/models/selfsup/test_moby.py @@ -14,11 +14,11 @@ _base_model_cfg = dict( momentum=0.99, pretrained=None, backbone=dict( - type='PytorchImageModelWrapper', - model_name='resnet50', # 2048 - num_classes=0, - pretrained=True, - ), + type='ResNet', + depth=50, + in_channels=3, + out_indices=[4], # 0: conv-1, x: stage-x + norm_cfg=dict(type='BN')), neck=dict( type='MoBYMLP', in_channels=2048, diff --git a/tools/test_inference_time.py b/tools/test_inference_time.py new file mode 100644 index 00000000..647108dd --- /dev/null +++ b/tools/test_inference_time.py @@ -0,0 +1,68 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import argparse + +import numpy as np +import torch +import tqdm +from torch.backends import cudnn + +from easycv.models import build_model +from easycv.utils.config_tools import mmcv_config_fromfile + +cudnn.benchmark = True + + +def parse_args(): + parser = argparse.ArgumentParser( + description='EasyCV model memory and inference_time test') + parser.add_argument('config', help='test config file path') + parser.add_argument( + 'gpu', type=str, choices=['0', '1', '2', '3', '4', '5', '6', '7']) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + cfg = mmcv_config_fromfile(args.config) + + device = torch.device('cuda:{}'.format(args.gpu)) + model = build_model(cfg.model).to(device) + repetitions = 300 + + dummy_input = torch.rand(1, 3, 224, 224).to(device) + + # Preheat: GPU may be hibernated to save energy at ordinary times, so it needs to be preheated. + print('warm up ...\n') + with torch.no_grad(): + for _ in range(100): + _ = model.forward_test(dummy_input) + + # Synchronize Waits for all GPU tasks to complete before returning to the CPU main thread. + torch.cuda.synchronize() + + # Set up cuda events for measuring time. This is PyTorch's official recommended interface and should theoretically be the most reliable. + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event( + enable_timing=True) + # Initialize a time container. + timings = np.zeros((repetitions, 1)) + + print('testing ...\n') + with torch.no_grad(): + for rep in tqdm.tqdm(range(repetitions)): + starter.record() + _ = model.forward_test(dummy_input) + ender.record() + torch.cuda.synchronize() # Wait for the GPU task to complete. + curr_time = starter.elapsed_time( + ender) # The time between starter and ender, in milliseconds. + timings[rep] = curr_time + + avg = timings.sum() / repetitions + print(torch.cuda.memory_summary(device)) + print('\navg={}\n'.format(avg)) + + +if __name__ == '__main__': + main() diff --git a/tools/train.py b/tools/train.py index 4c92ba9d..a95134ba 100644 --- a/tools/train.py +++ b/tools/train.py @@ -39,7 +39,6 @@ def parse_args(): parser = argparse.ArgumentParser(description='Train a model') parser.add_argument( 'config', help='train config file path', type=str, default=None) - # parser.add_argument('--config', help='train config file path', default="configs/classification/imagenet/r50.py") parser.add_argument( '--work_dir', type=str,