[Fix] Fix some inferencer bugs (#1706)

* [Fix] Fix some inferencer bugs

* fix
This commit is contained in:
Tong Gao 2023-02-09 18:31:25 +08:00 committed by GitHub
parent d8e615921d
commit 20a87d476c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 31 additions and 14 deletions

View File

@ -24,5 +24,5 @@ test_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadKIEAnnotations'), dict(type='LoadKIEAnnotations'),
dict(type='Resize', scale=(1024, 512), keep_ratio=True), dict(type='Resize', scale=(1024, 512), keep_ratio=True),
dict(type='PackKIEInputs'), dict(type='PackKIEInputs', meta_keys=('img_path', )),
] ]

View File

@ -80,6 +80,7 @@ class BaseMMOCRInferencer(BaseInferencer):
Args: Args:
inputs (InputsType): Inputs for the inferencer. It can be a path inputs (InputsType): Inputs for the inferencer. It can be a path
to image / image directory, or an array, or a list of these. to image / image directory, or an array, or a list of these.
Note: If it's an numpy array, it should be in BGR order.
return_datasamples (bool): Whether to return results as return_datasamples (bool): Whether to return results as
:obj:`BaseDataElement`. Defaults to False. :obj:`BaseDataElement`. Defaults to False.
batch_size (int): Inference batch size. Defaults to 1. batch_size (int): Inference batch size. Defaults to 1.
@ -206,7 +207,7 @@ class BaseMMOCRInferencer(BaseInferencer):
img = mmcv.imfrombytes(img_bytes, channel_order='rgb') img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
img_name = osp.basename(single_input) img_name = osp.basename(single_input)
elif isinstance(single_input, np.ndarray): elif isinstance(single_input, np.ndarray):
img = single_input.copy() img = single_input.copy()[:, :, ::-1] # to RGB
img_num = str(self.num_visualized_imgs).zfill(8) img_num = str(self.num_visualized_imgs).zfill(8)
img_name = f'{img_num}.jpg' img_name = f'{img_num}.jpg'
else: else:

View File

@ -77,6 +77,13 @@ class KIEInferencer(BaseMMOCRInferencer):
self.novisual = all( self.novisual = all(
self._get_transform_idx(pipeline_cfg, t) == -1 self._get_transform_idx(pipeline_cfg, t) == -1
for t in self.loading_transforms) for t in self.loading_transforms)
# Remove Resize from test_pipeline, since SDMGR requires bbox
# annotations to be resized together with pictures, but visualization
# loads the original image from the disk.
# TODO: find a more elegant way to fix this
idx = self._get_transform_idx(pipeline_cfg, 'Resize')
if idx != -1:
pipeline_cfg.pop(idx)
# If it's in non-visual mode, self.pipeline will be specified. # If it's in non-visual mode, self.pipeline will be specified.
# Otherwise, file_pipeline and ndarray_pipeline will be specified. # Otherwise, file_pipeline and ndarray_pipeline will be specified.
if self.novisual: if self.novisual:
@ -93,6 +100,7 @@ class KIEInferencer(BaseMMOCRInferencer):
- img (str or ndarray): Path to the image or the image itself. If KIE - img (str or ndarray): Path to the image or the image itself. If KIE
Inferencer is used in no-visual mode, this key is not required. Inferencer is used in no-visual mode, this key is not required.
Note: If it's an numpy array, it should be in BGR order.
- img_shape (tuple(int, int)): Image shape in (H, W). In - img_shape (tuple(int, int)): Image shape in (H, W). In
- instances (list[dict]): A list of instances. - instances (list[dict]): A list of instances.
- bbox (ndarray(dtype=np.float32)): Shape (4, ). Bounding box. - bbox (ndarray(dtype=np.float32)): Shape (4, ). Bounding box.
@ -182,6 +190,7 @@ class KIEInferencer(BaseMMOCRInferencer):
- img (str or ndarray): Path to the image or the image itself. If KIE - img (str or ndarray): Path to the image or the image itself. If KIE
Inferencer is used in no-visual mode, this key is not required. Inferencer is used in no-visual mode, this key is not required.
Note: If it's an numpy array, it should be in BGR order.
- img_shape (tuple(int, int)): Image shape in (H, W). In - img_shape (tuple(int, int)): Image shape in (H, W). In
- instances (list[dict]): A list of instances. - instances (list[dict]): A list of instances.
- bbox (ndarray(dtype=np.float32)): Shape (4, ). Bounding box. - bbox (ndarray(dtype=np.float32)): Shape (4, ). Bounding box.
@ -286,10 +295,10 @@ class KIEInferencer(BaseMMOCRInferencer):
assert 'img' in single_input or 'img_shape' in single_input assert 'img' in single_input or 'img_shape' in single_input
if 'img' in single_input: if 'img' in single_input:
if isinstance(single_input['img'], str): if isinstance(single_input['img'], str):
img = mmcv.imread(single_input['img']) img = mmcv.imread(single_input['img'], channel_order='rgb')
img_name = osp.basename(single_input['img']) img_name = osp.basename(single_input['img'])
elif isinstance(single_input['img'], np.ndarray): elif isinstance(single_input['img'], np.ndarray):
img = single_input['img'].copy() img = single_input['img'].copy()[:, :, ::-1] # To RGB
img_name = f'{img_num}.jpg' img_name = f'{img_num}.jpg'
elif 'img_shape' in single_input: elif 'img_shape' in single_input:
img = np.zeros(single_input['img_shape'], dtype=np.uint8) img = np.zeros(single_input['img_shape'], dtype=np.uint8)

View File

@ -46,7 +46,7 @@ def parse_args():
help='Pretrained key information extraction algorithm. It\'s the path' help='Pretrained key information extraction algorithm. It\'s the path'
'to the config file or the model name defined in metafile.') 'to the config file or the model name defined in metafile.')
parser.add_argument( parser.add_argument(
'--kie-ckpt', '--kie-weights',
type=str, type=str,
default=None, default=None,
help='Path to the custom checkpoint file of the selected kie model. ' help='Path to the custom checkpoint file of the selected kie model. '
@ -77,7 +77,8 @@ def parse_args():
call_args = vars(parser.parse_args()) call_args = vars(parser.parse_args())
init_kws = [ init_kws = [
'det', 'det_weights', 'rec', 'rec_weights', 'kie', 'kie_ckpt', 'device' 'det', 'det_weights', 'rec', 'rec_weights', 'kie', 'kie_weights',
'device'
] ]
init_args = {} init_args = {}
for init_kw in init_kws: for init_kw in init_kws:

View File

@ -54,8 +54,9 @@ class TestTextDetinferencer(TestCase):
res_ndarray = self.inferencer(img, return_vis=True) res_ndarray = self.inferencer(img, return_vis=True)
self.assert_predictions_equal(res_path['predictions'], self.assert_predictions_equal(res_path['predictions'],
res_ndarray['predictions']) res_ndarray['predictions'])
self.assertIn('visualization', res_path) self.assertTrue(
self.assertIn('visualization', res_ndarray) np.allclose(res_path['visualization'],
res_ndarray['visualization']))
# multiple images # multiple images
img_paths = [ img_paths = [
@ -68,8 +69,10 @@ class TestTextDetinferencer(TestCase):
res_ndarray = self.inferencer(imgs, return_vis=True) res_ndarray = self.inferencer(imgs, return_vis=True)
self.assert_predictions_equal(res_path['predictions'], self.assert_predictions_equal(res_path['predictions'],
res_ndarray['predictions']) res_ndarray['predictions'])
self.assertIn('visualization', res_path) for i in range(len(img_paths)):
self.assertIn('visualization', res_ndarray) self.assertTrue(
np.allclose(res_path['visualization'][i],
res_ndarray['visualization'][i]))
# img dir, test different batch sizes # img dir, test different batch sizes
img_dir = 'tests/data/det_toy_dataset/imgs/test/' img_dir = 'tests/data/det_toy_dataset/imgs/test/'

View File

@ -52,8 +52,9 @@ class TestTextRecinferencer(TestCase):
res_ndarray = self.inferencer(img, return_vis=True) res_ndarray = self.inferencer(img, return_vis=True)
self.assert_predictions_equal(res_path['predictions'], self.assert_predictions_equal(res_path['predictions'],
res_ndarray['predictions']) res_ndarray['predictions'])
self.assertIn('visualization', res_path) self.assertTrue(
self.assertIn('visualization', res_ndarray) np.allclose(res_path['visualization'],
res_ndarray['visualization']))
# multiple images # multiple images
img_paths = [ img_paths = [
@ -66,8 +67,10 @@ class TestTextRecinferencer(TestCase):
res_ndarray = self.inferencer(imgs, return_vis=True) res_ndarray = self.inferencer(imgs, return_vis=True)
self.assert_predictions_equal(res_path['predictions'], self.assert_predictions_equal(res_path['predictions'],
res_ndarray['predictions']) res_ndarray['predictions'])
self.assertIn('visualization', res_path) for i in range(len(img_paths)):
self.assertIn('visualization', res_ndarray) self.assertTrue(
np.allclose(res_path['visualization'][i],
res_ndarray['visualization'][i]))
# img dir, test different batch sizes # img dir, test different batch sizes
img_dir = 'tests/data/rec_toy_dataset/imgs' img_dir = 'tests/data/rec_toy_dataset/imgs'