mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Enhancement] decouple batch_size to det_batch_size, rec_batch_size and kie_batch_size in MMOCRInferencer (#1801)
* decouple batch_size to det_batch_size, rec_batch_size, kie_batch_size and chunk_size in MMOCRInferencer * remove chunk_size parameter * add Optional keyword in function definitions and doc strings * add det_batch_size, rec_batch_size, kie_batch_size in user_guides * minor formatting
This commit is contained in:
parent
22f40b79ed
commit
c886936117
@ -460,6 +460,9 @@ Here are extensive lists of parameters that you can use.
|
|||||||
| `inputs` | str/list/tuple/np.array | **required** | It can be a path to an image/a folder, an np array or a list/tuple (with img paths or np arrays) |
|
| `inputs` | str/list/tuple/np.array | **required** | It can be a path to an image/a folder, an np array or a list/tuple (with img paths or np arrays) |
|
||||||
| `return_datasamples` | bool | False | Whether to return results as DataSamples. If False, the results will be packed into a dict. |
|
| `return_datasamples` | bool | False | Whether to return results as DataSamples. If False, the results will be packed into a dict. |
|
||||||
| `batch_size` | int | 1 | Inference batch size. |
|
| `batch_size` | int | 1 | Inference batch size. |
|
||||||
|
| `det_batch_size` | int, optional | None | Inference batch size for text detection model. Overwrite batch_size if it is not None. |
|
||||||
|
| `rec_batch_size` | int, optional | None | Inference batch size for text recognition model. Overwrite batch_size if it is not None. |
|
||||||
|
| `kie_batch_size` | int, optional | None | Inference batch size for KIE model. Overwrite batch_size if it is not None. |
|
||||||
| `return_vis` | bool | False | Whether to return the visualization result. |
|
| `return_vis` | bool | False | Whether to return the visualization result. |
|
||||||
| `print_result` | bool | False | Whether to print the inference result to the console. |
|
| `print_result` | bool | False | Whether to print the inference result to the console. |
|
||||||
| `show` | bool | False | Whether to display the visualization results in a popup window. |
|
| `show` | bool | False | Whether to display the visualization results in a popup window. |
|
||||||
|
@ -457,6 +457,9 @@ outputs
|
|||||||
| `inputs` | str/list/tuple/np.array | **必需** | 它可以是一个图片/文件夹的路径,一个 numpy 数组,或者是一个包含图片路径或 numpy 数组的列表/元组 |
|
| `inputs` | str/list/tuple/np.array | **必需** | 它可以是一个图片/文件夹的路径,一个 numpy 数组,或者是一个包含图片路径或 numpy 数组的列表/元组 |
|
||||||
| `return_datasamples` | bool | False | 是否将结果作为 DataSample 返回。如果为 False,结果将被打包成一个字典。 |
|
| `return_datasamples` | bool | False | 是否将结果作为 DataSample 返回。如果为 False,结果将被打包成一个字典。 |
|
||||||
| `batch_size` | int | 1 | 推理的批大小。 |
|
| `batch_size` | int | 1 | 推理的批大小。 |
|
||||||
|
| `det_batch_size` | int, 可选 | None | 推理的批大小 (文本检测模型)。如果不为 None,则覆盖 batch_size。 |
|
||||||
|
| `rec_batch_size` | int, 可选 | None | 推理的批大小 (文本识别模型)。如果不为 None,则覆盖 batch_size。 |
|
||||||
|
| `kie_batch_size` | int, 可选 | None | 推理的批大小 (关键信息提取模型)。如果不为 None,则覆盖 batch_size。 |
|
||||||
| `return_vis` | bool | False | 是否返回可视化结果。 |
|
| `return_vis` | bool | False | 是否返回可视化结果。 |
|
||||||
| `print_result` | bool | False | 是否将推理结果打印到控制台。 |
|
| `print_result` | bool | False | 是否将推理结果打印到控制台。 |
|
||||||
| `show` | bool | False | 是否在弹出窗口中显示可视化结果。 |
|
| `show` | bool | False | 是否在弹出窗口中显示可视化结果。 |
|
||||||
|
@ -105,13 +105,27 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||||||
'supported yet.')
|
'supported yet.')
|
||||||
return new_inputs
|
return new_inputs
|
||||||
|
|
||||||
def forward(self, inputs: InputsType, batch_size: int,
|
def forward(self,
|
||||||
|
inputs: InputsType,
|
||||||
|
batch_size: int = 1,
|
||||||
|
det_batch_size: Optional[int] = None,
|
||||||
|
rec_batch_size: Optional[int] = None,
|
||||||
|
kie_batch_size: Optional[int] = None,
|
||||||
**forward_kwargs) -> PredType:
|
**forward_kwargs) -> PredType:
|
||||||
"""Forward the inputs to the model.
|
"""Forward the inputs to the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (InputsType): The inputs to be forwarded.
|
inputs (InputsType): The inputs to be forwarded.
|
||||||
batch_size (int): Batch size. Defaults to 1.
|
batch_size (int): Batch size. Defaults to 1.
|
||||||
|
det_batch_size (Optional[int]): Batch size for text detection
|
||||||
|
model. Overwrite batch_size if it is not None.
|
||||||
|
Defaults to None.
|
||||||
|
rec_batch_size (Optional[int]): Batch size for text recognition
|
||||||
|
model. Overwrite batch_size if it is not None.
|
||||||
|
Defaults to None.
|
||||||
|
kie_batch_size (Optional[int]): Batch size for KIE model.
|
||||||
|
Overwrite batch_size if it is not None.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: The prediction results. Possibly with keys "det", "rec", and
|
Dict: The prediction results. Possibly with keys "det", "rec", and
|
||||||
@ -119,20 +133,26 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||||||
"""
|
"""
|
||||||
result = {}
|
result = {}
|
||||||
forward_kwargs['progress_bar'] = False
|
forward_kwargs['progress_bar'] = False
|
||||||
|
if det_batch_size is None:
|
||||||
|
det_batch_size = batch_size
|
||||||
|
if rec_batch_size is None:
|
||||||
|
rec_batch_size = batch_size
|
||||||
|
if kie_batch_size is None:
|
||||||
|
kie_batch_size = batch_size
|
||||||
if self.mode == 'rec':
|
if self.mode == 'rec':
|
||||||
# The extra list wrapper here is for the ease of postprocessing
|
# The extra list wrapper here is for the ease of postprocessing
|
||||||
self.rec_inputs = inputs
|
self.rec_inputs = inputs
|
||||||
predictions = self.textrec_inferencer(
|
predictions = self.textrec_inferencer(
|
||||||
self.rec_inputs,
|
self.rec_inputs,
|
||||||
return_datasamples=True,
|
return_datasamples=True,
|
||||||
batch_size=batch_size,
|
batch_size=rec_batch_size,
|
||||||
**forward_kwargs)['predictions']
|
**forward_kwargs)['predictions']
|
||||||
result['rec'] = [[p] for p in predictions]
|
result['rec'] = [[p] for p in predictions]
|
||||||
elif self.mode.startswith('det'): # 'det'/'det_rec'/'det_rec_kie'
|
elif self.mode.startswith('det'): # 'det'/'det_rec'/'det_rec_kie'
|
||||||
result['det'] = self.textdet_inferencer(
|
result['det'] = self.textdet_inferencer(
|
||||||
inputs,
|
inputs,
|
||||||
return_datasamples=True,
|
return_datasamples=True,
|
||||||
batch_size=batch_size,
|
batch_size=det_batch_size,
|
||||||
**forward_kwargs)['predictions']
|
**forward_kwargs)['predictions']
|
||||||
if self.mode.startswith('det_rec'): # 'det_rec'/'det_rec_kie'
|
if self.mode.startswith('det_rec'): # 'det_rec'/'det_rec_kie'
|
||||||
result['rec'] = []
|
result['rec'] = []
|
||||||
@ -149,7 +169,7 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||||||
self.textrec_inferencer(
|
self.textrec_inferencer(
|
||||||
self.rec_inputs,
|
self.rec_inputs,
|
||||||
return_datasamples=True,
|
return_datasamples=True,
|
||||||
batch_size=batch_size,
|
batch_size=rec_batch_size,
|
||||||
**forward_kwargs)['predictions'])
|
**forward_kwargs)['predictions'])
|
||||||
if self.mode == 'det_rec_kie':
|
if self.mode == 'det_rec_kie':
|
||||||
self.kie_inputs = []
|
self.kie_inputs = []
|
||||||
@ -172,7 +192,7 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||||||
result['kie'] = self.kie_inferencer(
|
result['kie'] = self.kie_inferencer(
|
||||||
self.kie_inputs,
|
self.kie_inputs,
|
||||||
return_datasamples=True,
|
return_datasamples=True,
|
||||||
batch_size=batch_size,
|
batch_size=kie_batch_size,
|
||||||
**forward_kwargs)['predictions']
|
**forward_kwargs)['predictions']
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -219,6 +239,9 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||||||
self,
|
self,
|
||||||
inputs: InputsType,
|
inputs: InputsType,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
|
det_batch_size: Optional[int] = None,
|
||||||
|
rec_batch_size: Optional[int] = None,
|
||||||
|
kie_batch_size: Optional[int] = None,
|
||||||
out_dir: str = 'results/',
|
out_dir: str = 'results/',
|
||||||
return_vis: bool = False,
|
return_vis: bool = False,
|
||||||
save_vis: bool = False,
|
save_vis: bool = False,
|
||||||
@ -231,6 +254,15 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||||||
inputs (InputsType): Inputs for the inferencer. It can be a path
|
inputs (InputsType): Inputs for the inferencer. It can be a path
|
||||||
to image / image directory, or an array, or a list of these.
|
to image / image directory, or an array, or a list of these.
|
||||||
batch_size (int): Batch size. Defaults to 1.
|
batch_size (int): Batch size. Defaults to 1.
|
||||||
|
det_batch_size (Optional[int]): Batch size for text detection
|
||||||
|
model. Overwrite batch_size if it is not None.
|
||||||
|
Defaults to None.
|
||||||
|
rec_batch_size (Optional[int]): Batch size for text recognition
|
||||||
|
model. Overwrite batch_size if it is not None.
|
||||||
|
Defaults to None.
|
||||||
|
kie_batch_size (Optional[int]): Batch size for KIE model.
|
||||||
|
Overwrite batch_size if it is not None.
|
||||||
|
Defaults to None.
|
||||||
out_dir (str): Output directory of results. Defaults to 'results/'.
|
out_dir (str): Output directory of results. Defaults to 'results/'.
|
||||||
return_vis (bool): Whether to return the visualization result.
|
return_vis (bool): Whether to return the visualization result.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
@ -269,12 +301,23 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
ori_inputs = self._inputs_to_list(inputs)
|
ori_inputs = self._inputs_to_list(inputs)
|
||||||
|
if det_batch_size is None:
|
||||||
|
det_batch_size = batch_size
|
||||||
|
if rec_batch_size is None:
|
||||||
|
rec_batch_size = batch_size
|
||||||
|
if kie_batch_size is None:
|
||||||
|
kie_batch_size = batch_size
|
||||||
|
|
||||||
chunked_inputs = super(BaseMMOCRInferencer,
|
chunked_inputs = super(BaseMMOCRInferencer,
|
||||||
self)._get_chunk_data(ori_inputs, batch_size)
|
self)._get_chunk_data(ori_inputs, batch_size)
|
||||||
results = {'predictions': [], 'visualization': []}
|
results = {'predictions': [], 'visualization': []}
|
||||||
for ori_input in track(chunked_inputs, description='Inference'):
|
for ori_input in track(chunked_inputs, description='Inference'):
|
||||||
preds = self.forward(ori_input, batch_size, **forward_kwargs)
|
preds = self.forward(
|
||||||
|
ori_input,
|
||||||
|
det_batch_size=det_batch_size,
|
||||||
|
rec_batch_size=rec_batch_size,
|
||||||
|
kie_batch_size=kie_batch_size,
|
||||||
|
**forward_kwargs)
|
||||||
visualization = self.visualize(
|
visualization = self.visualize(
|
||||||
ori_input, preds, img_out_dir=img_out_dir, **visualize_kwargs)
|
ori_input, preds, img_out_dir=img_out_dir, **visualize_kwargs)
|
||||||
batch_res = self.postprocess(
|
batch_res = self.postprocess(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user