[Fix] Fix some inferencer bugs (#1706)

* [Fix] Fix some inferencer bugs

* fix
pull/1719/head
Tong Gao 2023-02-09 18:31:25 +08:00 committed by GitHub
parent d8e615921d
commit 20a87d476c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 31 additions and 14 deletions

View File

@ -24,5 +24,5 @@ test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadKIEAnnotations'),
dict(type='Resize', scale=(1024, 512), keep_ratio=True),
dict(type='PackKIEInputs'),
dict(type='PackKIEInputs', meta_keys=('img_path', )),
]

View File

@ -80,6 +80,7 @@ class BaseMMOCRInferencer(BaseInferencer):
Args:
inputs (InputsType): Inputs for the inferencer. It can be a path
to image / image directory, or an array, or a list of these.
Note: If it's an numpy array, it should be in BGR order.
return_datasamples (bool): Whether to return results as
:obj:`BaseDataElement`. Defaults to False.
batch_size (int): Inference batch size. Defaults to 1.
@ -206,7 +207,7 @@ class BaseMMOCRInferencer(BaseInferencer):
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
img_name = osp.basename(single_input)
elif isinstance(single_input, np.ndarray):
img = single_input.copy()
img = single_input.copy()[:, :, ::-1] # to RGB
img_num = str(self.num_visualized_imgs).zfill(8)
img_name = f'{img_num}.jpg'
else:

View File

@ -77,6 +77,13 @@ class KIEInferencer(BaseMMOCRInferencer):
self.novisual = all(
self._get_transform_idx(pipeline_cfg, t) == -1
for t in self.loading_transforms)
# Remove Resize from test_pipeline, since SDMGR requires bbox
# annotations to be resized together with pictures, but visualization
# loads the original image from the disk.
# TODO: find a more elegant way to fix this
idx = self._get_transform_idx(pipeline_cfg, 'Resize')
if idx != -1:
pipeline_cfg.pop(idx)
# If it's in non-visual mode, self.pipeline will be specified.
# Otherwise, file_pipeline and ndarray_pipeline will be specified.
if self.novisual:
@ -93,6 +100,7 @@ class KIEInferencer(BaseMMOCRInferencer):
- img (str or ndarray): Path to the image or the image itself. If KIE
Inferencer is used in no-visual mode, this key is not required.
Note: If it's an numpy array, it should be in BGR order.
- img_shape (tuple(int, int)): Image shape in (H, W). In
- instances (list[dict]): A list of instances.
- bbox (ndarray(dtype=np.float32)): Shape (4, ). Bounding box.
@ -182,6 +190,7 @@ class KIEInferencer(BaseMMOCRInferencer):
- img (str or ndarray): Path to the image or the image itself. If KIE
Inferencer is used in no-visual mode, this key is not required.
Note: If it's an numpy array, it should be in BGR order.
- img_shape (tuple(int, int)): Image shape in (H, W). In
- instances (list[dict]): A list of instances.
- bbox (ndarray(dtype=np.float32)): Shape (4, ). Bounding box.
@ -286,10 +295,10 @@ class KIEInferencer(BaseMMOCRInferencer):
assert 'img' in single_input or 'img_shape' in single_input
if 'img' in single_input:
if isinstance(single_input['img'], str):
img = mmcv.imread(single_input['img'])
img = mmcv.imread(single_input['img'], channel_order='rgb')
img_name = osp.basename(single_input['img'])
elif isinstance(single_input['img'], np.ndarray):
img = single_input['img'].copy()
img = single_input['img'].copy()[:, :, ::-1] # To RGB
img_name = f'{img_num}.jpg'
elif 'img_shape' in single_input:
img = np.zeros(single_input['img_shape'], dtype=np.uint8)

View File

@ -46,7 +46,7 @@ def parse_args():
help='Pretrained key information extraction algorithm. It\'s the path'
'to the config file or the model name defined in metafile.')
parser.add_argument(
'--kie-ckpt',
'--kie-weights',
type=str,
default=None,
help='Path to the custom checkpoint file of the selected kie model. '
@ -77,7 +77,8 @@ def parse_args():
call_args = vars(parser.parse_args())
init_kws = [
'det', 'det_weights', 'rec', 'rec_weights', 'kie', 'kie_ckpt', 'device'
'det', 'det_weights', 'rec', 'rec_weights', 'kie', 'kie_weights',
'device'
]
init_args = {}
for init_kw in init_kws:

View File

@ -54,8 +54,9 @@ class TestTextDetinferencer(TestCase):
res_ndarray = self.inferencer(img, return_vis=True)
self.assert_predictions_equal(res_path['predictions'],
res_ndarray['predictions'])
self.assertIn('visualization', res_path)
self.assertIn('visualization', res_ndarray)
self.assertTrue(
np.allclose(res_path['visualization'],
res_ndarray['visualization']))
# multiple images
img_paths = [
@ -68,8 +69,10 @@ class TestTextDetinferencer(TestCase):
res_ndarray = self.inferencer(imgs, return_vis=True)
self.assert_predictions_equal(res_path['predictions'],
res_ndarray['predictions'])
self.assertIn('visualization', res_path)
self.assertIn('visualization', res_ndarray)
for i in range(len(img_paths)):
self.assertTrue(
np.allclose(res_path['visualization'][i],
res_ndarray['visualization'][i]))
# img dir, test different batch sizes
img_dir = 'tests/data/det_toy_dataset/imgs/test/'

View File

@ -52,8 +52,9 @@ class TestTextRecinferencer(TestCase):
res_ndarray = self.inferencer(img, return_vis=True)
self.assert_predictions_equal(res_path['predictions'],
res_ndarray['predictions'])
self.assertIn('visualization', res_path)
self.assertIn('visualization', res_ndarray)
self.assertTrue(
np.allclose(res_path['visualization'],
res_ndarray['visualization']))
# multiple images
img_paths = [
@ -66,8 +67,10 @@ class TestTextRecinferencer(TestCase):
res_ndarray = self.inferencer(imgs, return_vis=True)
self.assert_predictions_equal(res_path['predictions'],
res_ndarray['predictions'])
self.assertIn('visualization', res_path)
self.assertIn('visualization', res_ndarray)
for i in range(len(img_paths)):
self.assertTrue(
np.allclose(res_path['visualization'][i],
res_ndarray['visualization'][i]))
# img dir, test different batch sizes
img_dir = 'tests/data/rec_toy_dataset/imgs'