From 9f01a37ad4df57b30430c41df08459025174e8fd Mon Sep 17 00:00:00 2001
From: tuofeilun <38110862+tuofeilunhifi@users.noreply.github.com>
Date: Fri, 16 Sep 2022 11:03:53 +0800
Subject: [PATCH] Refactor ViTDet backbone and simple feature pyramid (#177)
1. The vitdet backbone implemented by d2 is about 20% faster than the vitdet backbone originally reproduced by easycv.
2. 50.57 -> 50.65
---
.../detection/vitdet/lsj_coco_detection.py | 6 +-
configs/detection/vitdet/lsj_coco_instance.py | 6 +-
.../vitdet/vitdet_basicblock_100e.py | 3 -
.../vitdet/vitdet_bottleneck_100e.py | 3 -
.../vitdet/vitdet_cascade_mask_rcnn.py | 231 ++++
.../vitdet/vitdet_cascade_mask_rcnn_100e.py | 4 +
.../detection/vitdet/vitdet_faster_rcnn.py | 31 +-
.../vitdet/vitdet_faster_rcnn_100e.py | 2 +-
configs/detection/vitdet/vitdet_mask_rcnn.py | 31 +-
...itdet_100e.py => vitdet_mask_rcnn_100e.py} | 0
.../detection/vitdet/vitdet_schedule_100e.py | 21 +-
docs/source/_static/result.jpg | 4 +-
docs/source/model_zoo_det.md | 2 +-
.../layer_decay_optimizer_constructor.py | 78 +-
easycv/models/backbones/vitdet.py | 1057 ++++++-----------
easycv/models/detection/necks/fpn.py | 3 -
easycv/models/detection/necks/sfp.py | 216 +---
easycv/predictors/detector.py | 10 +-
tests/models/backbones/test_vitdet.py | 23 +-
tests/predictors/test_detector.py | 189 ++-
20 files changed, 925 insertions(+), 995 deletions(-)
delete mode 100644 configs/detection/vitdet/vitdet_basicblock_100e.py
delete mode 100644 configs/detection/vitdet/vitdet_bottleneck_100e.py
create mode 100644 configs/detection/vitdet/vitdet_cascade_mask_rcnn.py
create mode 100644 configs/detection/vitdet/vitdet_cascade_mask_rcnn_100e.py
rename configs/detection/vitdet/{vitdet_100e.py => vitdet_mask_rcnn_100e.py} (100%)
diff --git a/configs/detection/vitdet/lsj_coco_detection.py b/configs/detection/vitdet/lsj_coco_detection.py
index f5da1064..fb243a23 100644
--- a/configs/detection/vitdet/lsj_coco_detection.py
+++ b/configs/detection/vitdet/lsj_coco_detection.py
@@ -101,13 +101,15 @@ val_dataset = dict(
pipeline=test_pipeline)
data = dict(
- imgs_per_gpu=1, workers_per_gpu=2, train=train_dataset, val=val_dataset)
+ imgs_per_gpu=4, workers_per_gpu=2, train=train_dataset, val=val_dataset
+) # 64(total batch size) = 4 (batch size/per gpu) x 8 (gpu num) x 2(node)
# evaluation
-eval_config = dict(interval=1, gpu_collect=False)
+eval_config = dict(initial=False, interval=1, gpu_collect=False)
eval_pipelines = [
dict(
mode='test',
+ # dist_eval=True,
evaluators=[
dict(type='CocoDetectionEvaluator', classes=CLASSES),
],
diff --git a/configs/detection/vitdet/lsj_coco_instance.py b/configs/detection/vitdet/lsj_coco_instance.py
index a42aa040..5271363f 100644
--- a/configs/detection/vitdet/lsj_coco_instance.py
+++ b/configs/detection/vitdet/lsj_coco_instance.py
@@ -101,13 +101,15 @@ val_dataset = dict(
pipeline=test_pipeline)
data = dict(
- imgs_per_gpu=1, workers_per_gpu=2, train=train_dataset, val=val_dataset)
+ imgs_per_gpu=4, workers_per_gpu=2, train=train_dataset, val=val_dataset
+) # 64(total batch size) = 4 (batch size/per gpu) x 8 (gpu num) x 2(node)
# evaluation
-eval_config = dict(interval=1, gpu_collect=False)
+eval_config = dict(initial=False, interval=1, gpu_collect=False)
eval_pipelines = [
dict(
mode='test',
+ # dist_eval=True,
evaluators=[
dict(type='CocoDetectionEvaluator', classes=CLASSES),
dict(type='CocoMaskEvaluator', classes=CLASSES)
diff --git a/configs/detection/vitdet/vitdet_basicblock_100e.py b/configs/detection/vitdet/vitdet_basicblock_100e.py
deleted file mode 100644
index a3ea54e7..00000000
--- a/configs/detection/vitdet/vitdet_basicblock_100e.py
+++ /dev/null
@@ -1,3 +0,0 @@
-_base_ = './vitdet_100e.py'
-
-model = dict(backbone=dict(aggregation='basicblock'))
diff --git a/configs/detection/vitdet/vitdet_bottleneck_100e.py b/configs/detection/vitdet/vitdet_bottleneck_100e.py
deleted file mode 100644
index a6031797..00000000
--- a/configs/detection/vitdet/vitdet_bottleneck_100e.py
+++ /dev/null
@@ -1,3 +0,0 @@
-_base_ = './vitdet_100e.py'
-
-model = dict(backbone=dict(aggregation='bottleneck'))
diff --git a/configs/detection/vitdet/vitdet_cascade_mask_rcnn.py b/configs/detection/vitdet/vitdet_cascade_mask_rcnn.py
new file mode 100644
index 00000000..dfe0d68d
--- /dev/null
+++ b/configs/detection/vitdet/vitdet_cascade_mask_rcnn.py
@@ -0,0 +1,231 @@
+# model settings
+
+norm_cfg = dict(type='GN', num_groups=1, eps=1e-6, requires_grad=True)
+
+pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth'
+model = dict(
+ type='CascadeRCNN',
+ pretrained=pretrained,
+ backbone=dict(
+ type='ViTDet',
+ img_size=1024,
+ patch_size=16,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ drop_path_rate=0.1,
+ window_size=14,
+ mlp_ratio=4,
+ qkv_bias=True,
+ window_block_indexes=[
+ # 2, 5, 8 11 for global attention
+ 0,
+ 1,
+ 3,
+ 4,
+ 6,
+ 7,
+ 9,
+ 10,
+ ],
+ residual_block_indexes=[],
+ use_rel_pos=True),
+ neck=dict(
+ type='SFP',
+ in_channels=768,
+ out_channels=256,
+ scale_factors=(4.0, 2.0, 1.0, 0.5),
+ norm_cfg=norm_cfg,
+ num_outs=5),
+ rpn_head=dict(
+ type='RPNHead',
+ in_channels=256,
+ feat_channels=256,
+ num_convs=2,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ scales=[8],
+ ratios=[0.5, 1.0, 2.0],
+ strides=[4, 8, 16, 32, 64]),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[.0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
+ loss_cls=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
+ roi_head=dict(
+ type='CascadeRoIHead',
+ num_stages=3,
+ stage_loss_weights=[1, 0.5, 0.25],
+ bbox_roi_extractor=dict(
+ type='SingleRoIExtractor',
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
+ out_channels=256,
+ featmap_strides=[4, 8, 16, 32]),
+ bbox_head=[
+ dict(
+ type='Shared4Conv1FCBBoxHead',
+ conv_out_channels=256,
+ norm_cfg=norm_cfg,
+ in_channels=256,
+ fc_out_channels=1024,
+ roi_feat_size=7,
+ num_classes=80,
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[0., 0., 0., 0.],
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
+ reg_class_agnostic=True,
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0),
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
+ loss_weight=1.0)),
+ dict(
+ type='Shared4Conv1FCBBoxHead',
+ conv_out_channels=256,
+ norm_cfg=norm_cfg,
+ in_channels=256,
+ fc_out_channels=1024,
+ roi_feat_size=7,
+ num_classes=80,
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[0., 0., 0., 0.],
+ target_stds=[0.05, 0.05, 0.1, 0.1]),
+ reg_class_agnostic=True,
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0),
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
+ loss_weight=1.0)),
+ dict(
+ type='Shared4Conv1FCBBoxHead',
+ conv_out_channels=256,
+ norm_cfg=norm_cfg,
+ in_channels=256,
+ fc_out_channels=1024,
+ roi_feat_size=7,
+ num_classes=80,
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[0., 0., 0., 0.],
+ target_stds=[0.033, 0.033, 0.067, 0.067]),
+ reg_class_agnostic=True,
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0),
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
+ ],
+ mask_roi_extractor=dict(
+ type='SingleRoIExtractor',
+ roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
+ out_channels=256,
+ featmap_strides=[4, 8, 16, 32]),
+ mask_head=dict(
+ type='FCNMaskHead',
+ norm_cfg=norm_cfg,
+ num_convs=4,
+ in_channels=256,
+ conv_out_channels=256,
+ num_classes=80,
+ loss_mask=dict(
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
+ # model training and testing settings
+ train_cfg=dict(
+ rpn=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.7,
+ neg_iou_thr=0.3,
+ min_pos_iou=0.3,
+ match_low_quality=True,
+ ignore_iof_thr=-1),
+ sampler=dict(
+ type='RandomSampler',
+ num=256,
+ pos_fraction=0.5,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=False),
+ allowed_border=0,
+ pos_weight=-1,
+ debug=False),
+ rpn_proposal=dict(
+ nms_pre=2000,
+ max_per_img=2000,
+ nms=dict(type='nms', iou_threshold=0.7),
+ min_bbox_size=0),
+ rcnn=[
+ dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.5,
+ neg_iou_thr=0.5,
+ min_pos_iou=0.5,
+ match_low_quality=False,
+ ignore_iof_thr=-1),
+ sampler=dict(
+ type='RandomSampler',
+ num=512,
+ pos_fraction=0.25,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True),
+ mask_size=28,
+ pos_weight=-1,
+ debug=False),
+ dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.6,
+ neg_iou_thr=0.6,
+ min_pos_iou=0.6,
+ match_low_quality=False,
+ ignore_iof_thr=-1),
+ sampler=dict(
+ type='RandomSampler',
+ num=512,
+ pos_fraction=0.25,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True),
+ mask_size=28,
+ pos_weight=-1,
+ debug=False),
+ dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.7,
+ neg_iou_thr=0.7,
+ min_pos_iou=0.7,
+ match_low_quality=False,
+ ignore_iof_thr=-1),
+ sampler=dict(
+ type='RandomSampler',
+ num=512,
+ pos_fraction=0.25,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True),
+ mask_size=28,
+ pos_weight=-1,
+ debug=False)
+ ]),
+ test_cfg=dict(
+ rpn=dict(
+ nms_pre=1000,
+ max_per_img=1000,
+ nms=dict(type='nms', iou_threshold=0.7),
+ min_bbox_size=0),
+ rcnn=dict(
+ score_thr=0.05,
+ nms=dict(type='nms', iou_threshold=0.5),
+ max_per_img=100,
+ mask_thr_binary=0.5)))
+
+mmlab_modules = [
+ dict(type='mmdet', name='CascadeRCNN', module='model'),
+ dict(type='mmdet', name='RPNHead', module='head'),
+ dict(type='mmdet', name='CascadeRoIHead', module='head'),
+]
diff --git a/configs/detection/vitdet/vitdet_cascade_mask_rcnn_100e.py b/configs/detection/vitdet/vitdet_cascade_mask_rcnn_100e.py
new file mode 100644
index 00000000..bbbc339f
--- /dev/null
+++ b/configs/detection/vitdet/vitdet_cascade_mask_rcnn_100e.py
@@ -0,0 +1,4 @@
+_base_ = [
+ './vitdet_cascade_mask_rcnn.py', './lsj_coco_instance.py',
+ './vitdet_schedule_100e.py'
+]
diff --git a/configs/detection/vitdet/vitdet_faster_rcnn.py b/configs/detection/vitdet/vitdet_faster_rcnn.py
index 48604d8b..0a00b397 100644
--- a/configs/detection/vitdet/vitdet_faster_rcnn.py
+++ b/configs/detection/vitdet/vitdet_faster_rcnn.py
@@ -1,6 +1,6 @@
# model settings
-norm_cfg = dict(type='GN', num_groups=1, requires_grad=True)
+norm_cfg = dict(type='GN', num_groups=1, eps=1e-6, requires_grad=True)
pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth'
model = dict(
@@ -9,22 +9,32 @@ model = dict(
backbone=dict(
type='ViTDet',
img_size=1024,
+ patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
+ drop_path_rate=0.1,
+ window_size=14,
mlp_ratio=4,
qkv_bias=True,
- qk_scale=None,
- drop_rate=0.,
- attn_drop_rate=0.,
- drop_path_rate=0.1,
- use_abs_pos_emb=True,
- aggregation='attn',
- ),
+ window_block_indexes=[
+ # 2, 5, 8 11 for global attention
+ 0,
+ 1,
+ 3,
+ 4,
+ 6,
+ 7,
+ 9,
+ 10,
+ ],
+ residual_block_indexes=[],
+ use_rel_pos=True),
neck=dict(
type='SFP',
- in_channels=[768, 768, 768, 768],
+ in_channels=768,
out_channels=256,
+ scale_factors=(4.0, 2.0, 1.0, 0.5),
norm_cfg=norm_cfg,
num_outs=5),
rpn_head=dict(
@@ -32,7 +42,6 @@ model = dict(
in_channels=256,
feat_channels=256,
num_convs=2,
- norm_cfg=norm_cfg,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
@@ -98,7 +107,7 @@ model = dict(
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
- match_low_quality=True,
+ match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
diff --git a/configs/detection/vitdet/vitdet_faster_rcnn_100e.py b/configs/detection/vitdet/vitdet_faster_rcnn_100e.py
index 5a43b575..bfeab9d1 100644
--- a/configs/detection/vitdet/vitdet_faster_rcnn_100e.py
+++ b/configs/detection/vitdet/vitdet_faster_rcnn_100e.py
@@ -1,4 +1,4 @@
_base_ = [
- './vitdet_faster_rcnn.py', './lsj_coco_detection.py',
+ './vitdet_faster_rcnn.py', './lsj_coco_instance.py',
'./vitdet_schedule_100e.py'
]
diff --git a/configs/detection/vitdet/vitdet_mask_rcnn.py b/configs/detection/vitdet/vitdet_mask_rcnn.py
index 890f6e8f..6b1ed1ce 100644
--- a/configs/detection/vitdet/vitdet_mask_rcnn.py
+++ b/configs/detection/vitdet/vitdet_mask_rcnn.py
@@ -1,6 +1,6 @@
# model settings
-norm_cfg = dict(type='GN', num_groups=1, requires_grad=True)
+norm_cfg = dict(type='GN', num_groups=1, eps=1e-6, requires_grad=True)
pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth'
model = dict(
@@ -9,22 +9,32 @@ model = dict(
backbone=dict(
type='ViTDet',
img_size=1024,
+ patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
+ drop_path_rate=0.1,
+ window_size=14,
mlp_ratio=4,
qkv_bias=True,
- qk_scale=None,
- drop_rate=0.,
- attn_drop_rate=0.,
- drop_path_rate=0.1,
- use_abs_pos_emb=True,
- aggregation='attn',
- ),
+ window_block_indexes=[
+ # 2, 5, 8 11 for global attention
+ 0,
+ 1,
+ 3,
+ 4,
+ 6,
+ 7,
+ 9,
+ 10,
+ ],
+ residual_block_indexes=[],
+ use_rel_pos=True),
neck=dict(
type='SFP',
- in_channels=[768, 768, 768, 768],
+ in_channels=768,
out_channels=256,
+ scale_factors=(4.0, 2.0, 1.0, 0.5),
norm_cfg=norm_cfg,
num_outs=5),
rpn_head=dict(
@@ -32,7 +42,6 @@ model = dict(
in_channels=256,
feat_channels=256,
num_convs=2,
- norm_cfg=norm_cfg,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
@@ -112,7 +121,7 @@ model = dict(
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
- match_low_quality=True,
+ match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
diff --git a/configs/detection/vitdet/vitdet_100e.py b/configs/detection/vitdet/vitdet_mask_rcnn_100e.py
similarity index 100%
rename from configs/detection/vitdet/vitdet_100e.py
rename to configs/detection/vitdet/vitdet_mask_rcnn_100e.py
diff --git a/configs/detection/vitdet/vitdet_schedule_100e.py b/configs/detection/vitdet/vitdet_schedule_100e.py
index e659b1f6..a9160eba 100644
--- a/configs/detection/vitdet/vitdet_schedule_100e.py
+++ b/configs/detection/vitdet/vitdet_schedule_100e.py
@@ -1,26 +1,29 @@
_base_ = 'configs/base.py'
+log_config = dict(
+ interval=200,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ # dict(type='TensorboardLoggerHook')
+ ])
+
checkpoint_config = dict(interval=10)
+
# optimizer
-paramwise_options = {
- 'norm': dict(weight_decay=0.),
- 'bias': dict(weight_decay=0.),
- 'pos_embed': dict(weight_decay=0.),
- 'cls_token': dict(weight_decay=0.)
-}
optimizer = dict(
type='AdamW',
lr=1e-4,
betas=(0.9, 0.999),
weight_decay=0.1,
- paramwise_options=paramwise_options)
-optimizer_config = dict(grad_clip=None, loss_scale=512.)
+ constructor='LayerDecayOptimizerConstructor',
+ paramwise_options=dict(num_layers=12, layer_decay_rate=0.7))
+optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=250,
- warmup_ratio=0.067,
+ warmup_ratio=0.001,
step=[88, 96])
total_epochs = 100
diff --git a/docs/source/_static/result.jpg b/docs/source/_static/result.jpg
index 5bb73d81..d63bad1d 100644
--- a/docs/source/_static/result.jpg
+++ b/docs/source/_static/result.jpg
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:ee64c0caef841c61c7e6344b7fe2c07a38fba07a8de81ff38c0686c641e0a283
-size 190356
+oid sha256:c696a58a2963b5ac47317751f04ff45bfed4723f2f70bacf91eac711f9710e54
+size 189432
diff --git a/docs/source/model_zoo_det.md b/docs/source/model_zoo_det.md
index 03eb3588..474496f0 100644
--- a/docs/source/model_zoo_det.md
+++ b/docs/source/model_zoo_det.md
@@ -22,7 +22,7 @@ Pretrained on COCO2017 dataset. (The result has been optimized with PAI-Blade, a
| Algorithm | Config | Params
(backbone/total) | inference time(V100)
(ms/img) | bbox_mAPval
0.5:0.95 | mask_mAPval
0.5:0.95 | Download |
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
-| ViTDet_MaskRCNN | [vitdet_maskrcnn](https://github.com/alibaba/EasyCV/tree/master/configs/detection/vitdet/vitdet_100e.py) | 88M/118M | 163ms | 50.57 | 44.96 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/vitdet_maskrcnn.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/vitdet_maskrcnn.log.json) |
+| ViTDet_MaskRCNN | [vitdet_maskrcnn](https://github.com/alibaba/EasyCV/tree/master/configs/detection/vitdet/vitdet_mask_rcnn_100e.py) | 86M/111M | 138ms | 50.65 | 45.41 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/20220901_135827.log.json) |
## FCOS
diff --git a/easycv/core/optimizer/layer_decay_optimizer_constructor.py b/easycv/core/optimizer/layer_decay_optimizer_constructor.py
index 45625494..310bb38c 100644
--- a/easycv/core/optimizer/layer_decay_optimizer_constructor.py
+++ b/easycv/core/optimizer/layer_decay_optimizer_constructor.py
@@ -1,5 +1,3 @@
-# Reference from https://github.com/ViTAE-Transformer/ViTDet/blob/main/mmcv_custom/layer_decay_optimizer_constructor.py
-
import json
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
@@ -7,23 +5,32 @@ from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
from .builder import OPTIMIZER_BUILDERS
-def get_num_layer_for_vit(var_name, num_max_layer, layer_sep=None):
- if var_name in ('backbone.cls_token', 'backbone.mask_token',
- 'backbone.pos_embed'):
- return 0
- elif var_name.startswith('backbone.patch_embed'):
- return 0
- elif var_name.startswith('backbone.blocks'):
- layer_id = int(var_name.split('.')[2])
- return layer_id + 1
- else:
- return num_max_layer - 1
+def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
+ """
+ Calculate lr decay rate for different ViT blocks.
+ Reference from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py
+ Args:
+ name (string): parameter name.
+ lr_decay_rate (float): base lr decay rate.
+ num_layers (int): number of ViT blocks.
+ Returns:
+ lr decay rate for the given parameter.
+ """
+ layer_id = num_layers + 1
+ if '.pos_embed' in name or '.patch_embed' in name:
+ layer_id = 0
+ elif '.blocks.' in name and '.residual.' not in name:
+ layer_id = int(name[name.find('.blocks.'):].split('.')[2]) + 1
+
+ scale = lr_decay_rate**(num_layers + 1 - layer_id)
+
+ return layer_id, scale
@OPTIMIZER_BUILDERS.register_module()
class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
- def add_params(self, params, module, prefix='', is_dcn_module=None):
+ def add_params(self, params, module):
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
@@ -31,54 +38,41 @@ class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
- prefix (str): The prefix of the module
- is_dcn_module (int|float|None): If the current module is a
- submodule of DCN, `is_dcn_module` will be passed to
- control conv_offset layer's learning rate. Defaults to None.
+
+ Reference from https://github.com/ViTAE-Transformer/ViTDet/blob/main/mmcv_custom/layer_decay_optimizer_constructor.py
+ Note: Currently, this optimizer constructor is built for ViTDet.
"""
- # get param-wise options
parameter_groups = {}
print(self.paramwise_cfg)
- num_layers = self.paramwise_cfg.get('num_layers') + 2
- layer_sep = self.paramwise_cfg.get('layer_sep', None)
- layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate')
+ lr_decay_rate = self.paramwise_cfg.get('layer_decay_rate')
+ num_layers = self.paramwise_cfg.get('num_layers')
print('Build LayerDecayOptimizerConstructor %f - %d' %
- (layer_decay_rate, num_layers))
+ (lr_decay_rate, num_layers))
+ lr = self.base_lr
weight_decay = self.base_wd
- custom_keys = self.paramwise_cfg.get('custom_keys', {})
- # first sort with alphabet order and then sort with reversed len of str
- sorted_keys = sorted(custom_keys.keys())
-
for name, param in module.named_parameters():
if not param.requires_grad:
continue # frozen weights
- if len(param.shape) == 1 or name.endswith('.bias') or (
- 'pos_embed' in name) or ('cls_token'
- in name) or ('rel_pos_' in name):
+ if 'backbone' in name and ('.norm' in name or '.pos_embed' in name
+ or '.gn.' in name or '.ln.' in name):
group_name = 'no_decay'
this_weight_decay = 0.
else:
group_name = 'decay'
this_weight_decay = weight_decay
- layer_id = get_num_layer_for_vit(name, num_layers, layer_sep)
+ if name.startswith('backbone'):
+ layer_id, scale = get_vit_lr_decay_rate(
+ name, lr_decay_rate=lr_decay_rate, num_layers=num_layers)
+ else:
+ layer_id, scale = -1, 1
group_name = 'layer_%d_%s' % (layer_id, group_name)
- # if the parameter match one of the custom keys, ignore other rules
- this_lr_multi = 1.
- for key in sorted_keys:
- if key in f'{name}':
- lr_mult = custom_keys[key].get('lr_mult', 1.)
- this_lr_multi = lr_mult
- group_name = '%s_%s' % (group_name, key)
- break
-
if group_name not in parameter_groups:
- scale = layer_decay_rate**(num_layers - layer_id - 1)
parameter_groups[group_name] = {
'weight_decay': this_weight_decay,
@@ -86,7 +80,7 @@ class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
'param_names': [],
'lr_scale': scale,
'group_name': group_name,
- 'lr': scale * self.base_lr * this_lr_multi,
+ 'lr': scale * lr,
}
parameter_groups[group_name]['params'].append(param)
diff --git a/easycv/models/backbones/vitdet.py b/easycv/models/backbones/vitdet.py
index 83e11efa..9380f740 100644
--- a/easycv/models/backbones/vitdet.py
+++ b/easycv/models/backbones/vitdet.py
@@ -1,5 +1,3 @@
-# Copyright 2018-2023 OpenMMLab. All rights reserved.
-# Reference: https://github.com/ViTAE-Transformer/ViTDet/blob/main/mmdet/models/backbones/vit.py
import math
from functools import partial
@@ -7,793 +5,466 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
-from mmcv.cnn import build_norm_layer, constant_init, kaiming_init
-from mmcv.runner import get_dist_info
-from timm.models.layers import to_2tuple, trunc_normal_
-from torch.nn.modules.batchnorm import _BatchNorm
+from timm.models.layers import DropPath, trunc_normal_
-from easycv.models.utils import DropPath, Mlp
+from easycv.models.utils import Mlp
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from ..registry import BACKBONES
-from ..utils import build_conv_layer
-
-
-class BasicBlock(nn.Module):
- expansion = 1
-
- def __init__(self,
- inplanes,
- planes,
- stride=1,
- dilation=1,
- conv_cfg=None,
- norm_cfg=dict(type='BN')):
- super(BasicBlock, self).__init__()
-
- self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
- self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
-
- self.conv1 = build_conv_layer(
- conv_cfg,
- inplanes,
- planes,
- 3,
- stride=stride,
- padding=dilation,
- dilation=dilation,
- bias=False)
- self.add_module(self.norm1_name, norm1)
- self.conv2 = build_conv_layer(
- conv_cfg, planes, planes, 3, padding=1, bias=False)
- self.add_module(self.norm2_name, norm2)
-
- self.relu = nn.ReLU(inplace=True)
- self.stride = stride
- self.dilation = dilation
-
- @property
- def norm1(self):
- return getattr(self, self.norm1_name)
-
- @property
- def norm2(self):
- return getattr(self, self.norm2_name)
-
- def forward(self, x, H, W):
- B, _, C = x.shape
- x = x.permute(0, 2, 1).reshape(B, -1, H, W)
- identity = x
-
- out = self.conv1(x)
- out = self.norm1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.norm2(out)
-
- out += identity
- out = self.relu(out)
- out = out.flatten(2).transpose(1, 2)
- return out
-
-
-class Bottleneck(nn.Module):
- expansion = 4
-
- def __init__(self,
- inplanes,
- planes,
- stride=1,
- dilation=1,
- conv_cfg=None,
- norm_cfg=dict(type='BN')):
- """Bottleneck block for ResNet.
- If style is "pytorch", the stride-two layer is the 3x3 conv layer,
- if it is "caffe", the stride-two layer is the first 1x1 conv layer.
- """
- super(Bottleneck, self).__init__()
-
- self.inplanes = inplanes
- self.planes = planes
- self.stride = stride
- self.dilation = dilation
- self.conv_cfg = conv_cfg
- self.norm_cfg = norm_cfg
-
- self.conv1_stride = 1
- self.conv2_stride = stride
-
- self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
- self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
- self.norm3_name, norm3 = build_norm_layer(
- norm_cfg, planes * self.expansion, postfix=3)
-
- self.conv1 = build_conv_layer(
- conv_cfg,
- inplanes,
- planes,
- kernel_size=1,
- stride=self.conv1_stride,
- bias=False)
- self.add_module(self.norm1_name, norm1)
- self.conv2 = build_conv_layer(
- conv_cfg,
- planes,
- planes,
- kernel_size=3,
- stride=self.conv2_stride,
- padding=dilation,
- dilation=dilation,
- bias=False)
- self.add_module(self.norm2_name, norm2)
- self.conv3 = build_conv_layer(
- conv_cfg,
- planes,
- planes * self.expansion,
- kernel_size=1,
- bias=False)
- self.add_module(self.norm3_name, norm3)
-
- self.relu = nn.ReLU(inplace=True)
-
- @property
- def norm1(self):
- return getattr(self, self.norm1_name)
-
- @property
- def norm2(self):
- return getattr(self, self.norm2_name)
-
- @property
- def norm3(self):
- return getattr(self, self.norm3_name)
-
- def forward(self, x, H, W):
- B, _, C = x.shape
- x = x.permute(0, 2, 1).reshape(B, -1, H, W)
- identity = x
-
- out = self.conv1(x)
- out = self.norm1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.norm2(out)
- out = self.relu(out)
-
- out = self.conv3(out)
- out = self.norm3(out)
-
- out += identity
- out = self.relu(out)
- out = out.flatten(2).transpose(1, 2)
- return out
-
-
-class Attention(nn.Module):
-
- def __init__(self,
- dim,
- num_heads=8,
- qkv_bias=False,
- qk_scale=None,
- attn_drop=0.,
- proj_drop=0.,
- window_size=None,
- attn_head_dim=None):
- super().__init__()
- self.num_heads = num_heads
- head_dim = dim // num_heads
- if attn_head_dim is not None:
- head_dim = attn_head_dim
- all_head_dim = head_dim * self.num_heads
- # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
- self.scale = qk_scale or head_dim**-0.5
-
- self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
- self.window_size = window_size
- q_size = window_size[0]
- kv_size = q_size
- rel_sp_dim = 2 * q_size - 1
- self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
- self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
-
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(all_head_dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x, H, W, rel_pos_bias=None):
- B, N, C = x.shape
- # qkv_bias = None
- # if self.q_bias is not None:
- # qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
- # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- qkv = self.qkv(x)
- qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[
- 2] # make torchscript happy (cannot use tensor as tuple)
-
- q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
- attn = calc_rel_pos_spatial(attn, q, self.window_size,
- self.window_size, self.rel_pos_h,
- self.rel_pos_w)
- # if self.relative_position_bias_table is not None:
- # relative_position_bias = \
- # self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
- # self.window_size[0] * self.window_size[1] + 1,
- # self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
- # relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- # attn = attn + relative_position_bias.unsqueeze(0)
-
- # if rel_pos_bias is not None:
- # attn = attn + rel_pos_bias
-
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
def window_partition(x, window_size):
"""
+ Partition into non-overlapping windows with padding if needed.
Args:
- x: (B, H, W, C)
- window_size (int): window size
+ x (tensor): input tokens with [B, H, W, C].
+ window_size (int): window size.
Returns:
- windows: (num_windows*B, window_size, window_size, C)
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
+ (Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size,
- C)
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+ Hp, Wp = H + pad_h, W + pad_w
+
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size,
+ window_size, C)
windows = x.permute(0, 1, 3, 2, 4,
5).contiguous().view(-1, window_size, window_size, C)
- return windows
+ return windows, (Hp, Wp)
-def window_reverse(windows, window_size, H, W):
+def window_unpartition(windows, window_size, pad_hw, hw):
"""
+ Window unpartition into original sequences and removing padding.
Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+ window_size (int): window size.
+ pad_hw (Tuple): padded height and width (Hp, Wp).
+ hw (Tuple): original height and width (H, W) before padding.
Returns:
- x: (B, H, W, C)
+ x: unpartitioned sequences with [B, H, W, C].
"""
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size,
+ Hp, Wp = pad_hw
+ H, W = hw
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size,
window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+
+ if Hp > H or Wp > W:
+ x = x[:, :H, :W, :].contiguous()
return x
-def calc_rel_pos_spatial(
- attn,
- q,
- q_shape,
- k_shape,
- rel_pos_h,
- rel_pos_w,
-):
+def get_rel_pos(q_size, k_size, rel_pos):
"""
- Spatial Relative Positional Embeddings.
+ Get relative positional embeddings according to the relative positions of
+ query and key sizes.
+ Args:
+ q_size (int): size of query q.
+ k_size (int): size of key k.
+ rel_pos (Tensor): relative position embeddings (L, C).
+ Returns:
+ Extracted positional embeddings according to relative positions.
"""
- sp_idx = 0
- q_h, q_w = q_shape
- k_h, k_w = k_shape
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos if needed.
+ if rel_pos.shape[0] != max_rel_dist:
+ # Interpolate rel pos.
+ rel_pos_resized = F.interpolate(
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+ size=max_rel_dist,
+ mode='linear',
+ )
+ rel_pos_resized = rel_pos_resized.reshape(-1,
+ max_rel_dist).permute(1, 0)
+ else:
+ rel_pos_resized = rel_pos
- # Scale up rel pos if shapes for q and k are different.
- q_h_ratio = max(k_h / q_h, 1.0)
- k_h_ratio = max(q_h / k_h, 1.0)
- dist_h = (
- torch.arange(q_h)[:, None] * q_h_ratio -
- torch.arange(k_h)[None, :] * k_h_ratio)
- dist_h += (k_h - 1) * k_h_ratio
- q_w_ratio = max(k_w / q_w, 1.0)
- k_w_ratio = max(q_w / k_w, 1.0)
- dist_w = (
- torch.arange(q_w)[:, None] * q_w_ratio -
- torch.arange(k_w)[None, :] * k_w_ratio)
- dist_w += (k_w - 1) * k_w_ratio
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords -
+ k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
- Rh = rel_pos_h[dist_h.long()]
- Rw = rel_pos_w[dist_w.long()]
+ return rel_pos_resized[relative_coords.long()]
- B, n_head, q_N, dim = q.shape
- r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim)
- rel_h = torch.einsum('byhwc,hkc->byhwk', r_q, Rh)
- rel_w = torch.einsum('byhwc,wkc->byhwk', r_q, Rw)
+def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
+ """
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
+ Args:
+ attn (Tensor): attention map.
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
+ Returns:
+ attn (Tensor): attention map with added relative positional embeddings.
+ """
+ q_h, q_w = q_size
+ k_h, k_w = k_size
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
- attn[:, :, sp_idx:, sp_idx:] = (
- attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w) +
- rel_h[:, :, :, :, :, None] + rel_w[:, :, :, :, None, :]).view(
- B, -1, q_h * q_w, k_h * k_w)
+ B, _, dim = q.shape
+ r_q = q.reshape(B, q_h, q_w, dim)
+ rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh)
+ rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw)
+
+ attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] +
+ rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)
return attn
-class WindowAttention(nn.Module):
- """ Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
+def get_abs_pos(abs_pos, has_cls_token, hw):
+ """
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
+ dimension for the original embeddings.
Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
+ hw (Tuple): size of input image tokens.
+ Returns:
+ Absolute positional embeddings after processing with shape (1, H, W, C)
+ """
+ h, w = hw
+ if has_cls_token:
+ abs_pos = abs_pos[:, 1:]
+ xy_num = abs_pos.shape[1]
+ size = int(math.sqrt(xy_num))
+ assert size * size == xy_num
+
+ if size != h or size != w:
+ new_abs_pos = F.interpolate(
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
+ size=(h, w),
+ mode='bicubic',
+ align_corners=False,
+ )
+
+ return new_abs_pos.permute(0, 2, 3, 1)
+ else:
+ return abs_pos.reshape(1, h, w, -1)
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding.
"""
def __init__(self,
- dim,
- window_size,
- num_heads,
- qkv_bias=True,
- qk_scale=None,
- attn_drop=0.,
- proj_drop=0.,
- attn_head_dim=None):
-
+ kernel_size=(16, 16),
+ stride=(16, 16),
+ padding=(0, 0),
+ in_chans=3,
+ embed_dim=768):
+ """
+ Args:
+ kernel_size (Tuple): kernel size of the projection layer.
+ stride (Tuple): stride of the projection layer.
+ padding (Tuple): padding size of the projection layer.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
+ """
+ super().__init__()
+
+ self.proj = nn.Conv2d(
+ in_chans,
+ embed_dim,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+
+ def forward(self, x):
+ x = self.proj(x)
+ # B C H W -> B H W C
+ x = x.permute(0, 2, 3, 1)
+ return x
+
+
+class Attention(nn.Module):
+ """Multi-head Attention block with relative position embeddings."""
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=True,
+ use_rel_pos=False,
+ rel_pos_zero_init=True,
+ input_size=None,
+ ):
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool: If True, add a learnable bias to query, key, value.
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ input_size (int or None): Input resolution for calculating the relative positional
+ parameter size.
+ """
super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
- self.scale = qk_scale or head_dim**-0.5
-
- q_size = window_size[0]
- kv_size = window_size[1]
- rel_sp_dim = 2 * q_size - 1
- self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
- self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
+ self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
- # trunc_normal_(self.relative_position_bias_table, std=.02)
- self.softmax = nn.Softmax(dim=-1)
+ self.use_rel_pos = use_rel_pos
+ if self.use_rel_pos:
+ # initialize relative positional embeddings
+ self.rel_pos_h = nn.Parameter(
+ torch.zeros(2 * input_size[0] - 1, head_dim))
+ self.rel_pos_w = nn.Parameter(
+ torch.zeros(2 * input_size[1] - 1, head_dim))
- def forward(self, x, H, W):
- """ Forward function.
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- x = x.reshape(B_, H, W, C)
- pad_l = pad_t = 0
- pad_r = (self.window_size[1] -
- W % self.window_size[1]) % self.window_size[1]
- pad_b = (self.window_size[0] -
- H % self.window_size[0]) % self.window_size[0]
+ if not rel_pos_zero_init:
+ trunc_normal_(self.rel_pos_h, std=0.02)
+ trunc_normal_(self.rel_pos_w, std=0.02)
- x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
- _, Hp, Wp, _ = x.shape
+ def forward(self, x):
+ B, H, W, _ = x.shape
+ # qkv with shape (3, B, nHead, H * W, C)
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads,
+ -1).permute(2, 0, 3, 1, 4)
+ # q, k, v with shape (B * nHead, H * W, C)
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
- x = window_partition(
- x, self.window_size[0]) # nW*B, window_size, window_size, C
- x = x.view(-1, self.window_size[1] * self.window_size[0],
- C) # nW*B, window_size*window_size, C
- B_w = x.shape[0]
- N_w = x.shape[1]
- qkv = self.qkv(x).reshape(B_w, N_w, 3, self.num_heads,
- C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[
- 2] # make torchscript happy (cannot use tensor as tuple)
+ attn = (q * self.scale) @ k.transpose(-2, -1)
- q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
+ if self.use_rel_pos:
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h,
+ self.rel_pos_w, (H, W), (H, W))
- attn = calc_rel_pos_spatial(attn, q, self.window_size,
- self.window_size, self.rel_pos_h,
- self.rel_pos_w)
-
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_w, N_w, C)
+ attn = attn.softmax(dim=-1)
+ x = (attn @ v).view(B, self.num_heads, H, W,
+ -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
- x = self.proj_drop(x)
-
- x = x.view(-1, self.window_size[1], self.window_size[0], C)
- x = window_reverse(x, self.window_size[0], Hp, Wp) # B H' W' C
-
- if pad_r > 0 or pad_b > 0:
- x = x[:, :H, :W, :].contiguous()
-
- x = x.view(B_, H * W, C)
return x
class Block(nn.Module):
+ """Transformer blocks with support of window attention and residual propagation blocks"""
- def __init__(self,
- dim,
- num_heads,
- mlp_ratio=4.,
- qkv_bias=False,
- qk_scale=None,
- drop=0.,
- attn_drop=0.,
- drop_path=0.,
- init_values=None,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm,
- window_size=None,
- attn_head_dim=None,
- window=False,
- aggregation='attn'):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=0.0,
+ norm_layer=nn.LayerNorm,
+ act_layer=nn.GELU,
+ use_rel_pos=False,
+ rel_pos_zero_init=True,
+ window_size=0,
+ use_residual_block=False,
+ input_size=None,
+ ):
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ drop_path (float): Stochastic depth rate.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks. If it equals 0, then not
+ use window attention.
+ use_residual_block (bool): If True, use a residual block after the MLP block.
+ input_size (int or None): Input resolution for calculating the relative positional
+ parameter size.
+ """
super().__init__()
self.norm1 = norm_layer(dim)
- self.aggregation = aggregation
- self.window = window
- if not window:
- if aggregation == 'attn':
- self.attn = Attention(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- attn_drop=attn_drop,
- proj_drop=drop,
- window_size=window_size,
- attn_head_dim=attn_head_dim)
- else:
- self.attn = WindowAttention(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- attn_drop=attn_drop,
- proj_drop=drop,
- window_size=window_size,
- attn_head_dim=attn_head_dim)
- if aggregation == 'basicblock':
- self.conv_aggregation = BasicBlock(
- inplanes=dim, planes=dim)
- elif aggregation == 'bottleneck':
- self.conv_aggregation = Bottleneck(
- inplanes=dim, planes=dim // 4)
- else:
- self.attn = WindowAttention(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- attn_drop=attn_drop,
- proj_drop=drop,
- window_size=window_size,
- attn_head_dim=attn_head_dim)
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ input_size=input_size if window_size == 0 else
+ (window_size, window_size),
+ )
+
self.drop_path = DropPath(
- drop_path) if drop_path > 0. else nn.Identity()
+ drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop)
+ hidden_features=int(dim * mlp_ratio),
+ act_layer=act_layer)
- if init_values is not None:
- self.gamma_1 = nn.Parameter(
- init_values * torch.ones((dim)), requires_grad=True)
- self.gamma_2 = nn.Parameter(
- init_values * torch.ones((dim)), requires_grad=True)
- else:
- self.gamma_1, self.gamma_2 = None, None
+ self.window_size = window_size
- def forward(self, x, H, W):
- if self.gamma_1 is None:
- x = x + self.drop_path(self.attn(self.norm1(x), H, W))
- x = x + self.drop_path(self.mlp(self.norm2(x)))
- else:
- x = x + self.drop_path(
- self.gamma_1 * self.attn(self.norm1(x), H, W))
- x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
- if not self.window and self.aggregation != 'attn':
- x = self.conv_aggregation(x, H, W)
- return x
-
-
-class PatchEmbed(nn.Module):
- """ Image to Patch Embedding
- """
-
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- num_patches = (img_size[1] // patch_size[1]) * (
- img_size[0] // patch_size[0])
- self.patch_shape = (img_size[0] // patch_size[0],
- img_size[1] // patch_size[1])
- self.img_size = img_size
- self.patch_size = patch_size
- self.num_patches = num_patches
-
- self.proj = nn.Conv2d(
- in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
-
- def forward(self, x, **kwargs):
- B, C, H, W = x.shape
- # FIXME look at relaxing size constraints
- # assert H == self.img_size[0] and W == self.img_size[1], \
- # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
- x = self.proj(x)
- Hp, Wp = x.shape[2], x.shape[3]
-
- x = x.flatten(2).transpose(1, 2)
- return x, (Hp, Wp)
-
-
-class HybridEmbed(nn.Module):
- """ CNN Feature Map Embedding
- Extract feature map from CNN, flatten, project to embedding dim.
- """
-
- def __init__(self,
- backbone,
- img_size=224,
- feature_size=None,
- in_chans=3,
- embed_dim=768):
- super().__init__()
- assert isinstance(backbone, nn.Module)
- img_size = to_2tuple(img_size)
- self.img_size = img_size
- self.backbone = backbone
- if feature_size is None:
- with torch.no_grad():
- # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
- # map for all networks, the feature metadata has reliable channel and stride info, but using
- # stride to calc feature dim requires info about padding of each stage that isn't captured.
- training = backbone.training
- if training:
- backbone.eval()
- o = self.backbone(
- torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
- feature_size = o.shape[-2:]
- feature_dim = o.shape[1]
- backbone.train(training)
- else:
- feature_size = to_2tuple(feature_size)
- feature_dim = self.backbone.feature_info.channels()[-1]
- self.num_patches = feature_size[0] * feature_size[1]
- self.proj = nn.Linear(feature_dim, embed_dim)
+ self.use_residual_block = use_residual_block
def forward(self, x):
- x = self.backbone(x)[-1]
- x = x.flatten(2).transpose(1, 2)
- x = self.proj(x)
+ shortcut = x
+ x = self.norm1(x)
+ # Window partition
+ if self.window_size > 0:
+ H, W = x.shape[1], x.shape[2]
+ x, pad_hw = window_partition(x, self.window_size)
+
+ x = self.attn(x)
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ if self.use_residual_block:
+ x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+
return x
-class Norm2d(nn.Module):
-
- def __init__(self, embed_dim):
- super().__init__()
- self.ln = nn.LayerNorm(embed_dim, eps=1e-6)
-
- def forward(self, x):
- x = x.permute(0, 2, 3, 1)
- x = self.ln(x)
- x = x.permute(0, 3, 1, 2).contiguous()
- return x
-
-
-# todo: refactor vitdet and vit_transformer_dynamic
@BACKBONES.register_module()
class ViTDet(nn.Module):
- """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+ This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
+ "Exploring Plain Vision Transformer Backbones for Object Detection",
+ https://arxiv.org/abs/2203.16527
"""
- def __init__(self,
- img_size=224,
- patch_size=16,
- in_chans=3,
- num_classes=80,
- embed_dim=768,
- depth=12,
- num_heads=12,
- mlp_ratio=4.,
- qkv_bias=False,
- qk_scale=None,
- drop_rate=0.,
- attn_drop_rate=0.,
- drop_path_rate=0.,
- hybrid_backbone=None,
- norm_layer=None,
- init_values=None,
- use_checkpoint=False,
- use_abs_pos_emb=False,
- use_rel_pos_bias=False,
- use_shared_rel_pos_bias=False,
- out_indices=[11],
- interval=3,
- pretrained=None,
- aggregation='attn'):
+ def __init__(
+ self,
+ img_size=1024,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path_rate=0.0,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ act_layer=nn.GELU,
+ use_abs_pos=True,
+ use_rel_pos=False,
+ rel_pos_zero_init=True,
+ window_size=0,
+ window_block_indexes=(),
+ residual_block_indexes=(),
+ use_act_checkpoint=False,
+ pretrain_img_size=224,
+ pretrain_use_cls_token=True,
+ pretrained=None,
+ ):
+ """
+ Args:
+ img_size (int): Input image size.
+ patch_size (int): Patch size.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): Patch embedding dimension.
+ depth (int): Depth of ViT.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ drop_path_rate (float): Stochastic depth rate.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_abs_pos (bool): If True, use absolute positional embeddings.
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks.
+ window_block_indexes (list): Indexes for blocks using window attention.
+ residual_block_indexes (list): Indexes for blocks using conv propagation.
+ use_act_checkpoint (bool): If True, use activation checkpointing.
+ pretrain_img_size (int): input image size for pretraining models.
+ pretrain_use_cls_token (bool): If True, pretrainig models use class token.
+ """
super().__init__()
- norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
- self.num_classes = num_classes
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.pretrain_use_cls_token = pretrain_use_cls_token
+ self.use_act_checkpoint = use_act_checkpoint
- if hybrid_backbone is not None:
- self.patch_embed = HybridEmbed(
- hybrid_backbone,
- img_size=img_size,
- in_chans=in_chans,
- embed_dim=embed_dim)
- else:
- self.patch_embed = PatchEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=in_chans,
- embed_dim=embed_dim)
+ self.patch_embed = PatchEmbed(
+ kernel_size=(patch_size, patch_size),
+ stride=(patch_size, patch_size),
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
- num_patches = self.patch_embed.num_patches
-
- self.out_indices = out_indices
-
- if use_abs_pos_emb:
+ if use_abs_pos:
+ # Initialize absolute positional embedding with pretrain image size.
+ num_patches = (pretrain_img_size // patch_size) * (
+ pretrain_img_size // patch_size)
+ num_positions = (num_patches +
+ 1) if pretrain_use_cls_token else num_patches
self.pos_embed = nn.Parameter(
- torch.zeros(1, num_patches, embed_dim))
+ torch.zeros(1, num_positions, embed_dim))
else:
self.pos_embed = None
- self.pos_drop = nn.Dropout(p=drop_rate)
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
- ] # stochastic depth decay rule
- self.use_rel_pos_bias = use_rel_pos_bias
- self.use_checkpoint = use_checkpoint
- self.blocks = nn.ModuleList([
- Block(
+ self.blocks = nn.ModuleList()
+ for i in range(depth):
+ block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
- init_values=init_values,
- window_size=(14, 14) if
- ((i + 1) % interval != 0
- or aggregation != 'attn') else self.patch_embed.patch_shape,
- window=((i + 1) % interval != 0),
- aggregation=aggregation) for i in range(depth)
- ])
+ act_layer=act_layer,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ window_size=window_size if i in window_block_indexes else 0,
+ use_residual_block=i in residual_block_indexes,
+ input_size=(img_size // patch_size, img_size // patch_size),
+ )
+ self.blocks.append(block)
if self.pos_embed is not None:
- trunc_normal_(self.pos_embed, std=.02)
-
- self.norm = norm_layer(embed_dim)
+ trunc_normal_(self.pos_embed, std=0.02)
+ self.apply(self._init_weights)
self.pretrained = pretrained
- self._register_load_state_dict_pre_hook(self._prepare_checkpoint_hook)
- def fix_init_weight(self):
-
- def rescale(param, layer_id):
- param.div_(math.sqrt(2.0 * layer_id))
-
- for layer_id, layer in enumerate(self.blocks):
- rescale(layer.attn.proj.weight.data, layer_id + 1)
- rescale(layer.mlp.fc2.weight.data, layer_id + 1)
-
- def init_weights(self, pretrained=None):
- """Initialize the weights in backbone.
- Args:
- pretrained (str, optional): Path to pre-trained weights.
- Defaults to None.
- """
- self.fix_init_weight()
- pretrained = pretrained or self.pretrained
-
- def _init_weights(m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
- if isinstance(m, nn.Conv2d):
- kaiming_init(m, mode='fan_in', nonlinearity='relu')
- elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
- constant_init(m, 1)
-
- if isinstance(m, Bottleneck):
- constant_init(m.norm3, 0)
- elif isinstance(m, BasicBlock):
- constant_init(m.norm2, 0)
-
- if isinstance(pretrained, str):
- self.apply(_init_weights)
+ def init_weights(self):
+ if isinstance(self.pretrained, str):
logger = get_root_logger()
- load_checkpoint(self, pretrained, strict=False, logger=logger)
- elif pretrained is None:
- self.apply(_init_weights)
- else:
- raise TypeError('pretrained must be a str or None')
-
- def _prepare_checkpoint_hook(self, state_dict, prefix, *args, **kwargs):
- rank, _ = get_dist_info()
- if 'pos_embed' in state_dict:
- pos_embed_checkpoint = state_dict['pos_embed']
- embedding_size = pos_embed_checkpoint.shape[-1]
- H, W = self.patch_embed.patch_shape
- num_patches = self.patch_embed.num_patches
- num_extra_tokens = 1
- # height (== width) for the checkpoint position embedding
- orig_size = int(
- (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
- # height (== width) for the new position embedding
- new_size = int(num_patches**0.5)
- # class_token and dist_token are kept unchanged
- if orig_size != new_size:
- if rank == 0:
- print('Position interpolate from %dx%d to %dx%d' %
- (orig_size, orig_size, H, W))
- # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
- # only the position tokens are interpolated
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
- embedding_size).permute(
- 0, 3, 1, 2)
- pos_tokens = torch.nn.functional.interpolate(
- pos_tokens,
- size=(H, W),
- mode='bicubic',
- align_corners=False)
- new_pos_embed = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
- # new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
- state_dict['pos_embed'] = new_pos_embed
-
- def get_num_layers(self):
- return len(self.blocks)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'pos_embed', 'cls_token'}
-
- def forward_features(self, x):
- B, C, H, W = x.shape
- x, (Hp, Wp) = self.patch_embed(x)
- batch_size, seq_len, _ = x.size()
-
- if self.pos_embed is not None:
- x = x + self.pos_embed
- x = self.pos_drop(x)
-
- outs = []
- for i, blk in enumerate(self.blocks):
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x)
- else:
- x = blk(x, Hp, Wp)
-
- x = self.norm(x)
- xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp)
-
- outs.append(xp)
-
- return tuple(outs)
+ load_checkpoint(self, self.pretrained, strict=False, logger=logger)
def forward(self, x):
- x = self.forward_features(x)
- return x
+ x = self.patch_embed(x)
+ if self.pos_embed is not None:
+ x = x + get_abs_pos(self.pos_embed, self.pretrain_use_cls_token,
+ (x.shape[1], x.shape[2]))
+
+ for blk in self.blocks:
+ if self.use_act_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+
+ outputs = [x.permute(0, 3, 1, 2)]
+ return outputs
diff --git a/easycv/models/detection/necks/fpn.py b/easycv/models/detection/necks/fpn.py
index 6d14bbef..8018903c 100644
--- a/easycv/models/detection/necks/fpn.py
+++ b/easycv/models/detection/necks/fpn.py
@@ -37,7 +37,6 @@ class FPN(nn.Module):
Default: None.
upsample_cfg (dict): Config dict for interpolate layer.
Default: dict(mode='nearest').
- init_cfg (dict or list[dict], optional): Initialization config dict.
Example:
>>> import torch
>>> in_channels = [2, 3, 5, 7]
@@ -67,8 +66,6 @@ class FPN(nn.Module):
norm_cfg=None,
act_cfg=None,
upsample_cfg=dict(mode='nearest')):
- # init_cfg=dict(
- # type='Xavier', layer='Conv2d', distribution='uniform')):
super(FPN, self).__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
diff --git a/easycv/models/detection/necks/sfp.py b/easycv/models/detection/necks/sfp.py
index be1273b0..b588f643 100644
--- a/easycv/models/detection/necks/sfp.py
+++ b/easycv/models/detection/necks/sfp.py
@@ -2,26 +2,12 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
-from mmcv.runner import BaseModule
from easycv.models.builder import NECKS
-class Norm2d(nn.Module):
-
- def __init__(self, embed_dim):
- super().__init__()
- self.ln = nn.LayerNorm(embed_dim, eps=1e-6)
-
- def forward(self, x):
- x = x.permute(0, 2, 3, 1)
- x = self.ln(x)
- x = x.permute(0, 3, 1, 2).contiguous()
- return x
-
-
@NECKS.register_module()
-class SFP(BaseModule):
+class SFP(nn.Module):
r"""Simple Feature Pyramid.
This is an implementation of paper `Exploring Plain Vision Transformer Backbones for Object Detection `_.
Args:
@@ -32,25 +18,12 @@ class SFP(BaseModule):
build the feature pyramid. Default: 0.
end_level (int): Index of the end input backbone level (exclusive) to
build the feature pyramid. Default: -1, which means the last level.
- add_extra_convs (bool | str): If bool, it decides whether to add conv
- layers on top of the original feature maps. Default to False.
- If True, it is equivalent to `add_extra_convs='on_input'`.
- If str, it specifies the source feature map of the extra convs.
- Only the following options are allowed
- - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
- - 'on_lateral': Last feature map after lateral convs.
- - 'on_output': The last output feature map after fpn convs.
- relu_before_extra_convs (bool): Whether to apply relu before the extra
conv. Default: False.
- no_norm_on_lateral (bool): Whether to apply norm on lateral.
Default: False.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (str): Config dict for activation layer in ConvModule.
Default: None.
- upsample_cfg (dict): Config dict for interpolate layer.
- Default: `dict(mode='nearest')`
- init_cfg (dict or list[dict], optional): Initialization config dict.
Example:
>>> import torch
>>> in_channels = [2, 3, 5, 7]
@@ -70,158 +43,83 @@ class SFP(BaseModule):
def __init__(self,
in_channels,
out_channels,
+ scale_factors,
num_outs,
- start_level=0,
- end_level=-1,
- add_extra_convs=False,
- relu_before_extra_convs=False,
- no_norm_on_lateral=False,
conv_cfg=None,
norm_cfg=None,
- act_cfg=None,
- upsample_cfg=dict(mode='nearest'),
- init_cfg=[
- dict(
- type='Xavier',
- layer=['Conv2d'],
- distribution='uniform'),
- dict(type='Constant', layer=['LayerNorm'], val=1, bias=0)
- ]):
- super(SFP, self).__init__(init_cfg)
- assert isinstance(in_channels, list)
- self.in_channels = in_channels
+ act_cfg=None):
+ super(SFP, self).__init__()
+ dim = in_channels
self.out_channels = out_channels
- self.num_ins = len(in_channels)
+ self.scale_factors = scale_factors
+ self.num_ins = len(scale_factors)
self.num_outs = num_outs
- self.relu_before_extra_convs = relu_before_extra_convs
- self.no_norm_on_lateral = no_norm_on_lateral
- self.upsample_cfg = upsample_cfg.copy()
- if end_level == -1:
- self.backbone_end_level = self.num_ins
- assert num_outs >= self.num_ins - start_level
- else:
- # if end_level < inputs, no extra level is allowed
- self.backbone_end_level = end_level
- assert end_level <= len(in_channels)
- assert num_outs == end_level - start_level
- self.start_level = start_level
- self.end_level = end_level
- self.add_extra_convs = add_extra_convs
- assert isinstance(add_extra_convs, (str, bool))
- if isinstance(add_extra_convs, str):
- # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
- assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
- elif add_extra_convs: # True
- self.add_extra_convs = 'on_input'
-
- self.top_downs = nn.ModuleList()
- self.lateral_convs = nn.ModuleList()
- self.fpn_convs = nn.ModuleList()
-
- for i in range(self.start_level, self.backbone_end_level):
- if i == 0:
- top_down = nn.Sequential(
+ self.stages = []
+ for idx, scale in enumerate(scale_factors):
+ out_dim = dim
+ if scale == 4.0:
+ layers = [
+ nn.ConvTranspose2d(dim, dim // 2, 2, stride=2, padding=0),
+ nn.GroupNorm(1, dim // 2, eps=1e-6),
+ nn.GELU(),
nn.ConvTranspose2d(
- in_channels[i], in_channels[i], 2, stride=2,
- padding=0), Norm2d(in_channels[i]), nn.GELU(),
- nn.ConvTranspose2d(
- in_channels[i], in_channels[i], 2, stride=2,
- padding=0))
- elif i == 1:
- top_down = nn.ConvTranspose2d(
- in_channels[i], in_channels[i], 2, stride=2, padding=0)
- elif i == 2:
- top_down = nn.Identity()
- elif i == 3:
- top_down = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+ dim // 2, dim // 4, 2, stride=2, padding=0)
+ ]
+ out_dim = dim // 4
+ elif scale == 2.0:
+ layers = [
+ nn.ConvTranspose2d(dim, dim // 2, 2, stride=2, padding=0)
+ ]
+ out_dim = dim // 2
+ elif scale == 1.0:
+ layers = []
+ elif scale == 0.5:
+ layers = [nn.MaxPool2d(kernel_size=2, stride=2, padding=0)]
+ else:
+ raise NotImplementedError(
+ f'scale_factor={scale} is not supported yet.')
- l_conv = ConvModule(
- in_channels[i],
- out_channels,
- 1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
- act_cfg=act_cfg,
- inplace=False)
- fpn_conv = ConvModule(
- out_channels,
- out_channels,
- 3,
- padding=1,
- conv_cfg=conv_cfg,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg,
- inplace=False)
-
- self.top_downs.append(top_down)
- self.lateral_convs.append(l_conv)
- self.fpn_convs.append(fpn_conv)
-
- # add extra conv layers (e.g., RetinaNet)
- extra_levels = num_outs - self.backbone_end_level + self.start_level
- if self.add_extra_convs and extra_levels >= 1:
- for i in range(extra_levels):
- if i == 0 and self.add_extra_convs == 'on_input':
- in_channels = self.in_channels[self.backbone_end_level - 1]
- else:
- in_channels = out_channels
- extra_fpn_conv = ConvModule(
- in_channels,
+ layers.extend([
+ ConvModule(
+ out_dim,
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False),
+ ConvModule(
+ out_channels,
out_channels,
3,
- stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)
- self.fpn_convs.append(extra_fpn_conv)
+ ])
+
+ layers = nn.Sequential(*layers)
+ self.add_module(f'sfp_{idx}', layers)
+ self.stages.append(layers)
+
+ def init_weights(self):
+ pass
def forward(self, inputs):
"""Forward function."""
- assert len(inputs) == 1
+ features = inputs[0]
+ outs = []
- # build top-down path
- features = [
- top_down(inputs[0]) for _, top_down in enumerate(self.top_downs)
- ]
- assert len(features) == len(self.in_channels)
+ # part 1: build simple feature pyramid
+ for stage in self.stages:
+ outs.append(stage(features))
- # build laterals
- laterals = [
- lateral_conv(features[i + self.start_level])
- for i, lateral_conv in enumerate(self.lateral_convs)
- ]
-
- used_backbone_levels = len(laterals)
-
- # build outputs
- # part 1: from original levels
- outs = [
- self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
- ]
# part 2: add extra levels
- if self.num_outs > len(outs):
+ if self.num_outs > self.num_ins:
# use max pool to get more levels on top of outputs
# (e.g., Faster R-CNN, Mask R-CNN)
- if not self.add_extra_convs:
- for i in range(self.num_outs - used_backbone_levels):
- outs.append(F.max_pool2d(outs[-1], 1, stride=2))
- # add conv layers on top of original feature maps (RetinaNet)
- else:
- if self.add_extra_convs == 'on_input':
- extra_source = inputs[self.backbone_end_level - 1]
- elif self.add_extra_convs == 'on_lateral':
- extra_source = laterals[-1]
- elif self.add_extra_convs == 'on_output':
- extra_source = outs[-1]
- else:
- raise NotImplementedError
- outs.append(self.fpn_convs[used_backbone_levels](extra_source))
- for i in range(used_backbone_levels + 1, self.num_outs):
- if self.relu_before_extra_convs:
- outs.append(self.fpn_convs[i](F.relu(outs[-1])))
- else:
- outs.append(self.fpn_convs[i](outs[-1]))
+ for i in range(self.num_outs - self.num_ins):
+ outs.append(F.max_pool2d(outs[-1], 1, stride=2))
return tuple(outs)
diff --git a/easycv/predictors/detector.py b/easycv/predictors/detector.py
index f9d05992..017d671e 100644
--- a/easycv/predictors/detector.py
+++ b/easycv/predictors/detector.py
@@ -253,11 +253,11 @@ class DetrPredictor(PredictorInterface):
img,
bboxes,
labels=labels,
- colors='green',
- text_color='white',
- font_size=20,
- thickness=1,
- font_scale=0.5,
+ colors='cyan',
+ text_color='cyan',
+ font_size=18,
+ thickness=2,
+ font_scale=0.0,
show=show,
out_file=out_file)
diff --git a/tests/models/backbones/test_vitdet.py b/tests/models/backbones/test_vitdet.py
index 3f0350a2..82012aed 100644
--- a/tests/models/backbones/test_vitdet.py
+++ b/tests/models/backbones/test_vitdet.py
@@ -14,18 +14,27 @@ class ViTDetTest(unittest.TestCase):
def test_vitdet(self):
model = ViTDet(
img_size=1024,
+ patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
+ drop_path_rate=0.1,
+ window_size=14,
mlp_ratio=4,
qkv_bias=True,
- qk_scale=None,
- drop_rate=0.,
- attn_drop_rate=0.,
- drop_path_rate=0.1,
- use_abs_pos_emb=True,
- aggregation='attn',
- )
+ window_block_indexes=[
+ # 2, 5, 8 11 for global attention
+ 0,
+ 1,
+ 3,
+ 4,
+ 6,
+ 7,
+ 9,
+ 10,
+ ],
+ residual_block_indexes=[],
+ use_rel_pos=True)
model.init_weights()
model.train()
diff --git a/tests/predictors/test_detector.py b/tests/predictors/test_detector.py
index 9187d3a7..c3be2ed6 100644
--- a/tests/predictors/test_detector.py
+++ b/tests/predictors/test_detector.py
@@ -155,7 +155,7 @@ class DetectorTest(unittest.TestCase):
decimal=1)
def test_vitdet_detector(self):
- model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/vitdet_maskrcnn_export.pth'
+ model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100_export.pth'
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
out_file = './result.jpg'
vitdet = DetrPredictor(model_path)
@@ -167,63 +167,170 @@ class DetectorTest(unittest.TestCase):
self.assertIn('detection_classes', output)
self.assertIn('detection_masks', output)
self.assertIn('img_metas', output)
- self.assertEqual(len(output['detection_boxes'][0]), 30)
- self.assertEqual(len(output['detection_scores'][0]), 30)
- self.assertEqual(len(output['detection_classes'][0]), 30)
+ self.assertEqual(len(output['detection_boxes'][0]), 33)
+ self.assertEqual(len(output['detection_scores'][0]), 33)
+ self.assertEqual(len(output['detection_classes'][0]), 33)
self.assertListEqual(
output['detection_classes'][0].tolist(),
np.array([
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
- 2, 2, 2, 7, 7, 13, 13, 13, 56
+ 2, 2, 2, 2, 2, 2, 7, 7, 13, 13, 13, 56
],
dtype=np.int32).tolist())
assert_array_almost_equal(
output['detection_scores'][0],
np.array([
- 0.99791867, 0.99665856, 0.99480623, 0.99060905, 0.9882515,
- 0.98319584, 0.9738879, 0.97290784, 0.9514897, 0.95104814,
- 0.9321701, 0.86165, 0.8228847, 0.7623552, 0.76129806,
- 0.6050861, 0.44348577, 0.3452973, 0.2895671, 0.22109479,
- 0.21265312, 0.17855245, 0.1205352, 0.08981906, 0.10596471,
- 0.05854294, 0.99749386, 0.9472857, 0.5945908, 0.09855112
+ 0.9975854158401489, 0.9965696334838867, 0.9922919869422913,
+ 0.9833580851554871, 0.983080267906189, 0.970454752445221,
+ 0.9701289534568787, 0.9649872183799744, 0.9642795324325562,
+ 0.9642238020896912, 0.9529680609703064, 0.9403366446495056,
+ 0.9391788244247437, 0.8941807150840759, 0.8178097009658813,
+ 0.8013413548469543, 0.6677654385566711, 0.3952914774417877,
+ 0.33463895320892334, 0.32501447200775146, 0.27323535084724426,
+ 0.20197080075740814, 0.15607696771621704, 0.1068163588643074,
+ 0.10183875262737274, 0.09735643863677979, 0.06559795141220093,
+ 0.08890066295862198, 0.076363705098629, 0.9954648613929749,
+ 0.9212945699691772, 0.5224372148513794, 0.20555885136127472
],
dtype=np.float32),
decimal=2)
assert_array_almost_equal(
output['detection_boxes'][0],
- np.array([[294.7058, 117.29371, 378.83713, 149.99928],
- [609.05444, 112.526474, 633.2971, 136.35175],
- [481.4165, 110.987335, 522.5531, 130.01529],
- [167.68184, 109.89049, 215.49057, 139.86987],
- [374.75082, 110.68697, 433.10028, 136.23654],
- [189.54971, 110.09322, 297.6167, 155.77412],
- [266.5185, 105.37718, 326.54385, 127.916374],
- [556.30225, 110.43166, 592.8248, 128.03764],
- [432.49252, 105.086464, 484.0512, 132.272],
- [0., 110.566444, 62.01249, 146.44017],
- [591.74664, 110.43527, 619.73816, 126.68549],
- [99.126854, 90.947975, 118.46699, 101.11096],
- [59.895264, 94.110054, 85.60521, 106.67633],
- [142.95819, 96.61966, 165.96964, 104.95929],
- [83.062515, 89.802605, 99.1546, 98.69074],
- [226.28802, 98.32568, 249.06772, 108.86408],
- [136.67789, 94.75706, 154.62924, 104.289536],
- [170.42459, 98.458694, 183.16309, 106.203156],
- [67.56731, 89.68286, 82.62955, 98.35645],
- [222.80092, 97.828445, 239.02655, 108.29377],
- [134.34427, 92.31653, 149.19615, 102.97457],
- [613.5186, 102.27066, 636.0434, 112.813644],
- [607.4787, 110.87984, 630.1123, 127.65646],
- [135.13664, 90.989876, 155.67192, 100.18036],
- [431.61505, 105.43844, 484.36508, 132.50078],
- [189.92722, 110.38832, 297.74353, 155.95557],
- [220.67035, 177.13489, 455.32092, 380.45712],
- [372.76584, 134.33807, 432.44357, 188.51534],
- [50.403812, 110.543495, 70.4368, 119.65186],
- [373.50272, 134.27258, 432.18475, 187.81824]]),
+ np.array([[
+ 294.22674560546875, 116.6078109741211, 379.4328918457031,
+ 150.14097595214844
+ ],
+ [
+ 482.6017761230469, 110.75955963134766,
+ 522.8798828125, 129.71286010742188
+ ],
+ [
+ 167.06460571289062, 109.95974731445312,
+ 212.83975219726562, 140.16102600097656
+ ],
+ [
+ 609.2930908203125, 113.13909149169922,
+ 637.3115844726562, 136.4690704345703
+ ],
+ [
+ 191.185791015625, 111.1408920288086, 301.31689453125,
+ 155.7731170654297
+ ],
+ [
+ 431.2244873046875, 106.19962310791016,
+ 483.860595703125, 132.21627807617188
+ ],
+ [
+ 267.48358154296875, 105.5920639038086,
+ 325.2832336425781, 127.11176300048828
+ ],
+ [
+ 591.2138671875, 110.29329681396484,
+ 619.8524169921875, 126.1990966796875
+ ],
+ [
+ 0.0, 110.7026596069336, 61.487945556640625,
+ 146.33018493652344
+ ],
+ [
+ 555.9155883789062, 110.03486633300781,
+ 591.7050170898438, 127.06097412109375
+ ],
+ [
+ 60.24559783935547, 94.12760162353516,
+ 85.63741302490234, 106.66705322265625
+ ],
+ [
+ 99.02665710449219, 90.53657531738281,
+ 118.83953094482422, 101.18717956542969
+ ],
+ [
+ 396.30438232421875, 111.59194946289062,
+ 431.559814453125, 133.96914672851562
+ ],
+ [
+ 83.81543731689453, 89.65665435791016,
+ 99.9166259765625, 98.25627899169922
+ ],
+ [
+ 139.29647827148438, 96.68000793457031,
+ 165.22410583496094, 105.60000610351562
+ ],
+ [
+ 67.27152252197266, 89.42798614501953,
+ 83.25617980957031, 98.0460205078125
+ ],
+ [
+ 223.74176025390625, 98.68321990966797,
+ 250.42506408691406, 109.32588958740234
+ ],
+ [
+ 136.7582244873047, 96.51412963867188,
+ 152.51190185546875, 104.73160552978516
+ ],
+ [
+ 221.71812438964844, 97.86445617675781,
+ 238.9705810546875, 106.96803283691406
+ ],
+ [
+ 135.06964111328125, 91.80916595458984, 155.24609375,
+ 102.20686340332031
+ ],
+ [
+ 169.11180114746094, 97.53628540039062,
+ 182.88504028320312, 105.95404815673828
+ ],
+ [
+ 133.8811798095703, 91.00375366210938,
+ 145.35507202148438, 102.3780288696289
+ ],
+ [
+ 614.2507934570312, 102.19828796386719,
+ 636.5692749023438, 112.59198760986328
+ ],
+ [
+ 35.94759750366211, 91.7213363647461,
+ 70.38274383544922, 117.19855499267578
+ ],
+ [
+ 554.6401977539062, 115.18976593017578,
+ 562.0255737304688, 127.4429931640625
+ ],
+ [
+ 39.07550811767578, 92.73261260986328,
+ 85.36636352539062, 106.73953247070312
+ ],
+ [
+ 200.85513305664062, 93.00469970703125,
+ 219.73086547851562, 107.99642181396484
+ ],
+ [
+ 0.0, 111.18904876708984, 61.7393684387207,
+ 146.72547912597656
+ ],
+ [
+ 191.88568115234375, 111.09577178955078,
+ 299.4097900390625, 155.14639282226562
+ ],
+ [
+ 221.06834411621094, 176.6427001953125,
+ 458.3475341796875, 378.89300537109375
+ ],
+ [
+ 372.7131652832031, 135.51429748535156,
+ 433.2494201660156, 188.0106658935547
+ ],
+ [
+ 52.19819641113281, 110.3646011352539,
+ 70.95110321044922, 120.10567474365234
+ ],
+ [
+ 376.1671447753906, 133.6930694580078,
+ 432.2721862792969, 187.99481201171875
+ ]]),
decimal=1)