support img input as ndarray in inference.py (#87)
* support img input as ndarray * revise according to commentspull/91/head
parent
fe45b241c5
commit
149ee3a30d
|
@ -52,7 +52,7 @@ def inference_model(model, img):
|
|||
|
||||
Args:
|
||||
model (nn.Module): The loaded classifier.
|
||||
img (str/ndarray): The image filename.
|
||||
img (str/ndarray): The image filename or loaded image.
|
||||
|
||||
Returns:
|
||||
result (dict): The classification results that contains
|
||||
|
@ -61,9 +61,15 @@ def inference_model(model, img):
|
|||
cfg = model.cfg
|
||||
device = next(model.parameters()).device # model device
|
||||
# build the data pipeline
|
||||
if isinstance(img, str):
|
||||
if cfg.data.test.pipeline[0]['type'] != 'LoadImageFromFile':
|
||||
cfg.data.test.pipeline.insert(0, dict(type='LoadImageFromFile'))
|
||||
data = dict(img_info=dict(filename=img), img_prefix=None)
|
||||
else:
|
||||
if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':
|
||||
cfg.data.test.pipeline.pop(0)
|
||||
data = dict(img=img)
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
# prepare data
|
||||
data = dict(img_info=dict(filename=img), img_prefix=None)
|
||||
data = test_pipeline(data)
|
||||
data = collate([data], samples_per_gpu=1)
|
||||
if next(model.parameters()).is_cuda:
|
||||
|
|
Loading…
Reference in New Issue