From 8806b4e5486146e55b801ffa1a239211884cd2e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= Date: Mon, 3 Jul 2023 09:43:00 +0800 Subject: [PATCH] [Fix] Fix visualizor (#3154) ## Motivation **Current visualize result** ![rs-dev](https://github.com/open-mmlab/mmsegmentation/assets/15952744/147ea3f7-f632-457b-b257-031199320825) **Fixed the visualization result** ![rs-fix](https://github.com/open-mmlab/mmsegmentation/assets/15952744/98a86025-5a1e-4c2b-83e0-653dd659ba79) ## Modification remove mmengine `draw_binary_masks` api --- demo/inference_demo.ipynb | 9 ++++----- mmseg/apis/inference.py | 2 +- mmseg/visualization/local_visualizer.py | 14 +++++++------- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/demo/inference_demo.ipynb b/demo/inference_demo.ipynb index d26173fc6..455c5df4e 100644 --- a/demo/inference_demo.ipynb +++ b/demo/inference_demo.ipynb @@ -21,7 +21,6 @@ "outputs": [], "source": [ "import torch\n", - "import mmcv\n", "import matplotlib.pyplot as plt\n", "from mmengine.model.utils import revert_sync_batchnorm\n", "from mmseg.apis import init_model, inference_model, show_result_pyplot" @@ -48,7 +47,7 @@ "outputs": [], "source": [ "# build the model from a config file and a checkpoint file\n", - "model = init_model(config_file, checkpoint_file, device='cuda:0')" + "model = init_model(config_file, checkpoint_file, device='cpu')" ] }, { @@ -71,8 +70,8 @@ "outputs": [], "source": [ "# show the results\n", - "vis_result = show_result_pyplot(model, img, result)\n", - "plt.imshow(mmcv.bgr2rgb(vis_result))" + "vis_result = show_result_pyplot(model, img, result, show=False)\n", + "plt.imshow(vis_result)" ] }, { @@ -99,7 +98,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.10.11" }, "pycharm": { "stem_cell": { diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 4aadffc79..81cd17d79 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -187,7 +187,7 @@ def show_result_pyplot(model: BaseSegmentor, if hasattr(model, 'module'): model = model.module if isinstance(img, str): - image = mmcv.imread(img) + image = mmcv.imread(img, channel_order='rgb') else: image = img if save_dir is not None: diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py index 504004dfc..0d693e582 100644 --- a/mmseg/visualization/local_visualizer.py +++ b/mmseg/visualization/local_visualizer.py @@ -108,14 +108,14 @@ class SegLocalVisualizer(Visualizer): colors = [palette[label] for label in labels] - self.set_image(image) - - # draw semantic masks + mask = np.zeros_like(image, dtype=np.uint8) for label, color in zip(labels, colors): - self.draw_binary_masks( - sem_seg == label, colors=[color], alphas=self.alpha) + mask[sem_seg[0] == label, :] = color - return self.get_image() + color_seg = (image * (1 - self.alpha) + mask * self.alpha).astype( + np.uint8) + self.set_image(color_seg) + return color_seg def set_dataset_meta(self, classes: Optional[List] = None, @@ -226,6 +226,6 @@ class SegLocalVisualizer(Visualizer): self.show(drawn_img, win_name=name, wait_time=wait_time) if out_file is not None: - mmcv.imwrite(mmcv.bgr2rgb(drawn_img), out_file) + mmcv.imwrite(mmcv.rgb2bgr(drawn_img), out_file) else: self.add_image(name, drawn_img, step)