parent
805eda81ea
commit
8558163753
|
@ -96,13 +96,16 @@ def _prepare_input_img(img_path,
|
|||
return mm_inputs
|
||||
|
||||
|
||||
def _update_input_img(img_list, img_meta_list):
|
||||
def _update_input_img(img_list, img_meta_list, update_ori_shape=False):
|
||||
# update img and its meta list
|
||||
N = img_list[0].size(0)
|
||||
N, C, H, W = img_list[0].shape
|
||||
img_meta = img_meta_list[0][0]
|
||||
img_shape = img_meta['img_shape']
|
||||
img_shape = (H, W, C)
|
||||
if update_ori_shape:
|
||||
ori_shape = img_shape
|
||||
else:
|
||||
ori_shape = img_meta['ori_shape']
|
||||
pad_shape = img_meta['pad_shape']
|
||||
pad_shape = img_shape
|
||||
new_img_meta_list = [[{
|
||||
'img_shape':
|
||||
img_shape,
|
||||
|
@ -220,7 +223,7 @@ def pytorch2onnx(model,
|
|||
|
||||
# update img_meta
|
||||
img_list, img_meta_list = _update_input_img(
|
||||
img_list, img_meta_list)
|
||||
img_list, img_meta_list, test_mode == 'whole')
|
||||
|
||||
# check the numerical value
|
||||
# get pytorch output
|
||||
|
|
Loading…
Reference in New Issue