diff --git a/deploy/triton-inference-server/README.md b/deploy/triton-inference-server/README.md new file mode 100644 index 0000000..22b01e3 --- /dev/null +++ b/deploy/triton-inference-server/README.md @@ -0,0 +1,161 @@ +# YOLOv7 on Triton Inference Server + +Instructions to deploy YOLOv7 as TensorRT engine to [Triton Inference Server](https://github.com/NVIDIA/triton-inference-server). + +Triton Inference Server takes care of model deployment with many out-of-the-box benefits, like a GRPC and HTTP interface, automatic scheduling on multiple GPUs, shared memory (even on GPU), dynamic server-side batching, health metrics and memory resource management. + +There are no additional dependencies needed to run this deployment, except a working docker daemon with GPU support. + +## Export TensorRT + +See https://github.com/WongKinYiu/yolov7#export for more info. + +```bash +# Pytorch Yolov7 -> ONNX with grid, EfficientNMS plugin and dynamic batch size +python export.py --weights ./yolov7.pt --grid --end2end --dynamic-batch --simplify --topk-all 100 --iou-thres 0.65 --conf-thres 0.35 --img-size 640 640 +# ONNX -> TensorRT with trtexec and docker +docker run -it --rm --gpus=all nvcr.io/nvidia/tensorrt:22.06-py3 +# Copy onnx -> container: docker cp yolov7.onnx <container-id>:/workspace/ +# Export with FP16 precision, min batch 1, opt batch 8 and max batch 8 +./tensorrt/bin/trtexec --onnx=yolov7.onnx --minShapes=images:1x3x640x640 --optShapes=images:8x3x640x640 --maxShapes=images:8x3x640x640 --fp16 --workspace=4096 --saveEngine=yolov7-fp16-1x8x8.engine --timingCacheFile=timing.cache +# Test engine +./tensorrt/bin/trtexec --loadEngine=yolov7-fp16-1x8x8.engine +# Copy engine -> host: docker cp <container-id>:/workspace/yolov7-fp16-1x8x8.engine . +``` + +Example output of test with RTX 3090. + +``` +[I] === Performance summary === +[I] Throughput: 73.4985 qps +[I] Latency: min = 14.8578 ms, max = 15.8344 ms, mean = 15.07 ms, median = 15.0422 ms, percentile(99%) = 15.7443 ms +[I] End-to-End Host Latency: min = 25.8715 ms, max = 28.4102 ms, mean = 26.672 ms, median = 26.6082 ms, percentile(99%) = 27.8314 ms +[I] Enqueue Time: min = 0.793701 ms, max = 1.47144 ms, mean = 1.2008 ms, median = 1.28644 ms, percentile(99%) = 1.38965 ms +[I] H2D Latency: min = 1.50073 ms, max = 1.52454 ms, mean = 1.51225 ms, median = 1.51404 ms, percentile(99%) = 1.51941 ms +[I] GPU Compute Time: min = 13.3386 ms, max = 14.3186 ms, mean = 13.5448 ms, median = 13.5178 ms, percentile(99%) = 14.2151 ms +[I] D2H Latency: min = 0.00878906 ms, max = 0.0172729 ms, mean = 0.0128844 ms, median = 0.0125732 ms, percentile(99%) = 0.0166016 ms +[I] Total Host Walltime: 3.04768 s +[I] Total GPU Compute Time: 3.03404 s +[I] Explanations of the performance metrics are printed in the verbose logs. +``` +Note: 73.5 qps x batch 8 = 588 fps @ ~15ms latency. + +## Model Repository + +See [Triton Model Repository Documentation](https://github.com/triton-inference-server/server/blob/main/docs/model_repository.md#model-repository) for more info. + +```bash +# Create folder structure +mkdir -p triton-deploy/models/yolov7/1/ +touch triton-deploy/models/yolov7/config.pbtxt +# Place model +mv yolov7-fp16-1x8x8.engine triton-deploy/models/yolov7/1/model.plan +``` + +## Model Configuration + +See [Triton Model Configuration Documentation](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md#model-configuration) for more info. + +Minimal configuration for `triton-deploy/models/yolov7/config.pbtxt`: + +``` +name: "yolov7" +platform: "tensorrt_plan" +max_batch_size: 8 +dynamic_batching { } +``` + +Example repository: + +```bash +$ tree triton-deploy/ +triton-deploy/ +└── models + └── yolov7 + ├── 1 + │ └── model.plan + └── config.pbtxt + +3 directories, 2 files +``` + +## Start Triton Inference Server + +``` +docker run --gpus all --rm --ipc=host --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd)/triton-deploy/models:/models nvcr.io/nvidia/tritonserver:22.06-py3 tritonserver --model-repository=/models --strict-model-config=false --log-verbose 1 +``` + +In the log you should see: + +``` ++--------+---------+--------+ +| Model | Version | Status | ++--------+---------+--------+ +| yolov7 | 1 | READY | ++--------+---------+--------+ +``` + +## Performance with Model Analyzer + +See [Triton Model Analyzer Documentation](https://github.com/triton-inference-server/server/blob/main/docs/model_analyzer.md#model-analyzer) for more info. + +Performance numbers @ RTX 3090 + AMD Ryzen 9 5950X + +Example test for 16 concurrent clients using shared memory, each with batch size 1 requests: + +```bash +docker run -it --ipc=host --net=host nvcr.io/nvidia/tritonserver:22.06-py3-sdk /bin/bash + +./install/bin/perf_analyzer -m yolov7 -u 127.0.0.1:8001 -i grpc --shared-memory system --concurrency-range 16 + +# Result (truncated) +Concurrency: 16, throughput: 590.119 infer/sec, latency 27080 usec +``` + +Throughput for 16 clients with batch size 1 is the same as for a single thread running the engine at 16 batch size locally thanks to Triton [Dynamic Batching Strategy](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md#dynamic-batcher). Result without dynamic batching (disable in model configuration) considerably worse: + +```bash +# Result (truncated) +Concurrency: 16, throughput: 335.587 infer/sec, latency 47616 usec +``` + +## How to run model in your code + +Example client can be found in client.py. It can run dummy input, images and videos. + +```bash +pip3 install tritonclient[all] opencv-python +python3 client.py image data/dog.jpg +``` + + + +``` +$ python3 client.py --help +usage: client.py [-h] [-m MODEL] [--width WIDTH] [--height HEIGHT] [-u URL] [-o OUT] [-f FPS] [-i] [-v] [-t CLIENT_TIMEOUT] [-s] [-r ROOT_CERTIFICATES] [-p PRIVATE_KEY] [-x CERTIFICATE_CHAIN] {dummy,image,video} [input] + +positional arguments: + {dummy,image,video} Run mode. 'dummy' will send an emtpy buffer to the server to test if inference works. 'image' will process an image. 'video' will process a video. + input Input file to load from in image or video mode + +optional arguments: + -h, --help show this help message and exit + -m MODEL, --model MODEL + Inference model name, default yolov7 + --width WIDTH Inference model input width, default 640 + --height HEIGHT Inference model input height, default 640 + -u URL, --url URL Inference server URL, default localhost:8001 + -o OUT, --out OUT Write output into file instead of displaying it + -f FPS, --fps FPS Video output fps, default 24.0 FPS + -i, --model-info Print model status, configuration and statistics + -v, --verbose Enable verbose client output + -t CLIENT_TIMEOUT, --client-timeout CLIENT_TIMEOUT + Client timeout in seconds, default no timeout + -s, --ssl Enable SSL encrypted channel to the server + -r ROOT_CERTIFICATES, --root-certificates ROOT_CERTIFICATES + File holding PEM-encoded root certificates, default none + -p PRIVATE_KEY, --private-key PRIVATE_KEY + File holding PEM-encoded private key, default is none + -x CERTIFICATE_CHAIN, --certificate-chain CERTIFICATE_CHAIN + File holding PEM-encoded certicate chain default is none +``` diff --git a/deploy/triton-inference-server/boundingbox.py b/deploy/triton-inference-server/boundingbox.py new file mode 100644 index 0000000..8b95330 --- /dev/null +++ b/deploy/triton-inference-server/boundingbox.py @@ -0,0 +1,33 @@ +class BoundingBox: + def __init__(self, classID, confidence, x1, x2, y1, y2, image_width, image_height): + self.classID = classID + self.confidence = confidence + self.x1 = x1 + self.x2 = x2 + self.y1 = y1 + self.y2 = y2 + self.u1 = x1 / image_width + self.u2 = x2 / image_width + self.v1 = y1 / image_height + self.v2 = y2 / image_height + + def box(self): + return (self.x1, self.y1, self.x2, self.y2) + + def width(self): + return self.x2 - self.x1 + + def height(self): + return self.y2 - self.y1 + + def center_absolute(self): + return (0.5 * (self.x1 + self.x2), 0.5 * (self.y1 + self.y2)) + + def center_normalized(self): + return (0.5 * (self.u1 + self.u2), 0.5 * (self.v1 + self.v2)) + + def size_absolute(self): + return (self.x2 - self.x1, self.y2 - self.y1) + + def size_normalized(self): + return (self.u2 - self.u1, self.v2 - self.v1) diff --git a/deploy/triton-inference-server/client.py b/deploy/triton-inference-server/client.py new file mode 100644 index 0000000..aedca11 --- /dev/null +++ b/deploy/triton-inference-server/client.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python + +import argparse +import numpy as np +import sys +import cv2 + +import tritonclient.grpc as grpcclient +from tritonclient.utils import InferenceServerException + +from processing import preprocess, postprocess +from render import render_box, render_filled_box, get_text_size, render_text, RAND_COLORS +from labels import COCOLabels + +INPUT_NAMES = ["images"] +OUTPUT_NAMES = ["num_dets", "det_boxes", "det_scores", "det_classes"] + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('mode', + choices=['dummy', 'image', 'video'], + default='dummy', + help='Run mode. \'dummy\' will send an emtpy buffer to the server to test if inference works. \'image\' will process an image. \'video\' will process a video.') + parser.add_argument('input', + type=str, + nargs='?', + help='Input file to load from in image or video mode') + parser.add_argument('-m', + '--model', + type=str, + required=False, + default='yolov7', + help='Inference model name, default yolov7') + parser.add_argument('--width', + type=int, + required=False, + default=640, + help='Inference model input width, default 640') + parser.add_argument('--height', + type=int, + required=False, + default=640, + help='Inference model input height, default 640') + parser.add_argument('-u', + '--url', + type=str, + required=False, + default='localhost:8001', + help='Inference server URL, default localhost:8001') + parser.add_argument('-o', + '--out', + type=str, + required=False, + default='', + help='Write output into file instead of displaying it') + parser.add_argument('-f', + '--fps', + type=float, + required=False, + default=24.0, + help='Video output fps, default 24.0 FPS') + parser.add_argument('-i', + '--model-info', + action="store_true", + required=False, + default=False, + help='Print model status, configuration and statistics') + parser.add_argument('-v', + '--verbose', + action="store_true", + required=False, + default=False, + help='Enable verbose client output') + parser.add_argument('-t', + '--client-timeout', + type=float, + required=False, + default=None, + help='Client timeout in seconds, default no timeout') + parser.add_argument('-s', + '--ssl', + action="store_true", + required=False, + default=False, + help='Enable SSL encrypted channel to the server') + parser.add_argument('-r', + '--root-certificates', + type=str, + required=False, + default=None, + help='File holding PEM-encoded root certificates, default none') + parser.add_argument('-p', + '--private-key', + type=str, + required=False, + default=None, + help='File holding PEM-encoded private key, default is none') + parser.add_argument('-x', + '--certificate-chain', + type=str, + required=False, + default=None, + help='File holding PEM-encoded certicate chain default is none') + + FLAGS = parser.parse_args() + + # Create server context + try: + triton_client = grpcclient.InferenceServerClient( + url=FLAGS.url, + verbose=FLAGS.verbose, + ssl=FLAGS.ssl, + root_certificates=FLAGS.root_certificates, + private_key=FLAGS.private_key, + certificate_chain=FLAGS.certificate_chain) + except Exception as e: + print("context creation failed: " + str(e)) + sys.exit() + + # Health check + if not triton_client.is_server_live(): + print("FAILED : is_server_live") + sys.exit(1) + + if not triton_client.is_server_ready(): + print("FAILED : is_server_ready") + sys.exit(1) + + if not triton_client.is_model_ready(FLAGS.model): + print("FAILED : is_model_ready") + sys.exit(1) + + if FLAGS.model_info: + # Model metadata + try: + metadata = triton_client.get_model_metadata(FLAGS.model) + print(metadata) + except InferenceServerException as ex: + if "Request for unknown model" not in ex.message(): + print("FAILED : get_model_metadata") + print("Got: {}".format(ex.message())) + sys.exit(1) + else: + print("FAILED : get_model_metadata") + sys.exit(1) + + # Model configuration + try: + config = triton_client.get_model_config(FLAGS.model) + if not (config.config.name == FLAGS.model): + print("FAILED: get_model_config") + sys.exit(1) + print(config) + except InferenceServerException as ex: + print("FAILED : get_model_config") + print("Got: {}".format(ex.message())) + sys.exit(1) + + # DUMMY MODE + if FLAGS.mode == 'dummy': + print("Running in 'dummy' mode") + print("Creating emtpy buffer filled with ones...") + inputs = [] + outputs = [] + inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32")) + inputs[0].set_data_from_numpy(np.ones(shape=(1, 3, FLAGS.width, FLAGS.height), dtype=np.float32)) + outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0])) + outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1])) + outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2])) + outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3])) + + print("Invoking inference...") + results = triton_client.infer(model_name=FLAGS.model, + inputs=inputs, + outputs=outputs, + client_timeout=FLAGS.client_timeout) + if FLAGS.model_info: + statistics = triton_client.get_inference_statistics(model_name=FLAGS.model) + if len(statistics.model_stats) != 1: + print("FAILED: get_inference_statistics") + sys.exit(1) + print(statistics) + print("Done") + + for output in OUTPUT_NAMES: + result = results.as_numpy(output) + print(f"Received result buffer \"{output}\" of size {result.shape}") + print(f"Naive buffer sum: {np.sum(result)}") + + # IMAGE MODE + if FLAGS.mode == 'image': + print("Running in 'image' mode") + if not FLAGS.input: + print("FAILED: no input image") + sys.exit(1) + + inputs = [] + outputs = [] + inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32")) + outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0])) + outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1])) + outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2])) + outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3])) + + print("Creating buffer from image file...") + input_image = cv2.imread(str(FLAGS.input)) + if input_image is None: + print(f"FAILED: could not load input image {str(FLAGS.input)}") + sys.exit(1) + input_image_buffer = preprocess(input_image, [FLAGS.width, FLAGS.height]) + input_image_buffer = np.expand_dims(input_image_buffer, axis=0) + + inputs[0].set_data_from_numpy(input_image_buffer) + + print("Invoking inference...") + results = triton_client.infer(model_name=FLAGS.model, + inputs=inputs, + outputs=outputs, + client_timeout=FLAGS.client_timeout) + if FLAGS.model_info: + statistics = triton_client.get_inference_statistics(model_name=FLAGS.model) + if len(statistics.model_stats) != 1: + print("FAILED: get_inference_statistics") + sys.exit(1) + print(statistics) + print("Done") + + for output in OUTPUT_NAMES: + result = results.as_numpy(output) + print(f"Received result buffer \"{output}\" of size {result.shape}") + print(f"Naive buffer sum: {np.sum(result)}") + + num_dets = results.as_numpy(OUTPUT_NAMES[0]) + det_boxes = results.as_numpy(OUTPUT_NAMES[1]) + det_scores = results.as_numpy(OUTPUT_NAMES[2]) + det_classes = results.as_numpy(OUTPUT_NAMES[3]) + detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, input_image.shape[1], input_image.shape[0], [FLAGS.width, FLAGS.height]) + print(f"Detected objects: {len(detected_objects)}") + + for box in detected_objects: + print(f"{COCOLabels(box.classID).name}: {box.confidence}") + input_image = render_box(input_image, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist())) + size = get_text_size(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6) + input_image = render_filled_box(input_image, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220)) + input_image = render_text(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5) + + if FLAGS.out: + cv2.imwrite(FLAGS.out, input_image) + print(f"Saved result to {FLAGS.out}") + else: + cv2.imshow('image', input_image) + cv2.waitKey(0) + cv2.destroyAllWindows() + + # VIDEO MODE + if FLAGS.mode == 'video': + print("Running in 'video' mode") + if not FLAGS.input: + print("FAILED: no input video") + sys.exit(1) + + inputs = [] + outputs = [] + inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32")) + outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0])) + outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1])) + outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2])) + outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3])) + + print("Opening input video stream...") + cap = cv2.VideoCapture(FLAGS.input) + if not cap.isOpened(): + print(f"FAILED: cannot open video {FLAGS.input}") + sys.exit(1) + + counter = 0 + out = None + print("Invoking inference...") + while True: + ret, frame = cap.read() + if not ret: + print("failed to fetch next frame") + break + + if counter == 0 and FLAGS.out: + print("Opening output video stream...") + fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V') + out = cv2.VideoWriter(FLAGS.out, fourcc, FLAGS.fps, (frame.shape[1], frame.shape[0])) + + input_image_buffer = preprocess(frame, [FLAGS.width, FLAGS.height]) + input_image_buffer = np.expand_dims(input_image_buffer, axis=0) + + inputs[0].set_data_from_numpy(input_image_buffer) + + results = triton_client.infer(model_name=FLAGS.model, + inputs=inputs, + outputs=outputs, + client_timeout=FLAGS.client_timeout) + + num_dets = results.as_numpy("num_dets") + det_boxes = results.as_numpy("det_boxes") + det_scores = results.as_numpy("det_scores") + det_classes = results.as_numpy("det_classes") + detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, frame.shape[1], frame.shape[0], [FLAGS.width, FLAGS.height]) + print(f"Frame {counter}: {len(detected_objects)} objects") + counter += 1 + + for box in detected_objects: + print(f"{COCOLabels(box.classID).name}: {box.confidence}") + frame = render_box(frame, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist())) + size = get_text_size(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6) + frame = render_filled_box(frame, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220)) + frame = render_text(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5) + + if FLAGS.out: + out.write(frame) + else: + cv2.imshow('image', frame) + if cv2.waitKey(1) == ord('q'): + break + + if FLAGS.model_info: + statistics = triton_client.get_inference_statistics(model_name=FLAGS.model) + if len(statistics.model_stats) != 1: + print("FAILED: get_inference_statistics") + sys.exit(1) + print(statistics) + print("Done") + + cap.release() + if FLAGS.out: + out.release() + else: + cv2.destroyAllWindows() diff --git a/deploy/triton-inference-server/data/dog.jpg b/deploy/triton-inference-server/data/dog.jpg new file mode 100644 index 0000000..77b0381 Binary files /dev/null and b/deploy/triton-inference-server/data/dog.jpg differ diff --git a/deploy/triton-inference-server/data/dog_result.jpg b/deploy/triton-inference-server/data/dog_result.jpg new file mode 100644 index 0000000..6f380ef Binary files /dev/null and b/deploy/triton-inference-server/data/dog_result.jpg differ diff --git a/deploy/triton-inference-server/labels.py b/deploy/triton-inference-server/labels.py new file mode 100644 index 0000000..ba6c5c5 --- /dev/null +++ b/deploy/triton-inference-server/labels.py @@ -0,0 +1,83 @@ +from enum import Enum + +class COCOLabels(Enum): + PERSON = 0 + BICYCLE = 1 + CAR = 2 + MOTORBIKE = 3 + AEROPLANE = 4 + BUS = 5 + TRAIN = 6 + TRUCK = 7 + BOAT = 8 + TRAFFIC_LIGHT = 9 + FIRE_HYDRANT = 10 + STOP_SIGN = 11 + PARKING_METER = 12 + BENCH = 13 + BIRD = 14 + CAT = 15 + DOG = 16 + HORSE = 17 + SHEEP = 18 + COW = 19 + ELEPHANT = 20 + BEAR = 21 + ZEBRA = 22 + GIRAFFE = 23 + BACKPACK = 24 + UMBRELLA = 25 + HANDBAG = 26 + TIE = 27 + SUITCASE = 28 + FRISBEE = 29 + SKIS = 30 + SNOWBOARD = 31 + SPORTS_BALL = 32 + KITE = 33 + BASEBALL_BAT = 34 + BASEBALL_GLOVE = 35 + SKATEBOARD = 36 + SURFBOARD = 37 + TENNIS_RACKET = 38 + BOTTLE = 39 + WINE_GLASS = 40 + CUP = 41 + FORK = 42 + KNIFE = 43 + SPOON = 44 + BOWL = 45 + BANANA = 46 + APPLE = 47 + SANDWICH = 48 + ORANGE = 49 + BROCCOLI = 50 + CARROT = 51 + HOT_DOG = 52 + PIZZA = 53 + DONUT = 54 + CAKE = 55 + CHAIR = 56 + SOFA = 57 + POTTEDPLANT = 58 + BED = 59 + DININGTABLE = 60 + TOILET = 61 + TVMONITOR = 62 + LAPTOP = 63 + MOUSE = 64 + REMOTE = 65 + KEYBOARD = 66 + CELL_PHONE = 67 + MICROWAVE = 68 + OVEN = 69 + TOASTER = 70 + SINK = 71 + REFRIGERATOR = 72 + BOOK = 73 + CLOCK = 74 + VASE = 75 + SCISSORS = 76 + TEDDY_BEAR = 77 + HAIR_DRIER = 78 + TOOTHBRUSH = 79 diff --git a/deploy/triton-inference-server/processing.py b/deploy/triton-inference-server/processing.py new file mode 100644 index 0000000..3d51c50 --- /dev/null +++ b/deploy/triton-inference-server/processing.py @@ -0,0 +1,51 @@ +from boundingbox import BoundingBox + +import cv2 +import numpy as np + +def preprocess(img, input_shape, letter_box=True): + if letter_box: + img_h, img_w, _ = img.shape + new_h, new_w = input_shape[0], input_shape[1] + offset_h, offset_w = 0, 0 + if (new_w / img_w) <= (new_h / img_h): + new_h = int(img_h * new_w / img_w) + offset_h = (input_shape[0] - new_h) // 2 + else: + new_w = int(img_w * new_h / img_h) + offset_w = (input_shape[1] - new_w) // 2 + resized = cv2.resize(img, (new_w, new_h)) + img = np.full((input_shape[0], input_shape[1], 3), 127, dtype=np.uint8) + img[offset_h:(offset_h + new_h), offset_w:(offset_w + new_w), :] = resized + else: + img = cv2.resize(img, (input_shape[1], input_shape[0])) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = img.transpose((2, 0, 1)).astype(np.float32) + img /= 255.0 + return img + +def postprocess(num_dets, det_boxes, det_scores, det_classes, img_w, img_h, input_shape, letter_box=True): + boxes = det_boxes[0, :num_dets[0][0]] / np.array([input_shape[0], input_shape[1], input_shape[0], input_shape[1]], dtype=np.float32) + scores = det_scores[0, :num_dets[0][0]] + classes = det_classes[0, :num_dets[0][0]].astype(np.int) + + old_h, old_w = img_h, img_w + offset_h, offset_w = 0, 0 + if letter_box: + if (img_w / input_shape[1]) >= (img_h / input_shape[0]): + old_h = int(input_shape[0] * img_w / input_shape[1]) + offset_h = (old_h - img_h) // 2 + else: + old_w = int(input_shape[1] * img_h / input_shape[0]) + offset_w = (old_w - img_w) // 2 + + boxes = boxes * np.array([old_w, old_h, old_w, old_h], dtype=np.float32) + if letter_box: + boxes -= np.array([offset_w, offset_h, offset_w, offset_h], dtype=np.float32) + boxes = boxes.astype(np.int) + + detected_objects = [] + for box, score, label in zip(boxes, scores, classes): + detected_objects.append(BoundingBox(label, score, box[0], box[2], box[1], box[3], img_w, img_h)) + return detected_objects diff --git a/deploy/triton-inference-server/render.py b/deploy/triton-inference-server/render.py new file mode 100644 index 0000000..dea0401 --- /dev/null +++ b/deploy/triton-inference-server/render.py @@ -0,0 +1,110 @@ +import numpy as np + +import cv2 + +from math import sqrt + +_LINE_THICKNESS_SCALING = 500.0 + +np.random.seed(0) +RAND_COLORS = np.random.randint(50, 255, (64, 3), "int") # used for class visu +RAND_COLORS[0] = [220, 220, 220] + +def render_box(img, box, color=(200, 200, 200)): + """ + Render a box. Calculates scaling and thickness automatically. + :param img: image to render into + :param box: (x1, y1, x2, y2) - box coordinates + :param color: (b, g, r) - box color + :return: updated image + """ + x1, y1, x2, y2 = box + thickness = int( + round( + (img.shape[0] * img.shape[1]) + / (_LINE_THICKNESS_SCALING * _LINE_THICKNESS_SCALING) + ) + ) + thickness = max(1, thickness) + img = cv2.rectangle( + img, + (int(x1), int(y1)), + (int(x2), int(y2)), + color, + thickness=thickness + ) + return img + +def render_filled_box(img, box, color=(200, 200, 200)): + """ + Render a box. Calculates scaling and thickness automatically. + :param img: image to render into + :param box: (x1, y1, x2, y2) - box coordinates + :param color: (b, g, r) - box color + :return: updated image + """ + x1, y1, x2, y2 = box + img = cv2.rectangle( + img, + (int(x1), int(y1)), + (int(x2), int(y2)), + color, + thickness=cv2.FILLED + ) + return img + +_TEXT_THICKNESS_SCALING = 700.0 +_TEXT_SCALING = 520.0 + + +def get_text_size(img, text, normalised_scaling=1.0): + """ + Get calculated text size (as box width and height) + :param img: image reference, used to determine appropriate text scaling + :param text: text to display + :param normalised_scaling: additional normalised scaling. Default 1.0. + :return: (width, height) - width and height of text box + """ + thickness = int( + round( + (img.shape[0] * img.shape[1]) + / (_TEXT_THICKNESS_SCALING * _TEXT_THICKNESS_SCALING) + ) + * normalised_scaling + ) + thickness = max(1, thickness) + scaling = img.shape[0] / _TEXT_SCALING * normalised_scaling + return cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, scaling, thickness)[0] + + +def render_text(img, text, pos, color=(200, 200, 200), normalised_scaling=1.0): + """ + Render a text into the image. Calculates scaling and thickness automatically. + :param img: image to render into + :param text: text to display + :param pos: (x, y) - upper left coordinates of render position + :param color: (b, g, r) - text color + :param normalised_scaling: additional normalised scaling. Default 1.0. + :return: updated image + """ + x, y = pos + thickness = int( + round( + (img.shape[0] * img.shape[1]) + / (_TEXT_THICKNESS_SCALING * _TEXT_THICKNESS_SCALING) + ) + * normalised_scaling + ) + thickness = max(1, thickness) + scaling = img.shape[0] / _TEXT_SCALING * normalised_scaling + size = get_text_size(img, text, normalised_scaling) + cv2.putText( + img, + text, + (int(x), int(y + size[1])), + cv2.FONT_HERSHEY_SIMPLEX, + scaling, + color, + thickness=thickness, + ) + return img