mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
Refactor ViTDet backbone and simple feature pyramid (#177)
1. The vitdet backbone implemented by d2 is about 20% faster than the vitdet backbone originally reproduced by easycv. 2. 50.57 -> 50.65
This commit is contained in:
parent
1d5edf6d78
commit
9f01a37ad4
@ -101,13 +101,15 @@ val_dataset = dict(
|
|||||||
pipeline=test_pipeline)
|
pipeline=test_pipeline)
|
||||||
|
|
||||||
data = dict(
|
data = dict(
|
||||||
imgs_per_gpu=1, workers_per_gpu=2, train=train_dataset, val=val_dataset)
|
imgs_per_gpu=4, workers_per_gpu=2, train=train_dataset, val=val_dataset
|
||||||
|
) # 64(total batch size) = 4 (batch size/per gpu) x 8 (gpu num) x 2(node)
|
||||||
|
|
||||||
# evaluation
|
# evaluation
|
||||||
eval_config = dict(interval=1, gpu_collect=False)
|
eval_config = dict(initial=False, interval=1, gpu_collect=False)
|
||||||
eval_pipelines = [
|
eval_pipelines = [
|
||||||
dict(
|
dict(
|
||||||
mode='test',
|
mode='test',
|
||||||
|
# dist_eval=True,
|
||||||
evaluators=[
|
evaluators=[
|
||||||
dict(type='CocoDetectionEvaluator', classes=CLASSES),
|
dict(type='CocoDetectionEvaluator', classes=CLASSES),
|
||||||
],
|
],
|
||||||
|
@ -101,13 +101,15 @@ val_dataset = dict(
|
|||||||
pipeline=test_pipeline)
|
pipeline=test_pipeline)
|
||||||
|
|
||||||
data = dict(
|
data = dict(
|
||||||
imgs_per_gpu=1, workers_per_gpu=2, train=train_dataset, val=val_dataset)
|
imgs_per_gpu=4, workers_per_gpu=2, train=train_dataset, val=val_dataset
|
||||||
|
) # 64(total batch size) = 4 (batch size/per gpu) x 8 (gpu num) x 2(node)
|
||||||
|
|
||||||
# evaluation
|
# evaluation
|
||||||
eval_config = dict(interval=1, gpu_collect=False)
|
eval_config = dict(initial=False, interval=1, gpu_collect=False)
|
||||||
eval_pipelines = [
|
eval_pipelines = [
|
||||||
dict(
|
dict(
|
||||||
mode='test',
|
mode='test',
|
||||||
|
# dist_eval=True,
|
||||||
evaluators=[
|
evaluators=[
|
||||||
dict(type='CocoDetectionEvaluator', classes=CLASSES),
|
dict(type='CocoDetectionEvaluator', classes=CLASSES),
|
||||||
dict(type='CocoMaskEvaluator', classes=CLASSES)
|
dict(type='CocoMaskEvaluator', classes=CLASSES)
|
||||||
|
@ -1,3 +0,0 @@
|
|||||||
_base_ = './vitdet_100e.py'
|
|
||||||
|
|
||||||
model = dict(backbone=dict(aggregation='basicblock'))
|
|
@ -1,3 +0,0 @@
|
|||||||
_base_ = './vitdet_100e.py'
|
|
||||||
|
|
||||||
model = dict(backbone=dict(aggregation='bottleneck'))
|
|
231
configs/detection/vitdet/vitdet_cascade_mask_rcnn.py
Normal file
231
configs/detection/vitdet/vitdet_cascade_mask_rcnn.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
# model settings
|
||||||
|
|
||||||
|
norm_cfg = dict(type='GN', num_groups=1, eps=1e-6, requires_grad=True)
|
||||||
|
|
||||||
|
pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth'
|
||||||
|
model = dict(
|
||||||
|
type='CascadeRCNN',
|
||||||
|
pretrained=pretrained,
|
||||||
|
backbone=dict(
|
||||||
|
type='ViTDet',
|
||||||
|
img_size=1024,
|
||||||
|
patch_size=16,
|
||||||
|
embed_dim=768,
|
||||||
|
depth=12,
|
||||||
|
num_heads=12,
|
||||||
|
drop_path_rate=0.1,
|
||||||
|
window_size=14,
|
||||||
|
mlp_ratio=4,
|
||||||
|
qkv_bias=True,
|
||||||
|
window_block_indexes=[
|
||||||
|
# 2, 5, 8 11 for global attention
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
],
|
||||||
|
residual_block_indexes=[],
|
||||||
|
use_rel_pos=True),
|
||||||
|
neck=dict(
|
||||||
|
type='SFP',
|
||||||
|
in_channels=768,
|
||||||
|
out_channels=256,
|
||||||
|
scale_factors=(4.0, 2.0, 1.0, 0.5),
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
num_outs=5),
|
||||||
|
rpn_head=dict(
|
||||||
|
type='RPNHead',
|
||||||
|
in_channels=256,
|
||||||
|
feat_channels=256,
|
||||||
|
num_convs=2,
|
||||||
|
anchor_generator=dict(
|
||||||
|
type='AnchorGenerator',
|
||||||
|
scales=[8],
|
||||||
|
ratios=[0.5, 1.0, 2.0],
|
||||||
|
strides=[4, 8, 16, 32, 64]),
|
||||||
|
bbox_coder=dict(
|
||||||
|
type='DeltaXYWHBBoxCoder',
|
||||||
|
target_means=[.0, .0, .0, .0],
|
||||||
|
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
||||||
|
loss_cls=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
||||||
|
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
|
||||||
|
roi_head=dict(
|
||||||
|
type='CascadeRoIHead',
|
||||||
|
num_stages=3,
|
||||||
|
stage_loss_weights=[1, 0.5, 0.25],
|
||||||
|
bbox_roi_extractor=dict(
|
||||||
|
type='SingleRoIExtractor',
|
||||||
|
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
||||||
|
out_channels=256,
|
||||||
|
featmap_strides=[4, 8, 16, 32]),
|
||||||
|
bbox_head=[
|
||||||
|
dict(
|
||||||
|
type='Shared4Conv1FCBBoxHead',
|
||||||
|
conv_out_channels=256,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
in_channels=256,
|
||||||
|
fc_out_channels=1024,
|
||||||
|
roi_feat_size=7,
|
||||||
|
num_classes=80,
|
||||||
|
bbox_coder=dict(
|
||||||
|
type='DeltaXYWHBBoxCoder',
|
||||||
|
target_means=[0., 0., 0., 0.],
|
||||||
|
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
||||||
|
reg_class_agnostic=True,
|
||||||
|
loss_cls=dict(
|
||||||
|
type='CrossEntropyLoss',
|
||||||
|
use_sigmoid=False,
|
||||||
|
loss_weight=1.0),
|
||||||
|
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
|
||||||
|
loss_weight=1.0)),
|
||||||
|
dict(
|
||||||
|
type='Shared4Conv1FCBBoxHead',
|
||||||
|
conv_out_channels=256,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
in_channels=256,
|
||||||
|
fc_out_channels=1024,
|
||||||
|
roi_feat_size=7,
|
||||||
|
num_classes=80,
|
||||||
|
bbox_coder=dict(
|
||||||
|
type='DeltaXYWHBBoxCoder',
|
||||||
|
target_means=[0., 0., 0., 0.],
|
||||||
|
target_stds=[0.05, 0.05, 0.1, 0.1]),
|
||||||
|
reg_class_agnostic=True,
|
||||||
|
loss_cls=dict(
|
||||||
|
type='CrossEntropyLoss',
|
||||||
|
use_sigmoid=False,
|
||||||
|
loss_weight=1.0),
|
||||||
|
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
|
||||||
|
loss_weight=1.0)),
|
||||||
|
dict(
|
||||||
|
type='Shared4Conv1FCBBoxHead',
|
||||||
|
conv_out_channels=256,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
in_channels=256,
|
||||||
|
fc_out_channels=1024,
|
||||||
|
roi_feat_size=7,
|
||||||
|
num_classes=80,
|
||||||
|
bbox_coder=dict(
|
||||||
|
type='DeltaXYWHBBoxCoder',
|
||||||
|
target_means=[0., 0., 0., 0.],
|
||||||
|
target_stds=[0.033, 0.033, 0.067, 0.067]),
|
||||||
|
reg_class_agnostic=True,
|
||||||
|
loss_cls=dict(
|
||||||
|
type='CrossEntropyLoss',
|
||||||
|
use_sigmoid=False,
|
||||||
|
loss_weight=1.0),
|
||||||
|
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
|
||||||
|
],
|
||||||
|
mask_roi_extractor=dict(
|
||||||
|
type='SingleRoIExtractor',
|
||||||
|
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
|
||||||
|
out_channels=256,
|
||||||
|
featmap_strides=[4, 8, 16, 32]),
|
||||||
|
mask_head=dict(
|
||||||
|
type='FCNMaskHead',
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
num_convs=4,
|
||||||
|
in_channels=256,
|
||||||
|
conv_out_channels=256,
|
||||||
|
num_classes=80,
|
||||||
|
loss_mask=dict(
|
||||||
|
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
|
||||||
|
# model training and testing settings
|
||||||
|
train_cfg=dict(
|
||||||
|
rpn=dict(
|
||||||
|
assigner=dict(
|
||||||
|
type='MaxIoUAssigner',
|
||||||
|
pos_iou_thr=0.7,
|
||||||
|
neg_iou_thr=0.3,
|
||||||
|
min_pos_iou=0.3,
|
||||||
|
match_low_quality=True,
|
||||||
|
ignore_iof_thr=-1),
|
||||||
|
sampler=dict(
|
||||||
|
type='RandomSampler',
|
||||||
|
num=256,
|
||||||
|
pos_fraction=0.5,
|
||||||
|
neg_pos_ub=-1,
|
||||||
|
add_gt_as_proposals=False),
|
||||||
|
allowed_border=0,
|
||||||
|
pos_weight=-1,
|
||||||
|
debug=False),
|
||||||
|
rpn_proposal=dict(
|
||||||
|
nms_pre=2000,
|
||||||
|
max_per_img=2000,
|
||||||
|
nms=dict(type='nms', iou_threshold=0.7),
|
||||||
|
min_bbox_size=0),
|
||||||
|
rcnn=[
|
||||||
|
dict(
|
||||||
|
assigner=dict(
|
||||||
|
type='MaxIoUAssigner',
|
||||||
|
pos_iou_thr=0.5,
|
||||||
|
neg_iou_thr=0.5,
|
||||||
|
min_pos_iou=0.5,
|
||||||
|
match_low_quality=False,
|
||||||
|
ignore_iof_thr=-1),
|
||||||
|
sampler=dict(
|
||||||
|
type='RandomSampler',
|
||||||
|
num=512,
|
||||||
|
pos_fraction=0.25,
|
||||||
|
neg_pos_ub=-1,
|
||||||
|
add_gt_as_proposals=True),
|
||||||
|
mask_size=28,
|
||||||
|
pos_weight=-1,
|
||||||
|
debug=False),
|
||||||
|
dict(
|
||||||
|
assigner=dict(
|
||||||
|
type='MaxIoUAssigner',
|
||||||
|
pos_iou_thr=0.6,
|
||||||
|
neg_iou_thr=0.6,
|
||||||
|
min_pos_iou=0.6,
|
||||||
|
match_low_quality=False,
|
||||||
|
ignore_iof_thr=-1),
|
||||||
|
sampler=dict(
|
||||||
|
type='RandomSampler',
|
||||||
|
num=512,
|
||||||
|
pos_fraction=0.25,
|
||||||
|
neg_pos_ub=-1,
|
||||||
|
add_gt_as_proposals=True),
|
||||||
|
mask_size=28,
|
||||||
|
pos_weight=-1,
|
||||||
|
debug=False),
|
||||||
|
dict(
|
||||||
|
assigner=dict(
|
||||||
|
type='MaxIoUAssigner',
|
||||||
|
pos_iou_thr=0.7,
|
||||||
|
neg_iou_thr=0.7,
|
||||||
|
min_pos_iou=0.7,
|
||||||
|
match_low_quality=False,
|
||||||
|
ignore_iof_thr=-1),
|
||||||
|
sampler=dict(
|
||||||
|
type='RandomSampler',
|
||||||
|
num=512,
|
||||||
|
pos_fraction=0.25,
|
||||||
|
neg_pos_ub=-1,
|
||||||
|
add_gt_as_proposals=True),
|
||||||
|
mask_size=28,
|
||||||
|
pos_weight=-1,
|
||||||
|
debug=False)
|
||||||
|
]),
|
||||||
|
test_cfg=dict(
|
||||||
|
rpn=dict(
|
||||||
|
nms_pre=1000,
|
||||||
|
max_per_img=1000,
|
||||||
|
nms=dict(type='nms', iou_threshold=0.7),
|
||||||
|
min_bbox_size=0),
|
||||||
|
rcnn=dict(
|
||||||
|
score_thr=0.05,
|
||||||
|
nms=dict(type='nms', iou_threshold=0.5),
|
||||||
|
max_per_img=100,
|
||||||
|
mask_thr_binary=0.5)))
|
||||||
|
|
||||||
|
mmlab_modules = [
|
||||||
|
dict(type='mmdet', name='CascadeRCNN', module='model'),
|
||||||
|
dict(type='mmdet', name='RPNHead', module='head'),
|
||||||
|
dict(type='mmdet', name='CascadeRoIHead', module='head'),
|
||||||
|
]
|
@ -0,0 +1,4 @@
|
|||||||
|
_base_ = [
|
||||||
|
'./vitdet_cascade_mask_rcnn.py', './lsj_coco_instance.py',
|
||||||
|
'./vitdet_schedule_100e.py'
|
||||||
|
]
|
@ -1,6 +1,6 @@
|
|||||||
# model settings
|
# model settings
|
||||||
|
|
||||||
norm_cfg = dict(type='GN', num_groups=1, requires_grad=True)
|
norm_cfg = dict(type='GN', num_groups=1, eps=1e-6, requires_grad=True)
|
||||||
|
|
||||||
pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth'
|
pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth'
|
||||||
model = dict(
|
model = dict(
|
||||||
@ -9,22 +9,32 @@ model = dict(
|
|||||||
backbone=dict(
|
backbone=dict(
|
||||||
type='ViTDet',
|
type='ViTDet',
|
||||||
img_size=1024,
|
img_size=1024,
|
||||||
|
patch_size=16,
|
||||||
embed_dim=768,
|
embed_dim=768,
|
||||||
depth=12,
|
depth=12,
|
||||||
num_heads=12,
|
num_heads=12,
|
||||||
|
drop_path_rate=0.1,
|
||||||
|
window_size=14,
|
||||||
mlp_ratio=4,
|
mlp_ratio=4,
|
||||||
qkv_bias=True,
|
qkv_bias=True,
|
||||||
qk_scale=None,
|
window_block_indexes=[
|
||||||
drop_rate=0.,
|
# 2, 5, 8 11 for global attention
|
||||||
attn_drop_rate=0.,
|
0,
|
||||||
drop_path_rate=0.1,
|
1,
|
||||||
use_abs_pos_emb=True,
|
3,
|
||||||
aggregation='attn',
|
4,
|
||||||
),
|
6,
|
||||||
|
7,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
],
|
||||||
|
residual_block_indexes=[],
|
||||||
|
use_rel_pos=True),
|
||||||
neck=dict(
|
neck=dict(
|
||||||
type='SFP',
|
type='SFP',
|
||||||
in_channels=[768, 768, 768, 768],
|
in_channels=768,
|
||||||
out_channels=256,
|
out_channels=256,
|
||||||
|
scale_factors=(4.0, 2.0, 1.0, 0.5),
|
||||||
norm_cfg=norm_cfg,
|
norm_cfg=norm_cfg,
|
||||||
num_outs=5),
|
num_outs=5),
|
||||||
rpn_head=dict(
|
rpn_head=dict(
|
||||||
@ -32,7 +42,6 @@ model = dict(
|
|||||||
in_channels=256,
|
in_channels=256,
|
||||||
feat_channels=256,
|
feat_channels=256,
|
||||||
num_convs=2,
|
num_convs=2,
|
||||||
norm_cfg=norm_cfg,
|
|
||||||
anchor_generator=dict(
|
anchor_generator=dict(
|
||||||
type='AnchorGenerator',
|
type='AnchorGenerator',
|
||||||
scales=[8],
|
scales=[8],
|
||||||
@ -98,7 +107,7 @@ model = dict(
|
|||||||
pos_iou_thr=0.5,
|
pos_iou_thr=0.5,
|
||||||
neg_iou_thr=0.5,
|
neg_iou_thr=0.5,
|
||||||
min_pos_iou=0.5,
|
min_pos_iou=0.5,
|
||||||
match_low_quality=True,
|
match_low_quality=False,
|
||||||
ignore_iof_thr=-1),
|
ignore_iof_thr=-1),
|
||||||
sampler=dict(
|
sampler=dict(
|
||||||
type='RandomSampler',
|
type='RandomSampler',
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
_base_ = [
|
_base_ = [
|
||||||
'./vitdet_faster_rcnn.py', './lsj_coco_detection.py',
|
'./vitdet_faster_rcnn.py', './lsj_coco_instance.py',
|
||||||
'./vitdet_schedule_100e.py'
|
'./vitdet_schedule_100e.py'
|
||||||
]
|
]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# model settings
|
# model settings
|
||||||
|
|
||||||
norm_cfg = dict(type='GN', num_groups=1, requires_grad=True)
|
norm_cfg = dict(type='GN', num_groups=1, eps=1e-6, requires_grad=True)
|
||||||
|
|
||||||
pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth'
|
pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth'
|
||||||
model = dict(
|
model = dict(
|
||||||
@ -9,22 +9,32 @@ model = dict(
|
|||||||
backbone=dict(
|
backbone=dict(
|
||||||
type='ViTDet',
|
type='ViTDet',
|
||||||
img_size=1024,
|
img_size=1024,
|
||||||
|
patch_size=16,
|
||||||
embed_dim=768,
|
embed_dim=768,
|
||||||
depth=12,
|
depth=12,
|
||||||
num_heads=12,
|
num_heads=12,
|
||||||
|
drop_path_rate=0.1,
|
||||||
|
window_size=14,
|
||||||
mlp_ratio=4,
|
mlp_ratio=4,
|
||||||
qkv_bias=True,
|
qkv_bias=True,
|
||||||
qk_scale=None,
|
window_block_indexes=[
|
||||||
drop_rate=0.,
|
# 2, 5, 8 11 for global attention
|
||||||
attn_drop_rate=0.,
|
0,
|
||||||
drop_path_rate=0.1,
|
1,
|
||||||
use_abs_pos_emb=True,
|
3,
|
||||||
aggregation='attn',
|
4,
|
||||||
),
|
6,
|
||||||
|
7,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
],
|
||||||
|
residual_block_indexes=[],
|
||||||
|
use_rel_pos=True),
|
||||||
neck=dict(
|
neck=dict(
|
||||||
type='SFP',
|
type='SFP',
|
||||||
in_channels=[768, 768, 768, 768],
|
in_channels=768,
|
||||||
out_channels=256,
|
out_channels=256,
|
||||||
|
scale_factors=(4.0, 2.0, 1.0, 0.5),
|
||||||
norm_cfg=norm_cfg,
|
norm_cfg=norm_cfg,
|
||||||
num_outs=5),
|
num_outs=5),
|
||||||
rpn_head=dict(
|
rpn_head=dict(
|
||||||
@ -32,7 +42,6 @@ model = dict(
|
|||||||
in_channels=256,
|
in_channels=256,
|
||||||
feat_channels=256,
|
feat_channels=256,
|
||||||
num_convs=2,
|
num_convs=2,
|
||||||
norm_cfg=norm_cfg,
|
|
||||||
anchor_generator=dict(
|
anchor_generator=dict(
|
||||||
type='AnchorGenerator',
|
type='AnchorGenerator',
|
||||||
scales=[8],
|
scales=[8],
|
||||||
@ -112,7 +121,7 @@ model = dict(
|
|||||||
pos_iou_thr=0.5,
|
pos_iou_thr=0.5,
|
||||||
neg_iou_thr=0.5,
|
neg_iou_thr=0.5,
|
||||||
min_pos_iou=0.5,
|
min_pos_iou=0.5,
|
||||||
match_low_quality=True,
|
match_low_quality=False,
|
||||||
ignore_iof_thr=-1),
|
ignore_iof_thr=-1),
|
||||||
sampler=dict(
|
sampler=dict(
|
||||||
type='RandomSampler',
|
type='RandomSampler',
|
||||||
|
@ -1,26 +1,29 @@
|
|||||||
_base_ = 'configs/base.py'
|
_base_ = 'configs/base.py'
|
||||||
|
|
||||||
|
log_config = dict(
|
||||||
|
interval=200,
|
||||||
|
hooks=[
|
||||||
|
dict(type='TextLoggerHook'),
|
||||||
|
# dict(type='TensorboardLoggerHook')
|
||||||
|
])
|
||||||
|
|
||||||
checkpoint_config = dict(interval=10)
|
checkpoint_config = dict(interval=10)
|
||||||
|
|
||||||
# optimizer
|
# optimizer
|
||||||
paramwise_options = {
|
|
||||||
'norm': dict(weight_decay=0.),
|
|
||||||
'bias': dict(weight_decay=0.),
|
|
||||||
'pos_embed': dict(weight_decay=0.),
|
|
||||||
'cls_token': dict(weight_decay=0.)
|
|
||||||
}
|
|
||||||
optimizer = dict(
|
optimizer = dict(
|
||||||
type='AdamW',
|
type='AdamW',
|
||||||
lr=1e-4,
|
lr=1e-4,
|
||||||
betas=(0.9, 0.999),
|
betas=(0.9, 0.999),
|
||||||
weight_decay=0.1,
|
weight_decay=0.1,
|
||||||
paramwise_options=paramwise_options)
|
constructor='LayerDecayOptimizerConstructor',
|
||||||
optimizer_config = dict(grad_clip=None, loss_scale=512.)
|
paramwise_options=dict(num_layers=12, layer_decay_rate=0.7))
|
||||||
|
optimizer_config = dict(grad_clip=None)
|
||||||
# learning policy
|
# learning policy
|
||||||
lr_config = dict(
|
lr_config = dict(
|
||||||
policy='step',
|
policy='step',
|
||||||
warmup='linear',
|
warmup='linear',
|
||||||
warmup_iters=250,
|
warmup_iters=250,
|
||||||
warmup_ratio=0.067,
|
warmup_ratio=0.001,
|
||||||
step=[88, 96])
|
step=[88, 96])
|
||||||
total_epochs = 100
|
total_epochs = 100
|
||||||
|
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:ee64c0caef841c61c7e6344b7fe2c07a38fba07a8de81ff38c0686c641e0a283
|
oid sha256:c696a58a2963b5ac47317751f04ff45bfed4723f2f70bacf91eac711f9710e54
|
||||||
size 190356
|
size 189432
|
||||||
|
@ -22,7 +22,7 @@ Pretrained on COCO2017 dataset. (The result has been optimized with PAI-Blade, a
|
|||||||
|
|
||||||
| Algorithm | Config | Params<br/>(backbone/total) | inference time(V100)<br/>(ms/img) | bbox_mAP<sup>val<br/><sub>0.5:0.95</sub> | mask_mAP<sup>val<br/><sub>0.5:0.95</sub> | Download |
|
| Algorithm | Config | Params<br/>(backbone/total) | inference time(V100)<br/>(ms/img) | bbox_mAP<sup>val<br/><sub>0.5:0.95</sub> | mask_mAP<sup>val<br/><sub>0.5:0.95</sub> | Download |
|
||||||
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||||
| ViTDet_MaskRCNN | [vitdet_maskrcnn](https://github.com/alibaba/EasyCV/tree/master/configs/detection/vitdet/vitdet_100e.py) | 88M/118M | 163ms | 50.57 | 44.96 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/vitdet_maskrcnn.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/vitdet_maskrcnn.log.json) |
|
| ViTDet_MaskRCNN | [vitdet_maskrcnn](https://github.com/alibaba/EasyCV/tree/master/configs/detection/vitdet/vitdet_mask_rcnn_100e.py) | 86M/111M | 138ms | 50.65 | 45.41 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/20220901_135827.log.json) |
|
||||||
|
|
||||||
## FCOS
|
## FCOS
|
||||||
|
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
# Reference from https://github.com/ViTAE-Transformer/ViTDet/blob/main/mmcv_custom/layer_decay_optimizer_constructor.py
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
|
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
|
||||||
@ -7,23 +5,32 @@ from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
|
|||||||
from .builder import OPTIMIZER_BUILDERS
|
from .builder import OPTIMIZER_BUILDERS
|
||||||
|
|
||||||
|
|
||||||
def get_num_layer_for_vit(var_name, num_max_layer, layer_sep=None):
|
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
|
||||||
if var_name in ('backbone.cls_token', 'backbone.mask_token',
|
"""
|
||||||
'backbone.pos_embed'):
|
Calculate lr decay rate for different ViT blocks.
|
||||||
return 0
|
Reference from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py
|
||||||
elif var_name.startswith('backbone.patch_embed'):
|
Args:
|
||||||
return 0
|
name (string): parameter name.
|
||||||
elif var_name.startswith('backbone.blocks'):
|
lr_decay_rate (float): base lr decay rate.
|
||||||
layer_id = int(var_name.split('.')[2])
|
num_layers (int): number of ViT blocks.
|
||||||
return layer_id + 1
|
Returns:
|
||||||
else:
|
lr decay rate for the given parameter.
|
||||||
return num_max_layer - 1
|
"""
|
||||||
|
layer_id = num_layers + 1
|
||||||
|
if '.pos_embed' in name or '.patch_embed' in name:
|
||||||
|
layer_id = 0
|
||||||
|
elif '.blocks.' in name and '.residual.' not in name:
|
||||||
|
layer_id = int(name[name.find('.blocks.'):].split('.')[2]) + 1
|
||||||
|
|
||||||
|
scale = lr_decay_rate**(num_layers + 1 - layer_id)
|
||||||
|
|
||||||
|
return layer_id, scale
|
||||||
|
|
||||||
|
|
||||||
@OPTIMIZER_BUILDERS.register_module()
|
@OPTIMIZER_BUILDERS.register_module()
|
||||||
class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
|
class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
|
||||||
|
|
||||||
def add_params(self, params, module, prefix='', is_dcn_module=None):
|
def add_params(self, params, module):
|
||||||
"""Add all parameters of module to the params list.
|
"""Add all parameters of module to the params list.
|
||||||
The parameters of the given module will be added to the list of param
|
The parameters of the given module will be added to the list of param
|
||||||
groups, with specific rules defined by paramwise_cfg.
|
groups, with specific rules defined by paramwise_cfg.
|
||||||
@ -31,54 +38,41 @@ class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
|
|||||||
params (list[dict]): A list of param groups, it will be modified
|
params (list[dict]): A list of param groups, it will be modified
|
||||||
in place.
|
in place.
|
||||||
module (nn.Module): The module to be added.
|
module (nn.Module): The module to be added.
|
||||||
prefix (str): The prefix of the module
|
|
||||||
is_dcn_module (int|float|None): If the current module is a
|
Reference from https://github.com/ViTAE-Transformer/ViTDet/blob/main/mmcv_custom/layer_decay_optimizer_constructor.py
|
||||||
submodule of DCN, `is_dcn_module` will be passed to
|
Note: Currently, this optimizer constructor is built for ViTDet.
|
||||||
control conv_offset layer's learning rate. Defaults to None.
|
|
||||||
"""
|
"""
|
||||||
# get param-wise options
|
|
||||||
|
|
||||||
parameter_groups = {}
|
parameter_groups = {}
|
||||||
print(self.paramwise_cfg)
|
print(self.paramwise_cfg)
|
||||||
num_layers = self.paramwise_cfg.get('num_layers') + 2
|
lr_decay_rate = self.paramwise_cfg.get('layer_decay_rate')
|
||||||
layer_sep = self.paramwise_cfg.get('layer_sep', None)
|
num_layers = self.paramwise_cfg.get('num_layers')
|
||||||
layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate')
|
|
||||||
print('Build LayerDecayOptimizerConstructor %f - %d' %
|
print('Build LayerDecayOptimizerConstructor %f - %d' %
|
||||||
(layer_decay_rate, num_layers))
|
(lr_decay_rate, num_layers))
|
||||||
|
lr = self.base_lr
|
||||||
weight_decay = self.base_wd
|
weight_decay = self.base_wd
|
||||||
|
|
||||||
custom_keys = self.paramwise_cfg.get('custom_keys', {})
|
|
||||||
# first sort with alphabet order and then sort with reversed len of str
|
|
||||||
sorted_keys = sorted(custom_keys.keys())
|
|
||||||
|
|
||||||
for name, param in module.named_parameters():
|
for name, param in module.named_parameters():
|
||||||
|
|
||||||
if not param.requires_grad:
|
if not param.requires_grad:
|
||||||
continue # frozen weights
|
continue # frozen weights
|
||||||
|
|
||||||
if len(param.shape) == 1 or name.endswith('.bias') or (
|
if 'backbone' in name and ('.norm' in name or '.pos_embed' in name
|
||||||
'pos_embed' in name) or ('cls_token'
|
or '.gn.' in name or '.ln.' in name):
|
||||||
in name) or ('rel_pos_' in name):
|
|
||||||
group_name = 'no_decay'
|
group_name = 'no_decay'
|
||||||
this_weight_decay = 0.
|
this_weight_decay = 0.
|
||||||
else:
|
else:
|
||||||
group_name = 'decay'
|
group_name = 'decay'
|
||||||
this_weight_decay = weight_decay
|
this_weight_decay = weight_decay
|
||||||
|
|
||||||
layer_id = get_num_layer_for_vit(name, num_layers, layer_sep)
|
if name.startswith('backbone'):
|
||||||
|
layer_id, scale = get_vit_lr_decay_rate(
|
||||||
|
name, lr_decay_rate=lr_decay_rate, num_layers=num_layers)
|
||||||
|
else:
|
||||||
|
layer_id, scale = -1, 1
|
||||||
group_name = 'layer_%d_%s' % (layer_id, group_name)
|
group_name = 'layer_%d_%s' % (layer_id, group_name)
|
||||||
|
|
||||||
# if the parameter match one of the custom keys, ignore other rules
|
|
||||||
this_lr_multi = 1.
|
|
||||||
for key in sorted_keys:
|
|
||||||
if key in f'{name}':
|
|
||||||
lr_mult = custom_keys[key].get('lr_mult', 1.)
|
|
||||||
this_lr_multi = lr_mult
|
|
||||||
group_name = '%s_%s' % (group_name, key)
|
|
||||||
break
|
|
||||||
|
|
||||||
if group_name not in parameter_groups:
|
if group_name not in parameter_groups:
|
||||||
scale = layer_decay_rate**(num_layers - layer_id - 1)
|
|
||||||
|
|
||||||
parameter_groups[group_name] = {
|
parameter_groups[group_name] = {
|
||||||
'weight_decay': this_weight_decay,
|
'weight_decay': this_weight_decay,
|
||||||
@ -86,7 +80,7 @@ class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
|
|||||||
'param_names': [],
|
'param_names': [],
|
||||||
'lr_scale': scale,
|
'lr_scale': scale,
|
||||||
'group_name': group_name,
|
'group_name': group_name,
|
||||||
'lr': scale * self.base_lr * this_lr_multi,
|
'lr': scale * lr,
|
||||||
}
|
}
|
||||||
|
|
||||||
parameter_groups[group_name]['params'].append(param)
|
parameter_groups[group_name]['params'].append(param)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -37,7 +37,6 @@ class FPN(nn.Module):
|
|||||||
Default: None.
|
Default: None.
|
||||||
upsample_cfg (dict): Config dict for interpolate layer.
|
upsample_cfg (dict): Config dict for interpolate layer.
|
||||||
Default: dict(mode='nearest').
|
Default: dict(mode='nearest').
|
||||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
||||||
Example:
|
Example:
|
||||||
>>> import torch
|
>>> import torch
|
||||||
>>> in_channels = [2, 3, 5, 7]
|
>>> in_channels = [2, 3, 5, 7]
|
||||||
@ -67,8 +66,6 @@ class FPN(nn.Module):
|
|||||||
norm_cfg=None,
|
norm_cfg=None,
|
||||||
act_cfg=None,
|
act_cfg=None,
|
||||||
upsample_cfg=dict(mode='nearest')):
|
upsample_cfg=dict(mode='nearest')):
|
||||||
# init_cfg=dict(
|
|
||||||
# type='Xavier', layer='Conv2d', distribution='uniform')):
|
|
||||||
super(FPN, self).__init__()
|
super(FPN, self).__init__()
|
||||||
assert isinstance(in_channels, list)
|
assert isinstance(in_channels, list)
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
@ -2,26 +2,12 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
from mmcv.runner import BaseModule
|
|
||||||
|
|
||||||
from easycv.models.builder import NECKS
|
from easycv.models.builder import NECKS
|
||||||
|
|
||||||
|
|
||||||
class Norm2d(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, embed_dim):
|
|
||||||
super().__init__()
|
|
||||||
self.ln = nn.LayerNorm(embed_dim, eps=1e-6)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.permute(0, 2, 3, 1)
|
|
||||||
x = self.ln(x)
|
|
||||||
x = x.permute(0, 3, 1, 2).contiguous()
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@NECKS.register_module()
|
@NECKS.register_module()
|
||||||
class SFP(BaseModule):
|
class SFP(nn.Module):
|
||||||
r"""Simple Feature Pyramid.
|
r"""Simple Feature Pyramid.
|
||||||
This is an implementation of paper `Exploring Plain Vision Transformer Backbones for Object Detection <https://arxiv.org/abs/2203.16527>`_.
|
This is an implementation of paper `Exploring Plain Vision Transformer Backbones for Object Detection <https://arxiv.org/abs/2203.16527>`_.
|
||||||
Args:
|
Args:
|
||||||
@ -32,25 +18,12 @@ class SFP(BaseModule):
|
|||||||
build the feature pyramid. Default: 0.
|
build the feature pyramid. Default: 0.
|
||||||
end_level (int): Index of the end input backbone level (exclusive) to
|
end_level (int): Index of the end input backbone level (exclusive) to
|
||||||
build the feature pyramid. Default: -1, which means the last level.
|
build the feature pyramid. Default: -1, which means the last level.
|
||||||
add_extra_convs (bool | str): If bool, it decides whether to add conv
|
|
||||||
layers on top of the original feature maps. Default to False.
|
|
||||||
If True, it is equivalent to `add_extra_convs='on_input'`.
|
|
||||||
If str, it specifies the source feature map of the extra convs.
|
|
||||||
Only the following options are allowed
|
|
||||||
- 'on_input': Last feat map of neck inputs (i.e. backbone feature).
|
|
||||||
- 'on_lateral': Last feature map after lateral convs.
|
|
||||||
- 'on_output': The last output feature map after fpn convs.
|
|
||||||
relu_before_extra_convs (bool): Whether to apply relu before the extra
|
|
||||||
conv. Default: False.
|
conv. Default: False.
|
||||||
no_norm_on_lateral (bool): Whether to apply norm on lateral.
|
|
||||||
Default: False.
|
Default: False.
|
||||||
conv_cfg (dict): Config dict for convolution layer. Default: None.
|
conv_cfg (dict): Config dict for convolution layer. Default: None.
|
||||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||||
act_cfg (str): Config dict for activation layer in ConvModule.
|
act_cfg (str): Config dict for activation layer in ConvModule.
|
||||||
Default: None.
|
Default: None.
|
||||||
upsample_cfg (dict): Config dict for interpolate layer.
|
|
||||||
Default: `dict(mode='nearest')`
|
|
||||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
||||||
Example:
|
Example:
|
||||||
>>> import torch
|
>>> import torch
|
||||||
>>> in_channels = [2, 3, 5, 7]
|
>>> in_channels = [2, 3, 5, 7]
|
||||||
@ -70,81 +43,53 @@ class SFP(BaseModule):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
scale_factors,
|
||||||
num_outs,
|
num_outs,
|
||||||
start_level=0,
|
|
||||||
end_level=-1,
|
|
||||||
add_extra_convs=False,
|
|
||||||
relu_before_extra_convs=False,
|
|
||||||
no_norm_on_lateral=False,
|
|
||||||
conv_cfg=None,
|
conv_cfg=None,
|
||||||
norm_cfg=None,
|
norm_cfg=None,
|
||||||
act_cfg=None,
|
act_cfg=None):
|
||||||
upsample_cfg=dict(mode='nearest'),
|
super(SFP, self).__init__()
|
||||||
init_cfg=[
|
dim = in_channels
|
||||||
dict(
|
|
||||||
type='Xavier',
|
|
||||||
layer=['Conv2d'],
|
|
||||||
distribution='uniform'),
|
|
||||||
dict(type='Constant', layer=['LayerNorm'], val=1, bias=0)
|
|
||||||
]):
|
|
||||||
super(SFP, self).__init__(init_cfg)
|
|
||||||
assert isinstance(in_channels, list)
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.num_ins = len(in_channels)
|
self.scale_factors = scale_factors
|
||||||
|
self.num_ins = len(scale_factors)
|
||||||
self.num_outs = num_outs
|
self.num_outs = num_outs
|
||||||
self.relu_before_extra_convs = relu_before_extra_convs
|
|
||||||
self.no_norm_on_lateral = no_norm_on_lateral
|
|
||||||
self.upsample_cfg = upsample_cfg.copy()
|
|
||||||
|
|
||||||
if end_level == -1:
|
self.stages = []
|
||||||
self.backbone_end_level = self.num_ins
|
for idx, scale in enumerate(scale_factors):
|
||||||
assert num_outs >= self.num_ins - start_level
|
out_dim = dim
|
||||||
|
if scale == 4.0:
|
||||||
|
layers = [
|
||||||
|
nn.ConvTranspose2d(dim, dim // 2, 2, stride=2, padding=0),
|
||||||
|
nn.GroupNorm(1, dim // 2, eps=1e-6),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
dim // 2, dim // 4, 2, stride=2, padding=0)
|
||||||
|
]
|
||||||
|
out_dim = dim // 4
|
||||||
|
elif scale == 2.0:
|
||||||
|
layers = [
|
||||||
|
nn.ConvTranspose2d(dim, dim // 2, 2, stride=2, padding=0)
|
||||||
|
]
|
||||||
|
out_dim = dim // 2
|
||||||
|
elif scale == 1.0:
|
||||||
|
layers = []
|
||||||
|
elif scale == 0.5:
|
||||||
|
layers = [nn.MaxPool2d(kernel_size=2, stride=2, padding=0)]
|
||||||
else:
|
else:
|
||||||
# if end_level < inputs, no extra level is allowed
|
raise NotImplementedError(
|
||||||
self.backbone_end_level = end_level
|
f'scale_factor={scale} is not supported yet.')
|
||||||
assert end_level <= len(in_channels)
|
|
||||||
assert num_outs == end_level - start_level
|
|
||||||
self.start_level = start_level
|
|
||||||
self.end_level = end_level
|
|
||||||
self.add_extra_convs = add_extra_convs
|
|
||||||
assert isinstance(add_extra_convs, (str, bool))
|
|
||||||
if isinstance(add_extra_convs, str):
|
|
||||||
# Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
|
|
||||||
assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
|
|
||||||
elif add_extra_convs: # True
|
|
||||||
self.add_extra_convs = 'on_input'
|
|
||||||
|
|
||||||
self.top_downs = nn.ModuleList()
|
layers.extend([
|
||||||
self.lateral_convs = nn.ModuleList()
|
ConvModule(
|
||||||
self.fpn_convs = nn.ModuleList()
|
out_dim,
|
||||||
|
|
||||||
for i in range(self.start_level, self.backbone_end_level):
|
|
||||||
if i == 0:
|
|
||||||
top_down = nn.Sequential(
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
in_channels[i], in_channels[i], 2, stride=2,
|
|
||||||
padding=0), Norm2d(in_channels[i]), nn.GELU(),
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
in_channels[i], in_channels[i], 2, stride=2,
|
|
||||||
padding=0))
|
|
||||||
elif i == 1:
|
|
||||||
top_down = nn.ConvTranspose2d(
|
|
||||||
in_channels[i], in_channels[i], 2, stride=2, padding=0)
|
|
||||||
elif i == 2:
|
|
||||||
top_down = nn.Identity()
|
|
||||||
elif i == 3:
|
|
||||||
top_down = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
|
||||||
|
|
||||||
l_conv = ConvModule(
|
|
||||||
in_channels[i],
|
|
||||||
out_channels,
|
out_channels,
|
||||||
1,
|
1,
|
||||||
conv_cfg=conv_cfg,
|
conv_cfg=conv_cfg,
|
||||||
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
|
norm_cfg=norm_cfg,
|
||||||
act_cfg=act_cfg,
|
act_cfg=act_cfg,
|
||||||
inplace=False)
|
inplace=False),
|
||||||
fpn_conv = ConvModule(
|
ConvModule(
|
||||||
out_channels,
|
out_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
3,
|
3,
|
||||||
@ -153,75 +98,28 @@ class SFP(BaseModule):
|
|||||||
norm_cfg=norm_cfg,
|
norm_cfg=norm_cfg,
|
||||||
act_cfg=act_cfg,
|
act_cfg=act_cfg,
|
||||||
inplace=False)
|
inplace=False)
|
||||||
|
])
|
||||||
|
|
||||||
self.top_downs.append(top_down)
|
layers = nn.Sequential(*layers)
|
||||||
self.lateral_convs.append(l_conv)
|
self.add_module(f'sfp_{idx}', layers)
|
||||||
self.fpn_convs.append(fpn_conv)
|
self.stages.append(layers)
|
||||||
|
|
||||||
# add extra conv layers (e.g., RetinaNet)
|
def init_weights(self):
|
||||||
extra_levels = num_outs - self.backbone_end_level + self.start_level
|
pass
|
||||||
if self.add_extra_convs and extra_levels >= 1:
|
|
||||||
for i in range(extra_levels):
|
|
||||||
if i == 0 and self.add_extra_convs == 'on_input':
|
|
||||||
in_channels = self.in_channels[self.backbone_end_level - 1]
|
|
||||||
else:
|
|
||||||
in_channels = out_channels
|
|
||||||
extra_fpn_conv = ConvModule(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
3,
|
|
||||||
stride=2,
|
|
||||||
padding=1,
|
|
||||||
conv_cfg=conv_cfg,
|
|
||||||
norm_cfg=norm_cfg,
|
|
||||||
act_cfg=act_cfg,
|
|
||||||
inplace=False)
|
|
||||||
self.fpn_convs.append(extra_fpn_conv)
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
"""Forward function."""
|
"""Forward function."""
|
||||||
assert len(inputs) == 1
|
features = inputs[0]
|
||||||
|
outs = []
|
||||||
|
|
||||||
# build top-down path
|
# part 1: build simple feature pyramid
|
||||||
features = [
|
for stage in self.stages:
|
||||||
top_down(inputs[0]) for _, top_down in enumerate(self.top_downs)
|
outs.append(stage(features))
|
||||||
]
|
|
||||||
assert len(features) == len(self.in_channels)
|
|
||||||
|
|
||||||
# build laterals
|
|
||||||
laterals = [
|
|
||||||
lateral_conv(features[i + self.start_level])
|
|
||||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
|
||||||
]
|
|
||||||
|
|
||||||
used_backbone_levels = len(laterals)
|
|
||||||
|
|
||||||
# build outputs
|
|
||||||
# part 1: from original levels
|
|
||||||
outs = [
|
|
||||||
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
|
|
||||||
]
|
|
||||||
# part 2: add extra levels
|
# part 2: add extra levels
|
||||||
if self.num_outs > len(outs):
|
if self.num_outs > self.num_ins:
|
||||||
# use max pool to get more levels on top of outputs
|
# use max pool to get more levels on top of outputs
|
||||||
# (e.g., Faster R-CNN, Mask R-CNN)
|
# (e.g., Faster R-CNN, Mask R-CNN)
|
||||||
if not self.add_extra_convs:
|
for i in range(self.num_outs - self.num_ins):
|
||||||
for i in range(self.num_outs - used_backbone_levels):
|
|
||||||
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
|
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
|
||||||
# add conv layers on top of original feature maps (RetinaNet)
|
|
||||||
else:
|
|
||||||
if self.add_extra_convs == 'on_input':
|
|
||||||
extra_source = inputs[self.backbone_end_level - 1]
|
|
||||||
elif self.add_extra_convs == 'on_lateral':
|
|
||||||
extra_source = laterals[-1]
|
|
||||||
elif self.add_extra_convs == 'on_output':
|
|
||||||
extra_source = outs[-1]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
outs.append(self.fpn_convs[used_backbone_levels](extra_source))
|
|
||||||
for i in range(used_backbone_levels + 1, self.num_outs):
|
|
||||||
if self.relu_before_extra_convs:
|
|
||||||
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
|
|
||||||
else:
|
|
||||||
outs.append(self.fpn_convs[i](outs[-1]))
|
|
||||||
return tuple(outs)
|
return tuple(outs)
|
||||||
|
@ -253,11 +253,11 @@ class DetrPredictor(PredictorInterface):
|
|||||||
img,
|
img,
|
||||||
bboxes,
|
bboxes,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
colors='green',
|
colors='cyan',
|
||||||
text_color='white',
|
text_color='cyan',
|
||||||
font_size=20,
|
font_size=18,
|
||||||
thickness=1,
|
thickness=2,
|
||||||
font_scale=0.5,
|
font_scale=0.0,
|
||||||
show=show,
|
show=show,
|
||||||
out_file=out_file)
|
out_file=out_file)
|
||||||
|
|
||||||
|
@ -14,18 +14,27 @@ class ViTDetTest(unittest.TestCase):
|
|||||||
def test_vitdet(self):
|
def test_vitdet(self):
|
||||||
model = ViTDet(
|
model = ViTDet(
|
||||||
img_size=1024,
|
img_size=1024,
|
||||||
|
patch_size=16,
|
||||||
embed_dim=768,
|
embed_dim=768,
|
||||||
depth=12,
|
depth=12,
|
||||||
num_heads=12,
|
num_heads=12,
|
||||||
|
drop_path_rate=0.1,
|
||||||
|
window_size=14,
|
||||||
mlp_ratio=4,
|
mlp_ratio=4,
|
||||||
qkv_bias=True,
|
qkv_bias=True,
|
||||||
qk_scale=None,
|
window_block_indexes=[
|
||||||
drop_rate=0.,
|
# 2, 5, 8 11 for global attention
|
||||||
attn_drop_rate=0.,
|
0,
|
||||||
drop_path_rate=0.1,
|
1,
|
||||||
use_abs_pos_emb=True,
|
3,
|
||||||
aggregation='attn',
|
4,
|
||||||
)
|
6,
|
||||||
|
7,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
],
|
||||||
|
residual_block_indexes=[],
|
||||||
|
use_rel_pos=True)
|
||||||
|
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -155,7 +155,7 @@ class DetectorTest(unittest.TestCase):
|
|||||||
decimal=1)
|
decimal=1)
|
||||||
|
|
||||||
def test_vitdet_detector(self):
|
def test_vitdet_detector(self):
|
||||||
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/vitdet_maskrcnn_export.pth'
|
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100_export.pth'
|
||||||
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
||||||
out_file = './result.jpg'
|
out_file = './result.jpg'
|
||||||
vitdet = DetrPredictor(model_path)
|
vitdet = DetrPredictor(model_path)
|
||||||
@ -167,63 +167,170 @@ class DetectorTest(unittest.TestCase):
|
|||||||
self.assertIn('detection_classes', output)
|
self.assertIn('detection_classes', output)
|
||||||
self.assertIn('detection_masks', output)
|
self.assertIn('detection_masks', output)
|
||||||
self.assertIn('img_metas', output)
|
self.assertIn('img_metas', output)
|
||||||
self.assertEqual(len(output['detection_boxes'][0]), 30)
|
self.assertEqual(len(output['detection_boxes'][0]), 33)
|
||||||
self.assertEqual(len(output['detection_scores'][0]), 30)
|
self.assertEqual(len(output['detection_scores'][0]), 33)
|
||||||
self.assertEqual(len(output['detection_classes'][0]), 30)
|
self.assertEqual(len(output['detection_classes'][0]), 33)
|
||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
output['detection_classes'][0].tolist(),
|
output['detection_classes'][0].tolist(),
|
||||||
np.array([
|
np.array([
|
||||||
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||||
2, 2, 2, 7, 7, 13, 13, 13, 56
|
2, 2, 2, 2, 2, 2, 7, 7, 13, 13, 13, 56
|
||||||
],
|
],
|
||||||
dtype=np.int32).tolist())
|
dtype=np.int32).tolist())
|
||||||
|
|
||||||
assert_array_almost_equal(
|
assert_array_almost_equal(
|
||||||
output['detection_scores'][0],
|
output['detection_scores'][0],
|
||||||
np.array([
|
np.array([
|
||||||
0.99791867, 0.99665856, 0.99480623, 0.99060905, 0.9882515,
|
0.9975854158401489, 0.9965696334838867, 0.9922919869422913,
|
||||||
0.98319584, 0.9738879, 0.97290784, 0.9514897, 0.95104814,
|
0.9833580851554871, 0.983080267906189, 0.970454752445221,
|
||||||
0.9321701, 0.86165, 0.8228847, 0.7623552, 0.76129806,
|
0.9701289534568787, 0.9649872183799744, 0.9642795324325562,
|
||||||
0.6050861, 0.44348577, 0.3452973, 0.2895671, 0.22109479,
|
0.9642238020896912, 0.9529680609703064, 0.9403366446495056,
|
||||||
0.21265312, 0.17855245, 0.1205352, 0.08981906, 0.10596471,
|
0.9391788244247437, 0.8941807150840759, 0.8178097009658813,
|
||||||
0.05854294, 0.99749386, 0.9472857, 0.5945908, 0.09855112
|
0.8013413548469543, 0.6677654385566711, 0.3952914774417877,
|
||||||
|
0.33463895320892334, 0.32501447200775146, 0.27323535084724426,
|
||||||
|
0.20197080075740814, 0.15607696771621704, 0.1068163588643074,
|
||||||
|
0.10183875262737274, 0.09735643863677979, 0.06559795141220093,
|
||||||
|
0.08890066295862198, 0.076363705098629, 0.9954648613929749,
|
||||||
|
0.9212945699691772, 0.5224372148513794, 0.20555885136127472
|
||||||
],
|
],
|
||||||
dtype=np.float32),
|
dtype=np.float32),
|
||||||
decimal=2)
|
decimal=2)
|
||||||
|
|
||||||
assert_array_almost_equal(
|
assert_array_almost_equal(
|
||||||
output['detection_boxes'][0],
|
output['detection_boxes'][0],
|
||||||
np.array([[294.7058, 117.29371, 378.83713, 149.99928],
|
np.array([[
|
||||||
[609.05444, 112.526474, 633.2971, 136.35175],
|
294.22674560546875, 116.6078109741211, 379.4328918457031,
|
||||||
[481.4165, 110.987335, 522.5531, 130.01529],
|
150.14097595214844
|
||||||
[167.68184, 109.89049, 215.49057, 139.86987],
|
],
|
||||||
[374.75082, 110.68697, 433.10028, 136.23654],
|
[
|
||||||
[189.54971, 110.09322, 297.6167, 155.77412],
|
482.6017761230469, 110.75955963134766,
|
||||||
[266.5185, 105.37718, 326.54385, 127.916374],
|
522.8798828125, 129.71286010742188
|
||||||
[556.30225, 110.43166, 592.8248, 128.03764],
|
],
|
||||||
[432.49252, 105.086464, 484.0512, 132.272],
|
[
|
||||||
[0., 110.566444, 62.01249, 146.44017],
|
167.06460571289062, 109.95974731445312,
|
||||||
[591.74664, 110.43527, 619.73816, 126.68549],
|
212.83975219726562, 140.16102600097656
|
||||||
[99.126854, 90.947975, 118.46699, 101.11096],
|
],
|
||||||
[59.895264, 94.110054, 85.60521, 106.67633],
|
[
|
||||||
[142.95819, 96.61966, 165.96964, 104.95929],
|
609.2930908203125, 113.13909149169922,
|
||||||
[83.062515, 89.802605, 99.1546, 98.69074],
|
637.3115844726562, 136.4690704345703
|
||||||
[226.28802, 98.32568, 249.06772, 108.86408],
|
],
|
||||||
[136.67789, 94.75706, 154.62924, 104.289536],
|
[
|
||||||
[170.42459, 98.458694, 183.16309, 106.203156],
|
191.185791015625, 111.1408920288086, 301.31689453125,
|
||||||
[67.56731, 89.68286, 82.62955, 98.35645],
|
155.7731170654297
|
||||||
[222.80092, 97.828445, 239.02655, 108.29377],
|
],
|
||||||
[134.34427, 92.31653, 149.19615, 102.97457],
|
[
|
||||||
[613.5186, 102.27066, 636.0434, 112.813644],
|
431.2244873046875, 106.19962310791016,
|
||||||
[607.4787, 110.87984, 630.1123, 127.65646],
|
483.860595703125, 132.21627807617188
|
||||||
[135.13664, 90.989876, 155.67192, 100.18036],
|
],
|
||||||
[431.61505, 105.43844, 484.36508, 132.50078],
|
[
|
||||||
[189.92722, 110.38832, 297.74353, 155.95557],
|
267.48358154296875, 105.5920639038086,
|
||||||
[220.67035, 177.13489, 455.32092, 380.45712],
|
325.2832336425781, 127.11176300048828
|
||||||
[372.76584, 134.33807, 432.44357, 188.51534],
|
],
|
||||||
[50.403812, 110.543495, 70.4368, 119.65186],
|
[
|
||||||
[373.50272, 134.27258, 432.18475, 187.81824]]),
|
591.2138671875, 110.29329681396484,
|
||||||
|
619.8524169921875, 126.1990966796875
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0.0, 110.7026596069336, 61.487945556640625,
|
||||||
|
146.33018493652344
|
||||||
|
],
|
||||||
|
[
|
||||||
|
555.9155883789062, 110.03486633300781,
|
||||||
|
591.7050170898438, 127.06097412109375
|
||||||
|
],
|
||||||
|
[
|
||||||
|
60.24559783935547, 94.12760162353516,
|
||||||
|
85.63741302490234, 106.66705322265625
|
||||||
|
],
|
||||||
|
[
|
||||||
|
99.02665710449219, 90.53657531738281,
|
||||||
|
118.83953094482422, 101.18717956542969
|
||||||
|
],
|
||||||
|
[
|
||||||
|
396.30438232421875, 111.59194946289062,
|
||||||
|
431.559814453125, 133.96914672851562
|
||||||
|
],
|
||||||
|
[
|
||||||
|
83.81543731689453, 89.65665435791016,
|
||||||
|
99.9166259765625, 98.25627899169922
|
||||||
|
],
|
||||||
|
[
|
||||||
|
139.29647827148438, 96.68000793457031,
|
||||||
|
165.22410583496094, 105.60000610351562
|
||||||
|
],
|
||||||
|
[
|
||||||
|
67.27152252197266, 89.42798614501953,
|
||||||
|
83.25617980957031, 98.0460205078125
|
||||||
|
],
|
||||||
|
[
|
||||||
|
223.74176025390625, 98.68321990966797,
|
||||||
|
250.42506408691406, 109.32588958740234
|
||||||
|
],
|
||||||
|
[
|
||||||
|
136.7582244873047, 96.51412963867188,
|
||||||
|
152.51190185546875, 104.73160552978516
|
||||||
|
],
|
||||||
|
[
|
||||||
|
221.71812438964844, 97.86445617675781,
|
||||||
|
238.9705810546875, 106.96803283691406
|
||||||
|
],
|
||||||
|
[
|
||||||
|
135.06964111328125, 91.80916595458984, 155.24609375,
|
||||||
|
102.20686340332031
|
||||||
|
],
|
||||||
|
[
|
||||||
|
169.11180114746094, 97.53628540039062,
|
||||||
|
182.88504028320312, 105.95404815673828
|
||||||
|
],
|
||||||
|
[
|
||||||
|
133.8811798095703, 91.00375366210938,
|
||||||
|
145.35507202148438, 102.3780288696289
|
||||||
|
],
|
||||||
|
[
|
||||||
|
614.2507934570312, 102.19828796386719,
|
||||||
|
636.5692749023438, 112.59198760986328
|
||||||
|
],
|
||||||
|
[
|
||||||
|
35.94759750366211, 91.7213363647461,
|
||||||
|
70.38274383544922, 117.19855499267578
|
||||||
|
],
|
||||||
|
[
|
||||||
|
554.6401977539062, 115.18976593017578,
|
||||||
|
562.0255737304688, 127.4429931640625
|
||||||
|
],
|
||||||
|
[
|
||||||
|
39.07550811767578, 92.73261260986328,
|
||||||
|
85.36636352539062, 106.73953247070312
|
||||||
|
],
|
||||||
|
[
|
||||||
|
200.85513305664062, 93.00469970703125,
|
||||||
|
219.73086547851562, 107.99642181396484
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0.0, 111.18904876708984, 61.7393684387207,
|
||||||
|
146.72547912597656
|
||||||
|
],
|
||||||
|
[
|
||||||
|
191.88568115234375, 111.09577178955078,
|
||||||
|
299.4097900390625, 155.14639282226562
|
||||||
|
],
|
||||||
|
[
|
||||||
|
221.06834411621094, 176.6427001953125,
|
||||||
|
458.3475341796875, 378.89300537109375
|
||||||
|
],
|
||||||
|
[
|
||||||
|
372.7131652832031, 135.51429748535156,
|
||||||
|
433.2494201660156, 188.0106658935547
|
||||||
|
],
|
||||||
|
[
|
||||||
|
52.19819641113281, 110.3646011352539,
|
||||||
|
70.95110321044922, 120.10567474365234
|
||||||
|
],
|
||||||
|
[
|
||||||
|
376.1671447753906, 133.6930694580078,
|
||||||
|
432.2721862792969, 187.99481201171875
|
||||||
|
]]),
|
||||||
decimal=1)
|
decimal=1)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user