diff --git a/README.md b/README.md index d87453b..e7dbe74 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,17 @@ Humans can recognize novel objects in this image despite having never seen them
+## Performance on COCO-Split setting + +### VOC to non-VOC cross-category generalization + +We train OLN on COCO VOC categories, and test on non-VOC categories. Note our AR@k evaluation does not count those proposals on the 'seen' classes into the budget (k), to avoid evaluating recall on see-class objects. + +| Method | AUC | AR@10 | AR@30 | AR@100 | AR@300 | AR@1000 | +|:--------------:|:-----:|:-----:|:-----:|:------:|:------:|:-------:| +| OLN-Box | 24.8 | 18.1 | 26.5 | 33.5 | 39.0 | 45.0 | + + ## Disclaimer diff --git a/configs/oln_box/class_agn_faster_rcnn.py b/configs/oln_box/class_agn_faster_rcnn.py new file mode 100644 index 0000000..0f8a2a9 --- /dev/null +++ b/configs/oln_box/class_agn_faster_rcnn.py @@ -0,0 +1,243 @@ +model = dict( + type='FasterRCNN', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[1.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0.0, 0.0, 0.0, 0.0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0.0, 0.0, 0.0, 0.0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_num=2000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + score_thr=0.0, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=1500, + ))) +dataset_type = 'CocoSplitDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict( + type='Normalize', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict( + type='Normalize', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type='CocoSplitDataset', + ann_file='data/coco/annotations/instances_train2017.json', + img_prefix='data/coco/train2017/', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict( + type='Normalize', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) + ], + is_class_agnostic=True, + train_class='voc', + eval_class='nonvoc'), + val=dict( + type='CocoSplitDataset', + ann_file='data/coco/annotations/instances_val2017.json', + img_prefix='data/coco/val2017/', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict( + type='Normalize', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ], + is_class_agnostic=True, + train_class='voc', + eval_class='nonvoc'), + test=dict( + type='CocoSplitDataset', + ann_file='data/coco/annotations/instances_val2017.json', + img_prefix='data/coco/val2017/', + pipeline=[ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict( + type='Normalize', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) + ], + is_class_agnostic=True, + train_class='voc', + eval_class='nonvoc')) +evaluation = dict(interval=1, metric='bbox') +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[6, 7]) +total_epochs = 8 +checkpoint_config = dict(interval=2) +log_config = dict( + interval=10, + hooks=[dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook')]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +work_dir = './work_dirs/frcnn/' +gpu_ids = range(0, 8) diff --git a/mmdet/models/dense_heads/__pycache__/oln_rpn_head.cpython-37.pyc b/mmdet/models/dense_heads/__pycache__/oln_rpn_head.cpython-37.pyc index 7e16317..f025952 100644 Binary files a/mmdet/models/dense_heads/__pycache__/oln_rpn_head.cpython-37.pyc and b/mmdet/models/dense_heads/__pycache__/oln_rpn_head.cpython-37.pyc differ