mirror of https://github.com/alibaba/EasyCV.git
fix predictor
parent
2856af7f42
commit
0c992540ed
|
@ -108,7 +108,7 @@ class PredictorV2(object):
|
|||
model_path (str): Path of model path.
|
||||
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
|
||||
batch_size (int): batch size for forward.
|
||||
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
|
||||
device (str | torch.device): Support str('cuda' or 'cpu') or torch.device, if is None, detect device automatically.
|
||||
save_results (bool): Whether to save predict results.
|
||||
save_path (str): File path for saving results, only valid when `save_results` is True.
|
||||
pipelines (list[dict]): Data pipeline configs.
|
||||
|
@ -257,9 +257,14 @@ class PredictorV2(object):
|
|||
|
||||
def postprocess(self, inputs, *args, **kwargs):
|
||||
"""Process model batch outputs.
|
||||
The "inputs" should be dict format as follows:
|
||||
{
|
||||
"key1": torch.Tensor or list, the first dimension should be batch_size,
|
||||
"key2": torch.Tensor or list, the first dimension should be batch_size,
|
||||
...
|
||||
}
|
||||
"""
|
||||
outputs = []
|
||||
out_i = {}
|
||||
batch_size = 1
|
||||
# get current batch size
|
||||
for k, batch_v in inputs.items():
|
||||
|
@ -268,6 +273,7 @@ class PredictorV2(object):
|
|||
break
|
||||
|
||||
for i in range(batch_size):
|
||||
out_i = {}
|
||||
for k, batch_v in inputs.items():
|
||||
if batch_v is not None:
|
||||
out_i[k] = batch_v[i]
|
||||
|
|
Loading…
Reference in New Issue