mirror of https://github.com/open-mmlab/mmocr.git
[Enhancement] decouple batch_size to det_batch_size, rec_batch_size and kie_batch_size in MMOCRInferencer (#1801)
* decouple batch_size to det_batch_size, rec_batch_size, kie_batch_size and chunk_size in MMOCRInferencer * remove chunk_size parameter * add Optional keyword in function definitions and doc strings * add det_batch_size, rec_batch_size, kie_batch_size in user_guides * minor formattingpull/1807/head
parent
22f40b79ed
commit
c886936117
|
@ -460,6 +460,9 @@ Here are extensive lists of parameters that you can use.
|
|||
| `inputs` | str/list/tuple/np.array | **required** | It can be a path to an image/a folder, an np array or a list/tuple (with img paths or np arrays) |
|
||||
| `return_datasamples` | bool | False | Whether to return results as DataSamples. If False, the results will be packed into a dict. |
|
||||
| `batch_size` | int | 1 | Inference batch size. |
|
||||
| `det_batch_size` | int, optional | None | Inference batch size for text detection model. Overwrite batch_size if it is not None. |
|
||||
| `rec_batch_size` | int, optional | None | Inference batch size for text recognition model. Overwrite batch_size if it is not None. |
|
||||
| `kie_batch_size` | int, optional | None | Inference batch size for KIE model. Overwrite batch_size if it is not None. |
|
||||
| `return_vis` | bool | False | Whether to return the visualization result. |
|
||||
| `print_result` | bool | False | Whether to print the inference result to the console. |
|
||||
| `show` | bool | False | Whether to display the visualization results in a popup window. |
|
||||
|
|
|
@ -457,6 +457,9 @@ outputs
|
|||
| `inputs` | str/list/tuple/np.array | **必需** | 它可以是一个图片/文件夹的路径,一个 numpy 数组,或者是一个包含图片路径或 numpy 数组的列表/元组 |
|
||||
| `return_datasamples` | bool | False | 是否将结果作为 DataSample 返回。如果为 False,结果将被打包成一个字典。 |
|
||||
| `batch_size` | int | 1 | 推理的批大小。 |
|
||||
| `det_batch_size` | int, 可选 | None | 推理的批大小 (文本检测模型)。如果不为 None,则覆盖 batch_size。 |
|
||||
| `rec_batch_size` | int, 可选 | None | 推理的批大小 (文本识别模型)。如果不为 None,则覆盖 batch_size。 |
|
||||
| `kie_batch_size` | int, 可选 | None | 推理的批大小 (关键信息提取模型)。如果不为 None,则覆盖 batch_size。 |
|
||||
| `return_vis` | bool | False | 是否返回可视化结果。 |
|
||||
| `print_result` | bool | False | 是否将推理结果打印到控制台。 |
|
||||
| `show` | bool | False | 是否在弹出窗口中显示可视化结果。 |
|
||||
|
|
|
@ -105,13 +105,27 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
'supported yet.')
|
||||
return new_inputs
|
||||
|
||||
def forward(self, inputs: InputsType, batch_size: int,
|
||||
def forward(self,
|
||||
inputs: InputsType,
|
||||
batch_size: int = 1,
|
||||
det_batch_size: Optional[int] = None,
|
||||
rec_batch_size: Optional[int] = None,
|
||||
kie_batch_size: Optional[int] = None,
|
||||
**forward_kwargs) -> PredType:
|
||||
"""Forward the inputs to the model.
|
||||
|
||||
Args:
|
||||
inputs (InputsType): The inputs to be forwarded.
|
||||
batch_size (int): Batch size. Defaults to 1.
|
||||
det_batch_size (Optional[int]): Batch size for text detection
|
||||
model. Overwrite batch_size if it is not None.
|
||||
Defaults to None.
|
||||
rec_batch_size (Optional[int]): Batch size for text recognition
|
||||
model. Overwrite batch_size if it is not None.
|
||||
Defaults to None.
|
||||
kie_batch_size (Optional[int]): Batch size for KIE model.
|
||||
Overwrite batch_size if it is not None.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Dict: The prediction results. Possibly with keys "det", "rec", and
|
||||
|
@ -119,20 +133,26 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
"""
|
||||
result = {}
|
||||
forward_kwargs['progress_bar'] = False
|
||||
if det_batch_size is None:
|
||||
det_batch_size = batch_size
|
||||
if rec_batch_size is None:
|
||||
rec_batch_size = batch_size
|
||||
if kie_batch_size is None:
|
||||
kie_batch_size = batch_size
|
||||
if self.mode == 'rec':
|
||||
# The extra list wrapper here is for the ease of postprocessing
|
||||
self.rec_inputs = inputs
|
||||
predictions = self.textrec_inferencer(
|
||||
self.rec_inputs,
|
||||
return_datasamples=True,
|
||||
batch_size=batch_size,
|
||||
batch_size=rec_batch_size,
|
||||
**forward_kwargs)['predictions']
|
||||
result['rec'] = [[p] for p in predictions]
|
||||
elif self.mode.startswith('det'): # 'det'/'det_rec'/'det_rec_kie'
|
||||
result['det'] = self.textdet_inferencer(
|
||||
inputs,
|
||||
return_datasamples=True,
|
||||
batch_size=batch_size,
|
||||
batch_size=det_batch_size,
|
||||
**forward_kwargs)['predictions']
|
||||
if self.mode.startswith('det_rec'): # 'det_rec'/'det_rec_kie'
|
||||
result['rec'] = []
|
||||
|
@ -149,7 +169,7 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
self.textrec_inferencer(
|
||||
self.rec_inputs,
|
||||
return_datasamples=True,
|
||||
batch_size=batch_size,
|
||||
batch_size=rec_batch_size,
|
||||
**forward_kwargs)['predictions'])
|
||||
if self.mode == 'det_rec_kie':
|
||||
self.kie_inputs = []
|
||||
|
@ -172,7 +192,7 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
result['kie'] = self.kie_inferencer(
|
||||
self.kie_inputs,
|
||||
return_datasamples=True,
|
||||
batch_size=batch_size,
|
||||
batch_size=kie_batch_size,
|
||||
**forward_kwargs)['predictions']
|
||||
return result
|
||||
|
||||
|
@ -219,6 +239,9 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
self,
|
||||
inputs: InputsType,
|
||||
batch_size: int = 1,
|
||||
det_batch_size: Optional[int] = None,
|
||||
rec_batch_size: Optional[int] = None,
|
||||
kie_batch_size: Optional[int] = None,
|
||||
out_dir: str = 'results/',
|
||||
return_vis: bool = False,
|
||||
save_vis: bool = False,
|
||||
|
@ -231,6 +254,15 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
inputs (InputsType): Inputs for the inferencer. It can be a path
|
||||
to image / image directory, or an array, or a list of these.
|
||||
batch_size (int): Batch size. Defaults to 1.
|
||||
det_batch_size (Optional[int]): Batch size for text detection
|
||||
model. Overwrite batch_size if it is not None.
|
||||
Defaults to None.
|
||||
rec_batch_size (Optional[int]): Batch size for text recognition
|
||||
model. Overwrite batch_size if it is not None.
|
||||
Defaults to None.
|
||||
kie_batch_size (Optional[int]): Batch size for KIE model.
|
||||
Overwrite batch_size if it is not None.
|
||||
Defaults to None.
|
||||
out_dir (str): Output directory of results. Defaults to 'results/'.
|
||||
return_vis (bool): Whether to return the visualization result.
|
||||
Defaults to False.
|
||||
|
@ -269,12 +301,23 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
**kwargs)
|
||||
|
||||
ori_inputs = self._inputs_to_list(inputs)
|
||||
if det_batch_size is None:
|
||||
det_batch_size = batch_size
|
||||
if rec_batch_size is None:
|
||||
rec_batch_size = batch_size
|
||||
if kie_batch_size is None:
|
||||
kie_batch_size = batch_size
|
||||
|
||||
chunked_inputs = super(BaseMMOCRInferencer,
|
||||
self)._get_chunk_data(ori_inputs, batch_size)
|
||||
results = {'predictions': [], 'visualization': []}
|
||||
for ori_input in track(chunked_inputs, description='Inference'):
|
||||
preds = self.forward(ori_input, batch_size, **forward_kwargs)
|
||||
preds = self.forward(
|
||||
ori_input,
|
||||
det_batch_size=det_batch_size,
|
||||
rec_batch_size=rec_batch_size,
|
||||
kie_batch_size=kie_batch_size,
|
||||
**forward_kwargs)
|
||||
visualization = self.visualize(
|
||||
ori_input, preds, img_out_dir=img_out_dir, **visualize_kwargs)
|
||||
batch_res = self.postprocess(
|
||||
|
|
Loading…
Reference in New Issue