mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
pass img_metas while exporting to onnx (#681)
* pass img_metas while exporting to onnx * remove try-catch in tools for beter debugging * use get * fix typo
This commit is contained in:
parent
5195ff9388
commit
17a7d60c7d
@ -61,6 +61,7 @@ def torch2onnx(img: Any,
|
|||||||
|
|
||||||
torch_model = task_processor.init_pytorch_model(model_checkpoint)
|
torch_model = task_processor.init_pytorch_model(model_checkpoint)
|
||||||
data, model_inputs = task_processor.create_input(img, input_shape)
|
data, model_inputs = task_processor.create_input(img, input_shape)
|
||||||
|
input_metas = dict(img_metas=data.get('img_metas', None))
|
||||||
if not isinstance(model_inputs, torch.Tensor) and len(model_inputs) == 1:
|
if not isinstance(model_inputs, torch.Tensor) and len(model_inputs) == 1:
|
||||||
model_inputs = model_inputs[0]
|
model_inputs = model_inputs[0]
|
||||||
|
|
||||||
@ -87,6 +88,7 @@ def torch2onnx(img: Any,
|
|||||||
export(
|
export(
|
||||||
torch_model,
|
torch_model,
|
||||||
model_inputs,
|
model_inputs,
|
||||||
|
input_metas=input_metas,
|
||||||
output_path_prefix=output_prefix,
|
output_path_prefix=output_prefix,
|
||||||
backend=backend,
|
backend=backend,
|
||||||
input_names=input_names,
|
input_names=input_names,
|
||||||
|
@ -6,7 +6,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||||||
'mmcls.models.classifiers.ImageClassifier.forward', backend='default')
|
'mmcls.models.classifiers.ImageClassifier.forward', backend='default')
|
||||||
@FUNCTION_REWRITER.register_rewriter(
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
|
'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
|
||||||
def base_classifier__forward(ctx, self, img, *args, **kwargs):
|
def base_classifier__forward(ctx, self, img, return_loss=False, **kwargs):
|
||||||
"""Rewrite `forward` of BaseClassifier for default backend.
|
"""Rewrite `forward` of BaseClassifier for default backend.
|
||||||
|
|
||||||
Rewrite this function to call simple_test function,
|
Rewrite this function to call simple_test function,
|
||||||
@ -23,5 +23,5 @@ def base_classifier__forward(ctx, self, img, *args, **kwargs):
|
|||||||
result(Tensor): The result of classifier.The tensor
|
result(Tensor): The result of classifier.The tensor
|
||||||
shape (batch_size,num_classes).
|
shape (batch_size,num_classes).
|
||||||
"""
|
"""
|
||||||
result = self.simple_test(img, {})
|
result = self.simple_test(img, **kwargs)
|
||||||
return result
|
return result
|
||||||
|
@ -7,14 +7,13 @@ from mmdeploy.utils import is_dynamic_shape
|
|||||||
|
|
||||||
@mark(
|
@mark(
|
||||||
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
|
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
|
||||||
def __forward_impl(ctx, self, img, img_metas=None, **kwargs):
|
def __forward_impl(ctx, self, img, img_metas, **kwargs):
|
||||||
"""Rewrite and adding mark for `forward`.
|
"""Rewrite and adding mark for `forward`.
|
||||||
|
|
||||||
Encapsulate this function for rewriting `forward` of BaseDetector.
|
Encapsulate this function for rewriting `forward` of BaseDetector.
|
||||||
1. Add mark for BaseDetector.
|
1. Add mark for BaseDetector.
|
||||||
2. Support both dynamic and static export to onnx.
|
2. Support both dynamic and static export to onnx.
|
||||||
"""
|
"""
|
||||||
assert isinstance(img_metas, dict)
|
|
||||||
assert isinstance(img, torch.Tensor)
|
assert isinstance(img, torch.Tensor)
|
||||||
|
|
||||||
deploy_cfg = ctx.cfg
|
deploy_cfg = ctx.cfg
|
||||||
@ -23,14 +22,18 @@ def __forward_impl(ctx, self, img, img_metas=None, **kwargs):
|
|||||||
img_shape = torch._shape_as_tensor(img)[2:]
|
img_shape = torch._shape_as_tensor(img)[2:]
|
||||||
if not is_dynamic_flag:
|
if not is_dynamic_flag:
|
||||||
img_shape = [int(val) for val in img_shape]
|
img_shape = [int(val) for val in img_shape]
|
||||||
img_metas['img_shape'] = img_shape
|
img_metas[0]['img_shape'] = img_shape
|
||||||
img_metas = [img_metas]
|
|
||||||
return self.simple_test(img, img_metas, **kwargs)
|
return self.simple_test(img, img_metas, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@FUNCTION_REWRITER.register_rewriter(
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
'mmdet.models.detectors.base.BaseDetector.forward')
|
'mmdet.models.detectors.base.BaseDetector.forward')
|
||||||
def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
|
def base_detector__forward(ctx,
|
||||||
|
self,
|
||||||
|
img,
|
||||||
|
img_metas=None,
|
||||||
|
return_loss=False,
|
||||||
|
**kwargs):
|
||||||
"""Rewrite `forward` of BaseDetector for default backend.
|
"""Rewrite `forward` of BaseDetector for default backend.
|
||||||
|
|
||||||
Rewrite this function to:
|
Rewrite this function to:
|
||||||
@ -56,14 +59,12 @@ def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
|
|||||||
corresponds to each class.
|
corresponds to each class.
|
||||||
"""
|
"""
|
||||||
if img_metas is None:
|
if img_metas is None:
|
||||||
img_metas = {}
|
img_metas = [{}]
|
||||||
|
else:
|
||||||
while isinstance(img_metas, list):
|
assert len(img_metas) == 1, 'do not support aug_test'
|
||||||
img_metas = img_metas[0]
|
img_metas = img_metas[0]
|
||||||
|
|
||||||
if isinstance(img, list):
|
if isinstance(img, list):
|
||||||
img = torch.cat(img, 0)
|
img = img[0]
|
||||||
|
|
||||||
if 'return_loss' in kwargs:
|
|
||||||
kwargs.pop('return_loss')
|
|
||||||
return __forward_impl(ctx, self, img, img_metas=img_metas, **kwargs)
|
return __forward_impl(ctx, self, img, img_metas=img_metas, **kwargs)
|
||||||
|
@ -23,12 +23,12 @@ def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
|
|||||||
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
|
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
|
||||||
"""
|
"""
|
||||||
if img_metas is None:
|
if img_metas is None:
|
||||||
img_metas = {}
|
img_metas = [{}]
|
||||||
while isinstance(img_metas, list):
|
else:
|
||||||
|
assert len(img_metas) == 1, 'do not support aug_test'
|
||||||
img_metas = img_metas[0]
|
img_metas = img_metas[0]
|
||||||
|
|
||||||
if isinstance(img, list):
|
if isinstance(img, list):
|
||||||
img = torch.cat(img, 0)
|
img = img[0]
|
||||||
assert isinstance(img, torch.Tensor)
|
assert isinstance(img, torch.Tensor)
|
||||||
|
|
||||||
deploy_cfg = ctx.cfg
|
deploy_cfg = ctx.cfg
|
||||||
@ -37,5 +37,5 @@ def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
|
|||||||
img_shape = img.shape[2:]
|
img_shape = img.shape[2:]
|
||||||
if not is_dynamic_flag:
|
if not is_dynamic_flag:
|
||||||
img_shape = [int(val) for val in img_shape]
|
img_shape = [int(val) for val in img_shape]
|
||||||
img_metas['img_shape'] = img_shape
|
img_metas[0]['img_shape'] = img_shape
|
||||||
return self.simple_test(img, img_metas, **kwargs)
|
return self.simple_test(img, img_metas, **kwargs)
|
||||||
|
@ -73,7 +73,7 @@ def test_baseclassifier_forward():
|
|||||||
def forward_train(self, imgs):
|
def forward_train(self, imgs):
|
||||||
return 'train'
|
return 'train'
|
||||||
|
|
||||||
def simple_test(self, img, tmp, **kwargs):
|
def simple_test(self, img, tmp=None, **kwargs):
|
||||||
return 'simple_test'
|
return 'simple_test'
|
||||||
|
|
||||||
model = DummyClassifier().eval()
|
model = DummyClassifier().eval()
|
||||||
|
@ -28,12 +28,8 @@ def main():
|
|||||||
output_prefix = args.output_prefix
|
output_prefix = args.output_prefix
|
||||||
|
|
||||||
logger.info(f'onnx2ncnn: \n\tonnx_path: {onnx_path} ')
|
logger.info(f'onnx2ncnn: \n\tonnx_path: {onnx_path} ')
|
||||||
try:
|
|
||||||
from_onnx(onnx_path, output_prefix)
|
from_onnx(onnx_path, output_prefix)
|
||||||
logger.info('onnx2ncnn success.')
|
logger.info('onnx2ncnn success.')
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
logger.error('onnx2ncnn failed.')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -49,15 +49,11 @@ def main():
|
|||||||
if isinstance(input_shapes[0], int):
|
if isinstance(input_shapes[0], int):
|
||||||
input_shapes = [input_shapes]
|
input_shapes = [input_shapes]
|
||||||
|
|
||||||
logger.info(f'onnx2ppl: \n\tonnx_path: {onnx_path} '
|
logger.info(f'onnx2pplnn: \n\tonnx_path: {onnx_path} '
|
||||||
f'\n\toutput_prefix: {output_prefix}'
|
f'\n\toutput_prefix: {output_prefix}'
|
||||||
f'\n\topt_shapes: {input_shapes}')
|
f'\n\topt_shapes: {input_shapes}')
|
||||||
try:
|
|
||||||
from_onnx(onnx_path, output_prefix, device, input_shapes)
|
from_onnx(onnx_path, output_prefix, device, input_shapes)
|
||||||
logger.info('onnx2tpplnn success.')
|
logger.info('onnx2pplnn success.')
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
logger.error('onnx2tpplnn failed.')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -55,7 +55,6 @@ def main():
|
|||||||
|
|
||||||
logger.info(f'onnx2tensorrt: \n\tonnx_path: {onnx_path} '
|
logger.info(f'onnx2tensorrt: \n\tonnx_path: {onnx_path} '
|
||||||
f'\n\tdeploy_cfg: {deploy_cfg_path}')
|
f'\n\tdeploy_cfg: {deploy_cfg_path}')
|
||||||
try:
|
|
||||||
from_onnx(
|
from_onnx(
|
||||||
onnx_path,
|
onnx_path,
|
||||||
output_prefix,
|
output_prefix,
|
||||||
@ -68,9 +67,6 @@ def main():
|
|||||||
device_id=device_id)
|
device_id=device_id)
|
||||||
|
|
||||||
logger.info('onnx2tensorrt success.')
|
logger.info('onnx2tensorrt success.')
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
logger.error('onnx2tensorrt failed.')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
x
Reference in New Issue
Block a user