pass img_metas while exporting to onnx (#681)
* pass img_metas while exporting to onnx * remove try-catch in tools for beter debugging * use get * fix typopull/649/head
parent
5195ff9388
commit
17a7d60c7d
|
@ -61,6 +61,7 @@ def torch2onnx(img: Any,
|
|||
|
||||
torch_model = task_processor.init_pytorch_model(model_checkpoint)
|
||||
data, model_inputs = task_processor.create_input(img, input_shape)
|
||||
input_metas = dict(img_metas=data.get('img_metas', None))
|
||||
if not isinstance(model_inputs, torch.Tensor) and len(model_inputs) == 1:
|
||||
model_inputs = model_inputs[0]
|
||||
|
||||
|
@ -87,6 +88,7 @@ def torch2onnx(img: Any,
|
|||
export(
|
||||
torch_model,
|
||||
model_inputs,
|
||||
input_metas=input_metas,
|
||||
output_path_prefix=output_prefix,
|
||||
backend=backend,
|
||||
input_names=input_names,
|
||||
|
|
|
@ -6,7 +6,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
'mmcls.models.classifiers.ImageClassifier.forward', backend='default')
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
|
||||
def base_classifier__forward(ctx, self, img, *args, **kwargs):
|
||||
def base_classifier__forward(ctx, self, img, return_loss=False, **kwargs):
|
||||
"""Rewrite `forward` of BaseClassifier for default backend.
|
||||
|
||||
Rewrite this function to call simple_test function,
|
||||
|
@ -23,5 +23,5 @@ def base_classifier__forward(ctx, self, img, *args, **kwargs):
|
|||
result(Tensor): The result of classifier.The tensor
|
||||
shape (batch_size,num_classes).
|
||||
"""
|
||||
result = self.simple_test(img, {})
|
||||
result = self.simple_test(img, **kwargs)
|
||||
return result
|
||||
|
|
|
@ -7,14 +7,13 @@ from mmdeploy.utils import is_dynamic_shape
|
|||
|
||||
@mark(
|
||||
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
|
||||
def __forward_impl(ctx, self, img, img_metas=None, **kwargs):
|
||||
def __forward_impl(ctx, self, img, img_metas, **kwargs):
|
||||
"""Rewrite and adding mark for `forward`.
|
||||
|
||||
Encapsulate this function for rewriting `forward` of BaseDetector.
|
||||
1. Add mark for BaseDetector.
|
||||
2. Support both dynamic and static export to onnx.
|
||||
"""
|
||||
assert isinstance(img_metas, dict)
|
||||
assert isinstance(img, torch.Tensor)
|
||||
|
||||
deploy_cfg = ctx.cfg
|
||||
|
@ -23,14 +22,18 @@ def __forward_impl(ctx, self, img, img_metas=None, **kwargs):
|
|||
img_shape = torch._shape_as_tensor(img)[2:]
|
||||
if not is_dynamic_flag:
|
||||
img_shape = [int(val) for val in img_shape]
|
||||
img_metas['img_shape'] = img_shape
|
||||
img_metas = [img_metas]
|
||||
img_metas[0]['img_shape'] = img_shape
|
||||
return self.simple_test(img, img_metas, **kwargs)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.detectors.base.BaseDetector.forward')
|
||||
def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
|
||||
def base_detector__forward(ctx,
|
||||
self,
|
||||
img,
|
||||
img_metas=None,
|
||||
return_loss=False,
|
||||
**kwargs):
|
||||
"""Rewrite `forward` of BaseDetector for default backend.
|
||||
|
||||
Rewrite this function to:
|
||||
|
@ -56,14 +59,12 @@ def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
|
|||
corresponds to each class.
|
||||
"""
|
||||
if img_metas is None:
|
||||
img_metas = {}
|
||||
|
||||
while isinstance(img_metas, list):
|
||||
img_metas = [{}]
|
||||
else:
|
||||
assert len(img_metas) == 1, 'do not support aug_test'
|
||||
img_metas = img_metas[0]
|
||||
|
||||
if isinstance(img, list):
|
||||
img = torch.cat(img, 0)
|
||||
img = img[0]
|
||||
|
||||
if 'return_loss' in kwargs:
|
||||
kwargs.pop('return_loss')
|
||||
return __forward_impl(ctx, self, img, img_metas=img_metas, **kwargs)
|
||||
|
|
|
@ -23,12 +23,12 @@ def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
|
|||
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
|
||||
"""
|
||||
if img_metas is None:
|
||||
img_metas = {}
|
||||
while isinstance(img_metas, list):
|
||||
img_metas = [{}]
|
||||
else:
|
||||
assert len(img_metas) == 1, 'do not support aug_test'
|
||||
img_metas = img_metas[0]
|
||||
|
||||
if isinstance(img, list):
|
||||
img = torch.cat(img, 0)
|
||||
img = img[0]
|
||||
assert isinstance(img, torch.Tensor)
|
||||
|
||||
deploy_cfg = ctx.cfg
|
||||
|
@ -37,5 +37,5 @@ def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
|
|||
img_shape = img.shape[2:]
|
||||
if not is_dynamic_flag:
|
||||
img_shape = [int(val) for val in img_shape]
|
||||
img_metas['img_shape'] = img_shape
|
||||
img_metas[0]['img_shape'] = img_shape
|
||||
return self.simple_test(img, img_metas, **kwargs)
|
||||
|
|
|
@ -73,7 +73,7 @@ def test_baseclassifier_forward():
|
|||
def forward_train(self, imgs):
|
||||
return 'train'
|
||||
|
||||
def simple_test(self, img, tmp, **kwargs):
|
||||
def simple_test(self, img, tmp=None, **kwargs):
|
||||
return 'simple_test'
|
||||
|
||||
model = DummyClassifier().eval()
|
||||
|
|
|
@ -28,12 +28,8 @@ def main():
|
|||
output_prefix = args.output_prefix
|
||||
|
||||
logger.info(f'onnx2ncnn: \n\tonnx_path: {onnx_path} ')
|
||||
try:
|
||||
from_onnx(onnx_path, output_prefix)
|
||||
logger.info('onnx2ncnn success.')
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error('onnx2ncnn failed.')
|
||||
from_onnx(onnx_path, output_prefix)
|
||||
logger.info('onnx2ncnn success.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -49,15 +49,11 @@ def main():
|
|||
if isinstance(input_shapes[0], int):
|
||||
input_shapes = [input_shapes]
|
||||
|
||||
logger.info(f'onnx2ppl: \n\tonnx_path: {onnx_path} '
|
||||
logger.info(f'onnx2pplnn: \n\tonnx_path: {onnx_path} '
|
||||
f'\n\toutput_prefix: {output_prefix}'
|
||||
f'\n\topt_shapes: {input_shapes}')
|
||||
try:
|
||||
from_onnx(onnx_path, output_prefix, device, input_shapes)
|
||||
logger.info('onnx2tpplnn success.')
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error('onnx2tpplnn failed.')
|
||||
from_onnx(onnx_path, output_prefix, device, input_shapes)
|
||||
logger.info('onnx2pplnn success.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -55,22 +55,18 @@ def main():
|
|||
|
||||
logger.info(f'onnx2tensorrt: \n\tonnx_path: {onnx_path} '
|
||||
f'\n\tdeploy_cfg: {deploy_cfg_path}')
|
||||
try:
|
||||
from_onnx(
|
||||
onnx_path,
|
||||
output_prefix,
|
||||
input_shapes=final_params['input_shapes'],
|
||||
log_level=get_trt_log_level(),
|
||||
fp16_mode=final_params.get('fp16_mode', False),
|
||||
int8_mode=final_params.get('int8_mode', False),
|
||||
int8_param=int8_param,
|
||||
max_workspace_size=final_params.get('max_workspace_size', 0),
|
||||
device_id=device_id)
|
||||
from_onnx(
|
||||
onnx_path,
|
||||
output_prefix,
|
||||
input_shapes=final_params['input_shapes'],
|
||||
log_level=get_trt_log_level(),
|
||||
fp16_mode=final_params.get('fp16_mode', False),
|
||||
int8_mode=final_params.get('int8_mode', False),
|
||||
int8_param=int8_param,
|
||||
max_workspace_size=final_params.get('max_workspace_size', 0),
|
||||
device_id=device_id)
|
||||
|
||||
logger.info('onnx2tensorrt success.')
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error('onnx2tensorrt failed.')
|
||||
logger.info('onnx2tensorrt success.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue