[Enhance]: Add more docstring. (#111)

* add docstring for apis

* add simple docstring for mmdet

* add simple docstring for mmseg

* add simple docstring for mmcls

* add simple docstring for mmedit

* add simple docstring for mmocr

* add simple docstring for rewriting

* update thresh for docstring coverage

* update

* update docstring

* solve comments

* remove unrelated symbol
This commit is contained in:
RunningLeon 2021-09-29 15:59:38 +08:00 committed by GitHub
parent 4587322441
commit de9498a8f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 191 additions and 25 deletions

View File

@ -26,7 +26,7 @@ jobs:
- name: Check docstring coverage
run: |
pip install interrogate
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 30 mmdeploy
interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 80 mmdeploy
build_cpu:
runs-on: ubuntu-18.04

View File

@ -60,6 +60,7 @@ class HDF5Calibrator(trt.IInt8Calibrator):
self.calib_file.close()
def get_batch(self, names: Sequence[str], **kwargs):
"""Get batch data."""
if self.count < self.dataset_length:
ret = []
@ -95,13 +96,33 @@ class HDF5Calibrator(trt.IInt8Calibrator):
return None
def get_algorithm(self):
"""Get Calibration algo type.
Returns:
trt.CalibrationAlgoType: Calibration algo type.
"""
return self.algorithm
def get_batch_size(self):
"""Get batch size.
Returns:
int: An integer represents batch size.
"""
return self.batch_size
def read_calibration_cache(self, *args, **kwargs):
"""Read calibration cache.
Notes:
No need to implement this function.
"""
pass
def write_calibration_cache(self, cache, *args, **kwargs):
"""Write calibration cache.
Notes:
No need to implement this function.
"""
pass

View File

@ -215,7 +215,7 @@ class TRTWrapper(torch.nn.Module):
self._load_io_names()
def _load_io_names(self):
# get input and output names from engine
"""Load input/output names from engine."""
names = [_ for _ in self.engine]
input_names = list(filter(self.engine.binding_is_input, names))
output_names = list(set(names) - set(input_names))

View File

@ -5,6 +5,7 @@ from packaging import version
def parse_extractor_io_string(io_str):
"""Parse IO string for extractor."""
name, io_type = io_str.split(':')
assert io_type in ['input', 'output']
func_id = 0
@ -17,8 +18,9 @@ def parse_extractor_io_string(io_str):
return name, func_id, io_type
def _dfs_search_reacable_nodes_fast(self, node_output_name, graph_input_nodes,
def _dfs_search_reachable_nodes_fast(self, node_output_name, graph_input_nodes,
reachable_nodes):
"""Using DFS to search reachable nodes."""
outputs = {}
for index, node in enumerate(self.graph.node):
for name in node.output:
@ -54,7 +56,7 @@ def create_extractor(model: onnx.ModelProto):
assert version.parse(onnx.__version__) >= version.parse('1.8.0')
# patch extractor
onnx.utils.Extractor._dfs_search_reachable_nodes = \
_dfs_search_reacable_nodes_fast
_dfs_search_reachable_nodes_fast
extractor = onnx.utils.Extractor(model)
return extractor

View File

@ -10,6 +10,7 @@ MARK_FUNCTION_COUNT = dict()
def reset_mark_function_count():
"""Reset counter of mark function."""
for k in MARK_FUNCTION_COUNT:
MARK_FUNCTION_COUNT[k] = 0
@ -59,6 +60,7 @@ class Mark(torch.autograd.Function):
@FUNCTION_REWRITER.register_rewriter(
'mmdeploy.core.optimizers.function_marker.Mark.symbolic')
def mark_symbolic(rewriter, g, x, *args):
"""Rewrite symbolic of mark op."""
if cfg_apply_marks(rewriter.cfg):
return rewriter.origin_func(g, x, *args)
return x
@ -68,6 +70,7 @@ def mark_symbolic(rewriter, g, x, *args):
'mmdeploy.core.optimizers.function_marker.Mark.forward')
def forward_of_mark(rewriter, ctx, x, dtype, shape, func, func_id, type, name,
id, attrs):
"""Rewrite forward of mark op."""
deploy_cfg = rewriter.cfg
# save calib data
apply_marks = cfg_apply_marks(deploy_cfg)

View File

@ -6,6 +6,14 @@ from onnx.helper import get_attribute_value
def attribute_to_dict(attr: onnx.AttributeProto):
"""Convert onnx op attribute to dict.
Args:
attr (onnx.AttributeProto): Input onnx op attribute.
Returns:
dict: A dict contains info from op attribute.
"""
ret = {}
for a in attr:
value = get_attribute_value(a)
@ -16,6 +24,15 @@ def attribute_to_dict(attr: onnx.AttributeProto):
def remove_nodes(model: onnx.ModelProto, predicate: Callable):
"""Remove nodes from ONNX model.
Args:
model (onnx.ModelProto): Input onnx model.
predicate (Callable): A function to predicate a node.
Returns:
onnx.ModelProto: Modified onnx model.
"""
# ! this doesn't handle inputs/outputs
while True:
connect = None
@ -38,6 +55,14 @@ def remove_nodes(model: onnx.ModelProto, predicate: Callable):
def is_unused_mark(marks: Iterable[onnx.NodeProto]):
"""Check whether a mark is unused.
Args:
marks (Iterable[onnx.NodeProto]): A list of onnx NodeProto.
Returns:
bool: `True` if a mark node is in `marks`.
"""
def f(node):
if node.op_type == 'Mark':
@ -51,6 +76,7 @@ def is_unused_mark(marks: Iterable[onnx.NodeProto]):
def is_identity(node: onnx.NodeProto):
"""Check if an op is identity."""
return node.op_type == 'Identity'
@ -73,6 +99,13 @@ def get_new_name(attrs: onnx.ModelProto,
def rename_value(model: onnx.ModelProto, old_name: str, new_name: str):
"""Rename a node in an ONNX model.
Args:
model (onnx.ModelProto): Input onnx model.
old_name (str): Original node name in the model.
new_name (str): New node name in the model.
"""
if old_name == new_name:
return
logging.info(f'rename {old_name} -> {new_name}')
@ -95,6 +128,11 @@ def rename_value(model: onnx.ModelProto, old_name: str, new_name: str):
def remove_identity(model: onnx.ModelProto):
"""Remove identity node from an ONNX model.
Args:
model (onnx.ModelProto): Input onnx model.
"""
graph = model.graph
def simplify_inputs():

View File

@ -6,4 +6,5 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
def forward_of_base_classifier(ctx, self, img, *args, **kwargs):
"""Rewrite `forward` for default backend."""
return self.simple_test(img, {})

View File

@ -4,4 +4,5 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.ClsHead.post_process')
def post_process_of_cls_head(ctx, self, pred, **kwargs):
"""Rewrite `post_process` for default backend."""
return pred

View File

@ -4,4 +4,5 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.heads.MultiLabelClsHead.post_process')
def post_process_of_multi_label_head(ctx, self, pred, **kwargs):
"""Rewrite `post_process` for default backend."""
return pred

View File

@ -5,15 +5,13 @@ from torch.onnx import symbolic_helper as sym_help
from mmdeploy.core import SYMBOLIC_REGISTER
class DummyONNXNMSop(torch.autograd.Function):
"""DummyONNXNMSop.
This class is only for creating onnx::NonMaxSuppression.
"""
class ONNXNMSop(torch.autograd.Function):
"""Create onnx::NonMaxSuppression op."""
@staticmethod
def forward(ctx, boxes, scores, max_output_boxes_per_class, iou_threshold,
score_threshold):
"""Forward of onnx nms."""
batch_size, num_class, _ = scores.shape
score_threshold = float(score_threshold)
@ -51,9 +49,11 @@ class DummyONNXNMSop(torch.autograd.Function):
@SYMBOLIC_REGISTER.register_symbolic(
'mmdeploy.mmcv.ops.DummyONNXNMSop', backend='default')
'mmdeploy.mmcv.ops.ONNXNMSop', backend='default')
def nms_dynamic(ctx, g, boxes, scores, max_output_boxes_per_class,
iou_threshold, score_threshold):
"""Rewrite symbolic function for default backend."""
if not sym_help._is_value(max_output_boxes_per_class):
max_output_boxes_per_class = g.op(
'Constant',
@ -73,9 +73,11 @@ def nms_dynamic(ctx, g, boxes, scores, max_output_boxes_per_class,
@SYMBOLIC_REGISTER.register_symbolic(
'mmdeploy.mmcv.ops.DummyONNXNMSop', backend='tensorrt')
'mmdeploy.mmcv.ops.ONNXNMSop', backend='tensorrt')
def nms_static(ctx, g, boxes, scores, max_output_boxes_per_class,
iou_threshold, score_threshold):
"""Rewrite symbolic function for TensorRT backend."""
if sym_help._is_value(max_output_boxes_per_class):
max_output_boxes_per_class = sym_help._maybe_get_const(
max_output_boxes_per_class, 'i')
@ -98,6 +100,7 @@ def nms_static(ctx, g, boxes, scores, max_output_boxes_per_class,
class TRTBatchedNMSop(torch.autograd.Function):
"""Create mmcv::TRTBatchedNMS op for TensorRT backend."""
@staticmethod
def forward(ctx,
@ -109,6 +112,7 @@ class TRTBatchedNMSop(torch.autograd.Function):
iou_threshold,
score_threshold,
background_label_id=-1):
"""Forward of batched nms."""
batch_size, num_boxes, num_classes = scores.shape
out_boxes = min(num_boxes, after_topk)

View File

@ -8,6 +8,8 @@ from mmdeploy.core import SYMBOLIC_REGISTER
'mmcv.ops.roi_align.__self__', backend='default')
def roi_align_default(ctx, g, input, rois, output_size, spatial_scale,
sampling_ratio, pool_mode, aligned):
"""Rewrite symbolic function for default backend."""
return g.op(
'mmcv::MMCVRoiAlign',
input,

View File

@ -17,6 +17,7 @@ def delta2bbox(ctx,
clip_border=True,
add_ctr_clamp=False,
ctr_clamp=32):
"""Rewrite for ONNX exporting of default backend."""
means = deltas.new_tensor(means).view(1,
-1).repeat(1,
deltas.size(-1) // 4)
@ -81,6 +82,7 @@ def delta2bbox_ncnn(ctx,
clip_border=True,
add_ctr_clamp=False,
ctr_clamp=32):
"""Rewrite for ONNX exporting of NCNN backend."""
means = deltas.new_tensor(means).view(1, 1,
-1).repeat(1, deltas.size(-2),
deltas.size(-1) // 4).data

View File

@ -13,6 +13,7 @@ def tblr2bboxes(ctx,
normalize_by_wh=True,
max_shape=None,
clip_border=True):
"""Rewrite for ONNX exporting of default backend."""
if not isinstance(normalizer, float):
normalizer = torch.tensor(normalizer, device=priors.device)
assert len(normalizer) == 4, 'Normalizer must have length = 4'
@ -54,6 +55,7 @@ def delta2bbox_ncnn(ctx,
normalize_by_wh=True,
max_shape=None,
clip_border=True):
"""Rewrite for ONNX exporting of NCNN backend."""
assert priors.size(0) == tblr.size(0)
if priors.ndim == 3:
assert priors.size(1) == tblr.size(1)

View File

@ -2,10 +2,29 @@ import torch
import mmdeploy
from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.mmcv.ops import DummyONNXNMSop, TRTBatchedNMSop
from mmdeploy.mmcv.ops import ONNXNMSop, TRTBatchedNMSop
def select_nms_index(scores, boxes, nms_index, batch_size, keep_top_k=-1):
def select_nms_index(scores: torch.Tensor,
boxes: torch.Tensor,
nms_index: torch.Tensor,
batch_size: int,
keep_top_k: int = -1):
"""Transform NMS output.
Args:
scores (Tensor): The detection scores of shape
[N, num_classes, num_boxes].
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
nms_index (Tensor): NMS output of bounding boxes indexing.
batch_size (int): Batch size of the input image.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.
Returns:
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
and `labels` of shape [N, num_det].
"""
batch_inds, cls_inds = nms_index[:, 0], nms_index[:, 1]
box_inds = nms_index[:, 2]
@ -65,9 +84,9 @@ def _multiclass_nms(boxes,
(N, num_boxes, 4).
Args:
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4]
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes]
[N, num_boxes, num_classes].
max_output_boxes_per_class (int): Maximum number of output
boxes per class of nms. Defaults to 1000.
iou_threshold (float): IOU threshold of nms. Defaults to 0.5
@ -79,8 +98,8 @@ def _multiclass_nms(boxes,
Defaults to -1.
Returns:
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5] and class labels
of shape [N, num_det].
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
and `labels` of shape [N, num_det].
"""
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
@ -96,7 +115,7 @@ def _multiclass_nms(boxes,
scores = scores[batch_inds, topk_inds, :]
scores = scores.permute(0, 2, 1)
selected_indices = DummyONNXNMSop.apply(boxes, scores,
selected_indices = ONNXNMSop.apply(boxes, scores,
max_output_boxes_per_class,
iou_threshold, score_threshold)
@ -117,6 +136,7 @@ def multiclass_nms_static(ctx,
score_threshold=0.05,
pre_top_k=-1,
keep_top_k=-1):
"""Wrapper for `multiclass_nms` with TensorRT."""
boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
keep_top_k = max_output_boxes_per_class if keep_top_k < 0 else min(
max_output_boxes_per_class, keep_top_k)
@ -129,5 +149,5 @@ def multiclass_nms_static(ctx,
@mark('multiclass_nms', inputs=['boxes', 'scores'], outputs=['dets', 'labels'])
def multiclass_nms(*args, **kwargs):
"""Wrapper function for _multiclass_nms."""
"""Wrapper function for `_multiclass_nms`."""
return mmdeploy.mmdet.core.post_processing._multiclass_nms(*args, **kwargs)

View File

@ -17,6 +17,7 @@ def get_bboxes_of_anchor_head(ctx,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for default backend."""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
@ -124,6 +125,7 @@ def get_bboxes_of_anchor_head_ncnn(ctx,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for NCNN backend."""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
assert not is_dynamic_shape(deploy_cfg)

View File

@ -18,6 +18,7 @@ def get_bboxes_of_fcos_head(ctx,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for default backend."""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
@ -118,6 +119,7 @@ def get_bboxes_of_fcos_head_ncnn(ctx,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for NCNN backend."""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
assert not is_dynamic_shape(deploy_cfg)

View File

@ -16,6 +16,7 @@ def get_bboxes_of_rpn_head(ctx,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for default backend."""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
@ -120,6 +121,7 @@ def get_bboxes_of_rpn_head_ncnn(ctx,
with_nms=True,
cfg=None,
**kwargs):
"""Rewrite `get_bboxes` for NCNN backend."""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
assert not is_dynamic_shape(deploy_cfg)

View File

@ -7,6 +7,7 @@ from mmdeploy.utils import is_dynamic_shape
@mark(
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
def _forward_of_base_detector_impl(ctx, self, img, img_metas=None, **kwargs):
"""Rewrite and adding mark for `forward`."""
assert isinstance(img_metas, dict)
assert isinstance(img, torch.Tensor)
@ -23,6 +24,7 @@ def _forward_of_base_detector_impl(ctx, self, img, img_metas=None, **kwargs):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.BaseDetector.forward')
def forward_of_base_detector(ctx, self, img, img_metas=None, **kwargs):
"""Rewrite `forward` for default backend."""
if img_metas is None:
img_metas = {}

View File

@ -3,5 +3,6 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(func_name='mmdet.models.RPN.simple_test')
def simple_test_of_rpn(ctx, self, img, img_metas, **kwargs):
"""Rewrite `simple_test` for default backend."""
x = self.extract_feat(img)
return self.rpn_head.simple_test_rpn(x, img_metas)

View File

@ -4,5 +4,6 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.SingleStageDetector.simple_test')
def simple_test_of_single_stage(ctx, self, img, img_metas, **kwargs):
"""Rewrite `simple_test` for default backend."""
feat = self.extract_feat(img)
return self.bbox_head.simple_test(feat, img_metas, **kwargs)

View File

@ -5,6 +5,7 @@ from mmdeploy.core import FUNCTION_REWRITER, mark
'mmdet.models.TwoStageDetector.extract_feat')
@mark('extract_feat', inputs='img', outputs='feat')
def extract_feat_of_two_stage(ctx, self, img):
"""Rewrite `extract_feat` for default backend."""
return ctx.origin_func(self, img)
@ -16,6 +17,7 @@ def simple_test_of_two_stage(ctx,
img_metas,
proposals=None,
**kwargs):
"""Rewrite `simple_test` for default backend."""
assert self.with_bbox, 'Bbox head must be implemented.'
x = self.extract_feat(img)
if proposals is None:

View File

@ -15,6 +15,7 @@ from mmdeploy.utils import get_mmdet_params
inputs=['bbox_feats'],
outputs=['cls_score', 'bbox_pred'])
def forward_of_bbox_head(ctx, self, x):
"""Rewrite `forward` for default backend."""
return ctx.origin_func(self, x)
@ -22,7 +23,7 @@ def forward_of_bbox_head(ctx, self, x):
func_name='mmdet.models.roi_heads.BBoxHead.get_bboxes')
def get_bboxes_of_bbox_head(ctx, self, rois, cls_score, bbox_pred, img_shape,
cfg, **kwargs):
"""Rewrite `get_bboxes` for default backend."""
assert rois.ndim == 3, 'Only support export two stage ' \
'model to ONNX ' \
'with batch dimension. '

View File

@ -59,6 +59,7 @@ def forward_of_single_roi_extractor_static(ctx,
feats,
rois,
roi_scale_factor=None):
"""Rewrite `forward` for TensorRT backend."""
featmap_strides = self.featmap_strides
finest_scale = self.finest_scale
@ -83,6 +84,7 @@ def forward_of_single_roi_extractor_dynamic(ctx,
feats,
rois,
roi_scale_factor=None):
"""Rewrite `forward` for default backend."""
out_size = self.roi_layers[0].output_size
num_levels = len(feats)
roi_feats = feats[0].new_zeros(rois.shape[0], self.out_channels, *out_size)

View File

@ -5,6 +5,7 @@ from mmdeploy.core import FUNCTION_REWRITER
func_name='mmdet.models.roi_heads.StandardRoIHead.simple_test')
def simple_test_of_standard_roi_head(ctx, self, x, proposals, img_metas,
**kwargs):
"""Rewrite `simple_test` for default backend."""
assert self.with_bbox, 'Bbox head must be implemented.'
det_bboxes, det_labels = self.simple_test_bboxes(
x, img_metas, proposals, self.test_cfg, rescale=False)

View File

@ -8,6 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER
BBoxTestMixin.simple_test_bboxes')
def simple_test_bboxes_of_bbox_test_mixin(ctx, self, x, img_metas, proposals,
rcnn_test_cfg, **kwargs):
"""Rewrite `simple_test_bboxes` for default backend."""
rois = proposals
batch_index = torch.arange(
rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand(

View File

@ -41,7 +41,9 @@ class SRCNNWrapper(nn.Module):
align_corners=False)
def forward(self, *args, **kwargs):
"""Run forward."""
return self._module(*args, **kwargs)
def init_weights(self, *args, **kwargs):
"""Initialize weights."""
return self._module.init_weights(*args, **kwargs)

View File

@ -9,6 +9,7 @@ def simple_test_of_single_stage_text_detector(ctx,
img_metas,
rescale=False,
**kwargs):
"""Rewrite `simple_test` for default backend."""
x = self.extract_feat(img)
outs = self.bbox_head(x)
return outs

View File

@ -6,6 +6,8 @@ from mmdeploy.core import FUNCTION_REWRITER
'.BidirectionalLSTM.forward',
backend='ncnn')
def forward_of_bidirectionallstm(ctx, self, input):
"""Rewrite `forward` for NCNN backend."""
self.rnn.batch_first = True
recurrent, _ = self.rnn(input)
self.rnn.batch_first = False

View File

@ -12,6 +12,7 @@ def forward_of_base_recognizer(ctx,
img_metas=None,
return_loss=False,
**kwargs):
"""Rewrite `forward` for NCNN backend."""
if img_metas is None:
img_metas = {}
assert isinstance(img_metas, dict)

View File

@ -6,6 +6,7 @@ from mmdeploy.core import FUNCTION_REWRITER
backend='ncnn')
def forward_train_of_crnndecoder(ctx, self, feat, out_enc, targets_dict,
img_metas):
"""Rewrite `forward_train` for NCNN backend."""
assert feat.size(2) == 1, 'feature height must be 1'
if self.rnn_flag:
x = feat.squeeze(2) # [N, C, W]

View File

@ -5,6 +5,7 @@ from mmdeploy.core import FUNCTION_REWRITER
func_name='mmocr.models.textrecog.EncodeDecodeRecognizer.simple_test')
def simple_test_of_encode_decode_recognizer(ctx, self, img, img_metas,
**kwargs):
"""Rewrite `forward` for default backend."""
feat = self.extract_feat(img)
out_enc = None

View File

@ -8,6 +8,7 @@ from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.decode_heads.ASPPHead.forward')
def forward_of_aspp_head(ctx, self, inputs):
"""Rewrite `forward` for default backend."""
x = self._transform_inputs(inputs)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)

View File

@ -7,6 +7,7 @@ from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.decode_heads.psp_head.PPM.forward')
def forward_of_ppm(ctx, self, x):
"""Rewrite `forward` for default backend."""
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
# get origin input shape as tensor to support onnx dynamic shape

View File

@ -7,6 +7,7 @@ from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.BaseSegmentor.forward')
def forward_of_base_segmentor(ctx, self, img, img_metas=None, **kwargs):
"""Rewrite `forward` for default backend."""
if img_metas is None:
img_metas = {}
assert isinstance(img_metas, dict)

View File

@ -7,6 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.EncoderDecoder.simple_test')
def simple_test_of_encoder_decoder(ctx, self, img, img_meta, **kwargs):
"""Rewrite `simple_test` for default backend."""
x = self.extract_feat(img)
seg_logit = self._decode_head_forward_test(x, img_meta)
seg_logit = resize(

View File

@ -6,6 +6,8 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.__getattribute__', backend='ncnn')
def getattribute_static(ctx, self, name):
"""Rewrite `__getattribute__` for NCNN backend."""
ret = ctx.origin_func(self, name)
if name == 'shape':
ret = torch.Size([int(s) for s in ret])

View File

@ -15,6 +15,7 @@ def group_norm_ncnn(
bias: Union[torch.Tensor, torch.NoneType] = None,
eps: float = 1e-05,
) -> torch.Tensor:
"""Rewrite `group_norm` for NCNN backend."""
input_shape = input.shape
batch_size = input_shape[0]
# We cannot use input.reshape(batch_size, num_groups, -1, 1)

View File

@ -10,6 +10,8 @@ def interpolate_static(ctx,
mode='nearest',
align_corners=None,
recompute_scale_factor=None):
"""Rewrite `interpolate` for NCNN backend."""
input_size = input.shape
if scale_factor is None:
scale_factor = [

View File

@ -13,6 +13,8 @@ def linear_ncnn(
weight: torch.Tensor,
bias: Union[torch.Tensor, torch.NoneType] = None,
):
"""Rewrite `linear` for NCNN backend."""
origin_func = ctx.origin_func
dim = input.dim()

View File

@ -4,6 +4,8 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.repeat', backend='tensorrt')
def repeat_static(ctx, input, *size):
"""Rewrite `repeat` for NCNN backend."""
origin_func = ctx.origin_func
if input.dim() == 1 and len(size) == 1:
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)

View File

@ -6,6 +6,8 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.size', backend='ncnn')
def size_of_tensor_static(ctx, self, *args):
"""Rewrite `size` for NCNN backend."""
ret = ctx.origin_func(self, *args)
if isinstance(ret, torch.Tensor):
ret = int(ret)

View File

@ -7,6 +7,8 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.topk', backend='default')
def topk_dynamic(ctx, input, k, dim=None, largest=True, sorted=True):
"""Rewrite `interpolate` for default backend."""
if dim is None:
dim = int(input.ndim - 1)
size = input.shape[dim]
@ -24,6 +26,8 @@ def topk_dynamic(ctx, input, k, dim=None, largest=True, sorted=True):
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.topk', backend='tensorrt')
def topk_static(ctx, input, k, dim=None, largest=True, sorted=True):
"""Rewrite `interpolate` for TensorRT backend."""
if dim is None:
dim = int(input.ndim - 1)
size = input.shape[dim]

View File

@ -8,6 +8,7 @@ from mmdeploy.core import SYMBOLIC_REGISTER
def _adaptive_pool(name, type, tuple_fn, fn=None):
"""Generic adaptive pooling."""
@parse_args('v', 'is')
def symbolic_fn(g, input, output_size):
@ -53,14 +54,17 @@ adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', 'AveragePool',
@SYMBOLIC_REGISTER.register_symbolic('adaptive_avg_pool1d', is_pytorch=True)
def adaptive_avg_pool1d_op(ctx, *args):
"""Register default symbolic function for `adaptive_avg_pool1d`."""
return adaptive_avg_pool1d(*args)
@SYMBOLIC_REGISTER.register_symbolic('adaptive_avg_pool2d', is_pytorch=True)
def adaptive_avg_pool2d_op(ctx, *args):
"""Register default symbolic function for `adaptive_avg_pool2d`."""
return adaptive_avg_pool2d(*args)
@SYMBOLIC_REGISTER.register_symbolic('adaptive_avg_pool3d', is_pytorch=True)
def adaptive_avg_pool3d_op(ctx, *args):
"""Register default symbolic function for `adaptive_avg_pool3d`."""
return adaptive_avg_pool3d(*args)

View File

@ -10,6 +10,7 @@ def grid_sampler(g,
interpolation_mode,
padding_mode,
align_corners=False):
"""Symbolic function for `grid_sampler`."""
return g.op(
'mmcv::grid_sampler',
input,
@ -21,4 +22,5 @@ def grid_sampler(g,
@SYMBOLIC_REGISTER.register_symbolic('grid_sampler', is_pytorch=True)
def grid_sampler_default(ctx, *args):
"""Register default symbolic function for `grid_sampler`."""
return grid_sampler(*args)

View File

@ -11,6 +11,7 @@ from mmdeploy.core import SYMBOLIC_REGISTER
@parse_args('v', 'i', 'v', 'v', 'f', 'i')
def instance_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
"""Symbolic function for `instance_norm`."""
channel_size = _get_tensor_dim_size(input, 1)
if channel_size is not None:
assert channel_size % num_groups == 0
@ -60,8 +61,12 @@ def instance_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
_unsqueeze_helper(g, bias, axes))
# instance normalization is implemented in group norm in pytorch
@SYMBOLIC_REGISTER.register_symbolic(
'group_norm', backend='tensorrt', is_pytorch=True)
def instance_norm_trt(ctx, *args):
"""Register symbolic function for TensorRT backend.
Notes:
Instance normalization is implemented in group norm in pytorch.
"""
return instance_norm(*args)

View File

@ -5,6 +5,7 @@ from mmdeploy.core import SYMBOLIC_REGISTER
@SYMBOLIC_REGISTER.register_symbolic('squeeze', is_pytorch=True)
def squeeze_default(ctx, g, self, dim=None):
"""Register default symbolic function for `squeeze`."""
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):

View File

@ -2,6 +2,7 @@ from enum import Enum
class AdvancedEnum(Enum):
"""Define an enumeration class."""
@classmethod
def get(cls, str, a):
@ -12,6 +13,7 @@ class AdvancedEnum(Enum):
class Task(AdvancedEnum):
"""Define task enumerations."""
TEXT_DETECTION = 'TextDetection'
TEXT_RECOGNITION = 'TextRecognition'
SEGMENTATION = 'Segmentation'
@ -21,6 +23,7 @@ class Task(AdvancedEnum):
class Codebase(AdvancedEnum):
"""Define codebase enumerations."""
MMDET = 'mmdet'
MMSEG = 'mmseg'
MMCLS = 'mmcls'
@ -29,6 +32,7 @@ class Codebase(AdvancedEnum):
class Backend(AdvancedEnum):
"""Define backend enumerations."""
PYTORCH = 'pytorch'
TENSORRT = 'tensorrt'
ONNXRUNTIME = 'onnxruntime'

View File

@ -12,6 +12,7 @@ from mmdeploy.utils import Backend, get_backend, get_onnx_config
class WrapFunction(nn.Module):
"""Simple wrapper for a function."""
def __init__(self, wrapped_function, **kwargs):
super(WrapFunction, self).__init__()

View File

@ -4,7 +4,15 @@ __version__ = '0.1.0'
short_version = __version__
def parse_version_info(version_str):
def parse_version_info(version_str: str):
"""Parse version from a string.
Args:
version_str (str): A string represents a version info.
Returns:
tuple: A sequence of integer and string represents version.
"""
version_info = []
for x in version_str.split('.'):
if x.isdigit():