mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
[feature] add ssl detection benchmark (#58)
* add ssl detection benchmark * wrape mmdel model with adding init_weights()
This commit is contained in:
parent
ecbfbbb359
commit
14af96e21d
264
benchmarks/selfsup/detection/coco/mask_rcnn_r50_fpn_1x_coco.py
Normal file
264
benchmarks/selfsup/detection/coco/mask_rcnn_r50_fpn_1x_coco.py
Normal file
@ -0,0 +1,264 @@
|
||||
_base_ = ['configs/base.py']
|
||||
|
||||
CLASSES = [
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
||||
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
||||
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
||||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
||||
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
|
||||
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
||||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
|
||||
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
|
||||
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
|
||||
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
||||
'hair drier', 'toothbrush'
|
||||
]
|
||||
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
# model settings
|
||||
model = dict(
|
||||
type='MaskRCNN',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(1, 2, 3, 4),
|
||||
frozen_stages=-1,
|
||||
norm_cfg=norm_cfg,
|
||||
norm_eval=False),
|
||||
# mmdet ResNet
|
||||
# backbone=dict(
|
||||
# type='ResNet',
|
||||
# depth=50,
|
||||
# num_stages=4,
|
||||
# out_indices=(0, 1, 2, 3),
|
||||
# frozen_stages=-1,
|
||||
# norm_cfg=norm_cfg,
|
||||
# norm_eval=False,
|
||||
# style='pytorch',
|
||||
# init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
|
||||
neck=dict(
|
||||
type='FPN',
|
||||
norm_cfg=norm_cfg,
|
||||
in_channels=[256, 512, 1024, 2048],
|
||||
out_channels=256,
|
||||
num_outs=5),
|
||||
rpn_head=dict(
|
||||
type='RPNHead',
|
||||
in_channels=256,
|
||||
feat_channels=256,
|
||||
anchor_generator=dict(
|
||||
type='AnchorGenerator',
|
||||
scales=[8],
|
||||
ratios=[0.5, 1.0, 2.0],
|
||||
strides=[4, 8, 16, 32, 64]),
|
||||
bbox_coder=dict(
|
||||
type='DeltaXYWHBBoxCoder',
|
||||
target_means=[.0, .0, .0, .0],
|
||||
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
||||
loss_cls=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
||||
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
||||
roi_head=dict(
|
||||
type='StandardRoIHead',
|
||||
bbox_roi_extractor=dict(
|
||||
type='SingleRoIExtractor',
|
||||
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
||||
out_channels=256,
|
||||
featmap_strides=[4, 8, 16, 32]),
|
||||
bbox_head=dict(
|
||||
type='Shared4Conv1FCBBoxHead',
|
||||
norm_cfg=norm_cfg,
|
||||
in_channels=256,
|
||||
fc_out_channels=1024,
|
||||
roi_feat_size=7,
|
||||
num_classes=80,
|
||||
bbox_coder=dict(
|
||||
type='DeltaXYWHBBoxCoder',
|
||||
target_means=[0., 0., 0., 0.],
|
||||
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
||||
reg_class_agnostic=False,
|
||||
loss_cls=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
||||
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
||||
mask_roi_extractor=dict(
|
||||
type='SingleRoIExtractor',
|
||||
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
|
||||
out_channels=256,
|
||||
featmap_strides=[4, 8, 16, 32]),
|
||||
mask_head=dict(
|
||||
type='FCNMaskHead',
|
||||
norm_cfg=norm_cfg,
|
||||
num_convs=4,
|
||||
in_channels=256,
|
||||
conv_out_channels=256,
|
||||
num_classes=80,
|
||||
loss_mask=dict(
|
||||
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(
|
||||
rpn=dict(
|
||||
assigner=dict(
|
||||
type='MaxIoUAssigner',
|
||||
pos_iou_thr=0.7,
|
||||
neg_iou_thr=0.3,
|
||||
min_pos_iou=0.3,
|
||||
match_low_quality=True,
|
||||
ignore_iof_thr=-1),
|
||||
sampler=dict(
|
||||
type='RandomSampler',
|
||||
num=256,
|
||||
pos_fraction=0.5,
|
||||
neg_pos_ub=-1,
|
||||
add_gt_as_proposals=False),
|
||||
allowed_border=-1,
|
||||
pos_weight=-1,
|
||||
debug=False),
|
||||
rpn_proposal=dict(
|
||||
nms_pre=2000,
|
||||
max_per_img=1000,
|
||||
nms=dict(type='nms', iou_threshold=0.7),
|
||||
min_bbox_size=0),
|
||||
rcnn=dict(
|
||||
assigner=dict(
|
||||
type='MaxIoUAssigner',
|
||||
pos_iou_thr=0.5,
|
||||
neg_iou_thr=0.5,
|
||||
min_pos_iou=0.5,
|
||||
match_low_quality=True,
|
||||
ignore_iof_thr=-1),
|
||||
sampler=dict(
|
||||
type='RandomSampler',
|
||||
num=512,
|
||||
pos_fraction=0.25,
|
||||
neg_pos_ub=-1,
|
||||
add_gt_as_proposals=True),
|
||||
mask_size=28,
|
||||
pos_weight=-1,
|
||||
debug=False)),
|
||||
test_cfg=dict(
|
||||
rpn=dict(
|
||||
nms_pre=1000,
|
||||
max_per_img=1000,
|
||||
nms=dict(type='nms', iou_threshold=0.7),
|
||||
min_bbox_size=0),
|
||||
rcnn=dict(
|
||||
score_thr=0.05,
|
||||
nms=dict(type='nms', iou_threshold=0.5),
|
||||
max_per_img=100,
|
||||
mask_thr_binary=0.5)))
|
||||
|
||||
mmlab_modules = [
|
||||
dict(type='mmdet', name='MaskRCNN', module='model'),
|
||||
# dict(type=MMDET, name='ResNet', module='backbone'), # comment out, use EasyCV ResNet
|
||||
dict(type='mmdet', name='FPN', module='neck'),
|
||||
dict(type='mmdet', name='RPNHead', module='head'),
|
||||
dict(type='mmdet', name='StandardRoIHead', module='head'),
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
data_root = 'data/coco/'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(
|
||||
type='MMResize',
|
||||
img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736),
|
||||
(1333, 768), (1333, 800)],
|
||||
multiscale_mode='value',
|
||||
keep_ratio=True),
|
||||
dict(type='MMRandomFlip', flip_ratio=0.5),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='MMPad', size_divisor=32),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks'],
|
||||
meta_keys=('filename', 'ori_filename', 'ori_shape', 'ori_img_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction', 'img_norm_cfg'))
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(
|
||||
type='MMMultiScaleFlipAug',
|
||||
img_scale=(1333, 800),
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='MMResize', keep_ratio=True),
|
||||
dict(type='MMRandomFlip'),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='MMPad', size_divisor=32),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=('filename', 'ori_filename', 'ori_shape',
|
||||
'ori_img_shape', 'img_shape', 'pad_shape',
|
||||
'scale_factor', 'flip', 'flip_direction',
|
||||
'img_norm_cfg')),
|
||||
])
|
||||
]
|
||||
|
||||
train_dataset = dict(
|
||||
type='DetDataset',
|
||||
data_source=dict(
|
||||
type='DetSourceCoco',
|
||||
ann_file=data_root + 'annotations/instances_train2017.json',
|
||||
img_prefix=data_root + 'train2017/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile', to_float32=True),
|
||||
dict(type='LoadAnnotations', with_bbox=True, with_mask=True)
|
||||
],
|
||||
classes=CLASSES,
|
||||
filter_empty_gt=True,
|
||||
iscrowd=False,
|
||||
),
|
||||
pipeline=train_pipeline)
|
||||
|
||||
val_dataset = dict(
|
||||
type='DetDataset',
|
||||
imgs_per_gpu=1,
|
||||
data_source=dict(
|
||||
type='DetSourceCoco',
|
||||
ann_file=data_root + 'annotations/instances_val2017.json',
|
||||
img_prefix=data_root + 'val2017/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile', to_float32=True),
|
||||
dict(type='LoadAnnotations', with_bbox=True)
|
||||
],
|
||||
classes=CLASSES,
|
||||
test_mode=True,
|
||||
iscrowd=True),
|
||||
pipeline=test_pipeline)
|
||||
|
||||
data = dict(
|
||||
imgs_per_gpu=2, workers_per_gpu=2, train=train_dataset, val=val_dataset)
|
||||
|
||||
checkpoint_config = dict(interval=1)
|
||||
# optimizer
|
||||
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='step',
|
||||
warmup='linear',
|
||||
warmup_iters=1000,
|
||||
warmup_ratio=0.001,
|
||||
step=[8, 11])
|
||||
total_epochs = 12
|
||||
|
||||
# evaluation
|
||||
eval_config = dict(interval=1, gpu_collect=False)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
evaluators=[
|
||||
dict(type='CocoDetectionEvaluator', classes=CLASSES),
|
||||
dict(type='CocoMaskEvaluator', classes=CLASSES)
|
||||
],
|
||||
)
|
||||
]
|
@ -64,3 +64,10 @@ For detailed usage of benchmark tools, please refer to benchmark [README.md](../
|
||||
| **MAE** | [mae_vit_base_patch16_8xb64_100e_lrdecay075_fintune](../../benchmarks/selfsup/classification/imagenet/mae_vit_base_patch16_8xb64_100e_lrdecay075_fintune.py) | [mae_vit_base_patch16_8xb64_400e](https://github.com/alibaba/EasyCV/tree/master/configs/selfsup/mae/mae_vit_base_patch16_8xb64_400e.py) | 83.13 | [fintune model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-400/fintune_400.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-400/20220126_171312.log.json)|
|
||||
| | [mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune](../../benchmarks/selfsup/classification/imagenet/mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune.py) | [mae_vit_base_patch16_8xb64_1600e](https://github.com/alibaba/EasyCV/tree/master/configs/selfsup/mae/mae_vit_base_patch16_8xb64_1600e.py) | 83.55 | [fintune model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/fintune_1600.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/20220426_101532.log.json)|
|
||||
| | [mae_vit_large_patch16_8xb16_50e_lrdecay075_fintune](../../benchmarks/selfsup/classification/imagenet/mae_vit_large_patch16_8xb16_50e_lrdecay075_fintune.py) | [mae_vit_large_patch16_8xb32_1600e](https://github.com/alibaba/EasyCV/tree/master/configs/selfsup/mae/mae_vit_large_patch16_8xb32_1600e.py) | 85.70 | [fintune model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-l-1600/fintune_1600.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-l-1600/20220427_150629.log.json)|
|
||||
|
||||
### COCO2017 Object Detection
|
||||
|
||||
| Algorithm | Eval Config | Pretrained Config | mAP (Box) | mAP (Mask) | Download |
|
||||
| --------- | ------------------------------------------------------------ | ------------------------------------------------------------ | --------- | ---------- | ------------------------------------------------------------ |
|
||||
| SwAV | [mask_rcnn_r50_fpn_1x_coco](https://github.com/alibaba/EasyCV/tree/master/benchmarks/selfsup/detection/coco/mask_rcnn_r50_fpn_1x_coco.py) | [swav_resnet50_8xb32_200e](https://github.com/alibaba/EasyCV/tree/master/configs/selfsup/swav/swav_rn50_8xb32_200e_tfrecord.py) | 40.38 | 36.48 | [eval model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/benchmarks/detection/mask_rcnn_r50_fpn/mocov2_r50/epoch_12.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/benchmarks/detection/mask_rcnn_r50_fpn/mocov2_r50/20220510_164934.log.json) |
|
||||
| MoCo-v2 | [mask_rcnn_r50_fpn_1x_coco](https://github.com/alibaba/EasyCV/tree/master/benchmarks/selfsup/detection/coco/mask_rcnn_r50_fpn_1x_coco.py) | [mocov2_resnet50_8xb32_200e](https://github.com/alibaba/EasyCV/tree/master/configs/selfsup/mocov2/mocov2_rn50_8xb32_200e_tfrecord.py) | 39.9 | 35.8 | [eval model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/benchmarks/detection/mask_rcnn_r50_fpn/swav_r50/epoch_12.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/benchmarks/detection/mask_rcnn_r50_fpn/swav_r50/20220513_142102.log.json) |
|
||||
|
@ -148,9 +148,19 @@ class MMDetWrapper:
|
||||
|
||||
def wrap_module(self, cls, module_type):
|
||||
if module_type == 'model':
|
||||
self._wrap_model_init(cls)
|
||||
self._wrap_model_forward(cls)
|
||||
self._wrap_model_forward_test(cls)
|
||||
|
||||
def _wrap_model_init(self, cls):
|
||||
origin_init = cls.__init__
|
||||
|
||||
def _new_init(self, *args, **kwargs):
|
||||
origin_init(self, *args, **kwargs)
|
||||
self.init_weights()
|
||||
|
||||
setattr(cls, '__init__', _new_init)
|
||||
|
||||
def _wrap_model_forward(self, cls):
|
||||
origin_forward = cls.forward
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user