204 lines
7.1 KiB
Python
Raw Normal View History

import argparse
import time
from typing import List, Tuple
import cv2
import loguru
import numpy as np
import onnxruntime as ort
logger = loguru.logger
def parse_args():
parser = argparse.ArgumentParser(
description='PP_Mobileseg ONNX inference demo.')
parser.add_argument('onnx_file', help='ONNX file path')
parser.add_argument('image_file', help='Input image file path')
parser.add_argument(
'--input-size',
type=int,
nargs='+',
default=[512, 512],
help='input image size')
parser.add_argument(
'--device', help='device type for inference', default='cpu')
parser.add_argument(
'--save-path',
help='path to save the output image',
default='output.jpg')
args = parser.parse_args()
return args
def preprocess(
img: np.ndarray, input_size: Tuple[int, int] = (512, 512)
) -> Tuple[np.ndarray, np.ndarray]:
"""Preprocess image for inference."""
img_shape = img.shape[:2]
# Resize
resized_img = cv2.resize(img, input_size)
# Normalize
mean = np.array([123.575, 116.28, 103.53], dtype=np.float32)
std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
resized_img = (resized_img - mean) / std
return resized_img, img_shape
def build_session(onnx_file: str, device: str = 'cpu') -> ort.InferenceSession:
"""Build onnxruntime session.
Args:
onnx_file (str): ONNX file path.
device (str): Device type for inference.
Returns:
sess (ort.InferenceSession): ONNXRuntime session.
"""
providers = ['CPUExecutionProvider'
] if device == 'cpu' else ['CUDAExecutionProvider']
sess = ort.InferenceSession(path_or_bytes=onnx_file, providers=providers)
return sess
def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
"""Inference RTMPose model.
Args:
sess (ort.InferenceSession): ONNXRuntime session.
img (np.ndarray): Input image in shape.
Returns:
outputs (np.ndarray): Output of RTMPose model.
"""
# build input
input_img = [img.transpose(2, 0, 1).astype(np.float32)]
# build output
sess_input = {sess.get_inputs()[0].name: input_img}
sess_output = []
for out in sess.get_outputs():
sess_output.append(out.name)
# inference
outputs = sess.run(output_names=sess_output, input_feed=sess_input)
return outputs
def postprocess(outputs: List[np.ndarray],
origin_shape: Tuple[int, int]) -> np.ndarray:
"""Postprocess outputs of PP_Mobileseg model.
Args:
outputs (List[np.ndarray]): Outputs of PP_Mobileseg model.
origin_shape (Tuple[int, int]): Input size of PP_Mobileseg model.
Returns:
seg_map (np.ndarray): Segmentation map.
"""
seg_map = outputs[0][0][0]
seg_map = cv2.resize(seg_map.astype(np.float32), origin_shape)
return seg_map
def visualize(img: np.ndarray,
seg_map: np.ndarray,
filename: str = 'output.jpg',
opacity: float = 0.8) -> np.ndarray:
assert 0.0 <= opacity <= 1.0, 'opacity should be in range [0, 1]'
palette = np.array(PALETTE)
color_seg = np.zeros((seg_map.shape[0], seg_map.shape[1], 3),
dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg_map == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]
img = img * (1 - opacity) + color_seg * opacity
cv2.imwrite(filename, img)
return img
def main():
args = parse_args()
logger.info('Start running model inference...')
# read image from file
logger.info(f'1. Read image from file {args.image_file}...')
img = cv2.imread(args.image_file)
# build onnx model
logger.info(f'2. Build onnx model from {args.onnx_file}...')
sess = build_session(args.onnx_file, args.device)
# preprocess
logger.info('3. Preprocess image...')
model_input_size = tuple(args.input_size)
assert len(model_input_size) == 2
resized_img, origin_shape = preprocess(img, model_input_size)
# inference
logger.info('4. Inference...')
start = time.time()
outputs = inference(sess, resized_img)
logger.info(f'Inference time: {time.time() - start:.4f}s')
# postprocess
logger.info('5. Postprocess...')
h, w = origin_shape
seg_map = postprocess(outputs, (w, h))
# visualize
logger.info('6. Visualize...')
visualize(img, seg_map, args.save_path)
logger.info('Done...')
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
[102, 255, 0], [92, 0, 255]]
if __name__ == '__main__':
main()