mirror of https://github.com/WongKinYiu/yolov7.git
335 lines
14 KiB
Python
335 lines
14 KiB
Python
|
#!/usr/bin/env python
|
||
|
|
||
|
import argparse
|
||
|
import numpy as np
|
||
|
import sys
|
||
|
import cv2
|
||
|
|
||
|
import tritonclient.grpc as grpcclient
|
||
|
from tritonclient.utils import InferenceServerException
|
||
|
|
||
|
from processing import preprocess, postprocess
|
||
|
from render import render_box, render_filled_box, get_text_size, render_text, RAND_COLORS
|
||
|
from labels import COCOLabels
|
||
|
|
||
|
INPUT_NAMES = ["images"]
|
||
|
OUTPUT_NAMES = ["num_dets", "det_boxes", "det_scores", "det_classes"]
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('mode',
|
||
|
choices=['dummy', 'image', 'video'],
|
||
|
default='dummy',
|
||
|
help='Run mode. \'dummy\' will send an emtpy buffer to the server to test if inference works. \'image\' will process an image. \'video\' will process a video.')
|
||
|
parser.add_argument('input',
|
||
|
type=str,
|
||
|
nargs='?',
|
||
|
help='Input file to load from in image or video mode')
|
||
|
parser.add_argument('-m',
|
||
|
'--model',
|
||
|
type=str,
|
||
|
required=False,
|
||
|
default='yolov7',
|
||
|
help='Inference model name, default yolov7')
|
||
|
parser.add_argument('--width',
|
||
|
type=int,
|
||
|
required=False,
|
||
|
default=640,
|
||
|
help='Inference model input width, default 640')
|
||
|
parser.add_argument('--height',
|
||
|
type=int,
|
||
|
required=False,
|
||
|
default=640,
|
||
|
help='Inference model input height, default 640')
|
||
|
parser.add_argument('-u',
|
||
|
'--url',
|
||
|
type=str,
|
||
|
required=False,
|
||
|
default='localhost:8001',
|
||
|
help='Inference server URL, default localhost:8001')
|
||
|
parser.add_argument('-o',
|
||
|
'--out',
|
||
|
type=str,
|
||
|
required=False,
|
||
|
default='',
|
||
|
help='Write output into file instead of displaying it')
|
||
|
parser.add_argument('-f',
|
||
|
'--fps',
|
||
|
type=float,
|
||
|
required=False,
|
||
|
default=24.0,
|
||
|
help='Video output fps, default 24.0 FPS')
|
||
|
parser.add_argument('-i',
|
||
|
'--model-info',
|
||
|
action="store_true",
|
||
|
required=False,
|
||
|
default=False,
|
||
|
help='Print model status, configuration and statistics')
|
||
|
parser.add_argument('-v',
|
||
|
'--verbose',
|
||
|
action="store_true",
|
||
|
required=False,
|
||
|
default=False,
|
||
|
help='Enable verbose client output')
|
||
|
parser.add_argument('-t',
|
||
|
'--client-timeout',
|
||
|
type=float,
|
||
|
required=False,
|
||
|
default=None,
|
||
|
help='Client timeout in seconds, default no timeout')
|
||
|
parser.add_argument('-s',
|
||
|
'--ssl',
|
||
|
action="store_true",
|
||
|
required=False,
|
||
|
default=False,
|
||
|
help='Enable SSL encrypted channel to the server')
|
||
|
parser.add_argument('-r',
|
||
|
'--root-certificates',
|
||
|
type=str,
|
||
|
required=False,
|
||
|
default=None,
|
||
|
help='File holding PEM-encoded root certificates, default none')
|
||
|
parser.add_argument('-p',
|
||
|
'--private-key',
|
||
|
type=str,
|
||
|
required=False,
|
||
|
default=None,
|
||
|
help='File holding PEM-encoded private key, default is none')
|
||
|
parser.add_argument('-x',
|
||
|
'--certificate-chain',
|
||
|
type=str,
|
||
|
required=False,
|
||
|
default=None,
|
||
|
help='File holding PEM-encoded certicate chain default is none')
|
||
|
|
||
|
FLAGS = parser.parse_args()
|
||
|
|
||
|
# Create server context
|
||
|
try:
|
||
|
triton_client = grpcclient.InferenceServerClient(
|
||
|
url=FLAGS.url,
|
||
|
verbose=FLAGS.verbose,
|
||
|
ssl=FLAGS.ssl,
|
||
|
root_certificates=FLAGS.root_certificates,
|
||
|
private_key=FLAGS.private_key,
|
||
|
certificate_chain=FLAGS.certificate_chain)
|
||
|
except Exception as e:
|
||
|
print("context creation failed: " + str(e))
|
||
|
sys.exit()
|
||
|
|
||
|
# Health check
|
||
|
if not triton_client.is_server_live():
|
||
|
print("FAILED : is_server_live")
|
||
|
sys.exit(1)
|
||
|
|
||
|
if not triton_client.is_server_ready():
|
||
|
print("FAILED : is_server_ready")
|
||
|
sys.exit(1)
|
||
|
|
||
|
if not triton_client.is_model_ready(FLAGS.model):
|
||
|
print("FAILED : is_model_ready")
|
||
|
sys.exit(1)
|
||
|
|
||
|
if FLAGS.model_info:
|
||
|
# Model metadata
|
||
|
try:
|
||
|
metadata = triton_client.get_model_metadata(FLAGS.model)
|
||
|
print(metadata)
|
||
|
except InferenceServerException as ex:
|
||
|
if "Request for unknown model" not in ex.message():
|
||
|
print("FAILED : get_model_metadata")
|
||
|
print("Got: {}".format(ex.message()))
|
||
|
sys.exit(1)
|
||
|
else:
|
||
|
print("FAILED : get_model_metadata")
|
||
|
sys.exit(1)
|
||
|
|
||
|
# Model configuration
|
||
|
try:
|
||
|
config = triton_client.get_model_config(FLAGS.model)
|
||
|
if not (config.config.name == FLAGS.model):
|
||
|
print("FAILED: get_model_config")
|
||
|
sys.exit(1)
|
||
|
print(config)
|
||
|
except InferenceServerException as ex:
|
||
|
print("FAILED : get_model_config")
|
||
|
print("Got: {}".format(ex.message()))
|
||
|
sys.exit(1)
|
||
|
|
||
|
# DUMMY MODE
|
||
|
if FLAGS.mode == 'dummy':
|
||
|
print("Running in 'dummy' mode")
|
||
|
print("Creating emtpy buffer filled with ones...")
|
||
|
inputs = []
|
||
|
outputs = []
|
||
|
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
|
||
|
inputs[0].set_data_from_numpy(np.ones(shape=(1, 3, FLAGS.width, FLAGS.height), dtype=np.float32))
|
||
|
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
|
||
|
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
|
||
|
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
|
||
|
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
|
||
|
|
||
|
print("Invoking inference...")
|
||
|
results = triton_client.infer(model_name=FLAGS.model,
|
||
|
inputs=inputs,
|
||
|
outputs=outputs,
|
||
|
client_timeout=FLAGS.client_timeout)
|
||
|
if FLAGS.model_info:
|
||
|
statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
|
||
|
if len(statistics.model_stats) != 1:
|
||
|
print("FAILED: get_inference_statistics")
|
||
|
sys.exit(1)
|
||
|
print(statistics)
|
||
|
print("Done")
|
||
|
|
||
|
for output in OUTPUT_NAMES:
|
||
|
result = results.as_numpy(output)
|
||
|
print(f"Received result buffer \"{output}\" of size {result.shape}")
|
||
|
print(f"Naive buffer sum: {np.sum(result)}")
|
||
|
|
||
|
# IMAGE MODE
|
||
|
if FLAGS.mode == 'image':
|
||
|
print("Running in 'image' mode")
|
||
|
if not FLAGS.input:
|
||
|
print("FAILED: no input image")
|
||
|
sys.exit(1)
|
||
|
|
||
|
inputs = []
|
||
|
outputs = []
|
||
|
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
|
||
|
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
|
||
|
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
|
||
|
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
|
||
|
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
|
||
|
|
||
|
print("Creating buffer from image file...")
|
||
|
input_image = cv2.imread(str(FLAGS.input))
|
||
|
if input_image is None:
|
||
|
print(f"FAILED: could not load input image {str(FLAGS.input)}")
|
||
|
sys.exit(1)
|
||
|
input_image_buffer = preprocess(input_image, [FLAGS.width, FLAGS.height])
|
||
|
input_image_buffer = np.expand_dims(input_image_buffer, axis=0)
|
||
|
|
||
|
inputs[0].set_data_from_numpy(input_image_buffer)
|
||
|
|
||
|
print("Invoking inference...")
|
||
|
results = triton_client.infer(model_name=FLAGS.model,
|
||
|
inputs=inputs,
|
||
|
outputs=outputs,
|
||
|
client_timeout=FLAGS.client_timeout)
|
||
|
if FLAGS.model_info:
|
||
|
statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
|
||
|
if len(statistics.model_stats) != 1:
|
||
|
print("FAILED: get_inference_statistics")
|
||
|
sys.exit(1)
|
||
|
print(statistics)
|
||
|
print("Done")
|
||
|
|
||
|
for output in OUTPUT_NAMES:
|
||
|
result = results.as_numpy(output)
|
||
|
print(f"Received result buffer \"{output}\" of size {result.shape}")
|
||
|
print(f"Naive buffer sum: {np.sum(result)}")
|
||
|
|
||
|
num_dets = results.as_numpy(OUTPUT_NAMES[0])
|
||
|
det_boxes = results.as_numpy(OUTPUT_NAMES[1])
|
||
|
det_scores = results.as_numpy(OUTPUT_NAMES[2])
|
||
|
det_classes = results.as_numpy(OUTPUT_NAMES[3])
|
||
|
detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, input_image.shape[1], input_image.shape[0], [FLAGS.width, FLAGS.height])
|
||
|
print(f"Detected objects: {len(detected_objects)}")
|
||
|
|
||
|
for box in detected_objects:
|
||
|
print(f"{COCOLabels(box.classID).name}: {box.confidence}")
|
||
|
input_image = render_box(input_image, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist()))
|
||
|
size = get_text_size(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6)
|
||
|
input_image = render_filled_box(input_image, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220))
|
||
|
input_image = render_text(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5)
|
||
|
|
||
|
if FLAGS.out:
|
||
|
cv2.imwrite(FLAGS.out, input_image)
|
||
|
print(f"Saved result to {FLAGS.out}")
|
||
|
else:
|
||
|
cv2.imshow('image', input_image)
|
||
|
cv2.waitKey(0)
|
||
|
cv2.destroyAllWindows()
|
||
|
|
||
|
# VIDEO MODE
|
||
|
if FLAGS.mode == 'video':
|
||
|
print("Running in 'video' mode")
|
||
|
if not FLAGS.input:
|
||
|
print("FAILED: no input video")
|
||
|
sys.exit(1)
|
||
|
|
||
|
inputs = []
|
||
|
outputs = []
|
||
|
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
|
||
|
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
|
||
|
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
|
||
|
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
|
||
|
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
|
||
|
|
||
|
print("Opening input video stream...")
|
||
|
cap = cv2.VideoCapture(FLAGS.input)
|
||
|
if not cap.isOpened():
|
||
|
print(f"FAILED: cannot open video {FLAGS.input}")
|
||
|
sys.exit(1)
|
||
|
|
||
|
counter = 0
|
||
|
out = None
|
||
|
print("Invoking inference...")
|
||
|
while True:
|
||
|
ret, frame = cap.read()
|
||
|
if not ret:
|
||
|
print("failed to fetch next frame")
|
||
|
break
|
||
|
|
||
|
if counter == 0 and FLAGS.out:
|
||
|
print("Opening output video stream...")
|
||
|
fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
|
||
|
out = cv2.VideoWriter(FLAGS.out, fourcc, FLAGS.fps, (frame.shape[1], frame.shape[0]))
|
||
|
|
||
|
input_image_buffer = preprocess(frame, [FLAGS.width, FLAGS.height])
|
||
|
input_image_buffer = np.expand_dims(input_image_buffer, axis=0)
|
||
|
|
||
|
inputs[0].set_data_from_numpy(input_image_buffer)
|
||
|
|
||
|
results = triton_client.infer(model_name=FLAGS.model,
|
||
|
inputs=inputs,
|
||
|
outputs=outputs,
|
||
|
client_timeout=FLAGS.client_timeout)
|
||
|
|
||
|
num_dets = results.as_numpy("num_dets")
|
||
|
det_boxes = results.as_numpy("det_boxes")
|
||
|
det_scores = results.as_numpy("det_scores")
|
||
|
det_classes = results.as_numpy("det_classes")
|
||
|
detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, frame.shape[1], frame.shape[0], [FLAGS.width, FLAGS.height])
|
||
|
print(f"Frame {counter}: {len(detected_objects)} objects")
|
||
|
counter += 1
|
||
|
|
||
|
for box in detected_objects:
|
||
|
print(f"{COCOLabels(box.classID).name}: {box.confidence}")
|
||
|
frame = render_box(frame, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist()))
|
||
|
size = get_text_size(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6)
|
||
|
frame = render_filled_box(frame, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220))
|
||
|
frame = render_text(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5)
|
||
|
|
||
|
if FLAGS.out:
|
||
|
out.write(frame)
|
||
|
else:
|
||
|
cv2.imshow('image', frame)
|
||
|
if cv2.waitKey(1) == ord('q'):
|
||
|
break
|
||
|
|
||
|
if FLAGS.model_info:
|
||
|
statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
|
||
|
if len(statistics.model_stats) != 1:
|
||
|
print("FAILED: get_inference_statistics")
|
||
|
sys.exit(1)
|
||
|
print(statistics)
|
||
|
print("Done")
|
||
|
|
||
|
cap.release()
|
||
|
if FLAGS.out:
|
||
|
out.release()
|
||
|
else:
|
||
|
cv2.destroyAllWindows()
|