From 0c992540ed71dfb243f0c3185003a5101f093e90 Mon Sep 17 00:00:00 2001 From: "jiangnana.jnn" Date: Tue, 27 Sep 2022 15:39:49 +0800 Subject: [PATCH] fix predictor --- easycv/predictors/base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/easycv/predictors/base.py b/easycv/predictors/base.py index cb5e5e4a..6c64fc21 100644 --- a/easycv/predictors/base.py +++ b/easycv/predictors/base.py @@ -108,7 +108,7 @@ class PredictorV2(object): model_path (str): Path of model path. config_file (Optinal[str]): config file path for model and processor to init. Defaults to None. batch_size (int): batch size for forward. - device (str): Support 'cuda' or 'cpu', if is None, detect device automatically. + device (str | torch.device): Support str('cuda' or 'cpu') or torch.device, if is None, detect device automatically. save_results (bool): Whether to save predict results. save_path (str): File path for saving results, only valid when `save_results` is True. pipelines (list[dict]): Data pipeline configs. @@ -257,9 +257,14 @@ class PredictorV2(object): def postprocess(self, inputs, *args, **kwargs): """Process model batch outputs. + The "inputs" should be dict format as follows: + { + "key1": torch.Tensor or list, the first dimension should be batch_size, + "key2": torch.Tensor or list, the first dimension should be batch_size, + ... + } """ outputs = [] - out_i = {} batch_size = 1 # get current batch size for k, batch_v in inputs.items(): @@ -268,6 +273,7 @@ class PredictorV2(object): break for i in range(batch_size): + out_i = {} for k, batch_v in inputs.items(): if batch_v is not None: out_i[k] = batch_v[i]