diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 5660ed900..14f25056d 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -96,13 +96,16 @@ def _prepare_input_img(img_path, return mm_inputs -def _update_input_img(img_list, img_meta_list): +def _update_input_img(img_list, img_meta_list, update_ori_shape=False): # update img and its meta list - N = img_list[0].size(0) + N, C, H, W = img_list[0].shape img_meta = img_meta_list[0][0] - img_shape = img_meta['img_shape'] - ori_shape = img_meta['ori_shape'] - pad_shape = img_meta['pad_shape'] + img_shape = (H, W, C) + if update_ori_shape: + ori_shape = img_shape + else: + ori_shape = img_meta['ori_shape'] + pad_shape = img_shape new_img_meta_list = [[{ 'img_shape': img_shape, @@ -220,7 +223,7 @@ def pytorch2onnx(model, # update img_meta img_list, img_meta_list = _update_input_img( - img_list, img_meta_list) + img_list, img_meta_list, test_mode == 'whole') # check the numerical value # get pytorch output