61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
|
from mmdet.models import build_detector
|
||
|
|
||
|
from mmcls.models import (MobileNetV2, MobileNetV3, RegNet, ResNeSt, ResNet,
|
||
|
ResNeXt, SEResNet, SEResNeXt, SwinTransformer)
|
||
|
|
||
|
backbone_configs = dict(
|
||
|
mobilenetv2=dict(
|
||
|
backbone=dict(
|
||
|
type='mmcls.MobileNetV2',
|
||
|
widen_factor=1.0,
|
||
|
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True),
|
||
|
out_indices=(4, 7))),
|
||
|
mobilenetv3=dict(
|
||
|
backbone=dict(
|
||
|
type='mmcls.MobileNetV3',
|
||
|
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True),
|
||
|
out_indices=range(7, 12))),
|
||
|
regnet=dict(backbone=dict(type='mmcls.RegNet', arch='regnetx_400mf')),
|
||
|
resnext=dict(
|
||
|
backbone=dict(
|
||
|
type='mmcls.ResNeXt', depth=50, groups=32, width_per_group=4)),
|
||
|
resnet=dict(backbone=dict(type='mmcls.ResNet', depth=50)),
|
||
|
seresnet=dict(backbone=dict(type='mmcls.SEResNet', depth=50)),
|
||
|
seresnext=dict(
|
||
|
backbone=dict(
|
||
|
type='mmcls.SEResNeXt', depth=50, groups=32, width_per_group=4)),
|
||
|
resnest=dict(
|
||
|
backbone=dict(
|
||
|
type='mmcls.ResNeSt',
|
||
|
depth=50,
|
||
|
radix=2,
|
||
|
reduction_factor=4,
|
||
|
out_indices=(0, 1, 2, 3))),
|
||
|
swin=dict(
|
||
|
backbone=dict(
|
||
|
type='mmcls.SwinTransformer', arch='small', drop_path_rate=0.2)))
|
||
|
|
||
|
module_mapping = {
|
||
|
'mobilenetv2': MobileNetV2,
|
||
|
'mobilenetv3': MobileNetV3,
|
||
|
'regnet': RegNet,
|
||
|
'resnext': ResNeXt,
|
||
|
'resnet': ResNet,
|
||
|
'seresnext': SEResNeXt,
|
||
|
'seresnet': SEResNet,
|
||
|
'resnest': ResNeSt,
|
||
|
'swin': SwinTransformer
|
||
|
}
|
||
|
|
||
|
|
||
|
def test_mmdet_inference():
|
||
|
from mmcv import Config
|
||
|
config_path = './tests/data/retinanet.py'
|
||
|
config = Config.fromfile(config_path)
|
||
|
|
||
|
for module_name, backbone_config in backbone_configs.items():
|
||
|
config.model.backbone = backbone_config['backbone']
|
||
|
model = build_detector(config.model)
|
||
|
module = module_mapping[module_name]
|
||
|
assert isinstance(model.backbone, module)
|