From 2eaf7ee8b3342e29a04fd63e45ba5db30ec64f3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= Date: Wed, 14 Sep 2022 16:58:19 +0800 Subject: [PATCH] [Doc] Update inference doc (#2052) * refactor inference doc * introduce segdatasample * add vis example * fix * Update docs/en/user_guides/3_inference.md * update link Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> --- docs/en/user_guides/3_inference.md | 142 ++++++++++++++++++++++++++--- mmseg/apis/inference.py | 6 +- 2 files changed, 135 insertions(+), 13 deletions(-) diff --git a/docs/en/user_guides/3_inference.md b/docs/en/user_guides/3_inference.md index 816c61c43..6b6f6f7f3 100644 --- a/docs/en/user_guides/3_inference.md +++ b/docs/en/user_guides/3_inference.md @@ -2,24 +2,142 @@ MMSegmentation provides pre-trained models for semantic segmentation in [Model Zoo](../model_zoo.md), and supports multiple standard datasets, including Cityscapes, ADE20K, etc. This note will show how to use existing models to inference on given images. -As for how to test existing models on standard datasets, please see this [guide](./4_train_test.md#Test-models-on-standard-datasets) +As for how to test existing models on standard datasets, please see this [guide](./4_train_test.md) -## Inference on given images +## Inference API -MMSegmentation provides high-level Python APIs for inference on images. Here is an example of building the model and inference on given images. -Please download the [pre-trained model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes/pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth) to the path specified by `checkpoint_file` first. +MMSegmentation provides several interfaces for users to easily use pre-trained models for inference. + +- [mmseg.apis.init_model](#mmsegapisinit_model) +- [mmseg.apis.inference_model](#mmsegapisinference_model) +- [mmseg.apis.show_result_pyplot](#mmsegapisshow_result_pyplot) + +### mmseg.apis.init_model + +Initialize a segmentor from config file. + +Parameters: + +- config (str, `Path`, or `mmengine.Config`) - Config file path or the config object. +- checkpoint (str, optional) - Checkpoint path. If left as None, the model will not load any weights. +- device (str, optional) - CPU/CUDA device option. Default 'cuda:0'. +- cfg_options (dict, optional) - Options to override some settings in the used config. + +Returns: + +- nn.Module: The constructed segmentor. + +Example: + +```python +from mmseg.apis import init_model +from mmseg.utils import register_all_modules + +config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py' +checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth' + +# register all modules in mmseg into the registries +register_all_modules() + +# initialize model without checkpoint +model = init_model(config_path) + +# init model and load checkpoint +model = init_model(config_path, checkpoint_path) + +# init model and load checkpoint on CPU +model = init_model(config_path, checkpoint_path, 'cpu') +``` + +### mmseg.apis.inference_model + +Inference image(s) with the segmentor. + +Parameters: + +- model (nn.Module) - The loaded segmentor +- imgs (str, np.ndarray, or list\[str/np.ndarray\]) - Either image files or loaded images + +Returns: + +- `SegDataSample` or list\[`SegDataSample`\]: If imgs is a list or tuple, the same length list type results will be returned, otherwise return the segmentation results directly. + +**Note:** [SegDataSample](https://github.com/open-mmlab/mmsegmentation/blob/1.x/mmseg/structures/seg_data_sample.py) is a data structure interface of MMSegmentation, it is used as interfaces between different components. `SegDataSample` implement the abstract data element `mmengine.structures.BaseDataElement`, please refer to data element [documentation](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/data_element.html) in [MMEngine](https://github.com/open-mmlab/mmengine) for more information. + +The attributes in `SegDataSample` are divided into several parts: + +- `gt_sem_seg` (`PixelData`) - Ground truth of semantic segmentation. +- `pred_sem_seg` (`PixelData`) - Prediction of semantic segmentation. +- `seg_logits` (`PixelData`) - Predicted logits of semantic segmentation. + +**Note** [PixelData](https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/pixel_data.py) is the data structure for pixel-level annotations or predictions, please refer to PixelData [documentation](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/data_element.html) in [MMEngine](https://github.com/open-mmlab/mmengine) for more information. + +Example: ```python from mmseg.apis import init_model, inference_model from mmseg.utils import register_all_modules -# Specify the path to model config and checkpoint file -config_file = 'configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py' -checkpoint_file = 'checkpoints/pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth' + +config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py' +checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth' +img_path = 'demo/demo.png' + # register all modules in mmseg into the registries register_all_modules() -# build the model from a config file and a checkpoint file -model = init_model(config_file, checkpoint_file, device='cuda:0') -# test image pair, and save the results -img = 'demo/demo.png' -result = inference_model(model, img) + +model = init_model(config_path, checkpoint_path) +result = inference_model(model, img_path) ``` + +### mmseg.apis.show_result_pyplot + +Visualize the segmentation results on the image. + +Parameters: + +- model (nn.Module) - The loaded segmentor. +- img (str or np.ndarray) - Image filename or loaded image. +- result (`SegDataSample`) - The prediction SegDataSample result. +- opacity (float) - Opacity of painted segmentation map. Default `0.5`, must be in `(0, 1]` range. +- title (str) - The title of pyplot figure. Default is ''. +- draw_gt (bool) - Whether to draw GT SegDataSample. Default to `True`. +- draw_pred (draws_pred) - Whether to draw Prediction SegDataSample. Default to `True`. +- wait_time (float) - The interval of show (s), 0 is the special value that means "forever". Default to `0`. +- show (bool) - Whether to display the drawn image. Default to `True`. +- save_dir (str, optional) - Save file dir for all storage backends. If it is `None`, the backend storage will not save any data. +- out_file (str, optional) - Path to output file. Default to `None`. + +Returns: + +- np.ndarray: the drawn image which channel is RGB. + +Example: + +```python +from mmseg.apis import init_model, inference_model, show_result_pyplot +from mmseg.utils import register_all_modules + +config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py' +checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth' +img_path = 'demo/demo.png' + +# register all modules in mmseg into the registries +register_all_modules() + +# build the model from a config file and a checkpoint file +model = init_model(config_path, checkpoint_path, device='cuda:0') + +# inference on given image +result = inference_model(model, img_path) + +# display the segmentation result +vis_image = show_result_pyplot(model, img_path, result) + +# save the visualization result, the output image would be found at the path `work_dirs/result.png` +vis_iamge = show_result_pyplot(model, img_path, result, out_file='work_dirs/result.png') + +# Modify the time of displaying images, note that 0 is the special value that means "forever". +vis_image = show_result_pyplot(model, img_path, result, wait_time=5) +``` + +**Note:** If your current device doesn't have graphical user interface, it is recommended that setting `show` to `False` and specify the `out_file` or `save_dir` to save the results. If you would like to display the result on a window, no special settings are required. diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 552ea7a83..9abc85d62 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -169,12 +169,16 @@ def show_result_pyplot(model: BaseSegmentor, draw_gt (bool): Whether to draw GT SegDataSample. Default to True. draw_pred (bool): Whether to draw Prediction SegDataSample. Defaults to True. - wait_time (float): The interval of show (s). Defaults to 0. + wait_time (float): The interval of show (s). 0 is the special value + that means "forever". Defaults to 0. show (bool): Whether to display the drawn image. Default to True. save_dir (str, optional): Save file dir for all storage backends. If it is None, the backend storage will not save any data. out_file (str, optional): Path to output file. Default to None. + + Returns: + np.ndarray: the drawn image which channel is RGB. """ if hasattr(model, 'module'): model = model.module