support img input as ndarray in inference.py (#87)

* support img input as ndarray

* revise according to comments
pull/91/head
LXXXXR 2020-11-19 18:58:25 +08:00 committed by GitHub
parent fe45b241c5
commit 149ee3a30d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 3 deletions

View File

@ -52,7 +52,7 @@ def inference_model(model, img):
Args:
model (nn.Module): The loaded classifier.
img (str/ndarray): The image filename.
img (str/ndarray): The image filename or loaded image.
Returns:
result (dict): The classification results that contains
@ -61,9 +61,15 @@ def inference_model(model, img):
cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline
if isinstance(img, str):
if cfg.data.test.pipeline[0]['type'] != 'LoadImageFromFile':
cfg.data.test.pipeline.insert(0, dict(type='LoadImageFromFile'))
data = dict(img_info=dict(filename=img), img_prefix=None)
else:
if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':
cfg.data.test.pipeline.pop(0)
data = dict(img=img)
test_pipeline = Compose(cfg.data.test.pipeline)
# prepare data
data = dict(img_info=dict(filename=img), img_prefix=None)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda: