From 149ee3a30d6ea936fbc5b0374fa4760071588577 Mon Sep 17 00:00:00 2001 From: LXXXXR <73265258+LXXXXR@users.noreply.github.com> Date: Thu, 19 Nov 2020 18:58:25 +0800 Subject: [PATCH] support img input as ndarray in inference.py (#87) * support img input as ndarray * revise according to comments --- mmcls/apis/inference.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mmcls/apis/inference.py b/mmcls/apis/inference.py index 985d689ed..9511875ca 100644 --- a/mmcls/apis/inference.py +++ b/mmcls/apis/inference.py @@ -52,7 +52,7 @@ def inference_model(model, img): Args: model (nn.Module): The loaded classifier. - img (str/ndarray): The image filename. + img (str/ndarray): The image filename or loaded image. Returns: result (dict): The classification results that contains @@ -61,9 +61,15 @@ def inference_model(model, img): cfg = model.cfg device = next(model.parameters()).device # model device # build the data pipeline + if isinstance(img, str): + if cfg.data.test.pipeline[0]['type'] != 'LoadImageFromFile': + cfg.data.test.pipeline.insert(0, dict(type='LoadImageFromFile')) + data = dict(img_info=dict(filename=img), img_prefix=None) + else: + if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile': + cfg.data.test.pipeline.pop(0) + data = dict(img=img) test_pipeline = Compose(cfg.data.test.pipeline) - # prepare data - data = dict(img_info=dict(filename=img), img_prefix=None) data = test_pipeline(data) data = collate([data], samples_per_gpu=1) if next(model.parameters()).is_cuda: