[Fix] Fix inference api and support setting palette to SegLocalVisualizer (#2475)

as title

Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
This commit is contained in:
谢昕辰 2023-01-20 21:40:13 +08:00 committed by GitHub
parent 7fc8ca0312
commit 37af545f6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 5 deletions

View File

@ -93,8 +93,9 @@ ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
def _preprare_data(imgs: ImageType, model: BaseSegmentor):
cfg = model.cfg
if dict(type='LoadAnnotations') in cfg.test_pipeline:
cfg.test_pipeline.remove(dict(type='LoadAnnotations'))
for t in cfg.test_pipeline:
if t.get('type') == 'LoadAnnotations':
cfg.test_pipeline.remove(t)
is_batch = True
if not isinstance(imgs, (list, tuple)):

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import mmcv
import numpy as np
@ -9,6 +9,7 @@ from mmengine.visualization import Visualizer
from mmseg.registry import VISUALIZERS
from mmseg.structures import SegDataSample
from mmseg.utils import get_classes, get_palette
@VISUALIZERS.register_module()
@ -55,14 +56,23 @@ class SegLocalVisualizer(Visualizer):
image: Optional[np.ndarray] = None,
vis_backends: Optional[Dict] = None,
save_dir: Optional[str] = None,
palette: Optional[Union[str, List]] = None,
classes: Optional[Union[str, List]] = None,
dataset_name: Optional[str] = None,
alpha: float = 0.8,
**kwargs):
super().__init__(name, image, vis_backends, save_dir, **kwargs)
self.alpha = alpha
self.alpha: float = alpha
# Set default value. When calling
# `SegLocalVisualizer().dataset_meta=xxx`,
# it will override the default value.
self.dataset_meta = {}
if dataset_name is None:
dataset_name = 'cityscapes'
classes = classes if classes else get_classes(dataset_name)
palette = palette if palette else get_palette(dataset_name)
assert len(classes) == len(
palette), 'The length of classes should be equal to palette'
self.dataset_meta: dict = {'classes': classes, 'palette': palette}
def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
classes: Optional[Tuple[str]],