mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Enhance]: Add more docstring. (#111)
* add docstring for apis * add simple docstring for mmdet * add simple docstring for mmseg * add simple docstring for mmcls * add simple docstring for mmedit * add simple docstring for mmocr * add simple docstring for rewriting * update thresh for docstring coverage * update * update docstring * solve comments * remove unrelated symbol
This commit is contained in:
parent
4587322441
commit
de9498a8f2
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@ -26,7 +26,7 @@ jobs:
|
||||
- name: Check docstring coverage
|
||||
run: |
|
||||
pip install interrogate
|
||||
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 30 mmdeploy
|
||||
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 80 mmdeploy
|
||||
|
||||
build_cpu:
|
||||
runs-on: ubuntu-18.04
|
||||
|
@ -60,6 +60,7 @@ class HDF5Calibrator(trt.IInt8Calibrator):
|
||||
self.calib_file.close()
|
||||
|
||||
def get_batch(self, names: Sequence[str], **kwargs):
|
||||
"""Get batch data."""
|
||||
if self.count < self.dataset_length:
|
||||
|
||||
ret = []
|
||||
@ -95,13 +96,33 @@ class HDF5Calibrator(trt.IInt8Calibrator):
|
||||
return None
|
||||
|
||||
def get_algorithm(self):
|
||||
"""Get Calibration algo type.
|
||||
|
||||
Returns:
|
||||
trt.CalibrationAlgoType: Calibration algo type.
|
||||
"""
|
||||
return self.algorithm
|
||||
|
||||
def get_batch_size(self):
|
||||
"""Get batch size.
|
||||
|
||||
Returns:
|
||||
int: An integer represents batch size.
|
||||
"""
|
||||
return self.batch_size
|
||||
|
||||
def read_calibration_cache(self, *args, **kwargs):
|
||||
"""Read calibration cache.
|
||||
|
||||
Notes:
|
||||
No need to implement this function.
|
||||
"""
|
||||
pass
|
||||
|
||||
def write_calibration_cache(self, cache, *args, **kwargs):
|
||||
"""Write calibration cache.
|
||||
|
||||
Notes:
|
||||
No need to implement this function.
|
||||
"""
|
||||
pass
|
||||
|
@ -215,7 +215,7 @@ class TRTWrapper(torch.nn.Module):
|
||||
self._load_io_names()
|
||||
|
||||
def _load_io_names(self):
|
||||
# get input and output names from engine
|
||||
"""Load input/output names from engine."""
|
||||
names = [_ for _ in self.engine]
|
||||
input_names = list(filter(self.engine.binding_is_input, names))
|
||||
output_names = list(set(names) - set(input_names))
|
||||
|
@ -5,6 +5,7 @@ from packaging import version
|
||||
|
||||
|
||||
def parse_extractor_io_string(io_str):
|
||||
"""Parse IO string for extractor."""
|
||||
name, io_type = io_str.split(':')
|
||||
assert io_type in ['input', 'output']
|
||||
func_id = 0
|
||||
@ -17,8 +18,9 @@ def parse_extractor_io_string(io_str):
|
||||
return name, func_id, io_type
|
||||
|
||||
|
||||
def _dfs_search_reacable_nodes_fast(self, node_output_name, graph_input_nodes,
|
||||
def _dfs_search_reachable_nodes_fast(self, node_output_name, graph_input_nodes,
|
||||
reachable_nodes):
|
||||
"""Using DFS to search reachable nodes."""
|
||||
outputs = {}
|
||||
for index, node in enumerate(self.graph.node):
|
||||
for name in node.output:
|
||||
@ -54,7 +56,7 @@ def create_extractor(model: onnx.ModelProto):
|
||||
assert version.parse(onnx.__version__) >= version.parse('1.8.0')
|
||||
# patch extractor
|
||||
onnx.utils.Extractor._dfs_search_reachable_nodes = \
|
||||
_dfs_search_reacable_nodes_fast
|
||||
_dfs_search_reachable_nodes_fast
|
||||
|
||||
extractor = onnx.utils.Extractor(model)
|
||||
return extractor
|
||||
|
@ -10,6 +10,7 @@ MARK_FUNCTION_COUNT = dict()
|
||||
|
||||
|
||||
def reset_mark_function_count():
|
||||
"""Reset counter of mark function."""
|
||||
for k in MARK_FUNCTION_COUNT:
|
||||
MARK_FUNCTION_COUNT[k] = 0
|
||||
|
||||
@ -59,6 +60,7 @@ class Mark(torch.autograd.Function):
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdeploy.core.optimizers.function_marker.Mark.symbolic')
|
||||
def mark_symbolic(rewriter, g, x, *args):
|
||||
"""Rewrite symbolic of mark op."""
|
||||
if cfg_apply_marks(rewriter.cfg):
|
||||
return rewriter.origin_func(g, x, *args)
|
||||
return x
|
||||
@ -68,6 +70,7 @@ def mark_symbolic(rewriter, g, x, *args):
|
||||
'mmdeploy.core.optimizers.function_marker.Mark.forward')
|
||||
def forward_of_mark(rewriter, ctx, x, dtype, shape, func, func_id, type, name,
|
||||
id, attrs):
|
||||
"""Rewrite forward of mark op."""
|
||||
deploy_cfg = rewriter.cfg
|
||||
# save calib data
|
||||
apply_marks = cfg_apply_marks(deploy_cfg)
|
||||
|
@ -6,6 +6,14 @@ from onnx.helper import get_attribute_value
|
||||
|
||||
|
||||
def attribute_to_dict(attr: onnx.AttributeProto):
|
||||
"""Convert onnx op attribute to dict.
|
||||
|
||||
Args:
|
||||
attr (onnx.AttributeProto): Input onnx op attribute.
|
||||
|
||||
Returns:
|
||||
dict: A dict contains info from op attribute.
|
||||
"""
|
||||
ret = {}
|
||||
for a in attr:
|
||||
value = get_attribute_value(a)
|
||||
@ -16,6 +24,15 @@ def attribute_to_dict(attr: onnx.AttributeProto):
|
||||
|
||||
|
||||
def remove_nodes(model: onnx.ModelProto, predicate: Callable):
|
||||
"""Remove nodes from ONNX model.
|
||||
|
||||
Args:
|
||||
model (onnx.ModelProto): Input onnx model.
|
||||
predicate (Callable): A function to predicate a node.
|
||||
|
||||
Returns:
|
||||
onnx.ModelProto: Modified onnx model.
|
||||
"""
|
||||
# ! this doesn't handle inputs/outputs
|
||||
while True:
|
||||
connect = None
|
||||
@ -38,6 +55,14 @@ def remove_nodes(model: onnx.ModelProto, predicate: Callable):
|
||||
|
||||
|
||||
def is_unused_mark(marks: Iterable[onnx.NodeProto]):
|
||||
"""Check whether a mark is unused.
|
||||
|
||||
Args:
|
||||
marks (Iterable[onnx.NodeProto]): A list of onnx NodeProto.
|
||||
|
||||
Returns:
|
||||
bool: `True` if a mark node is in `marks`.
|
||||
"""
|
||||
|
||||
def f(node):
|
||||
if node.op_type == 'Mark':
|
||||
@ -51,6 +76,7 @@ def is_unused_mark(marks: Iterable[onnx.NodeProto]):
|
||||
|
||||
|
||||
def is_identity(node: onnx.NodeProto):
|
||||
"""Check if an op is identity."""
|
||||
return node.op_type == 'Identity'
|
||||
|
||||
|
||||
@ -73,6 +99,13 @@ def get_new_name(attrs: onnx.ModelProto,
|
||||
|
||||
|
||||
def rename_value(model: onnx.ModelProto, old_name: str, new_name: str):
|
||||
"""Rename a node in an ONNX model.
|
||||
|
||||
Args:
|
||||
model (onnx.ModelProto): Input onnx model.
|
||||
old_name (str): Original node name in the model.
|
||||
new_name (str): New node name in the model.
|
||||
"""
|
||||
if old_name == new_name:
|
||||
return
|
||||
logging.info(f'rename {old_name} -> {new_name}')
|
||||
@ -95,6 +128,11 @@ def rename_value(model: onnx.ModelProto, old_name: str, new_name: str):
|
||||
|
||||
|
||||
def remove_identity(model: onnx.ModelProto):
|
||||
"""Remove identity node from an ONNX model.
|
||||
|
||||
Args:
|
||||
model (onnx.ModelProto): Input onnx model.
|
||||
"""
|
||||
graph = model.graph
|
||||
|
||||
def simplify_inputs():
|
||||
|
@ -6,4 +6,5 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
|
||||
def forward_of_base_classifier(ctx, self, img, *args, **kwargs):
|
||||
"""Rewrite `forward` for default backend."""
|
||||
return self.simple_test(img, {})
|
||||
|
@ -4,4 +4,5 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.heads.ClsHead.post_process')
|
||||
def post_process_of_cls_head(ctx, self, pred, **kwargs):
|
||||
"""Rewrite `post_process` for default backend."""
|
||||
return pred
|
||||
|
@ -4,4 +4,5 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.heads.MultiLabelClsHead.post_process')
|
||||
def post_process_of_multi_label_head(ctx, self, pred, **kwargs):
|
||||
"""Rewrite `post_process` for default backend."""
|
||||
return pred
|
||||
|
@ -5,15 +5,13 @@ from torch.onnx import symbolic_helper as sym_help
|
||||
from mmdeploy.core import SYMBOLIC_REGISTER
|
||||
|
||||
|
||||
class DummyONNXNMSop(torch.autograd.Function):
|
||||
"""DummyONNXNMSop.
|
||||
|
||||
This class is only for creating onnx::NonMaxSuppression.
|
||||
"""
|
||||
class ONNXNMSop(torch.autograd.Function):
|
||||
"""Create onnx::NonMaxSuppression op."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, boxes, scores, max_output_boxes_per_class, iou_threshold,
|
||||
score_threshold):
|
||||
"""Forward of onnx nms."""
|
||||
batch_size, num_class, _ = scores.shape
|
||||
|
||||
score_threshold = float(score_threshold)
|
||||
@ -51,9 +49,11 @@ class DummyONNXNMSop(torch.autograd.Function):
|
||||
|
||||
|
||||
@SYMBOLIC_REGISTER.register_symbolic(
|
||||
'mmdeploy.mmcv.ops.DummyONNXNMSop', backend='default')
|
||||
'mmdeploy.mmcv.ops.ONNXNMSop', backend='default')
|
||||
def nms_dynamic(ctx, g, boxes, scores, max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold):
|
||||
"""Rewrite symbolic function for default backend."""
|
||||
|
||||
if not sym_help._is_value(max_output_boxes_per_class):
|
||||
max_output_boxes_per_class = g.op(
|
||||
'Constant',
|
||||
@ -73,9 +73,11 @@ def nms_dynamic(ctx, g, boxes, scores, max_output_boxes_per_class,
|
||||
|
||||
|
||||
@SYMBOLIC_REGISTER.register_symbolic(
|
||||
'mmdeploy.mmcv.ops.DummyONNXNMSop', backend='tensorrt')
|
||||
'mmdeploy.mmcv.ops.ONNXNMSop', backend='tensorrt')
|
||||
def nms_static(ctx, g, boxes, scores, max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold):
|
||||
"""Rewrite symbolic function for TensorRT backend."""
|
||||
|
||||
if sym_help._is_value(max_output_boxes_per_class):
|
||||
max_output_boxes_per_class = sym_help._maybe_get_const(
|
||||
max_output_boxes_per_class, 'i')
|
||||
@ -98,6 +100,7 @@ def nms_static(ctx, g, boxes, scores, max_output_boxes_per_class,
|
||||
|
||||
|
||||
class TRTBatchedNMSop(torch.autograd.Function):
|
||||
"""Create mmcv::TRTBatchedNMS op for TensorRT backend."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
@ -109,6 +112,7 @@ class TRTBatchedNMSop(torch.autograd.Function):
|
||||
iou_threshold,
|
||||
score_threshold,
|
||||
background_label_id=-1):
|
||||
"""Forward of batched nms."""
|
||||
batch_size, num_boxes, num_classes = scores.shape
|
||||
|
||||
out_boxes = min(num_boxes, after_topk)
|
||||
|
@ -8,6 +8,8 @@ from mmdeploy.core import SYMBOLIC_REGISTER
|
||||
'mmcv.ops.roi_align.__self__', backend='default')
|
||||
def roi_align_default(ctx, g, input, rois, output_size, spatial_scale,
|
||||
sampling_ratio, pool_mode, aligned):
|
||||
"""Rewrite symbolic function for default backend."""
|
||||
|
||||
return g.op(
|
||||
'mmcv::MMCVRoiAlign',
|
||||
input,
|
||||
|
@ -17,6 +17,7 @@ def delta2bbox(ctx,
|
||||
clip_border=True,
|
||||
add_ctr_clamp=False,
|
||||
ctr_clamp=32):
|
||||
"""Rewrite for ONNX exporting of default backend."""
|
||||
means = deltas.new_tensor(means).view(1,
|
||||
-1).repeat(1,
|
||||
deltas.size(-1) // 4)
|
||||
@ -81,6 +82,7 @@ def delta2bbox_ncnn(ctx,
|
||||
clip_border=True,
|
||||
add_ctr_clamp=False,
|
||||
ctr_clamp=32):
|
||||
"""Rewrite for ONNX exporting of NCNN backend."""
|
||||
means = deltas.new_tensor(means).view(1, 1,
|
||||
-1).repeat(1, deltas.size(-2),
|
||||
deltas.size(-1) // 4).data
|
||||
|
@ -13,6 +13,7 @@ def tblr2bboxes(ctx,
|
||||
normalize_by_wh=True,
|
||||
max_shape=None,
|
||||
clip_border=True):
|
||||
"""Rewrite for ONNX exporting of default backend."""
|
||||
if not isinstance(normalizer, float):
|
||||
normalizer = torch.tensor(normalizer, device=priors.device)
|
||||
assert len(normalizer) == 4, 'Normalizer must have length = 4'
|
||||
@ -54,6 +55,7 @@ def delta2bbox_ncnn(ctx,
|
||||
normalize_by_wh=True,
|
||||
max_shape=None,
|
||||
clip_border=True):
|
||||
"""Rewrite for ONNX exporting of NCNN backend."""
|
||||
assert priors.size(0) == tblr.size(0)
|
||||
if priors.ndim == 3:
|
||||
assert priors.size(1) == tblr.size(1)
|
||||
|
@ -2,10 +2,29 @@ import torch
|
||||
|
||||
import mmdeploy
|
||||
from mmdeploy.core import FUNCTION_REWRITER, mark
|
||||
from mmdeploy.mmcv.ops import DummyONNXNMSop, TRTBatchedNMSop
|
||||
from mmdeploy.mmcv.ops import ONNXNMSop, TRTBatchedNMSop
|
||||
|
||||
|
||||
def select_nms_index(scores, boxes, nms_index, batch_size, keep_top_k=-1):
|
||||
def select_nms_index(scores: torch.Tensor,
|
||||
boxes: torch.Tensor,
|
||||
nms_index: torch.Tensor,
|
||||
batch_size: int,
|
||||
keep_top_k: int = -1):
|
||||
"""Transform NMS output.
|
||||
|
||||
Args:
|
||||
scores (Tensor): The detection scores of shape
|
||||
[N, num_classes, num_boxes].
|
||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
|
||||
nms_index (Tensor): NMS output of bounding boxes indexing.
|
||||
batch_size (int): Batch size of the input image.
|
||||
keep_top_k (int): Number of top K boxes to keep after nms.
|
||||
Defaults to -1.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
|
||||
and `labels` of shape [N, num_det].
|
||||
"""
|
||||
batch_inds, cls_inds = nms_index[:, 0], nms_index[:, 1]
|
||||
box_inds = nms_index[:, 2]
|
||||
|
||||
@ -65,9 +84,9 @@ def _multiclass_nms(boxes,
|
||||
(N, num_boxes, 4).
|
||||
|
||||
Args:
|
||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4]
|
||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
|
||||
scores (Tensor): The detection scores of shape
|
||||
[N, num_boxes, num_classes]
|
||||
[N, num_boxes, num_classes].
|
||||
max_output_boxes_per_class (int): Maximum number of output
|
||||
boxes per class of nms. Defaults to 1000.
|
||||
iou_threshold (float): IOU threshold of nms. Defaults to 0.5
|
||||
@ -79,8 +98,8 @@ def _multiclass_nms(boxes,
|
||||
Defaults to -1.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5] and class labels
|
||||
of shape [N, num_det].
|
||||
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
|
||||
and `labels` of shape [N, num_det].
|
||||
"""
|
||||
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
|
||||
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
|
||||
@ -96,7 +115,7 @@ def _multiclass_nms(boxes,
|
||||
scores = scores[batch_inds, topk_inds, :]
|
||||
|
||||
scores = scores.permute(0, 2, 1)
|
||||
selected_indices = DummyONNXNMSop.apply(boxes, scores,
|
||||
selected_indices = ONNXNMSop.apply(boxes, scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold)
|
||||
|
||||
@ -117,6 +136,7 @@ def multiclass_nms_static(ctx,
|
||||
score_threshold=0.05,
|
||||
pre_top_k=-1,
|
||||
keep_top_k=-1):
|
||||
"""Wrapper for `multiclass_nms` with TensorRT."""
|
||||
boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
|
||||
keep_top_k = max_output_boxes_per_class if keep_top_k < 0 else min(
|
||||
max_output_boxes_per_class, keep_top_k)
|
||||
@ -129,5 +149,5 @@ def multiclass_nms_static(ctx,
|
||||
|
||||
@mark('multiclass_nms', inputs=['boxes', 'scores'], outputs=['dets', 'labels'])
|
||||
def multiclass_nms(*args, **kwargs):
|
||||
"""Wrapper function for _multiclass_nms."""
|
||||
"""Wrapper function for `_multiclass_nms`."""
|
||||
return mmdeploy.mmdet.core.post_processing._multiclass_nms(*args, **kwargs)
|
||||
|
@ -17,6 +17,7 @@ def get_bboxes_of_anchor_head(ctx,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` for default backend."""
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
@ -124,6 +125,7 @@ def get_bboxes_of_anchor_head_ncnn(ctx,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` for NCNN backend."""
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
assert not is_dynamic_shape(deploy_cfg)
|
||||
|
@ -18,6 +18,7 @@ def get_bboxes_of_fcos_head(ctx,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` for default backend."""
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
@ -118,6 +119,7 @@ def get_bboxes_of_fcos_head_ncnn(ctx,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` for NCNN backend."""
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
assert not is_dynamic_shape(deploy_cfg)
|
||||
|
@ -16,6 +16,7 @@ def get_bboxes_of_rpn_head(ctx,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` for default backend."""
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
@ -120,6 +121,7 @@ def get_bboxes_of_rpn_head_ncnn(ctx,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` for NCNN backend."""
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
assert not is_dynamic_shape(deploy_cfg)
|
||||
|
@ -7,6 +7,7 @@ from mmdeploy.utils import is_dynamic_shape
|
||||
@mark(
|
||||
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
|
||||
def _forward_of_base_detector_impl(ctx, self, img, img_metas=None, **kwargs):
|
||||
"""Rewrite and adding mark for `forward`."""
|
||||
assert isinstance(img_metas, dict)
|
||||
assert isinstance(img, torch.Tensor)
|
||||
|
||||
@ -23,6 +24,7 @@ def _forward_of_base_detector_impl(ctx, self, img, img_metas=None, **kwargs):
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.BaseDetector.forward')
|
||||
def forward_of_base_detector(ctx, self, img, img_metas=None, **kwargs):
|
||||
"""Rewrite `forward` for default backend."""
|
||||
if img_metas is None:
|
||||
img_metas = {}
|
||||
|
||||
|
@ -3,5 +3,6 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(func_name='mmdet.models.RPN.simple_test')
|
||||
def simple_test_of_rpn(ctx, self, img, img_metas, **kwargs):
|
||||
"""Rewrite `simple_test` for default backend."""
|
||||
x = self.extract_feat(img)
|
||||
return self.rpn_head.simple_test_rpn(x, img_metas)
|
||||
|
@ -4,5 +4,6 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.SingleStageDetector.simple_test')
|
||||
def simple_test_of_single_stage(ctx, self, img, img_metas, **kwargs):
|
||||
"""Rewrite `simple_test` for default backend."""
|
||||
feat = self.extract_feat(img)
|
||||
return self.bbox_head.simple_test(feat, img_metas, **kwargs)
|
||||
|
@ -5,6 +5,7 @@ from mmdeploy.core import FUNCTION_REWRITER, mark
|
||||
'mmdet.models.TwoStageDetector.extract_feat')
|
||||
@mark('extract_feat', inputs='img', outputs='feat')
|
||||
def extract_feat_of_two_stage(ctx, self, img):
|
||||
"""Rewrite `extract_feat` for default backend."""
|
||||
return ctx.origin_func(self, img)
|
||||
|
||||
|
||||
@ -16,6 +17,7 @@ def simple_test_of_two_stage(ctx,
|
||||
img_metas,
|
||||
proposals=None,
|
||||
**kwargs):
|
||||
"""Rewrite `simple_test` for default backend."""
|
||||
assert self.with_bbox, 'Bbox head must be implemented.'
|
||||
x = self.extract_feat(img)
|
||||
if proposals is None:
|
||||
|
@ -15,6 +15,7 @@ from mmdeploy.utils import get_mmdet_params
|
||||
inputs=['bbox_feats'],
|
||||
outputs=['cls_score', 'bbox_pred'])
|
||||
def forward_of_bbox_head(ctx, self, x):
|
||||
"""Rewrite `forward` for default backend."""
|
||||
return ctx.origin_func(self, x)
|
||||
|
||||
|
||||
@ -22,7 +23,7 @@ def forward_of_bbox_head(ctx, self, x):
|
||||
func_name='mmdet.models.roi_heads.BBoxHead.get_bboxes')
|
||||
def get_bboxes_of_bbox_head(ctx, self, rois, cls_score, bbox_pred, img_shape,
|
||||
cfg, **kwargs):
|
||||
|
||||
"""Rewrite `get_bboxes` for default backend."""
|
||||
assert rois.ndim == 3, 'Only support export two stage ' \
|
||||
'model to ONNX ' \
|
||||
'with batch dimension. '
|
||||
|
@ -59,6 +59,7 @@ def forward_of_single_roi_extractor_static(ctx,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
"""Rewrite `forward` for TensorRT backend."""
|
||||
featmap_strides = self.featmap_strides
|
||||
finest_scale = self.finest_scale
|
||||
|
||||
@ -83,6 +84,7 @@ def forward_of_single_roi_extractor_dynamic(ctx,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
"""Rewrite `forward` for default backend."""
|
||||
out_size = self.roi_layers[0].output_size
|
||||
num_levels = len(feats)
|
||||
roi_feats = feats[0].new_zeros(rois.shape[0], self.out_channels, *out_size)
|
||||
|
@ -5,6 +5,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
func_name='mmdet.models.roi_heads.StandardRoIHead.simple_test')
|
||||
def simple_test_of_standard_roi_head(ctx, self, x, proposals, img_metas,
|
||||
**kwargs):
|
||||
"""Rewrite `simple_test` for default backend."""
|
||||
assert self.with_bbox, 'Bbox head must be implemented.'
|
||||
det_bboxes, det_labels = self.simple_test_bboxes(
|
||||
x, img_metas, proposals, self.test_cfg, rescale=False)
|
||||
|
@ -8,6 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
BBoxTestMixin.simple_test_bboxes')
|
||||
def simple_test_bboxes_of_bbox_test_mixin(ctx, self, x, img_metas, proposals,
|
||||
rcnn_test_cfg, **kwargs):
|
||||
"""Rewrite `simple_test_bboxes` for default backend."""
|
||||
rois = proposals
|
||||
batch_index = torch.arange(
|
||||
rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand(
|
||||
|
@ -41,7 +41,9 @@ class SRCNNWrapper(nn.Module):
|
||||
align_corners=False)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""Run forward."""
|
||||
return self._module(*args, **kwargs)
|
||||
|
||||
def init_weights(self, *args, **kwargs):
|
||||
"""Initialize weights."""
|
||||
return self._module.init_weights(*args, **kwargs)
|
||||
|
@ -9,6 +9,7 @@ def simple_test_of_single_stage_text_detector(ctx,
|
||||
img_metas,
|
||||
rescale=False,
|
||||
**kwargs):
|
||||
"""Rewrite `simple_test` for default backend."""
|
||||
x = self.extract_feat(img)
|
||||
outs = self.bbox_head(x)
|
||||
return outs
|
||||
|
@ -6,6 +6,8 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
'.BidirectionalLSTM.forward',
|
||||
backend='ncnn')
|
||||
def forward_of_bidirectionallstm(ctx, self, input):
|
||||
"""Rewrite `forward` for NCNN backend."""
|
||||
|
||||
self.rnn.batch_first = True
|
||||
recurrent, _ = self.rnn(input)
|
||||
self.rnn.batch_first = False
|
||||
|
@ -12,6 +12,7 @@ def forward_of_base_recognizer(ctx,
|
||||
img_metas=None,
|
||||
return_loss=False,
|
||||
**kwargs):
|
||||
"""Rewrite `forward` for NCNN backend."""
|
||||
if img_metas is None:
|
||||
img_metas = {}
|
||||
assert isinstance(img_metas, dict)
|
||||
|
@ -6,6 +6,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
backend='ncnn')
|
||||
def forward_train_of_crnndecoder(ctx, self, feat, out_enc, targets_dict,
|
||||
img_metas):
|
||||
"""Rewrite `forward_train` for NCNN backend."""
|
||||
assert feat.size(2) == 1, 'feature height must be 1'
|
||||
if self.rnn_flag:
|
||||
x = feat.squeeze(2) # [N, C, W]
|
||||
|
@ -5,6 +5,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
func_name='mmocr.models.textrecog.EncodeDecodeRecognizer.simple_test')
|
||||
def simple_test_of_encode_decode_recognizer(ctx, self, img, img_metas,
|
||||
**kwargs):
|
||||
"""Rewrite `forward` for default backend."""
|
||||
feat = self.extract_feat(img)
|
||||
|
||||
out_enc = None
|
||||
|
@ -8,6 +8,7 @@ from mmdeploy.utils import is_dynamic_shape
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.decode_heads.ASPPHead.forward')
|
||||
def forward_of_aspp_head(ctx, self, inputs):
|
||||
"""Rewrite `forward` for default backend."""
|
||||
x = self._transform_inputs(inputs)
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
|
@ -7,6 +7,7 @@ from mmdeploy.utils import is_dynamic_shape
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.decode_heads.psp_head.PPM.forward')
|
||||
def forward_of_ppm(ctx, self, x):
|
||||
"""Rewrite `forward` for default backend."""
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
# get origin input shape as tensor to support onnx dynamic shape
|
||||
|
@ -7,6 +7,7 @@ from mmdeploy.utils import is_dynamic_shape
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.segmentors.BaseSegmentor.forward')
|
||||
def forward_of_base_segmentor(ctx, self, img, img_metas=None, **kwargs):
|
||||
"""Rewrite `forward` for default backend."""
|
||||
if img_metas is None:
|
||||
img_metas = {}
|
||||
assert isinstance(img_metas, dict)
|
||||
|
@ -7,6 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.segmentors.EncoderDecoder.simple_test')
|
||||
def simple_test_of_encoder_decoder(ctx, self, img, img_meta, **kwargs):
|
||||
"""Rewrite `simple_test` for default backend."""
|
||||
x = self.extract_feat(img)
|
||||
seg_logit = self._decode_head_forward_test(x, img_meta)
|
||||
seg_logit = resize(
|
||||
|
@ -6,6 +6,8 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.__getattribute__', backend='ncnn')
|
||||
def getattribute_static(ctx, self, name):
|
||||
"""Rewrite `__getattribute__` for NCNN backend."""
|
||||
|
||||
ret = ctx.origin_func(self, name)
|
||||
if name == 'shape':
|
||||
ret = torch.Size([int(s) for s in ret])
|
||||
|
@ -15,6 +15,7 @@ def group_norm_ncnn(
|
||||
bias: Union[torch.Tensor, torch.NoneType] = None,
|
||||
eps: float = 1e-05,
|
||||
) -> torch.Tensor:
|
||||
"""Rewrite `group_norm` for NCNN backend."""
|
||||
input_shape = input.shape
|
||||
batch_size = input_shape[0]
|
||||
# We cannot use input.reshape(batch_size, num_groups, -1, 1)
|
||||
|
@ -10,6 +10,8 @@ def interpolate_static(ctx,
|
||||
mode='nearest',
|
||||
align_corners=None,
|
||||
recompute_scale_factor=None):
|
||||
"""Rewrite `interpolate` for NCNN backend."""
|
||||
|
||||
input_size = input.shape
|
||||
if scale_factor is None:
|
||||
scale_factor = [
|
||||
|
@ -13,6 +13,8 @@ def linear_ncnn(
|
||||
weight: torch.Tensor,
|
||||
bias: Union[torch.Tensor, torch.NoneType] = None,
|
||||
):
|
||||
"""Rewrite `linear` for NCNN backend."""
|
||||
|
||||
origin_func = ctx.origin_func
|
||||
|
||||
dim = input.dim()
|
||||
|
@ -4,6 +4,8 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.repeat', backend='tensorrt')
|
||||
def repeat_static(ctx, input, *size):
|
||||
"""Rewrite `repeat` for NCNN backend."""
|
||||
|
||||
origin_func = ctx.origin_func
|
||||
if input.dim() == 1 and len(size) == 1:
|
||||
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
|
||||
|
@ -6,6 +6,8 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.size', backend='ncnn')
|
||||
def size_of_tensor_static(ctx, self, *args):
|
||||
"""Rewrite `size` for NCNN backend."""
|
||||
|
||||
ret = ctx.origin_func(self, *args)
|
||||
if isinstance(ret, torch.Tensor):
|
||||
ret = int(ret)
|
||||
|
@ -7,6 +7,8 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.topk', backend='default')
|
||||
def topk_dynamic(ctx, input, k, dim=None, largest=True, sorted=True):
|
||||
"""Rewrite `interpolate` for default backend."""
|
||||
|
||||
if dim is None:
|
||||
dim = int(input.ndim - 1)
|
||||
size = input.shape[dim]
|
||||
@ -24,6 +26,8 @@ def topk_dynamic(ctx, input, k, dim=None, largest=True, sorted=True):
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.topk', backend='tensorrt')
|
||||
def topk_static(ctx, input, k, dim=None, largest=True, sorted=True):
|
||||
"""Rewrite `interpolate` for TensorRT backend."""
|
||||
|
||||
if dim is None:
|
||||
dim = int(input.ndim - 1)
|
||||
size = input.shape[dim]
|
||||
|
@ -8,6 +8,7 @@ from mmdeploy.core import SYMBOLIC_REGISTER
|
||||
|
||||
|
||||
def _adaptive_pool(name, type, tuple_fn, fn=None):
|
||||
"""Generic adaptive pooling."""
|
||||
|
||||
@parse_args('v', 'is')
|
||||
def symbolic_fn(g, input, output_size):
|
||||
@ -53,14 +54,17 @@ adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', 'AveragePool',
|
||||
|
||||
@SYMBOLIC_REGISTER.register_symbolic('adaptive_avg_pool1d', is_pytorch=True)
|
||||
def adaptive_avg_pool1d_op(ctx, *args):
|
||||
"""Register default symbolic function for `adaptive_avg_pool1d`."""
|
||||
return adaptive_avg_pool1d(*args)
|
||||
|
||||
|
||||
@SYMBOLIC_REGISTER.register_symbolic('adaptive_avg_pool2d', is_pytorch=True)
|
||||
def adaptive_avg_pool2d_op(ctx, *args):
|
||||
"""Register default symbolic function for `adaptive_avg_pool2d`."""
|
||||
return adaptive_avg_pool2d(*args)
|
||||
|
||||
|
||||
@SYMBOLIC_REGISTER.register_symbolic('adaptive_avg_pool3d', is_pytorch=True)
|
||||
def adaptive_avg_pool3d_op(ctx, *args):
|
||||
"""Register default symbolic function for `adaptive_avg_pool3d`."""
|
||||
return adaptive_avg_pool3d(*args)
|
||||
|
@ -10,6 +10,7 @@ def grid_sampler(g,
|
||||
interpolation_mode,
|
||||
padding_mode,
|
||||
align_corners=False):
|
||||
"""Symbolic function for `grid_sampler`."""
|
||||
return g.op(
|
||||
'mmcv::grid_sampler',
|
||||
input,
|
||||
@ -21,4 +22,5 @@ def grid_sampler(g,
|
||||
|
||||
@SYMBOLIC_REGISTER.register_symbolic('grid_sampler', is_pytorch=True)
|
||||
def grid_sampler_default(ctx, *args):
|
||||
"""Register default symbolic function for `grid_sampler`."""
|
||||
return grid_sampler(*args)
|
||||
|
@ -11,6 +11,7 @@ from mmdeploy.core import SYMBOLIC_REGISTER
|
||||
|
||||
@parse_args('v', 'i', 'v', 'v', 'f', 'i')
|
||||
def instance_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
|
||||
"""Symbolic function for `instance_norm`."""
|
||||
channel_size = _get_tensor_dim_size(input, 1)
|
||||
if channel_size is not None:
|
||||
assert channel_size % num_groups == 0
|
||||
@ -60,8 +61,12 @@ def instance_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
|
||||
_unsqueeze_helper(g, bias, axes))
|
||||
|
||||
|
||||
# instance normalization is implemented in group norm in pytorch
|
||||
@SYMBOLIC_REGISTER.register_symbolic(
|
||||
'group_norm', backend='tensorrt', is_pytorch=True)
|
||||
def instance_norm_trt(ctx, *args):
|
||||
"""Register symbolic function for TensorRT backend.
|
||||
|
||||
Notes:
|
||||
Instance normalization is implemented in group norm in pytorch.
|
||||
"""
|
||||
return instance_norm(*args)
|
||||
|
@ -5,6 +5,7 @@ from mmdeploy.core import SYMBOLIC_REGISTER
|
||||
|
||||
@SYMBOLIC_REGISTER.register_symbolic('squeeze', is_pytorch=True)
|
||||
def squeeze_default(ctx, g, self, dim=None):
|
||||
"""Register default symbolic function for `squeeze`."""
|
||||
if dim is None:
|
||||
dims = []
|
||||
for i, size in enumerate(self.type().sizes()):
|
||||
|
@ -2,6 +2,7 @@ from enum import Enum
|
||||
|
||||
|
||||
class AdvancedEnum(Enum):
|
||||
"""Define an enumeration class."""
|
||||
|
||||
@classmethod
|
||||
def get(cls, str, a):
|
||||
@ -12,6 +13,7 @@ class AdvancedEnum(Enum):
|
||||
|
||||
|
||||
class Task(AdvancedEnum):
|
||||
"""Define task enumerations."""
|
||||
TEXT_DETECTION = 'TextDetection'
|
||||
TEXT_RECOGNITION = 'TextRecognition'
|
||||
SEGMENTATION = 'Segmentation'
|
||||
@ -21,6 +23,7 @@ class Task(AdvancedEnum):
|
||||
|
||||
|
||||
class Codebase(AdvancedEnum):
|
||||
"""Define codebase enumerations."""
|
||||
MMDET = 'mmdet'
|
||||
MMSEG = 'mmseg'
|
||||
MMCLS = 'mmcls'
|
||||
@ -29,6 +32,7 @@ class Codebase(AdvancedEnum):
|
||||
|
||||
|
||||
class Backend(AdvancedEnum):
|
||||
"""Define backend enumerations."""
|
||||
PYTORCH = 'pytorch'
|
||||
TENSORRT = 'tensorrt'
|
||||
ONNXRUNTIME = 'onnxruntime'
|
||||
|
@ -12,6 +12,7 @@ from mmdeploy.utils import Backend, get_backend, get_onnx_config
|
||||
|
||||
|
||||
class WrapFunction(nn.Module):
|
||||
"""Simple wrapper for a function."""
|
||||
|
||||
def __init__(self, wrapped_function, **kwargs):
|
||||
super(WrapFunction, self).__init__()
|
||||
|
@ -4,7 +4,15 @@ __version__ = '0.1.0'
|
||||
short_version = __version__
|
||||
|
||||
|
||||
def parse_version_info(version_str):
|
||||
def parse_version_info(version_str: str):
|
||||
"""Parse version from a string.
|
||||
|
||||
Args:
|
||||
version_str (str): A string represents a version info.
|
||||
|
||||
Returns:
|
||||
tuple: A sequence of integer and string represents version.
|
||||
"""
|
||||
version_info = []
|
||||
for x in version_str.split('.'):
|
||||
if x.isdigit():
|
||||
|
Loading…
x
Reference in New Issue
Block a user