mmdeploy/tests/test_codebase/test_mmrotate/test_mmrotate_models.py

204 lines
6.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import random
from typing import Dict, List
import mmcv
import numpy as np
import pytest
import torch
from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Backend, Codebase
from mmdeploy.utils.config_utils import get_ir_config
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
get_rewrite_outputs)
try:
import_codebase(Codebase.MMROTATE)
except ImportError:
pytest.skip(
f'{Codebase.MMROTATE} is not installed.', allow_module_level=True)
def seed_everything(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
def convert_to_list(rewrite_output: Dict, output_names: List[str]) -> List:
"""Converts output from a dictionary to a list.
The new list will contain only those output values, whose names are in list
'output_names'.
"""
outputs = [
value for name, value in rewrite_output.items() if name in output_names
]
return outputs
def get_anchor_head_model():
"""AnchorHead Config."""
test_cfg = mmcv.Config(
dict(
nms_pre=2000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(iou_thr=0.1),
max_per_img=2000))
from mmrotate.models.dense_heads import RotatedAnchorHead
model = RotatedAnchorHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
model.requires_grad_(False)
return model
def _replace_r50_with_r18(model):
"""Replace ResNet50 with ResNet18 in config."""
model = copy.deepcopy(model)
if model.backbone.type == 'ResNet':
model.backbone.depth = 18
model.backbone.base_channels = 2
model.neck.in_channels = [2, 4, 8, 16]
return model
# @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
# @pytest.mark.parametrize('model_cfg_path', [
# 'tests/test_codebase/test_mmrotate/data/single_stage_model.json'
# ])
# def test_forward_of_base_detector(model_cfg_path, backend):
# check_backend(backend)
# deploy_cfg = mmcv.Config(
# dict(
# backend_config=dict(type=backend.value),
# onnx_config=dict(
# output_names=['dets', 'labels'], input_shape=None),
# codebase_config=dict(
# type='mmrotate',
# task='RotatedDetection',
# post_processing=dict(
# score_threshold=0.05,
# iou_threshold=0.5,
# pre_top_k=-1,
# keep_top_k=100,
# ))))
# model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path)))
# model_cfg.model = _replace_r50_with_r18(model_cfg.model)
# from mmrotate.models import build_detector
# model_cfg.model.pretrained = None
# model_cfg.model.train_cfg = None
# model = build_detector(
# model_cfg.model, test_cfg= model_cfg.get('test_cfg'))
# model.cfg = model_cfg
# model.to('cpu')
# img = torch.randn(1, 3, 64, 64)
# rewrite_inputs = {'img': img}
# rewrite_outputs, _ = get_rewrite_outputs(
# wrapped_model=model,
# model_inputs=rewrite_inputs,
# deploy_cfg=deploy_cfg)
# assert rewrite_outputs is not None
def get_deploy_cfg(backend_type: Backend, ir_type: str):
return mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(
type=ir_type,
output_names=['dets', 'labels'],
input_shape=None),
codebase_config=dict(
type='mmrotate',
task='RotatedDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.1,
pre_top_k=2000,
keep_top_k=2000,
))))
@pytest.mark.parametrize('backend_type, ir_type',
[(Backend.ONNXRUNTIME, 'onnx')])
def test_base_dense_head_get_bboxes(backend_type: Backend, ir_type: str):
"""Test get_bboxes rewrite of base dense head."""
check_backend(backend_type)
anchor_head = get_anchor_head_model()
anchor_head.cpu().eval()
s = 128
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
deploy_cfg = get_deploy_cfg(backend_type, ir_type)
output_names = get_ir_config(deploy_cfg).get('output_names', None)
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
seed_everything(1234)
cls_score = [
torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
seed_everything(5678)
bboxes = [torch.rand(1, 45, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
# to get outputs of pytorch model
model_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'img_metas': img_metas
}
model_outputs = get_model_outputs(anchor_head, 'get_bboxes', model_inputs)
# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
anchor_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
if isinstance(rewrite_outputs, dict):
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
# hard code to make two tensors with the same shape
# rewrite and original codes applied different nms strategy
assert np.allclose(
model_output[:rewrite_output.shape[0]][:2],
rewrite_output[:2],
rtol=1e-03,
atol=1e-05)
else:
assert rewrite_outputs is not None