[Fix] Fix mmseg.api.inference inference_segmentor (#1849)

* [Fix] Fix mmseg.api.inference inference_segmentor

Motivation
Fix inference_segmentor not working with multiple images path or images. List[str/ndarray]

Modification
- process images if instance is list

* fix typo

* Update mmseg/apis/inference.py

Co-authored-by: Hakjin Lee <nijkah@gmail.com>

Co-authored-by: Hakjin Lee <nijkah@gmail.com>
pull/2073/head
whooray 2022-09-14 01:13:43 +09:00 committed by GitHub
parent ca7c098767
commit ecd1ecb6ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 4 deletions

View File

@ -67,7 +67,7 @@ class LoadImage:
return results return results
def inference_segmentor(model, img): def inference_segmentor(model, imgs):
"""Inference image(s) with the segmentor. """Inference image(s) with the segmentor.
Args: Args:
@ -84,9 +84,13 @@ def inference_segmentor(model, img):
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline) test_pipeline = Compose(test_pipeline)
# prepare data # prepare data
data = dict(img=img) data = []
data = test_pipeline(data) imgs = imgs if isinstance(imgs, list) else [imgs]
data = collate([data], samples_per_gpu=1) for img in imgs:
img_data = dict(img=img)
img_data = test_pipeline(img_data)
data.append(img_data)
data = collate(data, samples_per_gpu=len(imgs))
if next(model.parameters()).is_cuda: if next(model.parameters()).is_cuda:
# scatter to specified GPU # scatter to specified GPU
data = scatter(data, [device])[0] data = scatter(data, [device])[0]