fix predictor

pull/207/head
jiangnana.jnn 2022-09-27 15:39:49 +08:00
parent 2856af7f42
commit 0c992540ed
1 changed files with 8 additions and 2 deletions

View File

@ -108,7 +108,7 @@ class PredictorV2(object):
model_path (str): Path of model path. model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None. config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward. batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically. device (str | torch.device): Support str('cuda' or 'cpu') or torch.device, if is None, detect device automatically.
save_results (bool): Whether to save predict results. save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True. save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs. pipelines (list[dict]): Data pipeline configs.
@ -257,9 +257,14 @@ class PredictorV2(object):
def postprocess(self, inputs, *args, **kwargs): def postprocess(self, inputs, *args, **kwargs):
"""Process model batch outputs. """Process model batch outputs.
The "inputs" should be dict format as follows:
{
"key1": torch.Tensor or list, the first dimension should be batch_size,
"key2": torch.Tensor or list, the first dimension should be batch_size,
...
}
""" """
outputs = [] outputs = []
out_i = {}
batch_size = 1 batch_size = 1
# get current batch size # get current batch size
for k, batch_v in inputs.items(): for k, batch_v in inputs.items():
@ -268,6 +273,7 @@ class PredictorV2(object):
break break
for i in range(batch_size): for i in range(batch_size):
out_i = {}
for k, batch_v in inputs.items(): for k, batch_v in inputs.items():
if batch_v is not None: if batch_v is not None:
out_i[k] = batch_v[i] out_i[k] = batch_v[i]