support img input as ndarray in inference.py (#87)
* support img input as ndarray * revise according to commentspull/91/head
parent
fe45b241c5
commit
149ee3a30d
|
@ -52,7 +52,7 @@ def inference_model(model, img):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): The loaded classifier.
|
model (nn.Module): The loaded classifier.
|
||||||
img (str/ndarray): The image filename.
|
img (str/ndarray): The image filename or loaded image.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
result (dict): The classification results that contains
|
result (dict): The classification results that contains
|
||||||
|
@ -61,9 +61,15 @@ def inference_model(model, img):
|
||||||
cfg = model.cfg
|
cfg = model.cfg
|
||||||
device = next(model.parameters()).device # model device
|
device = next(model.parameters()).device # model device
|
||||||
# build the data pipeline
|
# build the data pipeline
|
||||||
|
if isinstance(img, str):
|
||||||
|
if cfg.data.test.pipeline[0]['type'] != 'LoadImageFromFile':
|
||||||
|
cfg.data.test.pipeline.insert(0, dict(type='LoadImageFromFile'))
|
||||||
|
data = dict(img_info=dict(filename=img), img_prefix=None)
|
||||||
|
else:
|
||||||
|
if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':
|
||||||
|
cfg.data.test.pipeline.pop(0)
|
||||||
|
data = dict(img=img)
|
||||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||||
# prepare data
|
|
||||||
data = dict(img_info=dict(filename=img), img_prefix=None)
|
|
||||||
data = test_pipeline(data)
|
data = test_pipeline(data)
|
||||||
data = collate([data], samples_per_gpu=1)
|
data = collate([data], samples_per_gpu=1)
|
||||||
if next(model.parameters()).is_cuda:
|
if next(model.parameters()).is_cuda:
|
||||||
|
|
Loading…
Reference in New Issue