pass img_metas while exporting to onnx (#681)

* pass img_metas while exporting to onnx

* remove try-catch in tools for beter debugging

* use get

* fix typo
This commit is contained in:
RunningLeon 2022-06-30 17:33:24 +08:00 committed by GitHub
parent 5195ff9388
commit 17a7d60c7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 38 additions and 47 deletions

View File

@ -61,6 +61,7 @@ def torch2onnx(img: Any,
torch_model = task_processor.init_pytorch_model(model_checkpoint) torch_model = task_processor.init_pytorch_model(model_checkpoint)
data, model_inputs = task_processor.create_input(img, input_shape) data, model_inputs = task_processor.create_input(img, input_shape)
input_metas = dict(img_metas=data.get('img_metas', None))
if not isinstance(model_inputs, torch.Tensor) and len(model_inputs) == 1: if not isinstance(model_inputs, torch.Tensor) and len(model_inputs) == 1:
model_inputs = model_inputs[0] model_inputs = model_inputs[0]
@ -87,6 +88,7 @@ def torch2onnx(img: Any,
export( export(
torch_model, torch_model,
model_inputs, model_inputs,
input_metas=input_metas,
output_path_prefix=output_prefix, output_path_prefix=output_prefix,
backend=backend, backend=backend,
input_names=input_names, input_names=input_names,

View File

@ -6,7 +6,7 @@ from mmdeploy.core import FUNCTION_REWRITER
'mmcls.models.classifiers.ImageClassifier.forward', backend='default') 'mmcls.models.classifiers.ImageClassifier.forward', backend='default')
@FUNCTION_REWRITER.register_rewriter( @FUNCTION_REWRITER.register_rewriter(
'mmcls.models.classifiers.BaseClassifier.forward', backend='default') 'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
def base_classifier__forward(ctx, self, img, *args, **kwargs): def base_classifier__forward(ctx, self, img, return_loss=False, **kwargs):
"""Rewrite `forward` of BaseClassifier for default backend. """Rewrite `forward` of BaseClassifier for default backend.
Rewrite this function to call simple_test function, Rewrite this function to call simple_test function,
@ -23,5 +23,5 @@ def base_classifier__forward(ctx, self, img, *args, **kwargs):
result(Tensor): The result of classifier.The tensor result(Tensor): The result of classifier.The tensor
shape (batch_size,num_classes). shape (batch_size,num_classes).
""" """
result = self.simple_test(img, {}) result = self.simple_test(img, **kwargs)
return result return result

View File

@ -7,14 +7,13 @@ from mmdeploy.utils import is_dynamic_shape
@mark( @mark(
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks']) 'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
def __forward_impl(ctx, self, img, img_metas=None, **kwargs): def __forward_impl(ctx, self, img, img_metas, **kwargs):
"""Rewrite and adding mark for `forward`. """Rewrite and adding mark for `forward`.
Encapsulate this function for rewriting `forward` of BaseDetector. Encapsulate this function for rewriting `forward` of BaseDetector.
1. Add mark for BaseDetector. 1. Add mark for BaseDetector.
2. Support both dynamic and static export to onnx. 2. Support both dynamic and static export to onnx.
""" """
assert isinstance(img_metas, dict)
assert isinstance(img, torch.Tensor) assert isinstance(img, torch.Tensor)
deploy_cfg = ctx.cfg deploy_cfg = ctx.cfg
@ -23,14 +22,18 @@ def __forward_impl(ctx, self, img, img_metas=None, **kwargs):
img_shape = torch._shape_as_tensor(img)[2:] img_shape = torch._shape_as_tensor(img)[2:]
if not is_dynamic_flag: if not is_dynamic_flag:
img_shape = [int(val) for val in img_shape] img_shape = [int(val) for val in img_shape]
img_metas['img_shape'] = img_shape img_metas[0]['img_shape'] = img_shape
img_metas = [img_metas]
return self.simple_test(img, img_metas, **kwargs) return self.simple_test(img, img_metas, **kwargs)
@FUNCTION_REWRITER.register_rewriter( @FUNCTION_REWRITER.register_rewriter(
'mmdet.models.detectors.base.BaseDetector.forward') 'mmdet.models.detectors.base.BaseDetector.forward')
def base_detector__forward(ctx, self, img, img_metas=None, **kwargs): def base_detector__forward(ctx,
self,
img,
img_metas=None,
return_loss=False,
**kwargs):
"""Rewrite `forward` of BaseDetector for default backend. """Rewrite `forward` of BaseDetector for default backend.
Rewrite this function to: Rewrite this function to:
@ -56,14 +59,12 @@ def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
corresponds to each class. corresponds to each class.
""" """
if img_metas is None: if img_metas is None:
img_metas = {} img_metas = [{}]
else:
while isinstance(img_metas, list): assert len(img_metas) == 1, 'do not support aug_test'
img_metas = img_metas[0] img_metas = img_metas[0]
if isinstance(img, list): if isinstance(img, list):
img = torch.cat(img, 0) img = img[0]
if 'return_loss' in kwargs:
kwargs.pop('return_loss')
return __forward_impl(ctx, self, img, img_metas=img_metas, **kwargs) return __forward_impl(ctx, self, img, img_metas=img_metas, **kwargs)

View File

@ -23,12 +23,12 @@ def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
torch.Tensor: Output segmentation map pf shape [N, 1, H, W]. torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
""" """
if img_metas is None: if img_metas is None:
img_metas = {} img_metas = [{}]
while isinstance(img_metas, list): else:
assert len(img_metas) == 1, 'do not support aug_test'
img_metas = img_metas[0] img_metas = img_metas[0]
if isinstance(img, list): if isinstance(img, list):
img = torch.cat(img, 0) img = img[0]
assert isinstance(img, torch.Tensor) assert isinstance(img, torch.Tensor)
deploy_cfg = ctx.cfg deploy_cfg = ctx.cfg
@ -37,5 +37,5 @@ def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
img_shape = img.shape[2:] img_shape = img.shape[2:]
if not is_dynamic_flag: if not is_dynamic_flag:
img_shape = [int(val) for val in img_shape] img_shape = [int(val) for val in img_shape]
img_metas['img_shape'] = img_shape img_metas[0]['img_shape'] = img_shape
return self.simple_test(img, img_metas, **kwargs) return self.simple_test(img, img_metas, **kwargs)

View File

@ -73,7 +73,7 @@ def test_baseclassifier_forward():
def forward_train(self, imgs): def forward_train(self, imgs):
return 'train' return 'train'
def simple_test(self, img, tmp, **kwargs): def simple_test(self, img, tmp=None, **kwargs):
return 'simple_test' return 'simple_test'
model = DummyClassifier().eval() model = DummyClassifier().eval()

View File

@ -28,12 +28,8 @@ def main():
output_prefix = args.output_prefix output_prefix = args.output_prefix
logger.info(f'onnx2ncnn: \n\tonnx_path: {onnx_path} ') logger.info(f'onnx2ncnn: \n\tonnx_path: {onnx_path} ')
try:
from_onnx(onnx_path, output_prefix) from_onnx(onnx_path, output_prefix)
logger.info('onnx2ncnn success.') logger.info('onnx2ncnn success.')
except Exception as e:
logger.error(e)
logger.error('onnx2ncnn failed.')
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -49,15 +49,11 @@ def main():
if isinstance(input_shapes[0], int): if isinstance(input_shapes[0], int):
input_shapes = [input_shapes] input_shapes = [input_shapes]
logger.info(f'onnx2ppl: \n\tonnx_path: {onnx_path} ' logger.info(f'onnx2pplnn: \n\tonnx_path: {onnx_path} '
f'\n\toutput_prefix: {output_prefix}' f'\n\toutput_prefix: {output_prefix}'
f'\n\topt_shapes: {input_shapes}') f'\n\topt_shapes: {input_shapes}')
try:
from_onnx(onnx_path, output_prefix, device, input_shapes) from_onnx(onnx_path, output_prefix, device, input_shapes)
logger.info('onnx2tpplnn success.') logger.info('onnx2pplnn success.')
except Exception as e:
logger.error(e)
logger.error('onnx2tpplnn failed.')
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -55,7 +55,6 @@ def main():
logger.info(f'onnx2tensorrt: \n\tonnx_path: {onnx_path} ' logger.info(f'onnx2tensorrt: \n\tonnx_path: {onnx_path} '
f'\n\tdeploy_cfg: {deploy_cfg_path}') f'\n\tdeploy_cfg: {deploy_cfg_path}')
try:
from_onnx( from_onnx(
onnx_path, onnx_path,
output_prefix, output_prefix,
@ -68,9 +67,6 @@ def main():
device_id=device_id) device_id=device_id)
logger.info('onnx2tensorrt success.') logger.info('onnx2tensorrt success.')
except Exception as e:
logger.error(e)
logger.error('onnx2tensorrt failed.')
if __name__ == '__main__': if __name__ == '__main__':