From f1fa61a48ac0acad94d89ab6d924936059da8c5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= Date: Mon, 18 Sep 2023 19:24:03 +0800 Subject: [PATCH] [Fix] Fix inferencer (#3333) --- mmseg/apis/mmseg_inferencer.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/mmseg/apis/mmseg_inferencer.py b/mmseg/apis/mmseg_inferencer.py index 1c72285c5..01f16ae79 100644 --- a/mmseg/apis/mmseg_inferencer.py +++ b/mmseg/apis/mmseg_inferencer.py @@ -59,7 +59,9 @@ class MMSegInferencer(BaseInferencer): preprocess_kwargs: set = set() forward_kwargs: set = {'mode', 'out_dir'} - visualize_kwargs: set = {'show', 'wait_time', 'img_out_dir', 'opacity'} + visualize_kwargs: set = { + 'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis' + } postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'} def __init__(self, @@ -137,6 +139,7 @@ class MMSegInferencer(BaseInferencer): inputs: InputsType, return_datasamples: bool = False, batch_size: int = 1, + return_vis: bool = False, show: bool = False, wait_time: int = 0, out_dir: str = '', @@ -188,11 +191,13 @@ class MMSegInferencer(BaseInferencer): wait_time=wait_time, img_out_dir=img_out_dir, pred_out_dir=pred_out_dir, + return_vis=return_vis, **kwargs) def visualize(self, inputs: list, preds: List[dict], + return_vis: bool = False, show: bool = False, wait_time: int = 0, img_out_dir: str = '', @@ -213,12 +218,12 @@ class MMSegInferencer(BaseInferencer): Returns: List[np.ndarray]: Visualization results. """ - if self.visualizer is None or (not show and img_out_dir == ''): + if not show and img_out_dir == '' and not return_vis: return None - - if getattr(self, 'visualizer') is None: + if self.visualizer is None: raise ValueError('Visualization needs the "visualizer" term' - 'defined in the config, but got None') + 'defined in the config, but got None.') + self.visualizer.set_dataset_meta(**self.model.dataset_meta) self.visualizer.alpha = opacity @@ -250,10 +255,11 @@ class MMSegInferencer(BaseInferencer): draw_gt=False, draw_pred=True, out_file=out_file) - results.append(self.visualizer.get_image()) + if return_vis: + results.append(self.visualizer.get_image()) self.num_visualized_imgs += 1 - return results + return results if return_vis else None def postprocess(self, preds: PredType,