mirror of https://github.com/alibaba/EasyCV.git
fix predictor
parent
2856af7f42
commit
0c992540ed
|
@ -108,7 +108,7 @@ class PredictorV2(object):
|
||||||
model_path (str): Path of model path.
|
model_path (str): Path of model path.
|
||||||
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
|
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
|
||||||
batch_size (int): batch size for forward.
|
batch_size (int): batch size for forward.
|
||||||
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
|
device (str | torch.device): Support str('cuda' or 'cpu') or torch.device, if is None, detect device automatically.
|
||||||
save_results (bool): Whether to save predict results.
|
save_results (bool): Whether to save predict results.
|
||||||
save_path (str): File path for saving results, only valid when `save_results` is True.
|
save_path (str): File path for saving results, only valid when `save_results` is True.
|
||||||
pipelines (list[dict]): Data pipeline configs.
|
pipelines (list[dict]): Data pipeline configs.
|
||||||
|
@ -257,9 +257,14 @@ class PredictorV2(object):
|
||||||
|
|
||||||
def postprocess(self, inputs, *args, **kwargs):
|
def postprocess(self, inputs, *args, **kwargs):
|
||||||
"""Process model batch outputs.
|
"""Process model batch outputs.
|
||||||
|
The "inputs" should be dict format as follows:
|
||||||
|
{
|
||||||
|
"key1": torch.Tensor or list, the first dimension should be batch_size,
|
||||||
|
"key2": torch.Tensor or list, the first dimension should be batch_size,
|
||||||
|
...
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
outputs = []
|
outputs = []
|
||||||
out_i = {}
|
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
# get current batch size
|
# get current batch size
|
||||||
for k, batch_v in inputs.items():
|
for k, batch_v in inputs.items():
|
||||||
|
@ -268,6 +273,7 @@ class PredictorV2(object):
|
||||||
break
|
break
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
out_i = {}
|
||||||
for k, batch_v in inputs.items():
|
for k, batch_v in inputs.items():
|
||||||
if batch_v is not None:
|
if batch_v is not None:
|
||||||
out_i[k] = batch_v[i]
|
out_i[k] = batch_v[i]
|
||||||
|
|
Loading…
Reference in New Issue