From 43b4efb122f1c4e934ee2588f40210e8c34eed5f Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Thu, 9 Jun 2022 17:20:28 +0800 Subject: [PATCH] [Fix] Fix image_demo.py error (#1640) * [Fix] Fix image_demo.py error * [Fix] Fix image_demo.py error * fix * delete plt.cla() --- demo/image_demo.py | 4 +++- docs/en/get_started.md | 2 +- mmseg/apis/inference.py | 7 ++++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/demo/image_demo.py b/demo/image_demo.py index 05e1a7913..87d6d6c41 100644 --- a/demo/image_demo.py +++ b/demo/image_demo.py @@ -10,6 +10,7 @@ def main(): parser.add_argument('img', help='Image file') parser.add_argument('config', help='Config file') parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument('--out-file', default=None, help='Path to output file') parser.add_argument( '--device', default='cuda:0', help='Device used for inference') parser.add_argument( @@ -33,7 +34,8 @@ def main(): args.img, result, get_palette(args.palette), - opacity=args.opacity) + opacity=args.opacity, + out_file=args.out_file) if __name__ == '__main__': diff --git a/docs/en/get_started.md b/docs/en/get_started.md index f540982b1..bbe3d5795 100644 --- a/docs/en/get_started.md +++ b/docs/en/get_started.md @@ -80,7 +80,7 @@ The downloading will take several seconds or more, depending on your network env Option (a). If you install mmsegmentation from source, just run the following command. ```shell -python demo/image_demo.py demo/demo.jpg pspnet_r50-d8_512x1024_40k_cityscapes.py pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth --device cpu --out-file result.jpg +python demo/image_demo.py demo/demo.png configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth --device cuda:0 --out-file result.jpg ``` You will see a new image `result.jpg` on your current folder, where segmentation masks are covered on all objects. diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 906943804..a2a8ab0cb 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -106,7 +106,8 @@ def show_result_pyplot(model, fig_size=(15, 10), opacity=0.5, title='', - block=True): + block=True, + out_file=None): """Visualize the segmentation results on the image. Args: @@ -124,6 +125,8 @@ def show_result_pyplot(model, Default is ''. block (bool): Whether to block the pyplot figure. Default is True. + out_file (str or None): The path to write the image. + Default: None. """ if hasattr(model, 'module'): model = model.module @@ -134,3 +137,5 @@ def show_result_pyplot(model, plt.title(title) plt.tight_layout() plt.show(block=block) + if out_file is not None: + mmcv.imwrite(img, out_file)