[Fix] Fix image_demo.py error (#1640)

* [Fix] Fix image_demo.py error

* [Fix] Fix image_demo.py error

* fix

* delete plt.cla()
pull/1667/head
MengzhangLI 2022-06-09 17:20:28 +08:00 committed by GitHub
parent 775d05c54f
commit 43b4efb122
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 3 deletions

View File

@ -10,6 +10,7 @@ def main():
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('--out-file', default=None, help='Path to output file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
@ -33,7 +34,8 @@ def main():
args.img,
result,
get_palette(args.palette),
opacity=args.opacity)
opacity=args.opacity,
out_file=args.out_file)
if __name__ == '__main__':

View File

@ -80,7 +80,7 @@ The downloading will take several seconds or more, depending on your network env
Option (a). If you install mmsegmentation from source, just run the following command.
```shell
python demo/image_demo.py demo/demo.jpg pspnet_r50-d8_512x1024_40k_cityscapes.py pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth --device cpu --out-file result.jpg
python demo/image_demo.py demo/demo.png configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth --device cuda:0 --out-file result.jpg
```
You will see a new image `result.jpg` on your current folder, where segmentation masks are covered on all objects.

View File

@ -106,7 +106,8 @@ def show_result_pyplot(model,
fig_size=(15, 10),
opacity=0.5,
title='',
block=True):
block=True,
out_file=None):
"""Visualize the segmentation results on the image.
Args:
@ -124,6 +125,8 @@ def show_result_pyplot(model,
Default is ''.
block (bool): Whether to block the pyplot figure.
Default is True.
out_file (str or None): The path to write the image.
Default: None.
"""
if hasattr(model, 'module'):
model = model.module
@ -134,3 +137,5 @@ def show_result_pyplot(model,
plt.title(title)
plt.tight_layout()
plt.show(block=block)
if out_file is not None:
mmcv.imwrite(img, out_file)