diff --git a/requirements/tests.txt b/requirements/tests.txt index 4babae51..29d351b5 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -2,6 +2,7 @@ codecov flake8 interrogate isort==4.3.21 +mmdet pytest xdoctest >= 0.10.0 yapf diff --git a/setup.cfg b/setup.cfg index 40ae0ed9..28b3b92e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = mmcls -known_third_party = PIL,cv2,matplotlib,mmcv,numpy,onnxruntime,pytest,seaborn,torch,torchvision,ts +known_third_party = PIL,cv2,matplotlib,mmcv,mmdet,numpy,onnxruntime,pytest,seaborn,torch,torchvision,ts no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tests/data/retinanet.py b/tests/data/retinanet.py new file mode 100644 index 00000000..56e43fa7 --- /dev/null +++ b/tests/data/retinanet.py @@ -0,0 +1,60 @@ +# model settings +model = dict( + type='RetinaNet', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_input', + num_outs=5), + bbox_head=dict( + type='RetinaHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + # model training and testing settings + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100)) diff --git a/tests/test_mmdet_inference.py b/tests/test_mmdet_inference.py new file mode 100644 index 00000000..152f7893 --- /dev/null +++ b/tests/test_mmdet_inference.py @@ -0,0 +1,60 @@ +from mmdet.models import build_detector + +from mmcls.models import (MobileNetV2, MobileNetV3, RegNet, ResNeSt, ResNet, + ResNeXt, SEResNet, SEResNeXt, SwinTransformer) + +backbone_configs = dict( + mobilenetv2=dict( + backbone=dict( + type='mmcls.MobileNetV2', + widen_factor=1.0, + norm_cfg=dict(type='GN', num_groups=2, requires_grad=True), + out_indices=(4, 7))), + mobilenetv3=dict( + backbone=dict( + type='mmcls.MobileNetV3', + norm_cfg=dict(type='GN', num_groups=2, requires_grad=True), + out_indices=range(7, 12))), + regnet=dict(backbone=dict(type='mmcls.RegNet', arch='regnetx_400mf')), + resnext=dict( + backbone=dict( + type='mmcls.ResNeXt', depth=50, groups=32, width_per_group=4)), + resnet=dict(backbone=dict(type='mmcls.ResNet', depth=50)), + seresnet=dict(backbone=dict(type='mmcls.SEResNet', depth=50)), + seresnext=dict( + backbone=dict( + type='mmcls.SEResNeXt', depth=50, groups=32, width_per_group=4)), + resnest=dict( + backbone=dict( + type='mmcls.ResNeSt', + depth=50, + radix=2, + reduction_factor=4, + out_indices=(0, 1, 2, 3))), + swin=dict( + backbone=dict( + type='mmcls.SwinTransformer', arch='small', drop_path_rate=0.2))) + +module_mapping = { + 'mobilenetv2': MobileNetV2, + 'mobilenetv3': MobileNetV3, + 'regnet': RegNet, + 'resnext': ResNeXt, + 'resnet': ResNet, + 'seresnext': SEResNeXt, + 'seresnet': SEResNet, + 'resnest': ResNeSt, + 'swin': SwinTransformer +} + + +def test_mmdet_inference(): + from mmcv import Config + config_path = './tests/data/retinanet.py' + config = Config.fromfile(config_path) + + for module_name, backbone_config in backbone_configs.items(): + config.model.backbone = backbone_config['backbone'] + model = build_detector(config.model) + module = module_mapping[module_name] + assert isinstance(model.backbone, module)