From 37af545f6b43dd9df5695861893807bf48504e12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= Date: Fri, 20 Jan 2023 21:40:13 +0800 Subject: [PATCH] [Fix] Fix inference api and support setting palette to SegLocalVisualizer (#2475) as title Co-authored-by: MengzhangLI --- mmseg/apis/inference.py | 5 +++-- mmseg/visualization/local_visualizer.py | 16 +++++++++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 9abc85d62..d1cc54559 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -93,8 +93,9 @@ ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]] def _preprare_data(imgs: ImageType, model: BaseSegmentor): cfg = model.cfg - if dict(type='LoadAnnotations') in cfg.test_pipeline: - cfg.test_pipeline.remove(dict(type='LoadAnnotations')) + for t in cfg.test_pipeline: + if t.get('type') == 'LoadAnnotations': + cfg.test_pipeline.remove(t) is_batch = True if not isinstance(imgs, (list, tuple)): diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py index 070b06b73..27443f2c5 100644 --- a/mmseg/visualization/local_visualizer.py +++ b/mmseg/visualization/local_visualizer.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import mmcv import numpy as np @@ -9,6 +9,7 @@ from mmengine.visualization import Visualizer from mmseg.registry import VISUALIZERS from mmseg.structures import SegDataSample +from mmseg.utils import get_classes, get_palette @VISUALIZERS.register_module() @@ -55,14 +56,23 @@ class SegLocalVisualizer(Visualizer): image: Optional[np.ndarray] = None, vis_backends: Optional[Dict] = None, save_dir: Optional[str] = None, + palette: Optional[Union[str, List]] = None, + classes: Optional[Union[str, List]] = None, + dataset_name: Optional[str] = None, alpha: float = 0.8, **kwargs): super().__init__(name, image, vis_backends, save_dir, **kwargs) - self.alpha = alpha + self.alpha: float = alpha # Set default value. When calling # `SegLocalVisualizer().dataset_meta=xxx`, # it will override the default value. - self.dataset_meta = {} + if dataset_name is None: + dataset_name = 'cityscapes' + classes = classes if classes else get_classes(dataset_name) + palette = palette if palette else get_palette(dataset_name) + assert len(classes) == len( + palette), 'The length of classes should be equal to palette' + self.dataset_meta: dict = {'classes': classes, 'palette': palette} def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, classes: Optional[Tuple[str]],