diff --git a/docs/en/user_guides/inference.md b/docs/en/user_guides/inference.md index f96737a3..0687a327 100644 --- a/docs/en/user_guides/inference.md +++ b/docs/en/user_guides/inference.md @@ -460,6 +460,9 @@ Here are extensive lists of parameters that you can use. | `inputs` | str/list/tuple/np.array | **required** | It can be a path to an image/a folder, an np array or a list/tuple (with img paths or np arrays) | | `return_datasamples` | bool | False | Whether to return results as DataSamples. If False, the results will be packed into a dict. | | `batch_size` | int | 1 | Inference batch size. | +| `det_batch_size` | int, optional | None | Inference batch size for text detection model. Overwrite batch_size if it is not None. | +| `rec_batch_size` | int, optional | None | Inference batch size for text recognition model. Overwrite batch_size if it is not None. | +| `kie_batch_size` | int, optional | None | Inference batch size for KIE model. Overwrite batch_size if it is not None. | | `return_vis` | bool | False | Whether to return the visualization result. | | `print_result` | bool | False | Whether to print the inference result to the console. | | `show` | bool | False | Whether to display the visualization results in a popup window. | diff --git a/docs/zh_cn/user_guides/inference.md b/docs/zh_cn/user_guides/inference.md index 9da8b35a..7b69dcee 100644 --- a/docs/zh_cn/user_guides/inference.md +++ b/docs/zh_cn/user_guides/inference.md @@ -457,6 +457,9 @@ outputs | `inputs` | str/list/tuple/np.array | **必需** | 它可以是一个图片/文件夹的路径,一个 numpy 数组,或者是一个包含图片路径或 numpy 数组的列表/元组 | | `return_datasamples` | bool | False | 是否将结果作为 DataSample 返回。如果为 False,结果将被打包成一个字典。 | | `batch_size` | int | 1 | 推理的批大小。 | +| `det_batch_size` | int, 可选 | None | 推理的批大小 (文本检测模型)。如果不为 None,则覆盖 batch_size。 | +| `rec_batch_size` | int, 可选 | None | 推理的批大小 (文本识别模型)。如果不为 None,则覆盖 batch_size。 | +| `kie_batch_size` | int, 可选 | None | 推理的批大小 (关键信息提取模型)。如果不为 None,则覆盖 batch_size。 | | `return_vis` | bool | False | 是否返回可视化结果。 | | `print_result` | bool | False | 是否将推理结果打印到控制台。 | | `show` | bool | False | 是否在弹出窗口中显示可视化结果。 | diff --git a/mmocr/apis/inferencers/mmocr_inferencer.py b/mmocr/apis/inferencers/mmocr_inferencer.py index 66b2084d..be7f7423 100644 --- a/mmocr/apis/inferencers/mmocr_inferencer.py +++ b/mmocr/apis/inferencers/mmocr_inferencer.py @@ -105,13 +105,27 @@ class MMOCRInferencer(BaseMMOCRInferencer): 'supported yet.') return new_inputs - def forward(self, inputs: InputsType, batch_size: int, + def forward(self, + inputs: InputsType, + batch_size: int = 1, + det_batch_size: Optional[int] = None, + rec_batch_size: Optional[int] = None, + kie_batch_size: Optional[int] = None, **forward_kwargs) -> PredType: """Forward the inputs to the model. Args: inputs (InputsType): The inputs to be forwarded. batch_size (int): Batch size. Defaults to 1. + det_batch_size (Optional[int]): Batch size for text detection + model. Overwrite batch_size if it is not None. + Defaults to None. + rec_batch_size (Optional[int]): Batch size for text recognition + model. Overwrite batch_size if it is not None. + Defaults to None. + kie_batch_size (Optional[int]): Batch size for KIE model. + Overwrite batch_size if it is not None. + Defaults to None. Returns: Dict: The prediction results. Possibly with keys "det", "rec", and @@ -119,20 +133,26 @@ class MMOCRInferencer(BaseMMOCRInferencer): """ result = {} forward_kwargs['progress_bar'] = False + if det_batch_size is None: + det_batch_size = batch_size + if rec_batch_size is None: + rec_batch_size = batch_size + if kie_batch_size is None: + kie_batch_size = batch_size if self.mode == 'rec': # The extra list wrapper here is for the ease of postprocessing self.rec_inputs = inputs predictions = self.textrec_inferencer( self.rec_inputs, return_datasamples=True, - batch_size=batch_size, + batch_size=rec_batch_size, **forward_kwargs)['predictions'] result['rec'] = [[p] for p in predictions] elif self.mode.startswith('det'): # 'det'/'det_rec'/'det_rec_kie' result['det'] = self.textdet_inferencer( inputs, return_datasamples=True, - batch_size=batch_size, + batch_size=det_batch_size, **forward_kwargs)['predictions'] if self.mode.startswith('det_rec'): # 'det_rec'/'det_rec_kie' result['rec'] = [] @@ -149,7 +169,7 @@ class MMOCRInferencer(BaseMMOCRInferencer): self.textrec_inferencer( self.rec_inputs, return_datasamples=True, - batch_size=batch_size, + batch_size=rec_batch_size, **forward_kwargs)['predictions']) if self.mode == 'det_rec_kie': self.kie_inputs = [] @@ -172,7 +192,7 @@ class MMOCRInferencer(BaseMMOCRInferencer): result['kie'] = self.kie_inferencer( self.kie_inputs, return_datasamples=True, - batch_size=batch_size, + batch_size=kie_batch_size, **forward_kwargs)['predictions'] return result @@ -219,6 +239,9 @@ class MMOCRInferencer(BaseMMOCRInferencer): self, inputs: InputsType, batch_size: int = 1, + det_batch_size: Optional[int] = None, + rec_batch_size: Optional[int] = None, + kie_batch_size: Optional[int] = None, out_dir: str = 'results/', return_vis: bool = False, save_vis: bool = False, @@ -231,6 +254,15 @@ class MMOCRInferencer(BaseMMOCRInferencer): inputs (InputsType): Inputs for the inferencer. It can be a path to image / image directory, or an array, or a list of these. batch_size (int): Batch size. Defaults to 1. + det_batch_size (Optional[int]): Batch size for text detection + model. Overwrite batch_size if it is not None. + Defaults to None. + rec_batch_size (Optional[int]): Batch size for text recognition + model. Overwrite batch_size if it is not None. + Defaults to None. + kie_batch_size (Optional[int]): Batch size for KIE model. + Overwrite batch_size if it is not None. + Defaults to None. out_dir (str): Output directory of results. Defaults to 'results/'. return_vis (bool): Whether to return the visualization result. Defaults to False. @@ -269,12 +301,23 @@ class MMOCRInferencer(BaseMMOCRInferencer): **kwargs) ori_inputs = self._inputs_to_list(inputs) + if det_batch_size is None: + det_batch_size = batch_size + if rec_batch_size is None: + rec_batch_size = batch_size + if kie_batch_size is None: + kie_batch_size = batch_size chunked_inputs = super(BaseMMOCRInferencer, self)._get_chunk_data(ori_inputs, batch_size) results = {'predictions': [], 'visualization': []} for ori_input in track(chunked_inputs, description='Inference'): - preds = self.forward(ori_input, batch_size, **forward_kwargs) + preds = self.forward( + ori_input, + det_batch_size=det_batch_size, + rec_batch_size=rec_batch_size, + kie_batch_size=kie_batch_size, + **forward_kwargs) visualization = self.visualize( ori_input, preds, img_out_dir=img_out_dir, **visualize_kwargs) batch_res = self.postprocess(