diff --git a/benchmarks/selfsup/detection/coco/mask_rcnn_r50_fpn_1x_coco.py b/benchmarks/selfsup/detection/coco/mask_rcnn_r50_fpn_1x_coco.py new file mode 100644 index 00000000..ffd11de7 --- /dev/null +++ b/benchmarks/selfsup/detection/coco/mask_rcnn_r50_fpn_1x_coco.py @@ -0,0 +1,264 @@ +_base_ = ['configs/base.py'] + +CLASSES = [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', + 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', + 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', + 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', + 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush' +] + +norm_cfg = dict(type='SyncBN', requires_grad=True) +# model settings +model = dict( + type='MaskRCNN', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(1, 2, 3, 4), + frozen_stages=-1, + norm_cfg=norm_cfg, + norm_eval=False), + # mmdet ResNet + # backbone=dict( + # type='ResNet', + # depth=50, + # num_stages=4, + # out_indices=(0, 1, 2, 3), + # frozen_stages=-1, + # norm_cfg=norm_cfg, + # norm_eval=False, + # style='pytorch', + # init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + norm_cfg=norm_cfg, + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared4Conv1FCBBoxHead', + norm_cfg=norm_cfg, + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + norm_cfg=norm_cfg, + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=80, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100, + mask_thr_binary=0.5))) + +mmlab_modules = [ + dict(type='mmdet', name='MaskRCNN', module='model'), + # dict(type=MMDET, name='ResNet', module='backbone'), # comment out, use EasyCV ResNet + dict(type='mmdet', name='FPN', module='neck'), + dict(type='mmdet', name='RPNHead', module='head'), + dict(type='mmdet', name='StandardRoIHead', module='head'), +] + +# dataset settings +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +train_pipeline = [ + dict( + type='MMResize', + img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), + (1333, 768), (1333, 800)], + multiscale_mode='value', + keep_ratio=True), + dict(type='MMRandomFlip', flip_ratio=0.5), + dict(type='MMNormalize', **img_norm_cfg), + dict(type='MMPad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict( + type='Collect', + keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks'], + meta_keys=('filename', 'ori_filename', 'ori_shape', 'ori_img_shape', + 'img_shape', 'pad_shape', 'scale_factor', 'flip', + 'flip_direction', 'img_norm_cfg')) +] + +test_pipeline = [ + dict( + type='MMMultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='MMResize', keep_ratio=True), + dict(type='MMRandomFlip'), + dict(type='MMNormalize', **img_norm_cfg), + dict(type='MMPad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict( + type='Collect', + keys=['img'], + meta_keys=('filename', 'ori_filename', 'ori_shape', + 'ori_img_shape', 'img_shape', 'pad_shape', + 'scale_factor', 'flip', 'flip_direction', + 'img_norm_cfg')), + ]) +] + +train_dataset = dict( + type='DetDataset', + data_source=dict( + type='DetSourceCoco', + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=[ + dict(type='LoadImageFromFile', to_float32=True), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True) + ], + classes=CLASSES, + filter_empty_gt=True, + iscrowd=False, + ), + pipeline=train_pipeline) + +val_dataset = dict( + type='DetDataset', + imgs_per_gpu=1, + data_source=dict( + type='DetSourceCoco', + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=[ + dict(type='LoadImageFromFile', to_float32=True), + dict(type='LoadAnnotations', with_bbox=True) + ], + classes=CLASSES, + test_mode=True, + iscrowd=True), + pipeline=test_pipeline) + +data = dict( + imgs_per_gpu=2, workers_per_gpu=2, train=train_dataset, val=val_dataset) + +checkpoint_config = dict(interval=1) +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=1000, + warmup_ratio=0.001, + step=[8, 11]) +total_epochs = 12 + +# evaluation +eval_config = dict(interval=1, gpu_collect=False) +eval_pipelines = [ + dict( + mode='test', + evaluators=[ + dict(type='CocoDetectionEvaluator', classes=CLASSES), + dict(type='CocoMaskEvaluator', classes=CLASSES) + ], + ) +] diff --git a/docs/source/model_zoo_ssl.md b/docs/source/model_zoo_ssl.md index da800b48..0c0f0d7a 100644 --- a/docs/source/model_zoo_ssl.md +++ b/docs/source/model_zoo_ssl.md @@ -64,3 +64,10 @@ For detailed usage of benchmark tools, please refer to benchmark [README.md](../ | **MAE** | [mae_vit_base_patch16_8xb64_100e_lrdecay075_fintune](../../benchmarks/selfsup/classification/imagenet/mae_vit_base_patch16_8xb64_100e_lrdecay075_fintune.py) | [mae_vit_base_patch16_8xb64_400e](https://github.com/alibaba/EasyCV/tree/master/configs/selfsup/mae/mae_vit_base_patch16_8xb64_400e.py) | 83.13 | [fintune model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-400/fintune_400.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-400/20220126_171312.log.json)| | | [mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune](../../benchmarks/selfsup/classification/imagenet/mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune.py) | [mae_vit_base_patch16_8xb64_1600e](https://github.com/alibaba/EasyCV/tree/master/configs/selfsup/mae/mae_vit_base_patch16_8xb64_1600e.py) | 83.55 | [fintune model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/fintune_1600.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-b-1600/20220426_101532.log.json)| | | [mae_vit_large_patch16_8xb16_50e_lrdecay075_fintune](../../benchmarks/selfsup/classification/imagenet/mae_vit_large_patch16_8xb16_50e_lrdecay075_fintune.py) | [mae_vit_large_patch16_8xb32_1600e](https://github.com/alibaba/EasyCV/tree/master/configs/selfsup/mae/mae_vit_large_patch16_8xb32_1600e.py) | 85.70 | [fintune model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-l-1600/fintune_1600.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/mae/vit-l-1600/20220427_150629.log.json)| + +### COCO2017 Object Detection + +| Algorithm | Eval Config | Pretrained Config | mAP (Box) | mAP (Mask) | Download | +| --------- | ------------------------------------------------------------ | ------------------------------------------------------------ | --------- | ---------- | ------------------------------------------------------------ | +| SwAV | [mask_rcnn_r50_fpn_1x_coco](https://github.com/alibaba/EasyCV/tree/master/benchmarks/selfsup/detection/coco/mask_rcnn_r50_fpn_1x_coco.py) | [swav_resnet50_8xb32_200e](https://github.com/alibaba/EasyCV/tree/master/configs/selfsup/swav/swav_rn50_8xb32_200e_tfrecord.py) | 40.38 | 36.48 | [eval model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/benchmarks/detection/mask_rcnn_r50_fpn/mocov2_r50/epoch_12.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/benchmarks/detection/mask_rcnn_r50_fpn/mocov2_r50/20220510_164934.log.json) | +| MoCo-v2 | [mask_rcnn_r50_fpn_1x_coco](https://github.com/alibaba/EasyCV/tree/master/benchmarks/selfsup/detection/coco/mask_rcnn_r50_fpn_1x_coco.py) | [mocov2_resnet50_8xb32_200e](https://github.com/alibaba/EasyCV/tree/master/configs/selfsup/mocov2/mocov2_rn50_8xb32_200e_tfrecord.py) | 39.9 | 35.8 | [eval model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/benchmarks/detection/mask_rcnn_r50_fpn/swav_r50/epoch_12.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/selfsup/benchmarks/detection/mask_rcnn_r50_fpn/swav_r50/20220513_142102.log.json) | diff --git a/easycv/utils/mmlab_utils.py b/easycv/utils/mmlab_utils.py index 9d5ad9f4..d0995f0b 100644 --- a/easycv/utils/mmlab_utils.py +++ b/easycv/utils/mmlab_utils.py @@ -148,9 +148,19 @@ class MMDetWrapper: def wrap_module(self, cls, module_type): if module_type == 'model': + self._wrap_model_init(cls) self._wrap_model_forward(cls) self._wrap_model_forward_test(cls) + def _wrap_model_init(self, cls): + origin_init = cls.__init__ + + def _new_init(self, *args, **kwargs): + origin_init(self, *args, **kwargs) + self.init_weights() + + setattr(cls, '__init__', _new_init) + def _wrap_model_forward(self, cls): origin_forward = cls.forward