mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Project] Add pp_mobileseg onnx inference demo (#3268)
## Motivation Add a model deployment example. ## Modification Add an inference script and update the README. ## BC-breaking (Optional) None ## Use cases (Optional) In README.
This commit is contained in:
parent
0391bc4998
commit
92774182ba
@ -43,6 +43,63 @@ Same as other models in MMsegmentation, you can run the following command to tes
|
||||
./tools/dist_test.sh projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_base.py checkpoints/pp_mobileseg_mobilenetv3_2xb16_3rdparty-base_512x512-ade20k-f12b44f3.pth 8
|
||||
```
|
||||
|
||||
## Inference with ONNXRuntime
|
||||
|
||||
### Prerequisites
|
||||
|
||||
**1. Install onnxruntime inference engine.**
|
||||
|
||||
Choose one of the following ways to install onnxruntime.
|
||||
|
||||
- CPU version
|
||||
|
||||
```shell
|
||||
pip install onnxruntime==1.15.1
|
||||
wget https://github.com/microsoft/onnxruntime/releases/download/v1.15.1/onnxruntime-linux-x64-1.15.1.tgz
|
||||
tar -zxvf onnxruntime-linux-x64-1.15.1.tgz
|
||||
export ONNXRUNTIME_DIR=$(pwd)/onnxruntime-linux-x64-1.15.1
|
||||
export LD_LIBRARY_PATH=$ONNXRUNTIME_DIR/lib:$LD_LIBRARY_PATH
|
||||
```
|
||||
|
||||
**2. Convert model to onnx file**
|
||||
|
||||
- Install `mim` and `mmdeploy`.
|
||||
|
||||
```shell
|
||||
pip install openmim
|
||||
mim install mmdeploy
|
||||
git clone https://github.com/open-mmlab/mmdeploy.git
|
||||
```
|
||||
|
||||
- Download pp_mobileseg model.
|
||||
|
||||
```shell
|
||||
wget https://download.openmmlab.com/mmsegmentation/v0.5/pp_mobileseg/pp_mobileseg_mobilenetv3_2xb16_3rdparty-tiny_512x512-ade20k-a351ebf5.pth
|
||||
```
|
||||
|
||||
- Convert model to onnx files.
|
||||
|
||||
```shell
|
||||
python mmdeploy/tools/deploy.py mmdeploy/configs/mmseg/segmentation_onnxruntime_dynamic.py \
|
||||
configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_tiny.py \
|
||||
pp_mobileseg_mobilenetv3_2xb16_3rdparty-tiny_512x512-ade20k-a351ebf5.pth \
|
||||
../../demo/demo.png \
|
||||
--work-dir mmdeploy_model/mmseg/ort \
|
||||
--show
|
||||
```
|
||||
|
||||
**3. Run demo**
|
||||
|
||||
```shell
|
||||
python inference_onnx.py ${ONNX_FILE_PATH} ${IMAGE_PATH} [${MODEL_INPUT_SIZE} ${DEVICE} ${OUTPUT_IMAGE_PATH}]
|
||||
```
|
||||
|
||||
Example:
|
||||
|
||||
```shell
|
||||
python inference_onnx.py mmdeploy_model/mmseg/ort/end2end.onnx ../../demo/demo.png
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
If you find our project useful in your research, please consider citing:
|
||||
|
203
projects/pp_mobileseg/inference_onnx.py
Normal file
203
projects/pp_mobileseg/inference_onnx.py
Normal file
@ -0,0 +1,203 @@
|
||||
import argparse
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
import cv2
|
||||
import loguru
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
logger = loguru.logger
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='PP_Mobileseg ONNX inference demo.')
|
||||
parser.add_argument('onnx_file', help='ONNX file path')
|
||||
parser.add_argument('image_file', help='Input image file path')
|
||||
parser.add_argument(
|
||||
'--input-size',
|
||||
type=int,
|
||||
nargs='+',
|
||||
default=[512, 512],
|
||||
help='input image size')
|
||||
parser.add_argument(
|
||||
'--device', help='device type for inference', default='cpu')
|
||||
parser.add_argument(
|
||||
'--save-path',
|
||||
help='path to save the output image',
|
||||
default='output.jpg')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def preprocess(
|
||||
img: np.ndarray, input_size: Tuple[int, int] = (512, 512)
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Preprocess image for inference."""
|
||||
img_shape = img.shape[:2]
|
||||
# Resize
|
||||
resized_img = cv2.resize(img, input_size)
|
||||
|
||||
# Normalize
|
||||
mean = np.array([123.575, 116.28, 103.53], dtype=np.float32)
|
||||
std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
|
||||
resized_img = (resized_img - mean) / std
|
||||
|
||||
return resized_img, img_shape
|
||||
|
||||
|
||||
def build_session(onnx_file: str, device: str = 'cpu') -> ort.InferenceSession:
|
||||
"""Build onnxruntime session.
|
||||
|
||||
Args:
|
||||
onnx_file (str): ONNX file path.
|
||||
device (str): Device type for inference.
|
||||
|
||||
Returns:
|
||||
sess (ort.InferenceSession): ONNXRuntime session.
|
||||
"""
|
||||
providers = ['CPUExecutionProvider'
|
||||
] if device == 'cpu' else ['CUDAExecutionProvider']
|
||||
sess = ort.InferenceSession(path_or_bytes=onnx_file, providers=providers)
|
||||
|
||||
return sess
|
||||
|
||||
|
||||
def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
|
||||
"""Inference RTMPose model.
|
||||
|
||||
Args:
|
||||
sess (ort.InferenceSession): ONNXRuntime session.
|
||||
img (np.ndarray): Input image in shape.
|
||||
|
||||
Returns:
|
||||
outputs (np.ndarray): Output of RTMPose model.
|
||||
"""
|
||||
# build input
|
||||
input_img = [img.transpose(2, 0, 1).astype(np.float32)]
|
||||
|
||||
# build output
|
||||
sess_input = {sess.get_inputs()[0].name: input_img}
|
||||
sess_output = []
|
||||
for out in sess.get_outputs():
|
||||
sess_output.append(out.name)
|
||||
|
||||
# inference
|
||||
outputs = sess.run(output_names=sess_output, input_feed=sess_input)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def postprocess(outputs: List[np.ndarray],
|
||||
origin_shape: Tuple[int, int]) -> np.ndarray:
|
||||
"""Postprocess outputs of PP_Mobileseg model.
|
||||
|
||||
Args:
|
||||
outputs (List[np.ndarray]): Outputs of PP_Mobileseg model.
|
||||
origin_shape (Tuple[int, int]): Input size of PP_Mobileseg model.
|
||||
|
||||
Returns:
|
||||
seg_map (np.ndarray): Segmentation map.
|
||||
"""
|
||||
seg_map = outputs[0][0][0]
|
||||
seg_map = cv2.resize(seg_map.astype(np.float32), origin_shape)
|
||||
return seg_map
|
||||
|
||||
|
||||
def visualize(img: np.ndarray,
|
||||
seg_map: np.ndarray,
|
||||
filename: str = 'output.jpg',
|
||||
opacity: float = 0.8) -> np.ndarray:
|
||||
assert 0.0 <= opacity <= 1.0, 'opacity should be in range [0, 1]'
|
||||
palette = np.array(PALETTE)
|
||||
color_seg = np.zeros((seg_map.shape[0], seg_map.shape[1], 3),
|
||||
dtype=np.uint8)
|
||||
for label, color in enumerate(palette):
|
||||
color_seg[seg_map == label, :] = color
|
||||
# convert to BGR
|
||||
color_seg = color_seg[..., ::-1]
|
||||
|
||||
img = img * (1 - opacity) + color_seg * opacity
|
||||
cv2.imwrite(filename, img)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
logger.info('Start running model inference...')
|
||||
|
||||
# read image from file
|
||||
logger.info(f'1. Read image from file {args.image_file}...')
|
||||
img = cv2.imread(args.image_file)
|
||||
|
||||
# build onnx model
|
||||
logger.info(f'2. Build onnx model from {args.onnx_file}...')
|
||||
sess = build_session(args.onnx_file, args.device)
|
||||
|
||||
# preprocess
|
||||
logger.info('3. Preprocess image...')
|
||||
model_input_size = tuple(args.input_size)
|
||||
assert len(model_input_size) == 2
|
||||
resized_img, origin_shape = preprocess(img, model_input_size)
|
||||
|
||||
# inference
|
||||
logger.info('4. Inference...')
|
||||
start = time.time()
|
||||
outputs = inference(sess, resized_img)
|
||||
logger.info(f'Inference time: {time.time() - start:.4f}s')
|
||||
|
||||
# postprocess
|
||||
logger.info('5. Postprocess...')
|
||||
h, w = origin_shape
|
||||
seg_map = postprocess(outputs, (w, h))
|
||||
|
||||
# visualize
|
||||
logger.info('6. Visualize...')
|
||||
visualize(img, seg_map, args.save_path)
|
||||
|
||||
logger.info('Done...')
|
||||
|
||||
|
||||
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
||||
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
||||
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
||||
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
||||
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
||||
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
||||
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
||||
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
||||
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
||||
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
||||
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
||||
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
||||
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
||||
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
||||
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
|
||||
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
|
||||
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
|
||||
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
|
||||
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
|
||||
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
|
||||
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
|
||||
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
|
||||
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
|
||||
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
|
||||
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
|
||||
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
|
||||
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
|
||||
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
|
||||
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
|
||||
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
|
||||
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
|
||||
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
|
||||
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
|
||||
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
|
||||
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
||||
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
||||
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
||||
[102, 255, 0], [92, 0, 255]]
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user