From 9f01a37ad4df57b30430c41df08459025174e8fd Mon Sep 17 00:00:00 2001 From: tuofeilun <38110862+tuofeilunhifi@users.noreply.github.com> Date: Fri, 16 Sep 2022 11:03:53 +0800 Subject: [PATCH] Refactor ViTDet backbone and simple feature pyramid (#177) 1. The vitdet backbone implemented by d2 is about 20% faster than the vitdet backbone originally reproduced by easycv. 2. 50.57 -> 50.65 --- .../detection/vitdet/lsj_coco_detection.py | 6 +- configs/detection/vitdet/lsj_coco_instance.py | 6 +- .../vitdet/vitdet_basicblock_100e.py | 3 - .../vitdet/vitdet_bottleneck_100e.py | 3 - .../vitdet/vitdet_cascade_mask_rcnn.py | 231 ++++ .../vitdet/vitdet_cascade_mask_rcnn_100e.py | 4 + .../detection/vitdet/vitdet_faster_rcnn.py | 31 +- .../vitdet/vitdet_faster_rcnn_100e.py | 2 +- configs/detection/vitdet/vitdet_mask_rcnn.py | 31 +- ...itdet_100e.py => vitdet_mask_rcnn_100e.py} | 0 .../detection/vitdet/vitdet_schedule_100e.py | 21 +- docs/source/_static/result.jpg | 4 +- docs/source/model_zoo_det.md | 2 +- .../layer_decay_optimizer_constructor.py | 78 +- easycv/models/backbones/vitdet.py | 1057 ++++++----------- easycv/models/detection/necks/fpn.py | 3 - easycv/models/detection/necks/sfp.py | 216 +--- easycv/predictors/detector.py | 10 +- tests/models/backbones/test_vitdet.py | 23 +- tests/predictors/test_detector.py | 189 ++- 20 files changed, 925 insertions(+), 995 deletions(-) delete mode 100644 configs/detection/vitdet/vitdet_basicblock_100e.py delete mode 100644 configs/detection/vitdet/vitdet_bottleneck_100e.py create mode 100644 configs/detection/vitdet/vitdet_cascade_mask_rcnn.py create mode 100644 configs/detection/vitdet/vitdet_cascade_mask_rcnn_100e.py rename configs/detection/vitdet/{vitdet_100e.py => vitdet_mask_rcnn_100e.py} (100%) diff --git a/configs/detection/vitdet/lsj_coco_detection.py b/configs/detection/vitdet/lsj_coco_detection.py index f5da1064..fb243a23 100644 --- a/configs/detection/vitdet/lsj_coco_detection.py +++ b/configs/detection/vitdet/lsj_coco_detection.py @@ -101,13 +101,15 @@ val_dataset = dict( pipeline=test_pipeline) data = dict( - imgs_per_gpu=1, workers_per_gpu=2, train=train_dataset, val=val_dataset) + imgs_per_gpu=4, workers_per_gpu=2, train=train_dataset, val=val_dataset +) # 64(total batch size) = 4 (batch size/per gpu) x 8 (gpu num) x 2(node) # evaluation -eval_config = dict(interval=1, gpu_collect=False) +eval_config = dict(initial=False, interval=1, gpu_collect=False) eval_pipelines = [ dict( mode='test', + # dist_eval=True, evaluators=[ dict(type='CocoDetectionEvaluator', classes=CLASSES), ], diff --git a/configs/detection/vitdet/lsj_coco_instance.py b/configs/detection/vitdet/lsj_coco_instance.py index a42aa040..5271363f 100644 --- a/configs/detection/vitdet/lsj_coco_instance.py +++ b/configs/detection/vitdet/lsj_coco_instance.py @@ -101,13 +101,15 @@ val_dataset = dict( pipeline=test_pipeline) data = dict( - imgs_per_gpu=1, workers_per_gpu=2, train=train_dataset, val=val_dataset) + imgs_per_gpu=4, workers_per_gpu=2, train=train_dataset, val=val_dataset +) # 64(total batch size) = 4 (batch size/per gpu) x 8 (gpu num) x 2(node) # evaluation -eval_config = dict(interval=1, gpu_collect=False) +eval_config = dict(initial=False, interval=1, gpu_collect=False) eval_pipelines = [ dict( mode='test', + # dist_eval=True, evaluators=[ dict(type='CocoDetectionEvaluator', classes=CLASSES), dict(type='CocoMaskEvaluator', classes=CLASSES) diff --git a/configs/detection/vitdet/vitdet_basicblock_100e.py b/configs/detection/vitdet/vitdet_basicblock_100e.py deleted file mode 100644 index a3ea54e7..00000000 --- a/configs/detection/vitdet/vitdet_basicblock_100e.py +++ /dev/null @@ -1,3 +0,0 @@ -_base_ = './vitdet_100e.py' - -model = dict(backbone=dict(aggregation='basicblock')) diff --git a/configs/detection/vitdet/vitdet_bottleneck_100e.py b/configs/detection/vitdet/vitdet_bottleneck_100e.py deleted file mode 100644 index a6031797..00000000 --- a/configs/detection/vitdet/vitdet_bottleneck_100e.py +++ /dev/null @@ -1,3 +0,0 @@ -_base_ = './vitdet_100e.py' - -model = dict(backbone=dict(aggregation='bottleneck')) diff --git a/configs/detection/vitdet/vitdet_cascade_mask_rcnn.py b/configs/detection/vitdet/vitdet_cascade_mask_rcnn.py new file mode 100644 index 00000000..dfe0d68d --- /dev/null +++ b/configs/detection/vitdet/vitdet_cascade_mask_rcnn.py @@ -0,0 +1,231 @@ +# model settings + +norm_cfg = dict(type='GN', num_groups=1, eps=1e-6, requires_grad=True) + +pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth' +model = dict( + type='CascadeRCNN', + pretrained=pretrained, + backbone=dict( + type='ViTDet', + img_size=1024, + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + drop_path_rate=0.1, + window_size=14, + mlp_ratio=4, + qkv_bias=True, + window_block_indexes=[ + # 2, 5, 8 11 for global attention + 0, + 1, + 3, + 4, + 6, + 7, + 9, + 10, + ], + residual_block_indexes=[], + use_rel_pos=True), + neck=dict( + type='SFP', + in_channels=768, + out_channels=256, + scale_factors=(4.0, 2.0, 1.0, 0.5), + norm_cfg=norm_cfg, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + num_convs=2, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + roi_head=dict( + type='CascadeRoIHead', + num_stages=3, + stage_loss_weights=[1, 0.5, 0.25], + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=[ + dict( + type='Shared4Conv1FCBBoxHead', + conv_out_channels=256, + norm_cfg=norm_cfg, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, + loss_weight=1.0)), + dict( + type='Shared4Conv1FCBBoxHead', + conv_out_channels=256, + norm_cfg=norm_cfg, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.05, 0.05, 0.1, 0.1]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, + loss_weight=1.0)), + dict( + type='Shared4Conv1FCBBoxHead', + conv_out_channels=256, + norm_cfg=norm_cfg, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.033, 0.033, 0.067, 0.067]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) + ], + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + norm_cfg=norm_cfg, + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=80, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=[ + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.6, + neg_iou_thr=0.6, + min_pos_iou=0.6, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.7, + min_pos_iou=0.7, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False) + ]), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100, + mask_thr_binary=0.5))) + +mmlab_modules = [ + dict(type='mmdet', name='CascadeRCNN', module='model'), + dict(type='mmdet', name='RPNHead', module='head'), + dict(type='mmdet', name='CascadeRoIHead', module='head'), +] diff --git a/configs/detection/vitdet/vitdet_cascade_mask_rcnn_100e.py b/configs/detection/vitdet/vitdet_cascade_mask_rcnn_100e.py new file mode 100644 index 00000000..bbbc339f --- /dev/null +++ b/configs/detection/vitdet/vitdet_cascade_mask_rcnn_100e.py @@ -0,0 +1,4 @@ +_base_ = [ + './vitdet_cascade_mask_rcnn.py', './lsj_coco_instance.py', + './vitdet_schedule_100e.py' +] diff --git a/configs/detection/vitdet/vitdet_faster_rcnn.py b/configs/detection/vitdet/vitdet_faster_rcnn.py index 48604d8b..0a00b397 100644 --- a/configs/detection/vitdet/vitdet_faster_rcnn.py +++ b/configs/detection/vitdet/vitdet_faster_rcnn.py @@ -1,6 +1,6 @@ # model settings -norm_cfg = dict(type='GN', num_groups=1, requires_grad=True) +norm_cfg = dict(type='GN', num_groups=1, eps=1e-6, requires_grad=True) pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth' model = dict( @@ -9,22 +9,32 @@ model = dict( backbone=dict( type='ViTDet', img_size=1024, + patch_size=16, embed_dim=768, depth=12, num_heads=12, + drop_path_rate=0.1, + window_size=14, mlp_ratio=4, qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.1, - use_abs_pos_emb=True, - aggregation='attn', - ), + window_block_indexes=[ + # 2, 5, 8 11 for global attention + 0, + 1, + 3, + 4, + 6, + 7, + 9, + 10, + ], + residual_block_indexes=[], + use_rel_pos=True), neck=dict( type='SFP', - in_channels=[768, 768, 768, 768], + in_channels=768, out_channels=256, + scale_factors=(4.0, 2.0, 1.0, 0.5), norm_cfg=norm_cfg, num_outs=5), rpn_head=dict( @@ -32,7 +42,6 @@ model = dict( in_channels=256, feat_channels=256, num_convs=2, - norm_cfg=norm_cfg, anchor_generator=dict( type='AnchorGenerator', scales=[8], @@ -98,7 +107,7 @@ model = dict( pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, - match_low_quality=True, + match_low_quality=False, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', diff --git a/configs/detection/vitdet/vitdet_faster_rcnn_100e.py b/configs/detection/vitdet/vitdet_faster_rcnn_100e.py index 5a43b575..bfeab9d1 100644 --- a/configs/detection/vitdet/vitdet_faster_rcnn_100e.py +++ b/configs/detection/vitdet/vitdet_faster_rcnn_100e.py @@ -1,4 +1,4 @@ _base_ = [ - './vitdet_faster_rcnn.py', './lsj_coco_detection.py', + './vitdet_faster_rcnn.py', './lsj_coco_instance.py', './vitdet_schedule_100e.py' ] diff --git a/configs/detection/vitdet/vitdet_mask_rcnn.py b/configs/detection/vitdet/vitdet_mask_rcnn.py index 890f6e8f..6b1ed1ce 100644 --- a/configs/detection/vitdet/vitdet_mask_rcnn.py +++ b/configs/detection/vitdet/vitdet_mask_rcnn.py @@ -1,6 +1,6 @@ # model settings -norm_cfg = dict(type='GN', num_groups=1, requires_grad=True) +norm_cfg = dict(type='GN', num_groups=1, eps=1e-6, requires_grad=True) pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth' model = dict( @@ -9,22 +9,32 @@ model = dict( backbone=dict( type='ViTDet', img_size=1024, + patch_size=16, embed_dim=768, depth=12, num_heads=12, + drop_path_rate=0.1, + window_size=14, mlp_ratio=4, qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.1, - use_abs_pos_emb=True, - aggregation='attn', - ), + window_block_indexes=[ + # 2, 5, 8 11 for global attention + 0, + 1, + 3, + 4, + 6, + 7, + 9, + 10, + ], + residual_block_indexes=[], + use_rel_pos=True), neck=dict( type='SFP', - in_channels=[768, 768, 768, 768], + in_channels=768, out_channels=256, + scale_factors=(4.0, 2.0, 1.0, 0.5), norm_cfg=norm_cfg, num_outs=5), rpn_head=dict( @@ -32,7 +42,6 @@ model = dict( in_channels=256, feat_channels=256, num_convs=2, - norm_cfg=norm_cfg, anchor_generator=dict( type='AnchorGenerator', scales=[8], @@ -112,7 +121,7 @@ model = dict( pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, - match_low_quality=True, + match_low_quality=False, ignore_iof_thr=-1), sampler=dict( type='RandomSampler', diff --git a/configs/detection/vitdet/vitdet_100e.py b/configs/detection/vitdet/vitdet_mask_rcnn_100e.py similarity index 100% rename from configs/detection/vitdet/vitdet_100e.py rename to configs/detection/vitdet/vitdet_mask_rcnn_100e.py diff --git a/configs/detection/vitdet/vitdet_schedule_100e.py b/configs/detection/vitdet/vitdet_schedule_100e.py index e659b1f6..a9160eba 100644 --- a/configs/detection/vitdet/vitdet_schedule_100e.py +++ b/configs/detection/vitdet/vitdet_schedule_100e.py @@ -1,26 +1,29 @@ _base_ = 'configs/base.py' +log_config = dict( + interval=200, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + checkpoint_config = dict(interval=10) + # optimizer -paramwise_options = { - 'norm': dict(weight_decay=0.), - 'bias': dict(weight_decay=0.), - 'pos_embed': dict(weight_decay=0.), - 'cls_token': dict(weight_decay=0.) -} optimizer = dict( type='AdamW', lr=1e-4, betas=(0.9, 0.999), weight_decay=0.1, - paramwise_options=paramwise_options) -optimizer_config = dict(grad_clip=None, loss_scale=512.) + constructor='LayerDecayOptimizerConstructor', + paramwise_options=dict(num_layers=12, layer_decay_rate=0.7)) +optimizer_config = dict(grad_clip=None) # learning policy lr_config = dict( policy='step', warmup='linear', warmup_iters=250, - warmup_ratio=0.067, + warmup_ratio=0.001, step=[88, 96]) total_epochs = 100 diff --git a/docs/source/_static/result.jpg b/docs/source/_static/result.jpg index 5bb73d81..d63bad1d 100644 --- a/docs/source/_static/result.jpg +++ b/docs/source/_static/result.jpg @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ee64c0caef841c61c7e6344b7fe2c07a38fba07a8de81ff38c0686c641e0a283 -size 190356 +oid sha256:c696a58a2963b5ac47317751f04ff45bfed4723f2f70bacf91eac711f9710e54 +size 189432 diff --git a/docs/source/model_zoo_det.md b/docs/source/model_zoo_det.md index 03eb3588..474496f0 100644 --- a/docs/source/model_zoo_det.md +++ b/docs/source/model_zoo_det.md @@ -22,7 +22,7 @@ Pretrained on COCO2017 dataset. (The result has been optimized with PAI-Blade, a | Algorithm | Config | Params
(backbone/total) | inference time(V100)
(ms/img) | bbox_mAPval
0.5:0.95 | mask_mAPval
0.5:0.95 | Download | | ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | -| ViTDet_MaskRCNN | [vitdet_maskrcnn](https://github.com/alibaba/EasyCV/tree/master/configs/detection/vitdet/vitdet_100e.py) | 88M/118M | 163ms | 50.57 | 44.96 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/vitdet_maskrcnn.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/vitdet_maskrcnn.log.json) | +| ViTDet_MaskRCNN | [vitdet_maskrcnn](https://github.com/alibaba/EasyCV/tree/master/configs/detection/vitdet/vitdet_mask_rcnn_100e.py) | 86M/111M | 138ms | 50.65 | 45.41 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/20220901_135827.log.json) | ## FCOS diff --git a/easycv/core/optimizer/layer_decay_optimizer_constructor.py b/easycv/core/optimizer/layer_decay_optimizer_constructor.py index 45625494..310bb38c 100644 --- a/easycv/core/optimizer/layer_decay_optimizer_constructor.py +++ b/easycv/core/optimizer/layer_decay_optimizer_constructor.py @@ -1,5 +1,3 @@ -# Reference from https://github.com/ViTAE-Transformer/ViTDet/blob/main/mmcv_custom/layer_decay_optimizer_constructor.py - import json from mmcv.runner import DefaultOptimizerConstructor, get_dist_info @@ -7,23 +5,32 @@ from mmcv.runner import DefaultOptimizerConstructor, get_dist_info from .builder import OPTIMIZER_BUILDERS -def get_num_layer_for_vit(var_name, num_max_layer, layer_sep=None): - if var_name in ('backbone.cls_token', 'backbone.mask_token', - 'backbone.pos_embed'): - return 0 - elif var_name.startswith('backbone.patch_embed'): - return 0 - elif var_name.startswith('backbone.blocks'): - layer_id = int(var_name.split('.')[2]) - return layer_id + 1 - else: - return num_max_layer - 1 +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12): + """ + Calculate lr decay rate for different ViT blocks. + Reference from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if '.pos_embed' in name or '.patch_embed' in name: + layer_id = 0 + elif '.blocks.' in name and '.residual.' not in name: + layer_id = int(name[name.find('.blocks.'):].split('.')[2]) + 1 + + scale = lr_decay_rate**(num_layers + 1 - layer_id) + + return layer_id, scale @OPTIMIZER_BUILDERS.register_module() class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor): - def add_params(self, params, module, prefix='', is_dcn_module=None): + def add_params(self, params, module): """Add all parameters of module to the params list. The parameters of the given module will be added to the list of param groups, with specific rules defined by paramwise_cfg. @@ -31,54 +38,41 @@ class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor): params (list[dict]): A list of param groups, it will be modified in place. module (nn.Module): The module to be added. - prefix (str): The prefix of the module - is_dcn_module (int|float|None): If the current module is a - submodule of DCN, `is_dcn_module` will be passed to - control conv_offset layer's learning rate. Defaults to None. + + Reference from https://github.com/ViTAE-Transformer/ViTDet/blob/main/mmcv_custom/layer_decay_optimizer_constructor.py + Note: Currently, this optimizer constructor is built for ViTDet. """ - # get param-wise options parameter_groups = {} print(self.paramwise_cfg) - num_layers = self.paramwise_cfg.get('num_layers') + 2 - layer_sep = self.paramwise_cfg.get('layer_sep', None) - layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate') + lr_decay_rate = self.paramwise_cfg.get('layer_decay_rate') + num_layers = self.paramwise_cfg.get('num_layers') print('Build LayerDecayOptimizerConstructor %f - %d' % - (layer_decay_rate, num_layers)) + (lr_decay_rate, num_layers)) + lr = self.base_lr weight_decay = self.base_wd - custom_keys = self.paramwise_cfg.get('custom_keys', {}) - # first sort with alphabet order and then sort with reversed len of str - sorted_keys = sorted(custom_keys.keys()) - for name, param in module.named_parameters(): if not param.requires_grad: continue # frozen weights - if len(param.shape) == 1 or name.endswith('.bias') or ( - 'pos_embed' in name) or ('cls_token' - in name) or ('rel_pos_' in name): + if 'backbone' in name and ('.norm' in name or '.pos_embed' in name + or '.gn.' in name or '.ln.' in name): group_name = 'no_decay' this_weight_decay = 0. else: group_name = 'decay' this_weight_decay = weight_decay - layer_id = get_num_layer_for_vit(name, num_layers, layer_sep) + if name.startswith('backbone'): + layer_id, scale = get_vit_lr_decay_rate( + name, lr_decay_rate=lr_decay_rate, num_layers=num_layers) + else: + layer_id, scale = -1, 1 group_name = 'layer_%d_%s' % (layer_id, group_name) - # if the parameter match one of the custom keys, ignore other rules - this_lr_multi = 1. - for key in sorted_keys: - if key in f'{name}': - lr_mult = custom_keys[key].get('lr_mult', 1.) - this_lr_multi = lr_mult - group_name = '%s_%s' % (group_name, key) - break - if group_name not in parameter_groups: - scale = layer_decay_rate**(num_layers - layer_id - 1) parameter_groups[group_name] = { 'weight_decay': this_weight_decay, @@ -86,7 +80,7 @@ class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor): 'param_names': [], 'lr_scale': scale, 'group_name': group_name, - 'lr': scale * self.base_lr * this_lr_multi, + 'lr': scale * lr, } parameter_groups[group_name]['params'].append(param) diff --git a/easycv/models/backbones/vitdet.py b/easycv/models/backbones/vitdet.py index 83e11efa..9380f740 100644 --- a/easycv/models/backbones/vitdet.py +++ b/easycv/models/backbones/vitdet.py @@ -1,5 +1,3 @@ -# Copyright 2018-2023 OpenMMLab. All rights reserved. -# Reference: https://github.com/ViTAE-Transformer/ViTDet/blob/main/mmdet/models/backbones/vit.py import math from functools import partial @@ -7,793 +5,466 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint -from mmcv.cnn import build_norm_layer, constant_init, kaiming_init -from mmcv.runner import get_dist_info -from timm.models.layers import to_2tuple, trunc_normal_ -from torch.nn.modules.batchnorm import _BatchNorm +from timm.models.layers import DropPath, trunc_normal_ -from easycv.models.utils import DropPath, Mlp +from easycv.models.utils import Mlp from easycv.utils.checkpoint import load_checkpoint from easycv.utils.logger import get_root_logger from ..registry import BACKBONES -from ..utils import build_conv_layer - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, - inplanes, - planes, - stride=1, - dilation=1, - conv_cfg=None, - norm_cfg=dict(type='BN')): - super(BasicBlock, self).__init__() - - self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) - self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) - - self.conv1 = build_conv_layer( - conv_cfg, - inplanes, - planes, - 3, - stride=stride, - padding=dilation, - dilation=dilation, - bias=False) - self.add_module(self.norm1_name, norm1) - self.conv2 = build_conv_layer( - conv_cfg, planes, planes, 3, padding=1, bias=False) - self.add_module(self.norm2_name, norm2) - - self.relu = nn.ReLU(inplace=True) - self.stride = stride - self.dilation = dilation - - @property - def norm1(self): - return getattr(self, self.norm1_name) - - @property - def norm2(self): - return getattr(self, self.norm2_name) - - def forward(self, x, H, W): - B, _, C = x.shape - x = x.permute(0, 2, 1).reshape(B, -1, H, W) - identity = x - - out = self.conv1(x) - out = self.norm1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.norm2(out) - - out += identity - out = self.relu(out) - out = out.flatten(2).transpose(1, 2) - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, - inplanes, - planes, - stride=1, - dilation=1, - conv_cfg=None, - norm_cfg=dict(type='BN')): - """Bottleneck block for ResNet. - If style is "pytorch", the stride-two layer is the 3x3 conv layer, - if it is "caffe", the stride-two layer is the first 1x1 conv layer. - """ - super(Bottleneck, self).__init__() - - self.inplanes = inplanes - self.planes = planes - self.stride = stride - self.dilation = dilation - self.conv_cfg = conv_cfg - self.norm_cfg = norm_cfg - - self.conv1_stride = 1 - self.conv2_stride = stride - - self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) - self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) - self.norm3_name, norm3 = build_norm_layer( - norm_cfg, planes * self.expansion, postfix=3) - - self.conv1 = build_conv_layer( - conv_cfg, - inplanes, - planes, - kernel_size=1, - stride=self.conv1_stride, - bias=False) - self.add_module(self.norm1_name, norm1) - self.conv2 = build_conv_layer( - conv_cfg, - planes, - planes, - kernel_size=3, - stride=self.conv2_stride, - padding=dilation, - dilation=dilation, - bias=False) - self.add_module(self.norm2_name, norm2) - self.conv3 = build_conv_layer( - conv_cfg, - planes, - planes * self.expansion, - kernel_size=1, - bias=False) - self.add_module(self.norm3_name, norm3) - - self.relu = nn.ReLU(inplace=True) - - @property - def norm1(self): - return getattr(self, self.norm1_name) - - @property - def norm2(self): - return getattr(self, self.norm2_name) - - @property - def norm3(self): - return getattr(self, self.norm3_name) - - def forward(self, x, H, W): - B, _, C = x.shape - x = x.permute(0, 2, 1).reshape(B, -1, H, W) - identity = x - - out = self.conv1(x) - out = self.norm1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.norm2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.norm3(out) - - out += identity - out = self.relu(out) - out = out.flatten(2).transpose(1, 2) - return out - - -class Attention(nn.Module): - - def __init__(self, - dim, - num_heads=8, - qkv_bias=False, - qk_scale=None, - attn_drop=0., - proj_drop=0., - window_size=None, - attn_head_dim=None): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - if attn_head_dim is not None: - head_dim = attn_head_dim - all_head_dim = head_dim * self.num_heads - # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights - self.scale = qk_scale or head_dim**-0.5 - - self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) - self.window_size = window_size - q_size = window_size[0] - kv_size = q_size - rel_sp_dim = 2 * q_size - 1 - self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) - self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) - - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(all_head_dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x, H, W, rel_pos_bias=None): - B, N, C = x.shape - # qkv_bias = None - # if self.q_bias is not None: - # qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) - # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - qkv = self.qkv(x) - qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[ - 2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - attn = calc_rel_pos_spatial(attn, q, self.window_size, - self.window_size, self.rel_pos_h, - self.rel_pos_w) - # if self.relative_position_bias_table is not None: - # relative_position_bias = \ - # self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - # self.window_size[0] * self.window_size[1] + 1, - # self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH - # relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - # attn = attn + relative_position_bias.unsqueeze(0) - - # if rel_pos_bias is not None: - # attn = attn + rel_pos_bias - - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, -1) - x = self.proj(x) - x = self.proj_drop(x) - return x def window_partition(x, window_size): """ + Partition into non-overlapping windows with padding if needed. Args: - x: (B, H, W, C) - window_size (int): window size + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. Returns: - windows: (num_windows*B, window_size, window_size, C) + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, - C) + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, + window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows + return windows, (Hp, Wp) -def window_reverse(windows, window_size, H, W): +def window_unpartition(windows, window_size, pad_hw, hw): """ + Window unpartition into original sequences and removing padding. Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. Returns: - x: (B, H, W, C) + x: unpartitioned sequences with [B, H, W, C]. """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() return x -def calc_rel_pos_spatial( - attn, - q, - q_shape, - k_shape, - rel_pos_h, - rel_pos_w, -): +def get_rel_pos(q_size, k_size, rel_pos): """ - Spatial Relative Positional Embeddings. + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + Returns: + Extracted positional embeddings according to relative positions. """ - sp_idx = 0 - q_h, q_w = q_shape - k_h, k_w = k_shape + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode='linear', + ) + rel_pos_resized = rel_pos_resized.reshape(-1, + max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos - # Scale up rel pos if shapes for q and k are different. - q_h_ratio = max(k_h / q_h, 1.0) - k_h_ratio = max(q_h / k_h, 1.0) - dist_h = ( - torch.arange(q_h)[:, None] * q_h_ratio - - torch.arange(k_h)[None, :] * k_h_ratio) - dist_h += (k_h - 1) * k_h_ratio - q_w_ratio = max(k_w / q_w, 1.0) - k_w_ratio = max(q_w / k_w, 1.0) - dist_w = ( - torch.arange(q_w)[:, None] * q_w_ratio - - torch.arange(k_w)[None, :] * k_w_ratio) - dist_w += (k_w - 1) * k_w_ratio + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - + k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) - Rh = rel_pos_h[dist_h.long()] - Rw = rel_pos_w[dist_w.long()] + return rel_pos_resized[relative_coords.long()] - B, n_head, q_N, dim = q.shape - r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim) - rel_h = torch.einsum('byhwc,hkc->byhwk', r_q, Rh) - rel_w = torch.einsum('byhwc,wkc->byhwk', r_q, Rw) +def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) - attn[:, :, sp_idx:, sp_idx:] = ( - attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w) + - rel_h[:, :, :, :, :, None] + rel_w[:, :, :, :, None, :]).view( - B, -1, q_h * q_w, k_h * k_w) + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh) + rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw) + + attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w) return attn -class WindowAttention(nn.Module): - """ Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. +def get_abs_pos(abs_pos, has_cls_token, hw): + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + h, w = hw + if has_cls_token: + abs_pos = abs_pos[:, 1:] + xy_num = abs_pos.shape[1] + size = int(math.sqrt(xy_num)) + assert size * size == xy_num + + if size != h or size != w: + new_abs_pos = F.interpolate( + abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), + size=(h, w), + mode='bicubic', + align_corners=False, + ) + + return new_abs_pos.permute(0, 2, 3, 1) + else: + return abs_pos.reshape(1, h, w, -1) + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. """ def __init__(self, - dim, - window_size, - num_heads, - qkv_bias=True, - qk_scale=None, - attn_drop=0., - proj_drop=0., - attn_head_dim=None): - + kernel_size=(16, 16), + stride=(16, 16), + padding=(0, 0), + in_chans=3, + embed_dim=768): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=kernel_size, + stride=stride, + padding=padding) + + def forward(self, x): + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + use_rel_pos=False, + rel_pos_zero_init=True, + input_size=None, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim**-0.5 - - q_size = window_size[0] - kv_size = window_size[1] - rel_sp_dim = 2 * q_size - 1 - self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) - self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) + self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - # trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter( + torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter( + torch.zeros(2 * input_size[1] - 1, head_dim)) - def forward(self, x, H, W): - """ Forward function. - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - x = x.reshape(B_, H, W, C) - pad_l = pad_t = 0 - pad_r = (self.window_size[1] - - W % self.window_size[1]) % self.window_size[1] - pad_b = (self.window_size[0] - - H % self.window_size[0]) % self.window_size[0] + if not rel_pos_zero_init: + trunc_normal_(self.rel_pos_h, std=0.02) + trunc_normal_(self.rel_pos_w, std=0.02) - x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) - _, Hp, Wp, _ = x.shape + def forward(self, x): + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, + -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) - x = window_partition( - x, self.window_size[0]) # nW*B, window_size, window_size, C - x = x.view(-1, self.window_size[1] * self.window_size[0], - C) # nW*B, window_size*window_size, C - B_w = x.shape[0] - N_w = x.shape[1] - qkv = self.qkv(x).reshape(B_w, N_w, 3, self.num_heads, - C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[ - 2] # make torchscript happy (cannot use tensor as tuple) + attn = (q * self.scale) @ k.transpose(-2, -1) - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, + self.rel_pos_w, (H, W), (H, W)) - attn = calc_rel_pos_spatial(attn, q, self.window_size, - self.window_size, self.rel_pos_h, - self.rel_pos_w) - - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_w, N_w, C) + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, + -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) x = self.proj(x) - x = self.proj_drop(x) - - x = x.view(-1, self.window_size[1], self.window_size[0], C) - x = window_reverse(x, self.window_size[0], Hp, Wp) # B H' W' C - - if pad_r > 0 or pad_b > 0: - x = x[:, :H, :W, :].contiguous() - - x = x.view(B_, H * W, C) return x class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" - def __init__(self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - init_values=None, - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - window_size=None, - attn_head_dim=None, - window=False, - aggregation='attn'): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=True, + drop_path=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + use_rel_pos=False, + rel_pos_zero_init=True, + window_size=0, + use_residual_block=False, + input_size=None, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then not + use window attention. + use_residual_block (bool): If True, use a residual block after the MLP block. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ super().__init__() self.norm1 = norm_layer(dim) - self.aggregation = aggregation - self.window = window - if not window: - if aggregation == 'attn': - self.attn = Attention( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, - window_size=window_size, - attn_head_dim=attn_head_dim) - else: - self.attn = WindowAttention( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, - window_size=window_size, - attn_head_dim=attn_head_dim) - if aggregation == 'basicblock': - self.conv_aggregation = BasicBlock( - inplanes=dim, planes=dim) - elif aggregation == 'bottleneck': - self.conv_aggregation = Bottleneck( - inplanes=dim, planes=dim // 4) - else: - self.attn = WindowAttention( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, - window_size=window_size, - attn_head_dim=attn_head_dim) - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else + (window_size, window_size), + ) + self.drop_path = DropPath( - drop_path) if drop_path > 0. else nn.Identity() + drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, - drop=drop) + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer) - if init_values is not None: - self.gamma_1 = nn.Parameter( - init_values * torch.ones((dim)), requires_grad=True) - self.gamma_2 = nn.Parameter( - init_values * torch.ones((dim)), requires_grad=True) - else: - self.gamma_1, self.gamma_2 = None, None + self.window_size = window_size - def forward(self, x, H, W): - if self.gamma_1 is None: - x = x + self.drop_path(self.attn(self.norm1(x), H, W)) - x = x + self.drop_path(self.mlp(self.norm2(x))) - else: - x = x + self.drop_path( - self.gamma_1 * self.attn(self.norm1(x), H, W)) - x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) - if not self.window and self.aggregation != 'attn': - x = self.conv_aggregation(x, H, W) - return x - - -class PatchEmbed(nn.Module): - """ Image to Patch Embedding - """ - - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - num_patches = (img_size[1] // patch_size[1]) * ( - img_size[0] // patch_size[0]) - self.patch_shape = (img_size[0] // patch_size[0], - img_size[1] // patch_size[1]) - self.img_size = img_size - self.patch_size = patch_size - self.num_patches = num_patches - - self.proj = nn.Conv2d( - in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - - def forward(self, x, **kwargs): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - # assert H == self.img_size[0] and W == self.img_size[1], \ - # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x) - Hp, Wp = x.shape[2], x.shape[3] - - x = x.flatten(2).transpose(1, 2) - return x, (Hp, Wp) - - -class HybridEmbed(nn.Module): - """ CNN Feature Map Embedding - Extract feature map from CNN, flatten, project to embedding dim. - """ - - def __init__(self, - backbone, - img_size=224, - feature_size=None, - in_chans=3, - embed_dim=768): - super().__init__() - assert isinstance(backbone, nn.Module) - img_size = to_2tuple(img_size) - self.img_size = img_size - self.backbone = backbone - if feature_size is None: - with torch.no_grad(): - # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature - # map for all networks, the feature metadata has reliable channel and stride info, but using - # stride to calc feature dim requires info about padding of each stage that isn't captured. - training = backbone.training - if training: - backbone.eval() - o = self.backbone( - torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] - feature_size = o.shape[-2:] - feature_dim = o.shape[1] - backbone.train(training) - else: - feature_size = to_2tuple(feature_size) - feature_dim = self.backbone.feature_info.channels()[-1] - self.num_patches = feature_size[0] * feature_size[1] - self.proj = nn.Linear(feature_dim, embed_dim) + self.use_residual_block = use_residual_block def forward(self, x): - x = self.backbone(x)[-1] - x = x.flatten(2).transpose(1, 2) - x = self.proj(x) + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + if self.use_residual_block: + x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + return x -class Norm2d(nn.Module): - - def __init__(self, embed_dim): - super().__init__() - self.ln = nn.LayerNorm(embed_dim, eps=1e-6) - - def forward(self, x): - x = x.permute(0, 2, 3, 1) - x = self.ln(x) - x = x.permute(0, 3, 1, 2).contiguous() - return x - - -# todo: refactor vitdet and vit_transformer_dynamic @BACKBONES.register_module() class ViTDet(nn.Module): - """ Vision Transformer with support for patch or hybrid CNN input stage + """ + This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. + "Exploring Plain Vision Transformer Backbones for Object Detection", + https://arxiv.org/abs/2203.16527 """ - def __init__(self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=80, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4., - qkv_bias=False, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - hybrid_backbone=None, - norm_layer=None, - init_values=None, - use_checkpoint=False, - use_abs_pos_emb=False, - use_rel_pos_bias=False, - use_shared_rel_pos_bias=False, - out_indices=[11], - interval=3, - pretrained=None, - aggregation='attn'): + def __init__( + self, + img_size=1024, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + drop_path_rate=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + use_abs_pos=True, + use_rel_pos=False, + rel_pos_zero_init=True, + window_size=0, + window_block_indexes=(), + residual_block_indexes=(), + use_act_checkpoint=False, + pretrain_img_size=224, + pretrain_use_cls_token=True, + pretrained=None, + ): + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + window_block_indexes (list): Indexes for blocks using window attention. + residual_block_indexes (list): Indexes for blocks using conv propagation. + use_act_checkpoint (bool): If True, use activation checkpointing. + pretrain_img_size (int): input image size for pretraining models. + pretrain_use_cls_token (bool): If True, pretrainig models use class token. + """ super().__init__() - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - self.num_classes = num_classes - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.pretrain_use_cls_token = pretrain_use_cls_token + self.use_act_checkpoint = use_act_checkpoint - if hybrid_backbone is not None: - self.patch_embed = HybridEmbed( - hybrid_backbone, - img_size=img_size, - in_chans=in_chans, - embed_dim=embed_dim) - else: - self.patch_embed = PatchEmbed( - img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim) + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) - num_patches = self.patch_embed.num_patches - - self.out_indices = out_indices - - if use_abs_pos_emb: + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + num_patches = (pretrain_img_size // patch_size) * ( + pretrain_img_size // patch_size) + num_positions = (num_patches + + 1) if pretrain_use_cls_token else num_patches self.pos_embed = nn.Parameter( - torch.zeros(1, num_patches, embed_dim)) + torch.zeros(1, num_positions, embed_dim)) else: self.pos_embed = None - self.pos_drop = nn.Dropout(p=drop_rate) + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) - ] # stochastic depth decay rule - self.use_rel_pos_bias = use_rel_pos_bias - self.use_checkpoint = use_checkpoint - self.blocks = nn.ModuleList([ - Block( + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, - init_values=init_values, - window_size=(14, 14) if - ((i + 1) % interval != 0 - or aggregation != 'attn') else self.patch_embed.patch_shape, - window=((i + 1) % interval != 0), - aggregation=aggregation) for i in range(depth) - ]) + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i in window_block_indexes else 0, + use_residual_block=i in residual_block_indexes, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) if self.pos_embed is not None: - trunc_normal_(self.pos_embed, std=.02) - - self.norm = norm_layer(embed_dim) + trunc_normal_(self.pos_embed, std=0.02) + self.apply(self._init_weights) self.pretrained = pretrained - self._register_load_state_dict_pre_hook(self._prepare_checkpoint_hook) - def fix_init_weight(self): - - def rescale(param, layer_id): - param.div_(math.sqrt(2.0 * layer_id)) - - for layer_id, layer in enumerate(self.blocks): - rescale(layer.attn.proj.weight.data, layer_id + 1) - rescale(layer.mlp.fc2.weight.data, layer_id + 1) - - def init_weights(self, pretrained=None): - """Initialize the weights in backbone. - Args: - pretrained (str, optional): Path to pre-trained weights. - Defaults to None. - """ - self.fix_init_weight() - pretrained = pretrained or self.pretrained - - def _init_weights(m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) - if isinstance(m, nn.Conv2d): - kaiming_init(m, mode='fan_in', nonlinearity='relu') - elif isinstance(m, (_BatchNorm, nn.GroupNorm)): - constant_init(m, 1) - - if isinstance(m, Bottleneck): - constant_init(m.norm3, 0) - elif isinstance(m, BasicBlock): - constant_init(m.norm2, 0) - - if isinstance(pretrained, str): - self.apply(_init_weights) + def init_weights(self): + if isinstance(self.pretrained, str): logger = get_root_logger() - load_checkpoint(self, pretrained, strict=False, logger=logger) - elif pretrained is None: - self.apply(_init_weights) - else: - raise TypeError('pretrained must be a str or None') - - def _prepare_checkpoint_hook(self, state_dict, prefix, *args, **kwargs): - rank, _ = get_dist_info() - if 'pos_embed' in state_dict: - pos_embed_checkpoint = state_dict['pos_embed'] - embedding_size = pos_embed_checkpoint.shape[-1] - H, W = self.patch_embed.patch_shape - num_patches = self.patch_embed.num_patches - num_extra_tokens = 1 - # height (== width) for the checkpoint position embedding - orig_size = int( - (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) - # height (== width) for the new position embedding - new_size = int(num_patches**0.5) - # class_token and dist_token are kept unchanged - if orig_size != new_size: - if rank == 0: - print('Position interpolate from %dx%d to %dx%d' % - (orig_size, orig_size, H, W)) - # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] - # only the position tokens are interpolated - pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] - pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, - embedding_size).permute( - 0, 3, 1, 2) - pos_tokens = torch.nn.functional.interpolate( - pos_tokens, - size=(H, W), - mode='bicubic', - align_corners=False) - new_pos_embed = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) - # new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) - state_dict['pos_embed'] = new_pos_embed - - def get_num_layers(self): - return len(self.blocks) - - @torch.jit.ignore - def no_weight_decay(self): - return {'pos_embed', 'cls_token'} - - def forward_features(self, x): - B, C, H, W = x.shape - x, (Hp, Wp) = self.patch_embed(x) - batch_size, seq_len, _ = x.size() - - if self.pos_embed is not None: - x = x + self.pos_embed - x = self.pos_drop(x) - - outs = [] - for i, blk in enumerate(self.blocks): - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) - else: - x = blk(x, Hp, Wp) - - x = self.norm(x) - xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp) - - outs.append(xp) - - return tuple(outs) + load_checkpoint(self, self.pretrained, strict=False, logger=logger) def forward(self, x): - x = self.forward_features(x) - return x + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + get_abs_pos(self.pos_embed, self.pretrain_use_cls_token, + (x.shape[1], x.shape[2])) + + for blk in self.blocks: + if self.use_act_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + outputs = [x.permute(0, 3, 1, 2)] + return outputs diff --git a/easycv/models/detection/necks/fpn.py b/easycv/models/detection/necks/fpn.py index 6d14bbef..8018903c 100644 --- a/easycv/models/detection/necks/fpn.py +++ b/easycv/models/detection/necks/fpn.py @@ -37,7 +37,6 @@ class FPN(nn.Module): Default: None. upsample_cfg (dict): Config dict for interpolate layer. Default: dict(mode='nearest'). - init_cfg (dict or list[dict], optional): Initialization config dict. Example: >>> import torch >>> in_channels = [2, 3, 5, 7] @@ -67,8 +66,6 @@ class FPN(nn.Module): norm_cfg=None, act_cfg=None, upsample_cfg=dict(mode='nearest')): - # init_cfg=dict( - # type='Xavier', layer='Conv2d', distribution='uniform')): super(FPN, self).__init__() assert isinstance(in_channels, list) self.in_channels = in_channels diff --git a/easycv/models/detection/necks/sfp.py b/easycv/models/detection/necks/sfp.py index be1273b0..b588f643 100644 --- a/easycv/models/detection/necks/sfp.py +++ b/easycv/models/detection/necks/sfp.py @@ -2,26 +2,12 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule -from mmcv.runner import BaseModule from easycv.models.builder import NECKS -class Norm2d(nn.Module): - - def __init__(self, embed_dim): - super().__init__() - self.ln = nn.LayerNorm(embed_dim, eps=1e-6) - - def forward(self, x): - x = x.permute(0, 2, 3, 1) - x = self.ln(x) - x = x.permute(0, 3, 1, 2).contiguous() - return x - - @NECKS.register_module() -class SFP(BaseModule): +class SFP(nn.Module): r"""Simple Feature Pyramid. This is an implementation of paper `Exploring Plain Vision Transformer Backbones for Object Detection `_. Args: @@ -32,25 +18,12 @@ class SFP(BaseModule): build the feature pyramid. Default: 0. end_level (int): Index of the end input backbone level (exclusive) to build the feature pyramid. Default: -1, which means the last level. - add_extra_convs (bool | str): If bool, it decides whether to add conv - layers on top of the original feature maps. Default to False. - If True, it is equivalent to `add_extra_convs='on_input'`. - If str, it specifies the source feature map of the extra convs. - Only the following options are allowed - - 'on_input': Last feat map of neck inputs (i.e. backbone feature). - - 'on_lateral': Last feature map after lateral convs. - - 'on_output': The last output feature map after fpn convs. - relu_before_extra_convs (bool): Whether to apply relu before the extra conv. Default: False. - no_norm_on_lateral (bool): Whether to apply norm on lateral. Default: False. conv_cfg (dict): Config dict for convolution layer. Default: None. norm_cfg (dict): Config dict for normalization layer. Default: None. act_cfg (str): Config dict for activation layer in ConvModule. Default: None. - upsample_cfg (dict): Config dict for interpolate layer. - Default: `dict(mode='nearest')` - init_cfg (dict or list[dict], optional): Initialization config dict. Example: >>> import torch >>> in_channels = [2, 3, 5, 7] @@ -70,158 +43,83 @@ class SFP(BaseModule): def __init__(self, in_channels, out_channels, + scale_factors, num_outs, - start_level=0, - end_level=-1, - add_extra_convs=False, - relu_before_extra_convs=False, - no_norm_on_lateral=False, conv_cfg=None, norm_cfg=None, - act_cfg=None, - upsample_cfg=dict(mode='nearest'), - init_cfg=[ - dict( - type='Xavier', - layer=['Conv2d'], - distribution='uniform'), - dict(type='Constant', layer=['LayerNorm'], val=1, bias=0) - ]): - super(SFP, self).__init__(init_cfg) - assert isinstance(in_channels, list) - self.in_channels = in_channels + act_cfg=None): + super(SFP, self).__init__() + dim = in_channels self.out_channels = out_channels - self.num_ins = len(in_channels) + self.scale_factors = scale_factors + self.num_ins = len(scale_factors) self.num_outs = num_outs - self.relu_before_extra_convs = relu_before_extra_convs - self.no_norm_on_lateral = no_norm_on_lateral - self.upsample_cfg = upsample_cfg.copy() - if end_level == -1: - self.backbone_end_level = self.num_ins - assert num_outs >= self.num_ins - start_level - else: - # if end_level < inputs, no extra level is allowed - self.backbone_end_level = end_level - assert end_level <= len(in_channels) - assert num_outs == end_level - start_level - self.start_level = start_level - self.end_level = end_level - self.add_extra_convs = add_extra_convs - assert isinstance(add_extra_convs, (str, bool)) - if isinstance(add_extra_convs, str): - # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' - assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') - elif add_extra_convs: # True - self.add_extra_convs = 'on_input' - - self.top_downs = nn.ModuleList() - self.lateral_convs = nn.ModuleList() - self.fpn_convs = nn.ModuleList() - - for i in range(self.start_level, self.backbone_end_level): - if i == 0: - top_down = nn.Sequential( + self.stages = [] + for idx, scale in enumerate(scale_factors): + out_dim = dim + if scale == 4.0: + layers = [ + nn.ConvTranspose2d(dim, dim // 2, 2, stride=2, padding=0), + nn.GroupNorm(1, dim // 2, eps=1e-6), + nn.GELU(), nn.ConvTranspose2d( - in_channels[i], in_channels[i], 2, stride=2, - padding=0), Norm2d(in_channels[i]), nn.GELU(), - nn.ConvTranspose2d( - in_channels[i], in_channels[i], 2, stride=2, - padding=0)) - elif i == 1: - top_down = nn.ConvTranspose2d( - in_channels[i], in_channels[i], 2, stride=2, padding=0) - elif i == 2: - top_down = nn.Identity() - elif i == 3: - top_down = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + dim // 2, dim // 4, 2, stride=2, padding=0) + ] + out_dim = dim // 4 + elif scale == 2.0: + layers = [ + nn.ConvTranspose2d(dim, dim // 2, 2, stride=2, padding=0) + ] + out_dim = dim // 2 + elif scale == 1.0: + layers = [] + elif scale == 0.5: + layers = [nn.MaxPool2d(kernel_size=2, stride=2, padding=0)] + else: + raise NotImplementedError( + f'scale_factor={scale} is not supported yet.') - l_conv = ConvModule( - in_channels[i], - out_channels, - 1, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, - act_cfg=act_cfg, - inplace=False) - fpn_conv = ConvModule( - out_channels, - out_channels, - 3, - padding=1, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=act_cfg, - inplace=False) - - self.top_downs.append(top_down) - self.lateral_convs.append(l_conv) - self.fpn_convs.append(fpn_conv) - - # add extra conv layers (e.g., RetinaNet) - extra_levels = num_outs - self.backbone_end_level + self.start_level - if self.add_extra_convs and extra_levels >= 1: - for i in range(extra_levels): - if i == 0 and self.add_extra_convs == 'on_input': - in_channels = self.in_channels[self.backbone_end_level - 1] - else: - in_channels = out_channels - extra_fpn_conv = ConvModule( - in_channels, + layers.extend([ + ConvModule( + out_dim, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False), + ConvModule( + out_channels, out_channels, 3, - stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, inplace=False) - self.fpn_convs.append(extra_fpn_conv) + ]) + + layers = nn.Sequential(*layers) + self.add_module(f'sfp_{idx}', layers) + self.stages.append(layers) + + def init_weights(self): + pass def forward(self, inputs): """Forward function.""" - assert len(inputs) == 1 + features = inputs[0] + outs = [] - # build top-down path - features = [ - top_down(inputs[0]) for _, top_down in enumerate(self.top_downs) - ] - assert len(features) == len(self.in_channels) + # part 1: build simple feature pyramid + for stage in self.stages: + outs.append(stage(features)) - # build laterals - laterals = [ - lateral_conv(features[i + self.start_level]) - for i, lateral_conv in enumerate(self.lateral_convs) - ] - - used_backbone_levels = len(laterals) - - # build outputs - # part 1: from original levels - outs = [ - self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) - ] # part 2: add extra levels - if self.num_outs > len(outs): + if self.num_outs > self.num_ins: # use max pool to get more levels on top of outputs # (e.g., Faster R-CNN, Mask R-CNN) - if not self.add_extra_convs: - for i in range(self.num_outs - used_backbone_levels): - outs.append(F.max_pool2d(outs[-1], 1, stride=2)) - # add conv layers on top of original feature maps (RetinaNet) - else: - if self.add_extra_convs == 'on_input': - extra_source = inputs[self.backbone_end_level - 1] - elif self.add_extra_convs == 'on_lateral': - extra_source = laterals[-1] - elif self.add_extra_convs == 'on_output': - extra_source = outs[-1] - else: - raise NotImplementedError - outs.append(self.fpn_convs[used_backbone_levels](extra_source)) - for i in range(used_backbone_levels + 1, self.num_outs): - if self.relu_before_extra_convs: - outs.append(self.fpn_convs[i](F.relu(outs[-1]))) - else: - outs.append(self.fpn_convs[i](outs[-1])) + for i in range(self.num_outs - self.num_ins): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) return tuple(outs) diff --git a/easycv/predictors/detector.py b/easycv/predictors/detector.py index f9d05992..017d671e 100644 --- a/easycv/predictors/detector.py +++ b/easycv/predictors/detector.py @@ -253,11 +253,11 @@ class DetrPredictor(PredictorInterface): img, bboxes, labels=labels, - colors='green', - text_color='white', - font_size=20, - thickness=1, - font_scale=0.5, + colors='cyan', + text_color='cyan', + font_size=18, + thickness=2, + font_scale=0.0, show=show, out_file=out_file) diff --git a/tests/models/backbones/test_vitdet.py b/tests/models/backbones/test_vitdet.py index 3f0350a2..82012aed 100644 --- a/tests/models/backbones/test_vitdet.py +++ b/tests/models/backbones/test_vitdet.py @@ -14,18 +14,27 @@ class ViTDetTest(unittest.TestCase): def test_vitdet(self): model = ViTDet( img_size=1024, + patch_size=16, embed_dim=768, depth=12, num_heads=12, + drop_path_rate=0.1, + window_size=14, mlp_ratio=4, qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.1, - use_abs_pos_emb=True, - aggregation='attn', - ) + window_block_indexes=[ + # 2, 5, 8 11 for global attention + 0, + 1, + 3, + 4, + 6, + 7, + 9, + 10, + ], + residual_block_indexes=[], + use_rel_pos=True) model.init_weights() model.train() diff --git a/tests/predictors/test_detector.py b/tests/predictors/test_detector.py index 9187d3a7..c3be2ed6 100644 --- a/tests/predictors/test_detector.py +++ b/tests/predictors/test_detector.py @@ -155,7 +155,7 @@ class DetectorTest(unittest.TestCase): decimal=1) def test_vitdet_detector(self): - model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/vitdet_maskrcnn_export.pth' + model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100_export.pth' img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg' out_file = './result.jpg' vitdet = DetrPredictor(model_path) @@ -167,63 +167,170 @@ class DetectorTest(unittest.TestCase): self.assertIn('detection_classes', output) self.assertIn('detection_masks', output) self.assertIn('img_metas', output) - self.assertEqual(len(output['detection_boxes'][0]), 30) - self.assertEqual(len(output['detection_scores'][0]), 30) - self.assertEqual(len(output['detection_classes'][0]), 30) + self.assertEqual(len(output['detection_boxes'][0]), 33) + self.assertEqual(len(output['detection_scores'][0]), 33) + self.assertEqual(len(output['detection_classes'][0]), 33) self.assertListEqual( output['detection_classes'][0].tolist(), np.array([ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 7, 7, 13, 13, 13, 56 + 2, 2, 2, 2, 2, 2, 7, 7, 13, 13, 13, 56 ], dtype=np.int32).tolist()) assert_array_almost_equal( output['detection_scores'][0], np.array([ - 0.99791867, 0.99665856, 0.99480623, 0.99060905, 0.9882515, - 0.98319584, 0.9738879, 0.97290784, 0.9514897, 0.95104814, - 0.9321701, 0.86165, 0.8228847, 0.7623552, 0.76129806, - 0.6050861, 0.44348577, 0.3452973, 0.2895671, 0.22109479, - 0.21265312, 0.17855245, 0.1205352, 0.08981906, 0.10596471, - 0.05854294, 0.99749386, 0.9472857, 0.5945908, 0.09855112 + 0.9975854158401489, 0.9965696334838867, 0.9922919869422913, + 0.9833580851554871, 0.983080267906189, 0.970454752445221, + 0.9701289534568787, 0.9649872183799744, 0.9642795324325562, + 0.9642238020896912, 0.9529680609703064, 0.9403366446495056, + 0.9391788244247437, 0.8941807150840759, 0.8178097009658813, + 0.8013413548469543, 0.6677654385566711, 0.3952914774417877, + 0.33463895320892334, 0.32501447200775146, 0.27323535084724426, + 0.20197080075740814, 0.15607696771621704, 0.1068163588643074, + 0.10183875262737274, 0.09735643863677979, 0.06559795141220093, + 0.08890066295862198, 0.076363705098629, 0.9954648613929749, + 0.9212945699691772, 0.5224372148513794, 0.20555885136127472 ], dtype=np.float32), decimal=2) assert_array_almost_equal( output['detection_boxes'][0], - np.array([[294.7058, 117.29371, 378.83713, 149.99928], - [609.05444, 112.526474, 633.2971, 136.35175], - [481.4165, 110.987335, 522.5531, 130.01529], - [167.68184, 109.89049, 215.49057, 139.86987], - [374.75082, 110.68697, 433.10028, 136.23654], - [189.54971, 110.09322, 297.6167, 155.77412], - [266.5185, 105.37718, 326.54385, 127.916374], - [556.30225, 110.43166, 592.8248, 128.03764], - [432.49252, 105.086464, 484.0512, 132.272], - [0., 110.566444, 62.01249, 146.44017], - [591.74664, 110.43527, 619.73816, 126.68549], - [99.126854, 90.947975, 118.46699, 101.11096], - [59.895264, 94.110054, 85.60521, 106.67633], - [142.95819, 96.61966, 165.96964, 104.95929], - [83.062515, 89.802605, 99.1546, 98.69074], - [226.28802, 98.32568, 249.06772, 108.86408], - [136.67789, 94.75706, 154.62924, 104.289536], - [170.42459, 98.458694, 183.16309, 106.203156], - [67.56731, 89.68286, 82.62955, 98.35645], - [222.80092, 97.828445, 239.02655, 108.29377], - [134.34427, 92.31653, 149.19615, 102.97457], - [613.5186, 102.27066, 636.0434, 112.813644], - [607.4787, 110.87984, 630.1123, 127.65646], - [135.13664, 90.989876, 155.67192, 100.18036], - [431.61505, 105.43844, 484.36508, 132.50078], - [189.92722, 110.38832, 297.74353, 155.95557], - [220.67035, 177.13489, 455.32092, 380.45712], - [372.76584, 134.33807, 432.44357, 188.51534], - [50.403812, 110.543495, 70.4368, 119.65186], - [373.50272, 134.27258, 432.18475, 187.81824]]), + np.array([[ + 294.22674560546875, 116.6078109741211, 379.4328918457031, + 150.14097595214844 + ], + [ + 482.6017761230469, 110.75955963134766, + 522.8798828125, 129.71286010742188 + ], + [ + 167.06460571289062, 109.95974731445312, + 212.83975219726562, 140.16102600097656 + ], + [ + 609.2930908203125, 113.13909149169922, + 637.3115844726562, 136.4690704345703 + ], + [ + 191.185791015625, 111.1408920288086, 301.31689453125, + 155.7731170654297 + ], + [ + 431.2244873046875, 106.19962310791016, + 483.860595703125, 132.21627807617188 + ], + [ + 267.48358154296875, 105.5920639038086, + 325.2832336425781, 127.11176300048828 + ], + [ + 591.2138671875, 110.29329681396484, + 619.8524169921875, 126.1990966796875 + ], + [ + 0.0, 110.7026596069336, 61.487945556640625, + 146.33018493652344 + ], + [ + 555.9155883789062, 110.03486633300781, + 591.7050170898438, 127.06097412109375 + ], + [ + 60.24559783935547, 94.12760162353516, + 85.63741302490234, 106.66705322265625 + ], + [ + 99.02665710449219, 90.53657531738281, + 118.83953094482422, 101.18717956542969 + ], + [ + 396.30438232421875, 111.59194946289062, + 431.559814453125, 133.96914672851562 + ], + [ + 83.81543731689453, 89.65665435791016, + 99.9166259765625, 98.25627899169922 + ], + [ + 139.29647827148438, 96.68000793457031, + 165.22410583496094, 105.60000610351562 + ], + [ + 67.27152252197266, 89.42798614501953, + 83.25617980957031, 98.0460205078125 + ], + [ + 223.74176025390625, 98.68321990966797, + 250.42506408691406, 109.32588958740234 + ], + [ + 136.7582244873047, 96.51412963867188, + 152.51190185546875, 104.73160552978516 + ], + [ + 221.71812438964844, 97.86445617675781, + 238.9705810546875, 106.96803283691406 + ], + [ + 135.06964111328125, 91.80916595458984, 155.24609375, + 102.20686340332031 + ], + [ + 169.11180114746094, 97.53628540039062, + 182.88504028320312, 105.95404815673828 + ], + [ + 133.8811798095703, 91.00375366210938, + 145.35507202148438, 102.3780288696289 + ], + [ + 614.2507934570312, 102.19828796386719, + 636.5692749023438, 112.59198760986328 + ], + [ + 35.94759750366211, 91.7213363647461, + 70.38274383544922, 117.19855499267578 + ], + [ + 554.6401977539062, 115.18976593017578, + 562.0255737304688, 127.4429931640625 + ], + [ + 39.07550811767578, 92.73261260986328, + 85.36636352539062, 106.73953247070312 + ], + [ + 200.85513305664062, 93.00469970703125, + 219.73086547851562, 107.99642181396484 + ], + [ + 0.0, 111.18904876708984, 61.7393684387207, + 146.72547912597656 + ], + [ + 191.88568115234375, 111.09577178955078, + 299.4097900390625, 155.14639282226562 + ], + [ + 221.06834411621094, 176.6427001953125, + 458.3475341796875, 378.89300537109375 + ], + [ + 372.7131652832031, 135.51429748535156, + 433.2494201660156, 188.0106658935547 + ], + [ + 52.19819641113281, 110.3646011352539, + 70.95110321044922, 120.10567474365234 + ], + [ + 376.1671447753906, 133.6930694580078, + 432.2721862792969, 187.99481201171875 + ]]), decimal=1)