[Fix] Fix mmseg.api.inference inference_segmentor (#1849)

* [Fix] Fix mmseg.api.inference inference_segmentor

Motivation
Fix inference_segmentor not working with multiple images path or images. List[str/ndarray]

Modification
- process images if instance is list

* fix typo

* Update mmseg/apis/inference.py

Co-authored-by: Hakjin Lee <nijkah@gmail.com>

Co-authored-by: Hakjin Lee <nijkah@gmail.com>
pull/2073/head
whooray 2022-09-14 01:13:43 +09:00 committed by GitHub
parent ca7c098767
commit ecd1ecb6ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 4 deletions

View File

@ -67,7 +67,7 @@ class LoadImage:
return results
def inference_segmentor(model, img):
def inference_segmentor(model, imgs):
"""Inference image(s) with the segmentor.
Args:
@ -84,9 +84,13 @@ def inference_segmentor(model, img):
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
data = []
imgs = imgs if isinstance(imgs, list) else [imgs]
for img in imgs:
img_data = dict(img=img)
img_data = test_pipeline(img_data)
data.append(img_data)
data = collate(data, samples_per_gpu=len(imgs))
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]