204 lines
6.6 KiB
Python
204 lines
6.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import os
|
|
import random
|
|
from typing import Dict, List
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from mmdeploy.codebase import import_codebase
|
|
from mmdeploy.utils import Backend, Codebase
|
|
from mmdeploy.utils.config_utils import get_ir_config
|
|
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
|
|
get_rewrite_outputs)
|
|
|
|
try:
|
|
import_codebase(Codebase.MMROTATE)
|
|
except ImportError:
|
|
pytest.skip(
|
|
f'{Codebase.MMROTATE} is not installed.', allow_module_level=True)
|
|
|
|
|
|
def seed_everything(seed=1029):
|
|
random.seed(seed)
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
|
torch.backends.cudnn.benchmark = False
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.enabled = False
|
|
|
|
|
|
def convert_to_list(rewrite_output: Dict, output_names: List[str]) -> List:
|
|
"""Converts output from a dictionary to a list.
|
|
|
|
The new list will contain only those output values, whose names are in list
|
|
'output_names'.
|
|
"""
|
|
outputs = [
|
|
value for name, value in rewrite_output.items() if name in output_names
|
|
]
|
|
return outputs
|
|
|
|
|
|
def get_anchor_head_model():
|
|
"""AnchorHead Config."""
|
|
test_cfg = mmcv.Config(
|
|
dict(
|
|
nms_pre=2000,
|
|
min_bbox_size=0,
|
|
score_thr=0.05,
|
|
nms=dict(iou_thr=0.1),
|
|
max_per_img=2000))
|
|
|
|
from mmrotate.models.dense_heads import RotatedAnchorHead
|
|
model = RotatedAnchorHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
|
|
model.requires_grad_(False)
|
|
|
|
return model
|
|
|
|
|
|
def _replace_r50_with_r18(model):
|
|
"""Replace ResNet50 with ResNet18 in config."""
|
|
model = copy.deepcopy(model)
|
|
if model.backbone.type == 'ResNet':
|
|
model.backbone.depth = 18
|
|
model.backbone.base_channels = 2
|
|
model.neck.in_channels = [2, 4, 8, 16]
|
|
return model
|
|
|
|
|
|
# @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
|
# @pytest.mark.parametrize('model_cfg_path', [
|
|
# 'tests/test_codebase/test_mmrotate/data/single_stage_model.json'
|
|
# ])
|
|
# def test_forward_of_base_detector(model_cfg_path, backend):
|
|
# check_backend(backend)
|
|
# deploy_cfg = mmcv.Config(
|
|
# dict(
|
|
# backend_config=dict(type=backend.value),
|
|
# onnx_config=dict(
|
|
# output_names=['dets', 'labels'], input_shape=None),
|
|
# codebase_config=dict(
|
|
# type='mmrotate',
|
|
# task='RotatedDetection',
|
|
# post_processing=dict(
|
|
# score_threshold=0.05,
|
|
# iou_threshold=0.5,
|
|
# pre_top_k=-1,
|
|
# keep_top_k=100,
|
|
# ))))
|
|
|
|
# model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path)))
|
|
# model_cfg.model = _replace_r50_with_r18(model_cfg.model)
|
|
|
|
# from mmrotate.models import build_detector
|
|
|
|
# model_cfg.model.pretrained = None
|
|
# model_cfg.model.train_cfg = None
|
|
# model = build_detector(
|
|
# model_cfg.model, test_cfg= model_cfg.get('test_cfg'))
|
|
# model.cfg = model_cfg
|
|
# model.to('cpu')
|
|
|
|
# img = torch.randn(1, 3, 64, 64)
|
|
# rewrite_inputs = {'img': img}
|
|
# rewrite_outputs, _ = get_rewrite_outputs(
|
|
# wrapped_model=model,
|
|
# model_inputs=rewrite_inputs,
|
|
# deploy_cfg=deploy_cfg)
|
|
|
|
# assert rewrite_outputs is not None
|
|
|
|
|
|
def get_deploy_cfg(backend_type: Backend, ir_type: str):
|
|
return mmcv.Config(
|
|
dict(
|
|
backend_config=dict(type=backend_type.value),
|
|
onnx_config=dict(
|
|
type=ir_type,
|
|
output_names=['dets', 'labels'],
|
|
input_shape=None),
|
|
codebase_config=dict(
|
|
type='mmrotate',
|
|
task='RotatedDetection',
|
|
post_processing=dict(
|
|
score_threshold=0.05,
|
|
iou_threshold=0.1,
|
|
pre_top_k=2000,
|
|
keep_top_k=2000,
|
|
))))
|
|
|
|
|
|
@pytest.mark.parametrize('backend_type, ir_type',
|
|
[(Backend.ONNXRUNTIME, 'onnx')])
|
|
def test_base_dense_head_get_bboxes(backend_type: Backend, ir_type: str):
|
|
"""Test get_bboxes rewrite of base dense head."""
|
|
check_backend(backend_type)
|
|
anchor_head = get_anchor_head_model()
|
|
anchor_head.cpu().eval()
|
|
s = 128
|
|
img_metas = [{
|
|
'scale_factor': np.ones(4),
|
|
'pad_shape': (s, s, 3),
|
|
'img_shape': (s, s, 3)
|
|
}]
|
|
|
|
deploy_cfg = get_deploy_cfg(backend_type, ir_type)
|
|
output_names = get_ir_config(deploy_cfg).get('output_names', None)
|
|
|
|
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
|
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
|
|
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
|
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
|
|
seed_everything(1234)
|
|
cls_score = [
|
|
torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
|
]
|
|
seed_everything(5678)
|
|
bboxes = [torch.rand(1, 45, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
|
|
|
|
# to get outputs of pytorch model
|
|
model_inputs = {
|
|
'cls_scores': cls_score,
|
|
'bbox_preds': bboxes,
|
|
'img_metas': img_metas
|
|
}
|
|
model_outputs = get_model_outputs(anchor_head, 'get_bboxes', model_inputs)
|
|
|
|
# to get outputs of onnx model after rewrite
|
|
img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
|
wrapped_model = WrapModel(
|
|
anchor_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
|
|
rewrite_inputs = {
|
|
'cls_scores': cls_score,
|
|
'bbox_preds': bboxes,
|
|
}
|
|
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
|
wrapped_model=wrapped_model,
|
|
model_inputs=rewrite_inputs,
|
|
deploy_cfg=deploy_cfg)
|
|
|
|
if is_backend_output:
|
|
if isinstance(rewrite_outputs, dict):
|
|
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
|
|
for model_output, rewrite_output in zip(model_outputs[0],
|
|
rewrite_outputs):
|
|
model_output = model_output.squeeze().cpu().numpy()
|
|
rewrite_output = rewrite_output.squeeze()
|
|
# hard code to make two tensors with the same shape
|
|
# rewrite and original codes applied different nms strategy
|
|
assert np.allclose(
|
|
model_output[:rewrite_output.shape[0]][:2],
|
|
rewrite_output[:2],
|
|
rtol=1e-03,
|
|
atol=1e-05)
|
|
else:
|
|
assert rewrite_outputs is not None
|