mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix inference api and support setting palette to SegLocalVisualizer (#2475)
as title Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
This commit is contained in:
parent
7fc8ca0312
commit
37af545f6b
@ -93,8 +93,9 @@ ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
|
||||
def _preprare_data(imgs: ImageType, model: BaseSegmentor):
|
||||
|
||||
cfg = model.cfg
|
||||
if dict(type='LoadAnnotations') in cfg.test_pipeline:
|
||||
cfg.test_pipeline.remove(dict(type='LoadAnnotations'))
|
||||
for t in cfg.test_pipeline:
|
||||
if t.get('type') == 'LoadAnnotations':
|
||||
cfg.test_pipeline.remove(t)
|
||||
|
||||
is_batch = True
|
||||
if not isinstance(imgs, (list, tuple)):
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
@ -9,6 +9,7 @@ from mmengine.visualization import Visualizer
|
||||
|
||||
from mmseg.registry import VISUALIZERS
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import get_classes, get_palette
|
||||
|
||||
|
||||
@VISUALIZERS.register_module()
|
||||
@ -55,14 +56,23 @@ class SegLocalVisualizer(Visualizer):
|
||||
image: Optional[np.ndarray] = None,
|
||||
vis_backends: Optional[Dict] = None,
|
||||
save_dir: Optional[str] = None,
|
||||
palette: Optional[Union[str, List]] = None,
|
||||
classes: Optional[Union[str, List]] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
alpha: float = 0.8,
|
||||
**kwargs):
|
||||
super().__init__(name, image, vis_backends, save_dir, **kwargs)
|
||||
self.alpha = alpha
|
||||
self.alpha: float = alpha
|
||||
# Set default value. When calling
|
||||
# `SegLocalVisualizer().dataset_meta=xxx`,
|
||||
# it will override the default value.
|
||||
self.dataset_meta = {}
|
||||
if dataset_name is None:
|
||||
dataset_name = 'cityscapes'
|
||||
classes = classes if classes else get_classes(dataset_name)
|
||||
palette = palette if palette else get_palette(dataset_name)
|
||||
assert len(classes) == len(
|
||||
palette), 'The length of classes should be equal to palette'
|
||||
self.dataset_meta: dict = {'classes': classes, 'palette': palette}
|
||||
|
||||
def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,
|
||||
classes: Optional[Tuple[str]],
|
||||
|
Loading…
x
Reference in New Issue
Block a user