mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* make -install -> make install (#621) change `make -install` to `make install` https://github.com/open-mmlab/mmdeploy/issues/618 * [Fix] fix csharp api detector release result (#620) * fix csharp api detector release result * fix wrong count arg of xxx_release_result in c# api * [Enhancement] Support two-stage rotated detector TensorRT. (#530) * upload * add fake_multiclass_nms_rotated * delete unused code * align with pytorch * Update delta_midpointoffset_rbbox_coder.py * add trt rotated roi align * add index feature in nms * not good * fix index * add ut * add benchmark * move to csrc/mmdeploy * update unit test Co-authored-by: zytx121 <592267829@qq.com> * Reduce mmcls version dependency (#635) * fix shufflenetv2 with trt (#645) * fix shufflenetv2 and pspnet * fix ci * remove print * ' -> " (#654) If there is a variable in the string, single quotes will ignored it, while double quotes will bring the variable into the string after parsing * ' -> " (#655) same with https://github.com/open-mmlab/mmdeploy/pull/654 * Support deployment of Segmenter (#587) * support segmentor with ncnn * update regression yml * replace chunk with split to support ts * update regression yml * update docs * fix segmenter ncnn inference failure brought by #477 * add test * fix test for ncnn and trt * fix lint * export nn.linear to Gemm op in onnx for ncnn * fix ci * simplify `Expand` (#617) * Fix typo (#625) * Add make install in en docs * Add make install in zh docs * Fix typo * Merge and add windows build Co-authored-by: tripleMu <865626@163.com> * [Enhancement] Fix ncnn unittest (#626) * optmize-csp-darknet * replace floordiv to torch.div * update csp_darknet default implement * fix test * [Enhancement] TensorRT Anchor generator plugin (#646) * custom trt anchor generator * add ut * add docstring, update doc * Add partition doc and sample code (#599) * update torch2onnx tool to support onnx partition * add model partition of yolov3 * add cn doc * update torch2onnx tool to support onnx partition * add model partition of yolov3 * add cn doc * add to index.rst * resolve comment * resolve comments * fix lint * change caption level in docs * update docs (#624) * Add java apis and demos (#563) * add java classifier detector * add segmentor * fix lint * add ImageRestorer java apis and demo * remove useless count parameter for Segmentor and Restorer, add PoseDetector * add RotatedDetection java api and demo * add Ocr java demo and apis * remove mmrotate ncnn java api and demo * fix lint * sync java api folder after rebase to master * fix include * remove record * fix java apis dir path in cmake * add java demo readme * fix lint mdformat * add test javaapi ci * fix lint * fix flake8 * fix test javaapi ci * refactor readme.md * fix install opencv for ci * fix install opencv : add permission * add all codebases and mmcv install * add torch * install mmdeploy * fix image path * fix picture path * fix import ncnn * fix import ncnn * add submodule of pybind * fix pybind submodule * change download to git clone for submodule * fix ncnn dir * fix README error * simplify the github ci * fix ci * fix yapf * add JNI as required * fix Capitalize * fix Capitalize * fix copyright * ignore .class changed * add OpenJDK installation docs * install target of javaapi * simplify ci * add jar * fix ci * fix ci * fix test java command * debugging what failed * debugging what failed * debugging what failed * add java version info * install openjdk * add java env var * fix export * fix export * fix export * fix export * fix picture path * fix picture path * fix file name * fix file name * fix README * remove java_api strategy * fix python version * format task name * move args position * extract common utils code * show image class result * add detector result * segmentation result format * add ImageRestorer result * add PoseDetection java result format * fix ci * stage ocr * add visualize * move utils * fix lint * fix ocr bugs * fix ci demo * fix java classpath for ci * fix popd * fix ocr demo text garbled * fix ci * fix ci * fix ci * fix path of utils ci * update the circleci config file by adding workflows both for linux, windows and linux-gpu (#368) * update circleci by adding more workflows * fix test workflow failure on windows platform * fix docker exec command for SDK unittests * Fixed tensorrt plugin not found in Windows (#672) * update introduction.png (#674) * [Enhancement] Add fuse select assign pass (#589) * Add fuse select assign pass * move code to csrc * add config flag * remove bool cast * fix export sdk info of input shape (#667) * Update get_started.md (#675) Fix backend model assignment * Update get_started.md (#676) Fix backend model assignment * [Fix] fix clang build (#677) * fix clang build * fix ndk build * fix ndk build * switch to `std::filesystem` for clang-7 and later * Deploy the Swin Transformer on TensorRT. (#652) * resolve conflicts * update ut and docs * fix ut * refine docstring * add comments and refine UT * resolve comments * resolve comments * update doc * add roll export * check backend * update regression test * bump version to 0.6.0 (#680) * bump vertion to 0.6.0 * update version * pass img_metas while exporting to onnx (#681) * pass img_metas while exporting to onnx * remove try-catch in tools for beter debugging * use get * fix typo * [Fix] fix ssd ncnn ut (#692) * fix ssd ncnn ut * fix yapf * fix passing img_metas to pytorch2onnx for mmedit (#700) * fix passing img_metas for mmdet3d (#707) * [Fix] Fix android build (#698) * fix android build * fix cmake * fix url link * fix wrong exit code in pipeline_manager (#715) * fix exit * change to general exit errorcode=1 * fix passing wrong backend type (#719) * Rename onnx2ncnn to mmdeploy_onnx2ncnn (#694) * improvement(tools/onnx2ncnn.py): rename to mmdeploy_onnx2ncnn * format(tools/deploy.py): clean code * fix(init_plugins.py): improve if condition * fix(CI): update target * fix(test_onnx2ncnn.py): update desc * Update init_plugins.py * [Fix] Fix mmdet ort static shape bug (#687) * fix shape * add device * fix yapf * fix rewriter for transforms * reverse image shape * fix ut of distance2bbox * fix rewriter name * fix c4 for torchscript (#724) * [Enhancement] Standardize C API (#634) * unify C API naming * fix demo and move apis/c/* -> apis/c/mmdeploy/* * fix lint * fix C# project * fix Java API * [Enhancement] Support Slide Vertex TRT (#650) * reorgnize mmrotate * fix * add hbb2obb * add ut * fix rotated nms * update docs * update benchmark * update test * remove ort regression test, remove comment * Fix get-started rendering issues in readthedocs (#740) * fix mermaid markdown rendering issue in readthedocs * fix error in C++ example * fix error in c++ example in zh_cn get_started doc * [Fix] set default topk for dump info (#702) * set default topk for dump info * remove redundant docstrings * add ci densenet * fix classification warnings * fix mmcls version * fix logger.warnings * add version control (#754) * fix satrn for ORT (#753) * fix satrn for ORT * move rewrite into pytorch * Add inference latency test tool (#665) * add profile tool * remove print envs in profile tool * set cudnn_benchmark to True * add doc * update tests * fix typo * support test with images from a directory * update doc * resolve comments * [Enhancement] Add CSE ONNX pass (#647) * Add fuse select assign pass * move code to csrc * add config flag * Add fuse select assign pass * Add CSE for ONNX * remove useless code * Test robot Just test robot * Update README.md Revert * [Fix] fix yolox point_generator (#758) * fix yolox point_generator * add a UT * resolve comments * fix comment lines * limit markdown version (#773) * [Enhancement] Better index put ONNX export. (#704) * Add rewriter for tensor setitem * add version check * Upgrade Dockerfile to use TensorRT==8.2.4.2 (#706) * Upgrade TensorRT to 8.2.4.2 * upgrade pytorch&mmcv in CPU Dockerfile * Delete redundant port example in Docker * change 160x160-608x608 to 64x64-608x608 for yolov3 * [Fix] reduce log verbosity & improve error reporting (#755) * reduce log verbosity & improve error reporting * improve error reporting * [Enhancement] Support latest ppl.nn & ppl.cv (#564) * support latest ppl.nn * fix pplnn for model convertor * fix lint * update memory policy * import algo from buffer * update ppl.cv * use `ppl.cv==0.7.0` * document supported ppl.nn version * skip pplnn dependency when building shared libs * [Fix][P0] Fix for torch1.12 (#751) * fix for torch1.12 * add comment * fix check env (#785) * [Fix] fix cascade mask rcnn (#787) * fix cascade mask rcnn * fix lint * add regression * [Feature] Support RoITransRoIHead (#713) * [Feature] Support RoITransRoIHead * Add docs * Add mmrotate models regression test * Add a draft for test code * change the argument name * fix test code * fix minor change for not class agnostic case * fix sample for test code * fix sample for test code * Add mmrotate in requirements * Revert "Add mmrotate in requirements" This reverts commit 043490075e6dbe4a8fb98e94b2b583b91fc5038d. * [Fix] fix triu (#792) * fix triu * triu -> triu_default * [Enhancement] Install Optimizer by setuptools (#690) * Add fuse select assign pass * move code to csrc * add config flag * Add fuse select assign pass * Add CSE for ONNX * remove useless code * Install optimizer by setup tools * fix comment * [Feature] support MMRotate model with le135 (#788) * support MMRotate model with le135 * cse before fuse select assign * remove unused import * [Fix] Support macOS build (#762) * fix macOS build * fix missing * add option to build & install examples (#822) * [Fix] Fix setup on non-linux-x64 (#811) * fix setup * replace long to int64_t * [Feature] support build single sdk library (#806) * build single lib for c api * update csharp doc & project * update test build * fix test build * fix * update document for building android sdk (#817) Co-authored-by: dwSun <dwsunny@icloud.com> * [Enhancement] support kwargs in SDK python bindings (#794) * support-kwargs * make '__call__' as single image inference and add 'batch' API to deal with batch images inference * fix linting error and typo * fix lint * improvement(sdk): add sdk code coverage (#808) * feat(doc): add CI * CI(sdk): add sdk coverage * style(test): code format * fix(CI): update coverage.info path * improvement(CI): use internal image * improvement(CI): push coverage info once * [Feature] Add C++ API for SDK (#831) * add C++ API * unify result type & add examples * minor fix * install cxx API headers * fix Mat, add more examples * fix monolithic build & fix lint * install examples correctly * fix lint * feat(tools/deploy.py): support snpe (#789) * fix(tools/deploy.py): support snpe * improvement(backend/snpe): review advices * docs(backend/snpe): update build * docs(backend/snpe): server support specify port * docs(backend/snpe): update path * fix(backend/snpe): time counter missing argument * docs(backend/snpe): add missing argument * docs(backend/snpe): update download and using * improvement(snpe_net.cpp): load model with modeldata * Support setup on environment with no PyTorch (#843) * support test with multi batch (#829) * support test with multi batch * resolve comment * import algorithm from buffer (#793) * [Enhancement] build sdk python api in standard-alone manner (#810) * build sdk python api in standard-alone manner * enable MMDEPLOY_BUILD_SDK_MONOLITHIC and MMDEPLOY_BUILD_EXAMPLES in prebuild config * link mmdeploy to python target when monolithic option is on * checkin README to describe precompiled package build procedure * use packaging.version.parse(python_version) instead of list(python_version) * fix according to review results * rebase master * rollback cmake.in and apis/python/CMakeLists.txt * reorganize files in install/example * let cmake detect visual studio instead of specifying 2019 * rename whl name of precompiled package * fix according to review results * Fix SDK backend (#844) * fix mmpose python api (#852) * add prebuild package usage docs on windows (#816) * add prebuild package usage docs on windows * fix lint * update * try fix lint * add en docs * update * update * udpate faq * fix typo (#862) * [Enhancement] Improve get_started documents and bump version to 0.7.0 (#813) * simplify commands in get_started * add installation commands for Windows * fix typo * limit markdown and sphinx_markdown_tables version * adopt html <details open> tag * bump mmdeploy version * bump mmdeploy version * update get_started * update get_started * use python3.8 instead of python3.7 * remove duplicate section * resolve issue #856 * update according to review results * add reference to prebuilt_package_windows.md * fix error when build sdk demos * fix mmcls Co-authored-by: Ryan_Huang <44900829+DrRyanHuang@users.noreply.github.com> Co-authored-by: Chen Xin <xinchen.tju@gmail.com> Co-authored-by: q.yao <yaoqian@sensetime.com> Co-authored-by: zytx121 <592267829@qq.com> Co-authored-by: Li Zhang <lzhang329@gmail.com> Co-authored-by: tripleMu <gpu@163.com> Co-authored-by: tripleMu <865626@163.com> Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com> Co-authored-by: lvhan028 <lvhan_028@163.com> Co-authored-by: Bryan Glen Suello <11388006+bgsuello@users.noreply.github.com> Co-authored-by: zambranohally <63218980+zambranohally@users.noreply.github.com> Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com> Co-authored-by: tpoisonooo <khj.application@aliyun.com> Co-authored-by: Hakjin Lee <nijkah@gmail.com> Co-authored-by: 孙德伟 <5899962+dwSun@users.noreply.github.com> Co-authored-by: dwSun <dwsunny@icloud.com> Co-authored-by: Chen Xin <irexyc@gmail.com>
686 lines
23 KiB
Python
686 lines
23 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import os
|
|
import random
|
|
from typing import Dict, List
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from mmdeploy.codebase import import_codebase
|
|
from mmdeploy.utils import Backend, Codebase
|
|
from mmdeploy.utils.config_utils import get_ir_config
|
|
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
|
|
get_rewrite_outputs)
|
|
|
|
try:
|
|
import_codebase(Codebase.MMROTATE)
|
|
except ImportError:
|
|
pytest.skip(
|
|
f'{Codebase.MMROTATE} is not installed.', allow_module_level=True)
|
|
|
|
|
|
def seed_everything(seed=1029):
|
|
random.seed(seed)
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
|
torch.backends.cudnn.benchmark = False
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.enabled = False
|
|
|
|
|
|
def convert_to_list(rewrite_output: Dict, output_names: List[str]) -> List:
|
|
"""Converts output from a dictionary to a list.
|
|
|
|
The new list will contain only those output values, whose names are in list
|
|
'output_names'.
|
|
"""
|
|
outputs = [
|
|
value for name, value in rewrite_output.items() if name in output_names
|
|
]
|
|
return outputs
|
|
|
|
|
|
def get_anchor_head_model():
|
|
"""AnchorHead Config."""
|
|
test_cfg = mmcv.Config(
|
|
dict(
|
|
nms_pre=2000,
|
|
min_bbox_size=0,
|
|
score_thr=0.05,
|
|
nms=dict(iou_thr=0.1),
|
|
max_per_img=2000))
|
|
|
|
from mmrotate.models.dense_heads import RotatedAnchorHead
|
|
model = RotatedAnchorHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
|
|
model.requires_grad_(False)
|
|
|
|
return model
|
|
|
|
|
|
def _replace_r50_with_r18(model):
|
|
"""Replace ResNet50 with ResNet18 in config."""
|
|
model = copy.deepcopy(model)
|
|
if model.backbone.type == 'ResNet':
|
|
model.backbone.depth = 18
|
|
model.backbone.base_channels = 2
|
|
model.neck.in_channels = [2, 4, 8, 16]
|
|
return model
|
|
|
|
|
|
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
|
@pytest.mark.parametrize(
|
|
'model_cfg_path',
|
|
['tests/test_codebase/test_mmrotate/data/single_stage_model.json'])
|
|
def test_forward_of_base_detector(model_cfg_path, backend):
|
|
check_backend(backend)
|
|
deploy_cfg = mmcv.Config(
|
|
dict(
|
|
backend_config=dict(type=backend.value),
|
|
onnx_config=dict(
|
|
output_names=['dets', 'labels'], input_shape=None),
|
|
codebase_config=dict(
|
|
type='mmrotate',
|
|
task='RotatedDetection',
|
|
post_processing=dict(
|
|
score_threshold=0.05,
|
|
iou_threshold=0.5,
|
|
pre_top_k=-1,
|
|
keep_top_k=100,
|
|
))))
|
|
|
|
model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path)))
|
|
model_cfg.model = _replace_r50_with_r18(model_cfg.model)
|
|
|
|
from mmrotate.models import build_detector
|
|
|
|
model_cfg.model.pretrained = None
|
|
model_cfg.model.train_cfg = None
|
|
model = build_detector(model_cfg.model, test_cfg=model_cfg.get('test_cfg'))
|
|
model.cfg = model_cfg
|
|
model.to('cpu')
|
|
|
|
img = torch.randn(1, 3, 64, 64)
|
|
rewrite_inputs = {'img': img}
|
|
rewrite_outputs, _ = get_rewrite_outputs(
|
|
wrapped_model=model,
|
|
model_inputs=rewrite_inputs,
|
|
deploy_cfg=deploy_cfg)
|
|
|
|
assert rewrite_outputs is not None
|
|
|
|
|
|
def get_deploy_cfg(backend_type: Backend, ir_type: str):
|
|
return mmcv.Config(
|
|
dict(
|
|
backend_config=dict(type=backend_type.value),
|
|
onnx_config=dict(
|
|
type=ir_type,
|
|
output_names=['dets', 'labels'],
|
|
input_shape=None),
|
|
codebase_config=dict(
|
|
type='mmrotate',
|
|
task='RotatedDetection',
|
|
post_processing=dict(
|
|
score_threshold=0.05,
|
|
iou_threshold=0.1,
|
|
pre_top_k=2000,
|
|
keep_top_k=2000,
|
|
))))
|
|
|
|
|
|
@pytest.mark.parametrize('backend_type, ir_type',
|
|
[(Backend.ONNXRUNTIME, 'onnx')])
|
|
def test_base_dense_head_get_bboxes(backend_type: Backend, ir_type: str):
|
|
"""Test get_bboxes rewrite of base dense head."""
|
|
check_backend(backend_type)
|
|
anchor_head = get_anchor_head_model()
|
|
anchor_head.cpu().eval()
|
|
s = 128
|
|
img_metas = [{
|
|
'scale_factor': np.ones(4),
|
|
'pad_shape': (s, s, 3),
|
|
'img_shape': (s, s, 3)
|
|
}]
|
|
|
|
deploy_cfg = get_deploy_cfg(backend_type, ir_type)
|
|
output_names = get_ir_config(deploy_cfg).get('output_names', None)
|
|
|
|
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
|
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
|
|
# the bboxes's size: (1, 45, 32, 32), (1, 45, 16, 16),
|
|
# (1, 45, 8, 8), (1, 45, 4, 4), (1, 45, 2, 2)
|
|
seed_everything(1234)
|
|
cls_score = [
|
|
torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
|
]
|
|
seed_everything(5678)
|
|
bboxes = [torch.rand(1, 45, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
|
|
|
|
# to get outputs of pytorch model
|
|
model_inputs = {
|
|
'cls_scores': cls_score,
|
|
'bbox_preds': bboxes,
|
|
'img_metas': img_metas
|
|
}
|
|
model_outputs = get_model_outputs(anchor_head, 'get_bboxes', model_inputs)
|
|
|
|
# to get outputs of onnx model after rewrite
|
|
img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
|
wrapped_model = WrapModel(
|
|
anchor_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
|
|
rewrite_inputs = {
|
|
'cls_scores': cls_score,
|
|
'bbox_preds': bboxes,
|
|
}
|
|
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
|
wrapped_model=wrapped_model,
|
|
model_inputs=rewrite_inputs,
|
|
deploy_cfg=deploy_cfg)
|
|
|
|
if is_backend_output:
|
|
if isinstance(rewrite_outputs, dict):
|
|
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
|
|
for model_output, rewrite_output in zip(model_outputs[0],
|
|
rewrite_outputs):
|
|
model_output = model_output.squeeze().cpu().numpy()
|
|
rewrite_output = rewrite_output.squeeze()
|
|
# hard code to make two tensors with the same shape
|
|
# rewrite and original codes applied different nms strategy
|
|
assert np.allclose(
|
|
model_output[:rewrite_output.shape[0]][:2],
|
|
rewrite_output[:2],
|
|
rtol=1e-03,
|
|
atol=1e-05)
|
|
else:
|
|
assert rewrite_outputs is not None
|
|
|
|
|
|
def get_single_roi_extractor():
|
|
"""SingleRoIExtractor Config."""
|
|
from mmrotate.models.roi_heads import RotatedSingleRoIExtractor
|
|
roi_layer = dict(
|
|
type='RoIAlignRotated', out_size=7, sample_num=2, clockwise=True)
|
|
out_channels = 1
|
|
featmap_strides = [4, 8, 16, 32]
|
|
model = RotatedSingleRoIExtractor(roi_layer, out_channels,
|
|
featmap_strides).eval()
|
|
|
|
return model
|
|
|
|
|
|
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
|
def test_rotated_single_roi_extractor(backend_type: Backend):
|
|
check_backend(backend_type)
|
|
|
|
single_roi_extractor = get_single_roi_extractor()
|
|
output_names = ['roi_feat']
|
|
deploy_cfg = mmcv.Config(
|
|
dict(
|
|
backend_config=dict(type=backend_type.value),
|
|
onnx_config=dict(output_names=output_names, input_shape=None),
|
|
codebase_config=dict(
|
|
type='mmrotate',
|
|
task='RotatedDetection',
|
|
)))
|
|
|
|
seed_everything(1234)
|
|
out_channels = single_roi_extractor.out_channels
|
|
feats = [
|
|
torch.rand((1, out_channels, 200, 336)),
|
|
torch.rand((1, out_channels, 100, 168)),
|
|
torch.rand((1, out_channels, 50, 84)),
|
|
torch.rand((1, out_channels, 25, 42)),
|
|
]
|
|
seed_everything(5678)
|
|
rois = torch.tensor(
|
|
[[0.0000, 587.8285, 52.1405, 886.2484, 341.5644, 0.0000]])
|
|
|
|
model_inputs = {
|
|
'feats': feats,
|
|
'rois': rois,
|
|
}
|
|
model_outputs = get_model_outputs(single_roi_extractor, 'forward',
|
|
model_inputs)
|
|
|
|
backend_outputs, _ = get_rewrite_outputs(
|
|
wrapped_model=single_roi_extractor,
|
|
model_inputs=model_inputs,
|
|
deploy_cfg=deploy_cfg)
|
|
if isinstance(backend_outputs, dict):
|
|
backend_outputs = backend_outputs.values()
|
|
for model_output, backend_output in zip(model_outputs[0], backend_outputs):
|
|
model_output = model_output.squeeze().cpu().numpy()
|
|
backend_output = backend_output.squeeze()
|
|
assert np.allclose(
|
|
model_output, backend_output, rtol=1e-03, atol=1e-05)
|
|
|
|
|
|
def get_oriented_rpn_head_model():
|
|
"""Oriented RPN Head Config."""
|
|
test_cfg = mmcv.Config(
|
|
dict(
|
|
nms_pre=2000,
|
|
min_bbox_size=0,
|
|
score_thr=0.05,
|
|
nms=dict(iou_thr=0.1),
|
|
max_per_img=2000))
|
|
from mmrotate.models.dense_heads import OrientedRPNHead
|
|
model = OrientedRPNHead(
|
|
in_channels=1,
|
|
version='le90',
|
|
bbox_coder=dict(type='MidpointOffsetCoder', angle_range='le90'),
|
|
test_cfg=test_cfg)
|
|
|
|
model.requires_grad_(False)
|
|
return model
|
|
|
|
|
|
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
|
def test_get_bboxes_of_oriented_rpn_head(backend_type: Backend):
|
|
check_backend(backend_type)
|
|
head = get_oriented_rpn_head_model()
|
|
head.cpu().eval()
|
|
s = 128
|
|
img_metas = [{
|
|
'scale_factor': np.ones(4),
|
|
'pad_shape': (s, s, 3),
|
|
'img_shape': (s, s, 3)
|
|
}]
|
|
|
|
output_names = ['dets', 'labels']
|
|
deploy_cfg = mmcv.Config(
|
|
dict(
|
|
backend_config=dict(type=backend_type.value),
|
|
onnx_config=dict(output_names=output_names, input_shape=None),
|
|
codebase_config=dict(
|
|
type='mmrotate',
|
|
task='RotatedDetection',
|
|
post_processing=dict(
|
|
score_threshold=0.05,
|
|
iou_threshold=0.1,
|
|
pre_top_k=2000,
|
|
keep_top_k=2000))))
|
|
|
|
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
|
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
|
|
# the bboxes's size: (1, 54, 32, 32), (1, 54, 16, 16),
|
|
# (1, 54, 8, 8), (1, 54, 4, 4), (1, 54, 2, 2)
|
|
seed_everything(1234)
|
|
cls_score = [
|
|
torch.rand(1, 9, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
|
]
|
|
seed_everything(5678)
|
|
bboxes = [torch.rand(1, 54, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
|
|
|
|
# to get outputs of onnx model after rewrite
|
|
img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
|
wrapped_model = WrapModel(
|
|
head, 'get_bboxes', img_metas=img_metas, with_nms=True)
|
|
rewrite_inputs = {
|
|
'cls_scores': cls_score,
|
|
'bbox_preds': bboxes,
|
|
}
|
|
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
|
wrapped_model=wrapped_model,
|
|
model_inputs=rewrite_inputs,
|
|
deploy_cfg=deploy_cfg)
|
|
assert rewrite_outputs is not None
|
|
|
|
|
|
def get_rotated_rpn_head_model():
|
|
"""Oriented RPN Head Config."""
|
|
test_cfg = mmcv.Config(
|
|
dict(
|
|
nms_pre=2000,
|
|
min_bbox_size=0,
|
|
score_thr=0.05,
|
|
nms=dict(iou_thr=0.1),
|
|
max_per_img=2000))
|
|
from mmrotate.models.dense_heads import RotatedRPNHead
|
|
model = RotatedRPNHead(
|
|
version='le90',
|
|
in_channels=256,
|
|
feat_channels=256,
|
|
anchor_generator=dict(
|
|
type='AnchorGenerator',
|
|
scales=[8],
|
|
ratios=[0.5, 1.0, 2.0],
|
|
strides=[4, 8, 16, 32, 64]),
|
|
bbox_coder=dict(
|
|
type='DeltaXYWHBBoxCoder',
|
|
target_means=[0.0, 0.0, 0.0, 0.0],
|
|
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
|
test_cfg=test_cfg)
|
|
|
|
model.requires_grad_(False)
|
|
return model
|
|
|
|
|
|
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
|
def test_get_bboxes_of_rotated_rpn_head(backend_type: Backend):
|
|
check_backend(backend_type)
|
|
head = get_rotated_rpn_head_model()
|
|
head.cpu().eval()
|
|
s = 128
|
|
img_metas = [{
|
|
'scale_factor': np.ones(4),
|
|
'pad_shape': (s, s, 3),
|
|
'img_shape': (s, s, 3)
|
|
}]
|
|
|
|
output_names = ['dets', 'labels']
|
|
deploy_cfg = mmcv.Config(
|
|
dict(
|
|
backend_config=dict(type=backend_type.value),
|
|
onnx_config=dict(output_names=output_names, input_shape=None),
|
|
codebase_config=dict(
|
|
type='mmrotate',
|
|
task='RotatedDetection',
|
|
post_processing=dict(
|
|
score_threshold=0.05,
|
|
iou_threshold=0.1,
|
|
pre_top_k=2000,
|
|
keep_top_k=2000))))
|
|
|
|
# the cls_score's size: (1, 3, 32, 32), (1, 3, 16, 16),
|
|
# (1, 3, 8, 8), (1, 3, 4, 4), (1, 3, 2, 2).
|
|
# the bboxes's size: (1, 18, 32, 32), (1, 18, 16, 16),
|
|
# (1, 18, 8, 8), (1, 18, 4, 4), (1, 18, 2, 2)
|
|
seed_everything(1234)
|
|
cls_score = [
|
|
torch.rand(1, 3, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
|
]
|
|
seed_everything(5678)
|
|
bboxes = [torch.rand(1, 18, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
|
|
|
|
# to get outputs of onnx model after rewrite
|
|
img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
|
wrapped_model = WrapModel(
|
|
head, 'get_bboxes', img_metas=img_metas, with_nms=True)
|
|
rewrite_inputs = {
|
|
'cls_scores': cls_score,
|
|
'bbox_preds': bboxes,
|
|
}
|
|
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
|
wrapped_model=wrapped_model,
|
|
model_inputs=rewrite_inputs,
|
|
deploy_cfg=deploy_cfg)
|
|
assert rewrite_outputs is not None
|
|
|
|
|
|
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
|
def test_rotate_standard_roi_head__simple_test(backend_type: Backend):
|
|
check_backend(backend_type)
|
|
from mmrotate.models.roi_heads import OrientedStandardRoIHead
|
|
output_names = ['dets', 'labels']
|
|
deploy_cfg = mmcv.Config(
|
|
dict(
|
|
backend_config=dict(type=backend_type.value),
|
|
onnx_config=dict(output_names=output_names, input_shape=None),
|
|
codebase_config=dict(
|
|
type='mmrotate',
|
|
task='RotatedDetection',
|
|
post_processing=dict(
|
|
score_threshold=0.05,
|
|
iou_threshold=0.1,
|
|
pre_top_k=2000,
|
|
keep_top_k=2000))))
|
|
angle_version = 'le90'
|
|
test_cfg = mmcv.Config(
|
|
dict(
|
|
nms_pre=2000,
|
|
min_bbox_size=0,
|
|
score_thr=0.05,
|
|
nms=dict(iou_thr=0.1),
|
|
max_per_img=2000))
|
|
head = OrientedStandardRoIHead(
|
|
bbox_roi_extractor=dict(
|
|
type='RotatedSingleRoIExtractor',
|
|
roi_layer=dict(
|
|
type='RoIAlignRotated',
|
|
out_size=7,
|
|
sample_num=2,
|
|
clockwise=True),
|
|
out_channels=3,
|
|
featmap_strides=[4, 8, 16, 32]),
|
|
bbox_head=dict(
|
|
type='RotatedShared2FCBBoxHead',
|
|
in_channels=3,
|
|
fc_out_channels=1024,
|
|
roi_feat_size=7,
|
|
num_classes=15,
|
|
bbox_coder=dict(
|
|
type='DeltaXYWHAOBBoxCoder',
|
|
angle_range=angle_version,
|
|
norm_factor=None,
|
|
edge_swap=True,
|
|
proj_xy=True,
|
|
target_means=(.0, .0, .0, .0, .0),
|
|
target_stds=(0.1, 0.1, 0.2, 0.2, 0.1)),
|
|
reg_class_agnostic=True),
|
|
test_cfg=test_cfg)
|
|
head.cpu().eval()
|
|
|
|
seed_everything(1234)
|
|
x = [torch.rand(1, 3, pow(2, i), pow(2, i)) for i in range(4, 0, -1)]
|
|
proposals = [torch.rand(1, 100, 6), torch.randint(0, 10, (1, 100))]
|
|
img_metas = [{'img_shape': torch.tensor([224, 224])}]
|
|
|
|
wrapped_model = WrapModel(
|
|
head, 'simple_test', proposals=proposals, img_metas=img_metas)
|
|
rewrite_inputs = {'x': x}
|
|
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
|
wrapped_model=wrapped_model,
|
|
model_inputs=rewrite_inputs,
|
|
deploy_cfg=deploy_cfg)
|
|
assert rewrite_outputs is not None
|
|
|
|
|
|
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
|
def test_gv_ratio_roi_head__simple_test(backend_type: Backend):
|
|
check_backend(backend_type)
|
|
from mmrotate.models.roi_heads import GVRatioRoIHead
|
|
output_names = ['dets', 'labels']
|
|
deploy_cfg = mmcv.Config(
|
|
dict(
|
|
backend_config=dict(type=backend_type.value),
|
|
onnx_config=dict(output_names=output_names, input_shape=None),
|
|
codebase_config=dict(
|
|
type='mmrotate',
|
|
task='RotatedDetection',
|
|
post_processing=dict(
|
|
score_threshold=0.05,
|
|
iou_threshold=0.1,
|
|
pre_top_k=2000,
|
|
keep_top_k=2000,
|
|
max_output_boxes_per_class=1000))))
|
|
angle_version = 'le90'
|
|
test_cfg = mmcv.Config(
|
|
dict(
|
|
nms_pre=2000,
|
|
min_bbox_size=0,
|
|
score_thr=0.05,
|
|
nms=dict(iou_thr=0.1),
|
|
max_per_img=2000))
|
|
head = GVRatioRoIHead(
|
|
version=angle_version,
|
|
bbox_roi_extractor=dict(
|
|
type='SingleRoIExtractor',
|
|
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
|
out_channels=3,
|
|
featmap_strides=[4, 8, 16, 32]),
|
|
bbox_head=dict(
|
|
type='GVBBoxHead',
|
|
version=angle_version,
|
|
num_shared_fcs=2,
|
|
in_channels=3,
|
|
fc_out_channels=1024,
|
|
roi_feat_size=7,
|
|
num_classes=15,
|
|
ratio_thr=0.8,
|
|
bbox_coder=dict(
|
|
type='DeltaXYWHBBoxCoder',
|
|
target_means=(.0, .0, .0, .0),
|
|
target_stds=(0.1, 0.1, 0.2, 0.2)),
|
|
fix_coder=dict(type='GVFixCoder', angle_range=angle_version),
|
|
ratio_coder=dict(type='GVRatioCoder', angle_range=angle_version),
|
|
reg_class_agnostic=True),
|
|
test_cfg=test_cfg)
|
|
head.cpu().eval()
|
|
|
|
seed_everything(1234)
|
|
x = [torch.rand(1, 3, pow(2, i), pow(2, i)) for i in range(4, 0, -1)]
|
|
bboxes = torch.rand(1, 100, 2)
|
|
bboxes = torch.cat(
|
|
[bboxes, bboxes + torch.rand(1, 100, 2) + torch.rand(1, 100, 1)],
|
|
dim=-1)
|
|
proposals = [bboxes, torch.randint(0, 10, (1, 100))]
|
|
img_metas = [{'img_shape': torch.tensor([224, 224])}]
|
|
|
|
wrapped_model = WrapModel(
|
|
head, 'simple_test', proposals=proposals, img_metas=img_metas)
|
|
rewrite_inputs = {'x': x}
|
|
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
|
wrapped_model=wrapped_model,
|
|
model_inputs=rewrite_inputs,
|
|
deploy_cfg=deploy_cfg)
|
|
assert rewrite_outputs is not None
|
|
|
|
|
|
def get_roi_trans_roi_head_model():
|
|
"""Oriented RPN Head Config."""
|
|
angle_version = 'le90'
|
|
|
|
num_stages = 2
|
|
stage_loss_weights = [1, 1]
|
|
version = angle_version
|
|
bbox_roi_extractor = [
|
|
dict(
|
|
type='SingleRoIExtractor',
|
|
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
|
out_channels=64,
|
|
featmap_strides=[4, 8, 16, 32]),
|
|
dict(
|
|
type='RotatedSingleRoIExtractor',
|
|
roi_layer=dict(
|
|
type='RoIAlignRotated',
|
|
out_size=7,
|
|
sample_num=2,
|
|
clockwise=True),
|
|
out_channels=64,
|
|
featmap_strides=[4, 8, 16, 32]),
|
|
]
|
|
|
|
bbox_head = [
|
|
dict(
|
|
type='RotatedShared2FCBBoxHead',
|
|
in_channels=64,
|
|
fc_out_channels=1024,
|
|
roi_feat_size=7,
|
|
num_classes=15,
|
|
bbox_coder=dict(
|
|
type='DeltaXYWHAHBBoxCoder',
|
|
angle_range=angle_version,
|
|
norm_factor=2,
|
|
edge_swap=True,
|
|
target_means=[0., 0., 0., 0., 0.],
|
|
target_stds=[0.1, 0.1, 0.2, 0.2, 1]),
|
|
reg_class_agnostic=True,
|
|
loss_cls=dict(
|
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
|
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
|
|
dict(
|
|
type='RotatedShared2FCBBoxHead',
|
|
in_channels=64,
|
|
fc_out_channels=1024,
|
|
roi_feat_size=7,
|
|
num_classes=15,
|
|
bbox_coder=dict(
|
|
type='DeltaXYWHAOBBoxCoder',
|
|
angle_range=angle_version,
|
|
norm_factor=None,
|
|
edge_swap=True,
|
|
proj_xy=True,
|
|
target_means=[0., 0., 0., 0., 0.],
|
|
target_stds=[0.05, 0.05, 0.1, 0.1, 0.5]),
|
|
reg_class_agnostic=False,
|
|
loss_cls=dict(
|
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
|
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
|
|
]
|
|
test_cfg = mmcv.Config(
|
|
dict(
|
|
nms_pre=2000,
|
|
min_bbox_size=0,
|
|
score_thr=0.05,
|
|
nms=dict(iou_thr=0.1),
|
|
max_per_img=2000))
|
|
|
|
args = [num_stages, stage_loss_weights, bbox_roi_extractor, bbox_head]
|
|
kwargs = {'version': version, 'test_cfg': test_cfg}
|
|
|
|
from mmrotate.models.roi_heads import RoITransRoIHead
|
|
model = RoITransRoIHead(*args, **kwargs).eval()
|
|
return model
|
|
|
|
|
|
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
|
def test_simple_test_of_roi_trans_roi_head(backend_type: Backend):
|
|
check_backend(backend_type)
|
|
|
|
roi_head = get_roi_trans_roi_head_model()
|
|
roi_head.cpu()
|
|
|
|
seed_everything(1234)
|
|
x = [
|
|
torch.rand((1, 64, 32, 32)),
|
|
torch.rand((1, 64, 16, 16)),
|
|
torch.rand((1, 64, 8, 8)),
|
|
torch.rand((1, 64, 4, 4)),
|
|
]
|
|
proposals = torch.tensor([[[58.8285, 52.1405, 188.2484, 141.5644, 0.5]]])
|
|
labels = torch.tensor([[[0.]]])
|
|
s = 256
|
|
img_metas = [{
|
|
'img_shape': torch.tensor([s, s]),
|
|
'ori_shape': torch.tensor([s, s]),
|
|
'scale_factor': torch.tensor([1, 1, 1, 1])
|
|
}]
|
|
|
|
model_inputs = {
|
|
'x': x,
|
|
}
|
|
|
|
output_names = ['det_bboxes', 'det_labels']
|
|
deploy_cfg = mmcv.Config(
|
|
dict(
|
|
backend_config=dict(type=backend_type.value),
|
|
onnx_config=dict(output_names=output_names, input_shape=None),
|
|
codebase_config=dict(
|
|
type='mmrotate',
|
|
task='RotatedDetection',
|
|
post_processing=dict(
|
|
score_threshold=0.05,
|
|
iou_threshold=0.1,
|
|
pre_top_k=2000,
|
|
keep_top_k=2000))))
|
|
|
|
wrapped_model = WrapModel(
|
|
roi_head,
|
|
'simple_test',
|
|
proposal_list=[proposals, labels],
|
|
img_metas=img_metas)
|
|
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
|
wrapped_model=wrapped_model,
|
|
model_inputs=model_inputs,
|
|
deploy_cfg=deploy_cfg)
|
|
|
|
assert rewrite_outputs is not None
|