mmpretrain/tests/test_mmdet_inference.py

61 lines
2.0 KiB
Python

from mmdet.models import build_detector
from mmcls.models import (MobileNetV2, MobileNetV3, RegNet, ResNeSt, ResNet,
ResNeXt, SEResNet, SEResNeXt, SwinTransformer)
backbone_configs = dict(
mobilenetv2=dict(
backbone=dict(
type='mmcls.MobileNetV2',
widen_factor=1.0,
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True),
out_indices=(4, 7))),
mobilenetv3=dict(
backbone=dict(
type='mmcls.MobileNetV3',
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True),
out_indices=range(7, 12))),
regnet=dict(backbone=dict(type='mmcls.RegNet', arch='regnetx_400mf')),
resnext=dict(
backbone=dict(
type='mmcls.ResNeXt', depth=50, groups=32, width_per_group=4)),
resnet=dict(backbone=dict(type='mmcls.ResNet', depth=50)),
seresnet=dict(backbone=dict(type='mmcls.SEResNet', depth=50)),
seresnext=dict(
backbone=dict(
type='mmcls.SEResNeXt', depth=50, groups=32, width_per_group=4)),
resnest=dict(
backbone=dict(
type='mmcls.ResNeSt',
depth=50,
radix=2,
reduction_factor=4,
out_indices=(0, 1, 2, 3))),
swin=dict(
backbone=dict(
type='mmcls.SwinTransformer', arch='small', drop_path_rate=0.2)))
module_mapping = {
'mobilenetv2': MobileNetV2,
'mobilenetv3': MobileNetV3,
'regnet': RegNet,
'resnext': ResNeXt,
'resnet': ResNet,
'seresnext': SEResNeXt,
'seresnet': SEResNet,
'resnest': ResNeSt,
'swin': SwinTransformer
}
def test_mmdet_inference():
from mmcv import Config
config_path = './tests/data/retinanet.py'
config = Config.fromfile(config_path)
for module_name, backbone_config in backbone_configs.items():
config.model.backbone = backbone_config['backbone']
model = build_detector(config.model)
module = module_mapping[module_name]
assert isinstance(model.backbone, module)