Refactor ViTDet backbone and simple feature pyramid (#177)

1. The vitdet backbone implemented by d2 is about 20% faster than the vitdet backbone originally reproduced by easycv.
2. 50.57 -> 50.65
This commit is contained in:
tuofeilun 2022-09-16 11:03:53 +08:00 committed by GitHub
parent 1d5edf6d78
commit 9f01a37ad4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 925 additions and 995 deletions

View File

@ -101,13 +101,15 @@ val_dataset = dict(
pipeline=test_pipeline) pipeline=test_pipeline)
data = dict( data = dict(
imgs_per_gpu=1, workers_per_gpu=2, train=train_dataset, val=val_dataset) imgs_per_gpu=4, workers_per_gpu=2, train=train_dataset, val=val_dataset
) # 64(total batch size) = 4 (batch size/per gpu) x 8 (gpu num) x 2(node)
# evaluation # evaluation
eval_config = dict(interval=1, gpu_collect=False) eval_config = dict(initial=False, interval=1, gpu_collect=False)
eval_pipelines = [ eval_pipelines = [
dict( dict(
mode='test', mode='test',
# dist_eval=True,
evaluators=[ evaluators=[
dict(type='CocoDetectionEvaluator', classes=CLASSES), dict(type='CocoDetectionEvaluator', classes=CLASSES),
], ],

View File

@ -101,13 +101,15 @@ val_dataset = dict(
pipeline=test_pipeline) pipeline=test_pipeline)
data = dict( data = dict(
imgs_per_gpu=1, workers_per_gpu=2, train=train_dataset, val=val_dataset) imgs_per_gpu=4, workers_per_gpu=2, train=train_dataset, val=val_dataset
) # 64(total batch size) = 4 (batch size/per gpu) x 8 (gpu num) x 2(node)
# evaluation # evaluation
eval_config = dict(interval=1, gpu_collect=False) eval_config = dict(initial=False, interval=1, gpu_collect=False)
eval_pipelines = [ eval_pipelines = [
dict( dict(
mode='test', mode='test',
# dist_eval=True,
evaluators=[ evaluators=[
dict(type='CocoDetectionEvaluator', classes=CLASSES), dict(type='CocoDetectionEvaluator', classes=CLASSES),
dict(type='CocoMaskEvaluator', classes=CLASSES) dict(type='CocoMaskEvaluator', classes=CLASSES)

View File

@ -1,3 +0,0 @@
_base_ = './vitdet_100e.py'
model = dict(backbone=dict(aggregation='basicblock'))

View File

@ -1,3 +0,0 @@
_base_ = './vitdet_100e.py'
model = dict(backbone=dict(aggregation='bottleneck'))

View File

@ -0,0 +1,231 @@
# model settings
norm_cfg = dict(type='GN', num_groups=1, eps=1e-6, requires_grad=True)
pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth'
model = dict(
type='CascadeRCNN',
pretrained=pretrained,
backbone=dict(
type='ViTDet',
img_size=1024,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
drop_path_rate=0.1,
window_size=14,
mlp_ratio=4,
qkv_bias=True,
window_block_indexes=[
# 2, 5, 8 11 for global attention
0,
1,
3,
4,
6,
7,
9,
10,
],
residual_block_indexes=[],
use_rel_pos=True),
neck=dict(
type='SFP',
in_channels=768,
out_channels=256,
scale_factors=(4.0, 2.0, 1.0, 0.5),
norm_cfg=norm_cfg,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
num_convs=2,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
roi_head=dict(
type='CascadeRoIHead',
num_stages=3,
stage_loss_weights=[1, 0.5, 0.25],
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=[
dict(
type='Shared4Conv1FCBBoxHead',
conv_out_channels=256,
norm_cfg=norm_cfg,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='Shared4Conv1FCBBoxHead',
conv_out_channels=256,
norm_cfg=norm_cfg,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.05, 0.05, 0.1, 0.1]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='Shared4Conv1FCBBoxHead',
conv_out_channels=256,
norm_cfg=norm_cfg,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.033, 0.033, 0.067, 0.067]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
],
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='FCNMaskHead',
norm_cfg=norm_cfg,
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=80,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=2000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=[
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.6,
neg_iou_thr=0.6,
min_pos_iou=0.6,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.7,
min_pos_iou=0.7,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
mask_size=28,
pos_weight=-1,
debug=False)
]),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100,
mask_thr_binary=0.5)))
mmlab_modules = [
dict(type='mmdet', name='CascadeRCNN', module='model'),
dict(type='mmdet', name='RPNHead', module='head'),
dict(type='mmdet', name='CascadeRoIHead', module='head'),
]

View File

@ -0,0 +1,4 @@
_base_ = [
'./vitdet_cascade_mask_rcnn.py', './lsj_coco_instance.py',
'./vitdet_schedule_100e.py'
]

View File

@ -1,6 +1,6 @@
# model settings # model settings
norm_cfg = dict(type='GN', num_groups=1, requires_grad=True) norm_cfg = dict(type='GN', num_groups=1, eps=1e-6, requires_grad=True)
pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth' pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth'
model = dict( model = dict(
@ -9,22 +9,32 @@ model = dict(
backbone=dict( backbone=dict(
type='ViTDet', type='ViTDet',
img_size=1024, img_size=1024,
patch_size=16,
embed_dim=768, embed_dim=768,
depth=12, depth=12,
num_heads=12, num_heads=12,
drop_path_rate=0.1,
window_size=14,
mlp_ratio=4, mlp_ratio=4,
qkv_bias=True, qkv_bias=True,
qk_scale=None, window_block_indexes=[
drop_rate=0., # 2, 5, 8 11 for global attention
attn_drop_rate=0., 0,
drop_path_rate=0.1, 1,
use_abs_pos_emb=True, 3,
aggregation='attn', 4,
), 6,
7,
9,
10,
],
residual_block_indexes=[],
use_rel_pos=True),
neck=dict( neck=dict(
type='SFP', type='SFP',
in_channels=[768, 768, 768, 768], in_channels=768,
out_channels=256, out_channels=256,
scale_factors=(4.0, 2.0, 1.0, 0.5),
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
num_outs=5), num_outs=5),
rpn_head=dict( rpn_head=dict(
@ -32,7 +42,6 @@ model = dict(
in_channels=256, in_channels=256,
feat_channels=256, feat_channels=256,
num_convs=2, num_convs=2,
norm_cfg=norm_cfg,
anchor_generator=dict( anchor_generator=dict(
type='AnchorGenerator', type='AnchorGenerator',
scales=[8], scales=[8],
@ -98,7 +107,7 @@ model = dict(
pos_iou_thr=0.5, pos_iou_thr=0.5,
neg_iou_thr=0.5, neg_iou_thr=0.5,
min_pos_iou=0.5, min_pos_iou=0.5,
match_low_quality=True, match_low_quality=False,
ignore_iof_thr=-1), ignore_iof_thr=-1),
sampler=dict( sampler=dict(
type='RandomSampler', type='RandomSampler',

View File

@ -1,4 +1,4 @@
_base_ = [ _base_ = [
'./vitdet_faster_rcnn.py', './lsj_coco_detection.py', './vitdet_faster_rcnn.py', './lsj_coco_instance.py',
'./vitdet_schedule_100e.py' './vitdet_schedule_100e.py'
] ]

View File

@ -1,6 +1,6 @@
# model settings # model settings
norm_cfg = dict(type='GN', num_groups=1, requires_grad=True) norm_cfg = dict(type='GN', num_groups=1, eps=1e-6, requires_grad=True)
pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth' pretrained = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/warpper_mae_vit-base-p16-1600e.pth'
model = dict( model = dict(
@ -9,22 +9,32 @@ model = dict(
backbone=dict( backbone=dict(
type='ViTDet', type='ViTDet',
img_size=1024, img_size=1024,
patch_size=16,
embed_dim=768, embed_dim=768,
depth=12, depth=12,
num_heads=12, num_heads=12,
drop_path_rate=0.1,
window_size=14,
mlp_ratio=4, mlp_ratio=4,
qkv_bias=True, qkv_bias=True,
qk_scale=None, window_block_indexes=[
drop_rate=0., # 2, 5, 8 11 for global attention
attn_drop_rate=0., 0,
drop_path_rate=0.1, 1,
use_abs_pos_emb=True, 3,
aggregation='attn', 4,
), 6,
7,
9,
10,
],
residual_block_indexes=[],
use_rel_pos=True),
neck=dict( neck=dict(
type='SFP', type='SFP',
in_channels=[768, 768, 768, 768], in_channels=768,
out_channels=256, out_channels=256,
scale_factors=(4.0, 2.0, 1.0, 0.5),
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
num_outs=5), num_outs=5),
rpn_head=dict( rpn_head=dict(
@ -32,7 +42,6 @@ model = dict(
in_channels=256, in_channels=256,
feat_channels=256, feat_channels=256,
num_convs=2, num_convs=2,
norm_cfg=norm_cfg,
anchor_generator=dict( anchor_generator=dict(
type='AnchorGenerator', type='AnchorGenerator',
scales=[8], scales=[8],
@ -112,7 +121,7 @@ model = dict(
pos_iou_thr=0.5, pos_iou_thr=0.5,
neg_iou_thr=0.5, neg_iou_thr=0.5,
min_pos_iou=0.5, min_pos_iou=0.5,
match_low_quality=True, match_low_quality=False,
ignore_iof_thr=-1), ignore_iof_thr=-1),
sampler=dict( sampler=dict(
type='RandomSampler', type='RandomSampler',

View File

@ -1,26 +1,29 @@
_base_ = 'configs/base.py' _base_ = 'configs/base.py'
log_config = dict(
interval=200,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
checkpoint_config = dict(interval=10) checkpoint_config = dict(interval=10)
# optimizer # optimizer
paramwise_options = {
'norm': dict(weight_decay=0.),
'bias': dict(weight_decay=0.),
'pos_embed': dict(weight_decay=0.),
'cls_token': dict(weight_decay=0.)
}
optimizer = dict( optimizer = dict(
type='AdamW', type='AdamW',
lr=1e-4, lr=1e-4,
betas=(0.9, 0.999), betas=(0.9, 0.999),
weight_decay=0.1, weight_decay=0.1,
paramwise_options=paramwise_options) constructor='LayerDecayOptimizerConstructor',
optimizer_config = dict(grad_clip=None, loss_scale=512.) paramwise_options=dict(num_layers=12, layer_decay_rate=0.7))
optimizer_config = dict(grad_clip=None)
# learning policy # learning policy
lr_config = dict( lr_config = dict(
policy='step', policy='step',
warmup='linear', warmup='linear',
warmup_iters=250, warmup_iters=250,
warmup_ratio=0.067, warmup_ratio=0.001,
step=[88, 96]) step=[88, 96])
total_epochs = 100 total_epochs = 100

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:ee64c0caef841c61c7e6344b7fe2c07a38fba07a8de81ff38c0686c641e0a283 oid sha256:c696a58a2963b5ac47317751f04ff45bfed4723f2f70bacf91eac711f9710e54
size 190356 size 189432

View File

@ -22,7 +22,7 @@ Pretrained on COCO2017 dataset. (The result has been optimized with PAI-Blade, a
| Algorithm | Config | Params<br/>(backbone/total) | inference time(V100)<br/>(ms/img) | bbox_mAP<sup>val<br/><sub>0.5:0.95</sub> | mask_mAP<sup>val<br/><sub>0.5:0.95</sub> | Download | | Algorithm | Config | Params<br/>(backbone/total) | inference time(V100)<br/>(ms/img) | bbox_mAP<sup>val<br/><sub>0.5:0.95</sub> | mask_mAP<sup>val<br/><sub>0.5:0.95</sub> | Download |
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | | ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| ViTDet_MaskRCNN | [vitdet_maskrcnn](https://github.com/alibaba/EasyCV/tree/master/configs/detection/vitdet/vitdet_100e.py) | 88M/118M | 163ms | 50.57 | 44.96 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/vitdet_maskrcnn.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/vitdet_maskrcnn.log.json) | | ViTDet_MaskRCNN | [vitdet_maskrcnn](https://github.com/alibaba/EasyCV/tree/master/configs/detection/vitdet/vitdet_mask_rcnn_100e.py) | 86M/111M | 138ms | 50.65 | 45.41 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/20220901_135827.log.json) |
## FCOS ## FCOS

View File

@ -1,5 +1,3 @@
# Reference from https://github.com/ViTAE-Transformer/ViTDet/blob/main/mmcv_custom/layer_decay_optimizer_constructor.py
import json import json
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
@ -7,23 +5,32 @@ from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
from .builder import OPTIMIZER_BUILDERS from .builder import OPTIMIZER_BUILDERS
def get_num_layer_for_vit(var_name, num_max_layer, layer_sep=None): def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
if var_name in ('backbone.cls_token', 'backbone.mask_token', """
'backbone.pos_embed'): Calculate lr decay rate for different ViT blocks.
return 0 Reference from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py
elif var_name.startswith('backbone.patch_embed'): Args:
return 0 name (string): parameter name.
elif var_name.startswith('backbone.blocks'): lr_decay_rate (float): base lr decay rate.
layer_id = int(var_name.split('.')[2]) num_layers (int): number of ViT blocks.
return layer_id + 1 Returns:
else: lr decay rate for the given parameter.
return num_max_layer - 1 """
layer_id = num_layers + 1
if '.pos_embed' in name or '.patch_embed' in name:
layer_id = 0
elif '.blocks.' in name and '.residual.' not in name:
layer_id = int(name[name.find('.blocks.'):].split('.')[2]) + 1
scale = lr_decay_rate**(num_layers + 1 - layer_id)
return layer_id, scale
@OPTIMIZER_BUILDERS.register_module() @OPTIMIZER_BUILDERS.register_module()
class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor): class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
def add_params(self, params, module, prefix='', is_dcn_module=None): def add_params(self, params, module):
"""Add all parameters of module to the params list. """Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg. groups, with specific rules defined by paramwise_cfg.
@ -31,54 +38,41 @@ class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
params (list[dict]): A list of param groups, it will be modified params (list[dict]): A list of param groups, it will be modified
in place. in place.
module (nn.Module): The module to be added. module (nn.Module): The module to be added.
prefix (str): The prefix of the module
is_dcn_module (int|float|None): If the current module is a Reference from https://github.com/ViTAE-Transformer/ViTDet/blob/main/mmcv_custom/layer_decay_optimizer_constructor.py
submodule of DCN, `is_dcn_module` will be passed to Note: Currently, this optimizer constructor is built for ViTDet.
control conv_offset layer's learning rate. Defaults to None.
""" """
# get param-wise options
parameter_groups = {} parameter_groups = {}
print(self.paramwise_cfg) print(self.paramwise_cfg)
num_layers = self.paramwise_cfg.get('num_layers') + 2 lr_decay_rate = self.paramwise_cfg.get('layer_decay_rate')
layer_sep = self.paramwise_cfg.get('layer_sep', None) num_layers = self.paramwise_cfg.get('num_layers')
layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate')
print('Build LayerDecayOptimizerConstructor %f - %d' % print('Build LayerDecayOptimizerConstructor %f - %d' %
(layer_decay_rate, num_layers)) (lr_decay_rate, num_layers))
lr = self.base_lr
weight_decay = self.base_wd weight_decay = self.base_wd
custom_keys = self.paramwise_cfg.get('custom_keys', {})
# first sort with alphabet order and then sort with reversed len of str
sorted_keys = sorted(custom_keys.keys())
for name, param in module.named_parameters(): for name, param in module.named_parameters():
if not param.requires_grad: if not param.requires_grad:
continue # frozen weights continue # frozen weights
if len(param.shape) == 1 or name.endswith('.bias') or ( if 'backbone' in name and ('.norm' in name or '.pos_embed' in name
'pos_embed' in name) or ('cls_token' or '.gn.' in name or '.ln.' in name):
in name) or ('rel_pos_' in name):
group_name = 'no_decay' group_name = 'no_decay'
this_weight_decay = 0. this_weight_decay = 0.
else: else:
group_name = 'decay' group_name = 'decay'
this_weight_decay = weight_decay this_weight_decay = weight_decay
layer_id = get_num_layer_for_vit(name, num_layers, layer_sep) if name.startswith('backbone'):
layer_id, scale = get_vit_lr_decay_rate(
name, lr_decay_rate=lr_decay_rate, num_layers=num_layers)
else:
layer_id, scale = -1, 1
group_name = 'layer_%d_%s' % (layer_id, group_name) group_name = 'layer_%d_%s' % (layer_id, group_name)
# if the parameter match one of the custom keys, ignore other rules
this_lr_multi = 1.
for key in sorted_keys:
if key in f'{name}':
lr_mult = custom_keys[key].get('lr_mult', 1.)
this_lr_multi = lr_mult
group_name = '%s_%s' % (group_name, key)
break
if group_name not in parameter_groups: if group_name not in parameter_groups:
scale = layer_decay_rate**(num_layers - layer_id - 1)
parameter_groups[group_name] = { parameter_groups[group_name] = {
'weight_decay': this_weight_decay, 'weight_decay': this_weight_decay,
@ -86,7 +80,7 @@ class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
'param_names': [], 'param_names': [],
'lr_scale': scale, 'lr_scale': scale,
'group_name': group_name, 'group_name': group_name,
'lr': scale * self.base_lr * this_lr_multi, 'lr': scale * lr,
} }
parameter_groups[group_name]['params'].append(param) parameter_groups[group_name]['params'].append(param)

File diff suppressed because it is too large Load Diff

View File

@ -37,7 +37,6 @@ class FPN(nn.Module):
Default: None. Default: None.
upsample_cfg (dict): Config dict for interpolate layer. upsample_cfg (dict): Config dict for interpolate layer.
Default: dict(mode='nearest'). Default: dict(mode='nearest').
init_cfg (dict or list[dict], optional): Initialization config dict.
Example: Example:
>>> import torch >>> import torch
>>> in_channels = [2, 3, 5, 7] >>> in_channels = [2, 3, 5, 7]
@ -67,8 +66,6 @@ class FPN(nn.Module):
norm_cfg=None, norm_cfg=None,
act_cfg=None, act_cfg=None,
upsample_cfg=dict(mode='nearest')): upsample_cfg=dict(mode='nearest')):
# init_cfg=dict(
# type='Xavier', layer='Conv2d', distribution='uniform')):
super(FPN, self).__init__() super(FPN, self).__init__()
assert isinstance(in_channels, list) assert isinstance(in_channels, list)
self.in_channels = in_channels self.in_channels = in_channels

View File

@ -2,26 +2,12 @@
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from easycv.models.builder import NECKS from easycv.models.builder import NECKS
class Norm2d(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.ln = nn.LayerNorm(embed_dim, eps=1e-6)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = self.ln(x)
x = x.permute(0, 3, 1, 2).contiguous()
return x
@NECKS.register_module() @NECKS.register_module()
class SFP(BaseModule): class SFP(nn.Module):
r"""Simple Feature Pyramid. r"""Simple Feature Pyramid.
This is an implementation of paper `Exploring Plain Vision Transformer Backbones for Object Detection <https://arxiv.org/abs/2203.16527>`_. This is an implementation of paper `Exploring Plain Vision Transformer Backbones for Object Detection <https://arxiv.org/abs/2203.16527>`_.
Args: Args:
@ -32,25 +18,12 @@ class SFP(BaseModule):
build the feature pyramid. Default: 0. build the feature pyramid. Default: 0.
end_level (int): Index of the end input backbone level (exclusive) to end_level (int): Index of the end input backbone level (exclusive) to
build the feature pyramid. Default: -1, which means the last level. build the feature pyramid. Default: -1, which means the last level.
add_extra_convs (bool | str): If bool, it decides whether to add conv
layers on top of the original feature maps. Default to False.
If True, it is equivalent to `add_extra_convs='on_input'`.
If str, it specifies the source feature map of the extra convs.
Only the following options are allowed
- 'on_input': Last feat map of neck inputs (i.e. backbone feature).
- 'on_lateral': Last feature map after lateral convs.
- 'on_output': The last output feature map after fpn convs.
relu_before_extra_convs (bool): Whether to apply relu before the extra
conv. Default: False. conv. Default: False.
no_norm_on_lateral (bool): Whether to apply norm on lateral.
Default: False. Default: False.
conv_cfg (dict): Config dict for convolution layer. Default: None. conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Config dict for normalization layer. Default: None. norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (str): Config dict for activation layer in ConvModule. act_cfg (str): Config dict for activation layer in ConvModule.
Default: None. Default: None.
upsample_cfg (dict): Config dict for interpolate layer.
Default: `dict(mode='nearest')`
init_cfg (dict or list[dict], optional): Initialization config dict.
Example: Example:
>>> import torch >>> import torch
>>> in_channels = [2, 3, 5, 7] >>> in_channels = [2, 3, 5, 7]
@ -70,81 +43,53 @@ class SFP(BaseModule):
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
scale_factors,
num_outs, num_outs,
start_level=0,
end_level=-1,
add_extra_convs=False,
relu_before_extra_convs=False,
no_norm_on_lateral=False,
conv_cfg=None, conv_cfg=None,
norm_cfg=None, norm_cfg=None,
act_cfg=None, act_cfg=None):
upsample_cfg=dict(mode='nearest'), super(SFP, self).__init__()
init_cfg=[ dim = in_channels
dict(
type='Xavier',
layer=['Conv2d'],
distribution='uniform'),
dict(type='Constant', layer=['LayerNorm'], val=1, bias=0)
]):
super(SFP, self).__init__(init_cfg)
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.num_ins = len(in_channels) self.scale_factors = scale_factors
self.num_ins = len(scale_factors)
self.num_outs = num_outs self.num_outs = num_outs
self.relu_before_extra_convs = relu_before_extra_convs
self.no_norm_on_lateral = no_norm_on_lateral
self.upsample_cfg = upsample_cfg.copy()
if end_level == -1: self.stages = []
self.backbone_end_level = self.num_ins for idx, scale in enumerate(scale_factors):
assert num_outs >= self.num_ins - start_level out_dim = dim
if scale == 4.0:
layers = [
nn.ConvTranspose2d(dim, dim // 2, 2, stride=2, padding=0),
nn.GroupNorm(1, dim // 2, eps=1e-6),
nn.GELU(),
nn.ConvTranspose2d(
dim // 2, dim // 4, 2, stride=2, padding=0)
]
out_dim = dim // 4
elif scale == 2.0:
layers = [
nn.ConvTranspose2d(dim, dim // 2, 2, stride=2, padding=0)
]
out_dim = dim // 2
elif scale == 1.0:
layers = []
elif scale == 0.5:
layers = [nn.MaxPool2d(kernel_size=2, stride=2, padding=0)]
else: else:
# if end_level < inputs, no extra level is allowed raise NotImplementedError(
self.backbone_end_level = end_level f'scale_factor={scale} is not supported yet.')
assert end_level <= len(in_channels)
assert num_outs == end_level - start_level
self.start_level = start_level
self.end_level = end_level
self.add_extra_convs = add_extra_convs
assert isinstance(add_extra_convs, (str, bool))
if isinstance(add_extra_convs, str):
# Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
elif add_extra_convs: # True
self.add_extra_convs = 'on_input'
self.top_downs = nn.ModuleList() layers.extend([
self.lateral_convs = nn.ModuleList() ConvModule(
self.fpn_convs = nn.ModuleList() out_dim,
for i in range(self.start_level, self.backbone_end_level):
if i == 0:
top_down = nn.Sequential(
nn.ConvTranspose2d(
in_channels[i], in_channels[i], 2, stride=2,
padding=0), Norm2d(in_channels[i]), nn.GELU(),
nn.ConvTranspose2d(
in_channels[i], in_channels[i], 2, stride=2,
padding=0))
elif i == 1:
top_down = nn.ConvTranspose2d(
in_channels[i], in_channels[i], 2, stride=2, padding=0)
elif i == 2:
top_down = nn.Identity()
elif i == 3:
top_down = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
l_conv = ConvModule(
in_channels[i],
out_channels, out_channels,
1, 1,
conv_cfg=conv_cfg, conv_cfg=conv_cfg,
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, norm_cfg=norm_cfg,
act_cfg=act_cfg, act_cfg=act_cfg,
inplace=False) inplace=False),
fpn_conv = ConvModule( ConvModule(
out_channels, out_channels,
out_channels, out_channels,
3, 3,
@ -153,75 +98,28 @@ class SFP(BaseModule):
norm_cfg=norm_cfg, norm_cfg=norm_cfg,
act_cfg=act_cfg, act_cfg=act_cfg,
inplace=False) inplace=False)
])
self.top_downs.append(top_down) layers = nn.Sequential(*layers)
self.lateral_convs.append(l_conv) self.add_module(f'sfp_{idx}', layers)
self.fpn_convs.append(fpn_conv) self.stages.append(layers)
# add extra conv layers (e.g., RetinaNet) def init_weights(self):
extra_levels = num_outs - self.backbone_end_level + self.start_level pass
if self.add_extra_convs and extra_levels >= 1:
for i in range(extra_levels):
if i == 0 and self.add_extra_convs == 'on_input':
in_channels = self.in_channels[self.backbone_end_level - 1]
else:
in_channels = out_channels
extra_fpn_conv = ConvModule(
in_channels,
out_channels,
3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)
self.fpn_convs.append(extra_fpn_conv)
def forward(self, inputs): def forward(self, inputs):
"""Forward function.""" """Forward function."""
assert len(inputs) == 1 features = inputs[0]
outs = []
# build top-down path # part 1: build simple feature pyramid
features = [ for stage in self.stages:
top_down(inputs[0]) for _, top_down in enumerate(self.top_downs) outs.append(stage(features))
]
assert len(features) == len(self.in_channels)
# build laterals
laterals = [
lateral_conv(features[i + self.start_level])
for i, lateral_conv in enumerate(self.lateral_convs)
]
used_backbone_levels = len(laterals)
# build outputs
# part 1: from original levels
outs = [
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
]
# part 2: add extra levels # part 2: add extra levels
if self.num_outs > len(outs): if self.num_outs > self.num_ins:
# use max pool to get more levels on top of outputs # use max pool to get more levels on top of outputs
# (e.g., Faster R-CNN, Mask R-CNN) # (e.g., Faster R-CNN, Mask R-CNN)
if not self.add_extra_convs: for i in range(self.num_outs - self.num_ins):
for i in range(self.num_outs - used_backbone_levels):
outs.append(F.max_pool2d(outs[-1], 1, stride=2)) outs.append(F.max_pool2d(outs[-1], 1, stride=2))
# add conv layers on top of original feature maps (RetinaNet)
else:
if self.add_extra_convs == 'on_input':
extra_source = inputs[self.backbone_end_level - 1]
elif self.add_extra_convs == 'on_lateral':
extra_source = laterals[-1]
elif self.add_extra_convs == 'on_output':
extra_source = outs[-1]
else:
raise NotImplementedError
outs.append(self.fpn_convs[used_backbone_levels](extra_source))
for i in range(used_backbone_levels + 1, self.num_outs):
if self.relu_before_extra_convs:
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
else:
outs.append(self.fpn_convs[i](outs[-1]))
return tuple(outs) return tuple(outs)

View File

@ -253,11 +253,11 @@ class DetrPredictor(PredictorInterface):
img, img,
bboxes, bboxes,
labels=labels, labels=labels,
colors='green', colors='cyan',
text_color='white', text_color='cyan',
font_size=20, font_size=18,
thickness=1, thickness=2,
font_scale=0.5, font_scale=0.0,
show=show, show=show,
out_file=out_file) out_file=out_file)

View File

@ -14,18 +14,27 @@ class ViTDetTest(unittest.TestCase):
def test_vitdet(self): def test_vitdet(self):
model = ViTDet( model = ViTDet(
img_size=1024, img_size=1024,
patch_size=16,
embed_dim=768, embed_dim=768,
depth=12, depth=12,
num_heads=12, num_heads=12,
drop_path_rate=0.1,
window_size=14,
mlp_ratio=4, mlp_ratio=4,
qkv_bias=True, qkv_bias=True,
qk_scale=None, window_block_indexes=[
drop_rate=0., # 2, 5, 8 11 for global attention
attn_drop_rate=0., 0,
drop_path_rate=0.1, 1,
use_abs_pos_emb=True, 3,
aggregation='attn', 4,
) 6,
7,
9,
10,
],
residual_block_indexes=[],
use_rel_pos=True)
model.init_weights() model.init_weights()
model.train() model.train()

View File

@ -155,7 +155,7 @@ class DetectorTest(unittest.TestCase):
decimal=1) decimal=1)
def test_vitdet_detector(self): def test_vitdet_detector(self):
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/vitdet_maskrcnn_export.pth' model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100_export.pth'
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg' img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
out_file = './result.jpg' out_file = './result.jpg'
vitdet = DetrPredictor(model_path) vitdet = DetrPredictor(model_path)
@ -167,63 +167,170 @@ class DetectorTest(unittest.TestCase):
self.assertIn('detection_classes', output) self.assertIn('detection_classes', output)
self.assertIn('detection_masks', output) self.assertIn('detection_masks', output)
self.assertIn('img_metas', output) self.assertIn('img_metas', output)
self.assertEqual(len(output['detection_boxes'][0]), 30) self.assertEqual(len(output['detection_boxes'][0]), 33)
self.assertEqual(len(output['detection_scores'][0]), 30) self.assertEqual(len(output['detection_scores'][0]), 33)
self.assertEqual(len(output['detection_classes'][0]), 30) self.assertEqual(len(output['detection_classes'][0]), 33)
self.assertListEqual( self.assertListEqual(
output['detection_classes'][0].tolist(), output['detection_classes'][0].tolist(),
np.array([ np.array([
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 7, 7, 13, 13, 13, 56 2, 2, 2, 2, 2, 2, 7, 7, 13, 13, 13, 56
], ],
dtype=np.int32).tolist()) dtype=np.int32).tolist())
assert_array_almost_equal( assert_array_almost_equal(
output['detection_scores'][0], output['detection_scores'][0],
np.array([ np.array([
0.99791867, 0.99665856, 0.99480623, 0.99060905, 0.9882515, 0.9975854158401489, 0.9965696334838867, 0.9922919869422913,
0.98319584, 0.9738879, 0.97290784, 0.9514897, 0.95104814, 0.9833580851554871, 0.983080267906189, 0.970454752445221,
0.9321701, 0.86165, 0.8228847, 0.7623552, 0.76129806, 0.9701289534568787, 0.9649872183799744, 0.9642795324325562,
0.6050861, 0.44348577, 0.3452973, 0.2895671, 0.22109479, 0.9642238020896912, 0.9529680609703064, 0.9403366446495056,
0.21265312, 0.17855245, 0.1205352, 0.08981906, 0.10596471, 0.9391788244247437, 0.8941807150840759, 0.8178097009658813,
0.05854294, 0.99749386, 0.9472857, 0.5945908, 0.09855112 0.8013413548469543, 0.6677654385566711, 0.3952914774417877,
0.33463895320892334, 0.32501447200775146, 0.27323535084724426,
0.20197080075740814, 0.15607696771621704, 0.1068163588643074,
0.10183875262737274, 0.09735643863677979, 0.06559795141220093,
0.08890066295862198, 0.076363705098629, 0.9954648613929749,
0.9212945699691772, 0.5224372148513794, 0.20555885136127472
], ],
dtype=np.float32), dtype=np.float32),
decimal=2) decimal=2)
assert_array_almost_equal( assert_array_almost_equal(
output['detection_boxes'][0], output['detection_boxes'][0],
np.array([[294.7058, 117.29371, 378.83713, 149.99928], np.array([[
[609.05444, 112.526474, 633.2971, 136.35175], 294.22674560546875, 116.6078109741211, 379.4328918457031,
[481.4165, 110.987335, 522.5531, 130.01529], 150.14097595214844
[167.68184, 109.89049, 215.49057, 139.86987], ],
[374.75082, 110.68697, 433.10028, 136.23654], [
[189.54971, 110.09322, 297.6167, 155.77412], 482.6017761230469, 110.75955963134766,
[266.5185, 105.37718, 326.54385, 127.916374], 522.8798828125, 129.71286010742188
[556.30225, 110.43166, 592.8248, 128.03764], ],
[432.49252, 105.086464, 484.0512, 132.272], [
[0., 110.566444, 62.01249, 146.44017], 167.06460571289062, 109.95974731445312,
[591.74664, 110.43527, 619.73816, 126.68549], 212.83975219726562, 140.16102600097656
[99.126854, 90.947975, 118.46699, 101.11096], ],
[59.895264, 94.110054, 85.60521, 106.67633], [
[142.95819, 96.61966, 165.96964, 104.95929], 609.2930908203125, 113.13909149169922,
[83.062515, 89.802605, 99.1546, 98.69074], 637.3115844726562, 136.4690704345703
[226.28802, 98.32568, 249.06772, 108.86408], ],
[136.67789, 94.75706, 154.62924, 104.289536], [
[170.42459, 98.458694, 183.16309, 106.203156], 191.185791015625, 111.1408920288086, 301.31689453125,
[67.56731, 89.68286, 82.62955, 98.35645], 155.7731170654297
[222.80092, 97.828445, 239.02655, 108.29377], ],
[134.34427, 92.31653, 149.19615, 102.97457], [
[613.5186, 102.27066, 636.0434, 112.813644], 431.2244873046875, 106.19962310791016,
[607.4787, 110.87984, 630.1123, 127.65646], 483.860595703125, 132.21627807617188
[135.13664, 90.989876, 155.67192, 100.18036], ],
[431.61505, 105.43844, 484.36508, 132.50078], [
[189.92722, 110.38832, 297.74353, 155.95557], 267.48358154296875, 105.5920639038086,
[220.67035, 177.13489, 455.32092, 380.45712], 325.2832336425781, 127.11176300048828
[372.76584, 134.33807, 432.44357, 188.51534], ],
[50.403812, 110.543495, 70.4368, 119.65186], [
[373.50272, 134.27258, 432.18475, 187.81824]]), 591.2138671875, 110.29329681396484,
619.8524169921875, 126.1990966796875
],
[
0.0, 110.7026596069336, 61.487945556640625,
146.33018493652344
],
[
555.9155883789062, 110.03486633300781,
591.7050170898438, 127.06097412109375
],
[
60.24559783935547, 94.12760162353516,
85.63741302490234, 106.66705322265625
],
[
99.02665710449219, 90.53657531738281,
118.83953094482422, 101.18717956542969
],
[
396.30438232421875, 111.59194946289062,
431.559814453125, 133.96914672851562
],
[
83.81543731689453, 89.65665435791016,
99.9166259765625, 98.25627899169922
],
[
139.29647827148438, 96.68000793457031,
165.22410583496094, 105.60000610351562
],
[
67.27152252197266, 89.42798614501953,
83.25617980957031, 98.0460205078125
],
[
223.74176025390625, 98.68321990966797,
250.42506408691406, 109.32588958740234
],
[
136.7582244873047, 96.51412963867188,
152.51190185546875, 104.73160552978516
],
[
221.71812438964844, 97.86445617675781,
238.9705810546875, 106.96803283691406
],
[
135.06964111328125, 91.80916595458984, 155.24609375,
102.20686340332031
],
[
169.11180114746094, 97.53628540039062,
182.88504028320312, 105.95404815673828
],
[
133.8811798095703, 91.00375366210938,
145.35507202148438, 102.3780288696289
],
[
614.2507934570312, 102.19828796386719,
636.5692749023438, 112.59198760986328
],
[
35.94759750366211, 91.7213363647461,
70.38274383544922, 117.19855499267578
],
[
554.6401977539062, 115.18976593017578,
562.0255737304688, 127.4429931640625
],
[
39.07550811767578, 92.73261260986328,
85.36636352539062, 106.73953247070312
],
[
200.85513305664062, 93.00469970703125,
219.73086547851562, 107.99642181396484
],
[
0.0, 111.18904876708984, 61.7393684387207,
146.72547912597656
],
[
191.88568115234375, 111.09577178955078,
299.4097900390625, 155.14639282226562
],
[
221.06834411621094, 176.6427001953125,
458.3475341796875, 378.89300537109375
],
[
372.7131652832031, 135.51429748535156,
433.2494201660156, 188.0106658935547
],
[
52.19819641113281, 110.3646011352539,
70.95110321044922, 120.10567474365234
],
[
376.1671447753906, 133.6930694580078,
432.2721862792969, 187.99481201171875
]]),
decimal=1) decimal=1)