diff --git a/demo/batch_image_demo.py b/demo/batch_image_demo.py
index 0a6c9bc7..245405fc 100644
--- a/demo/batch_image_demo.py
+++ b/demo/batch_image_demo.py
@@ -18,9 +18,7 @@ def main():
         '--images',
         nargs='+',
         help='Image files to be predicted with batch mode, '
-        'separated by space, like "image_1.jpg image2.jpg". '
-        'If algorithm use augmentation test, only one '
-        'image file can be given.')
+        'separated by space, like "image_1.jpg image2.jpg".')
     parser.add_argument(
         '--device', default='cuda:0', help='Device used for inference.')
     parser.add_argument(
@@ -36,7 +34,7 @@ def main():
             0].pipeline
 
     # test multiple images
-    results = model_inference(model, args.images)
+    results = model_inference(model, args.images, batch_mode=True)
     print(f'results: {results}')
 
     save_path = Path(args.save_path)
diff --git a/demo/ocr_image_demo.py b/demo/ocr_image_demo.py
index 494869d2..27b83d39 100644
--- a/demo/ocr_image_demo.py
+++ b/demo/ocr_image_demo.py
@@ -17,6 +17,7 @@ def det_and_recog_inference(args, det_model, recog_model):
     det_result = model_inference(det_model, image)
     bboxes = det_result['boundary_result']
 
+    box_imgs = []
     for bbox in bboxes:
         box_res = {}
         box_res['box'] = [round(x) for x in bbox[:-1]]
@@ -29,17 +30,37 @@ def det_and_recog_inference(args, det_model, recog_model):
             max_y = max(bbox[1:-1:2])
             box = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]
         box_img = crop_img(image, box)
+        if args.batch_mode:
+            box_imgs.append(box_img)
+        else:
+            recog_result = model_inference(recog_model, box_img)
+            text = recog_result['text']
+            text_score = recog_result['score']
+            if isinstance(text_score, list):
+                text_score = sum(text_score) / max(1, len(text))
+            box_res['text'] = text
+            box_res['text_score'] = text_score
 
-        recog_result = model_inference(recog_model, box_img)
-
-        text = recog_result['text']
-        text_score = recog_result['score']
-        if isinstance(text_score, list):
-            text_score = sum(text_score) / max(1, len(text))
-        box_res['text'] = text
-        box_res['text_score'] = text_score
         end2end_res['result'].append(box_res)
 
+    if args.batch_mode:
+        batch_size = args.batch_size
+        for chunk_idx in range(len(box_imgs) // batch_size + 1):
+            start_idx = chunk_idx * batch_size
+            end_idx = (chunk_idx + 1) * batch_size
+            chunk_box_imgs = box_imgs[start_idx:end_idx]
+            if len(chunk_box_imgs) == 0:
+                continue
+            recog_results = model_inference(
+                recog_model, chunk_box_imgs, batch_mode=True)
+            for i, recog_result in enumerate(recog_results):
+                text = recog_result['text']
+                text_score = recog_result['score']
+                if isinstance(text_score, list):
+                    text_score = sum(text_score) / max(1, len(text))
+                end2end_res['result'][start_idx + i]['text'] = text
+                end2end_res['result'][start_idx + i]['text_score'] = text_score
+
     return end2end_res
 
 
@@ -74,6 +95,16 @@ def main():
         'mmocr/textrecog/sar/'
         'sar_r31_parallel_decoder_academic-dba3a4a3.pth',
         help='Text recognition checkpint file (local or url).')
+    parser.add_argument(
+        '--batch-mode',
+        action='store_true',
+        help='Whether use batch mode for text recognition.')
+    parser.add_argument(
+        '--batch-size',
+        type=int,
+        default=4,
+        help='Batch size for text recognition inference '
+        'if batch_mode is True above.')
     parser.add_argument(
         '--device', default='cuda:0', help='Device used for inference.')
     parser.add_argument(
diff --git a/mmocr/apis/inference.py b/mmocr/apis/inference.py
index be955156..670ab1cf 100644
--- a/mmocr/apis/inference.py
+++ b/mmocr/apis/inference.py
@@ -7,16 +7,33 @@ from mmdet.datasets import replace_ImageToTensor
 from mmdet.datasets.pipelines import Compose
 
 
-def model_inference(model, imgs):
+def disable_text_recog_aug_test(cfg):
+    """Remove aug_test from test pipeline of text recognition.
+    Args:
+        cfg (mmcv.Config): Input config.
+
+    Returns:
+        cfg (mmcv.Config): Output config removing
+            `MultiRotateAugOCR` in test pipeline.
+    """
+    if cfg.data.test.pipeline[1].type == 'MultiRotateAugOCR':
+        cfg.data.test.pipeline = [
+            cfg.data.test.pipeline[0], *cfg.data.test.pipeline[1].transforms
+        ]
+
+    return cfg
+
+
+def model_inference(model, imgs, batch_mode=False):
     """Inference image(s) with the detector.
 
     Args:
         model (nn.Module): The loaded detector.
         imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
             Either image files or loaded images.
-
+        batch_mode (bool): If True, use batch mode for inference.
     Returns:
-        result (dict): Detection results.
+        result (dict): Predicted results.
     """
 
     if isinstance(imgs, (list, tuple)):
@@ -31,7 +48,18 @@ def model_inference(model, imgs):
         raise AssertionError('imgs must be strings or numpy arrays')
 
     is_ndarray = isinstance(imgs[0], np.ndarray)
+
     cfg = model.cfg
+
+    if batch_mode:
+        if cfg.data.test.pipeline[1].type == 'ResizeOCR':
+            if cfg.data.test.pipeline[1].max_width is None:
+                raise Exception('Free resize do not support batch mode '
+                                'since the image width is not fixed, '
+                                'for resize keeping aspect ratio and '
+                                'max_width is not give.')
+        cfg = disable_text_recog_aug_test(cfg)
+
     device = next(model.parameters()).device  # model device
 
     if is_ndarray:
@@ -54,8 +82,18 @@ def model_inference(model, imgs):
 
         # build the data pipeline
         data = test_pipeline(data)
+        # get tensor from list to stack for batch mode (text detection)
+        if batch_mode:
+            if cfg.data.test.pipeline[1].type == 'MultiScaleFlipAug':
+                for key, value in data.items():
+                    data[key] = value[0]
         datas.append(data)
 
+    if isinstance(datas[0]['img'], list) and len(datas) > 1:
+        raise Exception('aug test does not support '
+                        f'inference with batch size '
+                        f'{len(datas)}')
+
     data = collate(datas, samples_per_gpu=len(imgs))
 
     # process img_metas
@@ -67,7 +105,9 @@ def model_inference(model, imgs):
         data['img_metas'] = data['img_metas'].data
 
     if isinstance(data['img'], list):
-        data['img'] = [img for img in data['img'][0].data]
+        data['img'] = [img.data for img in data['img']]
+        if isinstance(data['img'][0], list):
+            data['img'] = [img[0] for img in data['img']]
     else:
         data['img'] = data['img'].data
 
diff --git a/mmocr/models/textrecog/recognizer/base.py b/mmocr/models/textrecog/recognizer/base.py
index 8e39feba..3870210b 100644
--- a/mmocr/models/textrecog/recognizer/base.py
+++ b/mmocr/models/textrecog/recognizer/base.py
@@ -75,10 +75,11 @@ class BaseRecognizer(nn.Module, metaclass=ABCMeta):
                 The outer list indicates images in a batch.
         """
         if isinstance(imgs, list):
-            assert len(imgs) == len(img_metas), ('aug test does not support '
-                                                 f'inference with batch size '
-                                                 f'{len(imgs)}')
             assert len(imgs) > 0
+            assert imgs[0].size(0) == 1, ('aug test does not support '
+                                          f'inference with batch size '
+                                          f'{imgs[0].size(0)}')
+            assert len(imgs) == len(img_metas)
             return self.aug_test(imgs, img_metas, **kwargs)
 
         return self.simple_test(imgs, img_metas, **kwargs)
diff --git a/tests/test_apis/test_model_inference.py b/tests/test_apis/test_model_inference.py
index dbf7c363..d964b5ff 100644
--- a/tests/test_apis/test_model_inference.py
+++ b/tests/test_apis/test_model_inference.py
@@ -20,20 +20,9 @@ def build_model(config_file):
     return model
 
 
-def disable_aug_test(model):
-    model.cfg.data.test.pipeline = [
-        model.cfg.data.test.pipeline[0],
-        *model.cfg.data.test.pipeline[1].transforms
-    ]
-
-    return model
-
-
 @pytest.mark.parametrize('cfg_file', [
     '../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
     '../configs/textrecog/crnn/crnn_academic_dataset.py',
-    '../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py',
-    '../configs/textrecog/robust_scanner/robustscanner_r31_academic.py',
     '../configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py',
     '../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py'
 ])
@@ -53,45 +42,42 @@ def test_model_inference(cfg_file):
     model_inference(model, img)
 
 
-@pytest.mark.parametrize('cfg_file', [
-    '../configs/textrecog/crnn/crnn_academic_dataset.py',
-    '../configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py'
-])
-def test_model_batch_inference(cfg_file):
+@pytest.mark.parametrize(
+    'cfg_file',
+    ['../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py'])
+def test_model_batch_inference_det(cfg_file):
     tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
     config_file = os.path.join(tmp_dir, cfg_file)
     model = build_model(config_file)
 
     sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
-    results = model_inference(model, [sample_img_path, sample_img_path])
+    results = model_inference(model, [sample_img_path], batch_mode=True)
 
-    assert len(results) == 2
+    assert len(results) == 1
 
     # numpy inference
     img = imread(sample_img_path)
-    results = model_inference(model, [img, img])
+    results = model_inference(model, [img], batch_mode=True)
 
-    assert len(results) == 2
+    assert len(results) == 1
 
 
 @pytest.mark.parametrize('cfg_file', [
     '../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
-    '../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py',
-    '../configs/textrecog/robust_scanner/robustscanner_r31_academic.py',
 ])
-def test_model_batch_inference_raises_assertion_error_if_unsupported(cfg_file):
+def test_model_batch_inference_raises_exception_error_aug_test_recog(cfg_file):
     tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
     config_file = os.path.join(tmp_dir, cfg_file)
     model = build_model(config_file)
 
     with pytest.raises(
-            AssertionError,
+            Exception,
             match='aug test does not support inference with batch size'):
         sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
         model_inference(model, [sample_img_path, sample_img_path])
 
     with pytest.raises(
-            AssertionError,
+            Exception,
             match='aug test does not support inference with batch size'):
         img = imread(sample_img_path)
         model_inference(model, [img, img])
@@ -99,22 +85,48 @@ def test_model_batch_inference_raises_assertion_error_if_unsupported(cfg_file):
 
 @pytest.mark.parametrize('cfg_file', [
     '../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
-    '../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py',
-    '../configs/textrecog/robust_scanner/robustscanner_r31_academic.py',
 ])
 def test_model_batch_inference_recog(cfg_file):
     tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
     config_file = os.path.join(tmp_dir, cfg_file)
     model = build_model(config_file)
-    model = disable_aug_test(model)
 
-    sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
-    results = model_inference(model, [sample_img_path, sample_img_path])
+    sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_recog.jpg')
+    results = model_inference(
+        model, [sample_img_path, sample_img_path], batch_mode=True)
 
     assert len(results) == 2
 
     # numpy inference
     img = imread(sample_img_path)
-    results = model_inference(model, [img, img])
+    results = model_inference(model, [img, img], batch_mode=True)
 
     assert len(results) == 2
+
+
+@pytest.mark.parametrize(
+    'cfg_file', ['../configs/textrecog/crnn/crnn_academic_dataset.py'])
+def test_model_batch_inference_raises_exception_error_free_resize_recog(
+        cfg_file):
+    tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
+    config_file = os.path.join(tmp_dir, cfg_file)
+    model = build_model(config_file)
+
+    with pytest.raises(
+            Exception,
+            match='Free resize do not support batch mode '
+            'since the image width is not fixed, '
+            'for resize keeping aspect ratio and '
+            'max_width is not give.'):
+        sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_recog.jpg')
+        model_inference(
+            model, [sample_img_path, sample_img_path], batch_mode=True)
+
+    with pytest.raises(
+            Exception,
+            match='Free resize do not support batch mode '
+            'since the image width is not fixed, '
+            'for resize keeping aspect ratio and '
+            'max_width is not give.'):
+        img = imread(sample_img_path)
+        model_inference(model, [img, img], batch_mode=True)