Add Triton Inference Server deployment (#346)

* Add client code

* Add README.md

Co-authored-by: Philipp Schmidt <philipp.schmidt@isarsoft.com>
pull/368/head
philipp-schmidt 2022-07-29 00:01:59 +02:00 committed by GitHub
parent a7c00297d5
commit 8eee99fcc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 772 additions and 0 deletions

View File

@ -0,0 +1,161 @@
# YOLOv7 on Triton Inference Server
Instructions to deploy YOLOv7 as TensorRT engine to [Triton Inference Server](https://github.com/NVIDIA/triton-inference-server).
Triton Inference Server takes care of model deployment with many out-of-the-box benefits, like a GRPC and HTTP interface, automatic scheduling on multiple GPUs, shared memory (even on GPU), dynamic server-side batching, health metrics and memory resource management.
There are no additional dependencies needed to run this deployment, except a working docker daemon with GPU support.
## Export TensorRT
See https://github.com/WongKinYiu/yolov7#export for more info.
```bash
# Pytorch Yolov7 -> ONNX with grid, EfficientNMS plugin and dynamic batch size
python export.py --weights ./yolov7.pt --grid --end2end --dynamic-batch --simplify --topk-all 100 --iou-thres 0.65 --conf-thres 0.35 --img-size 640 640
# ONNX -> TensorRT with trtexec and docker
docker run -it --rm --gpus=all nvcr.io/nvidia/tensorrt:22.06-py3
# Copy onnx -> container: docker cp yolov7.onnx <container-id>:/workspace/
# Export with FP16 precision, min batch 1, opt batch 8 and max batch 8
./tensorrt/bin/trtexec --onnx=yolov7.onnx --minShapes=images:1x3x640x640 --optShapes=images:8x3x640x640 --maxShapes=images:8x3x640x640 --fp16 --workspace=4096 --saveEngine=yolov7-fp16-1x8x8.engine --timingCacheFile=timing.cache
# Test engine
./tensorrt/bin/trtexec --loadEngine=yolov7-fp16-1x8x8.engine
# Copy engine -> host: docker cp <container-id>:/workspace/yolov7-fp16-1x8x8.engine .
```
Example output of test with RTX 3090.
```
[I] === Performance summary ===
[I] Throughput: 73.4985 qps
[I] Latency: min = 14.8578 ms, max = 15.8344 ms, mean = 15.07 ms, median = 15.0422 ms, percentile(99%) = 15.7443 ms
[I] End-to-End Host Latency: min = 25.8715 ms, max = 28.4102 ms, mean = 26.672 ms, median = 26.6082 ms, percentile(99%) = 27.8314 ms
[I] Enqueue Time: min = 0.793701 ms, max = 1.47144 ms, mean = 1.2008 ms, median = 1.28644 ms, percentile(99%) = 1.38965 ms
[I] H2D Latency: min = 1.50073 ms, max = 1.52454 ms, mean = 1.51225 ms, median = 1.51404 ms, percentile(99%) = 1.51941 ms
[I] GPU Compute Time: min = 13.3386 ms, max = 14.3186 ms, mean = 13.5448 ms, median = 13.5178 ms, percentile(99%) = 14.2151 ms
[I] D2H Latency: min = 0.00878906 ms, max = 0.0172729 ms, mean = 0.0128844 ms, median = 0.0125732 ms, percentile(99%) = 0.0166016 ms
[I] Total Host Walltime: 3.04768 s
[I] Total GPU Compute Time: 3.03404 s
[I] Explanations of the performance metrics are printed in the verbose logs.
```
Note: 73.5 qps x batch 8 = 588 fps @ ~15ms latency.
## Model Repository
See [Triton Model Repository Documentation](https://github.com/triton-inference-server/server/blob/main/docs/model_repository.md#model-repository) for more info.
```bash
# Create folder structure
mkdir -p triton-deploy/models/yolov7/1/
touch triton-deploy/models/yolov7/config.pbtxt
# Place model
mv yolov7-fp16-1x8x8.engine triton-deploy/models/yolov7/1/model.plan
```
## Model Configuration
See [Triton Model Configuration Documentation](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md#model-configuration) for more info.
Minimal configuration for `triton-deploy/models/yolov7/config.pbtxt`:
```
name: "yolov7"
platform: "tensorrt_plan"
max_batch_size: 8
dynamic_batching { }
```
Example repository:
```bash
$ tree triton-deploy/
triton-deploy/
└── models
└── yolov7
├── 1
│   └── model.plan
└── config.pbtxt
3 directories, 2 files
```
## Start Triton Inference Server
```
docker run --gpus all --rm --ipc=host --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd)/triton-deploy/models:/models nvcr.io/nvidia/tritonserver:22.06-py3 tritonserver --model-repository=/models --strict-model-config=false --log-verbose 1
```
In the log you should see:
```
+--------+---------+--------+
| Model | Version | Status |
+--------+---------+--------+
| yolov7 | 1 | READY |
+--------+---------+--------+
```
## Performance with Model Analyzer
See [Triton Model Analyzer Documentation](https://github.com/triton-inference-server/server/blob/main/docs/model_analyzer.md#model-analyzer) for more info.
Performance numbers @ RTX 3090 + AMD Ryzen 9 5950X
Example test for 16 concurrent clients using shared memory, each with batch size 1 requests:
```bash
docker run -it --ipc=host --net=host nvcr.io/nvidia/tritonserver:22.06-py3-sdk /bin/bash
./install/bin/perf_analyzer -m yolov7 -u 127.0.0.1:8001 -i grpc --shared-memory system --concurrency-range 16
# Result (truncated)
Concurrency: 16, throughput: 590.119 infer/sec, latency 27080 usec
```
Throughput for 16 clients with batch size 1 is the same as for a single thread running the engine at 16 batch size locally thanks to Triton [Dynamic Batching Strategy](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md#dynamic-batcher). Result without dynamic batching (disable in model configuration) considerably worse:
```bash
# Result (truncated)
Concurrency: 16, throughput: 335.587 infer/sec, latency 47616 usec
```
## How to run model in your code
Example client can be found in client.py. It can run dummy input, images and videos.
```bash
pip3 install tritonclient[all] opencv-python
python3 client.py image data/dog.jpg
```
![exemplary output result](data/dog_result.jpg)
```
$ python3 client.py --help
usage: client.py [-h] [-m MODEL] [--width WIDTH] [--height HEIGHT] [-u URL] [-o OUT] [-f FPS] [-i] [-v] [-t CLIENT_TIMEOUT] [-s] [-r ROOT_CERTIFICATES] [-p PRIVATE_KEY] [-x CERTIFICATE_CHAIN] {dummy,image,video} [input]
positional arguments:
{dummy,image,video} Run mode. 'dummy' will send an emtpy buffer to the server to test if inference works. 'image' will process an image. 'video' will process a video.
input Input file to load from in image or video mode
optional arguments:
-h, --help show this help message and exit
-m MODEL, --model MODEL
Inference model name, default yolov7
--width WIDTH Inference model input width, default 640
--height HEIGHT Inference model input height, default 640
-u URL, --url URL Inference server URL, default localhost:8001
-o OUT, --out OUT Write output into file instead of displaying it
-f FPS, --fps FPS Video output fps, default 24.0 FPS
-i, --model-info Print model status, configuration and statistics
-v, --verbose Enable verbose client output
-t CLIENT_TIMEOUT, --client-timeout CLIENT_TIMEOUT
Client timeout in seconds, default no timeout
-s, --ssl Enable SSL encrypted channel to the server
-r ROOT_CERTIFICATES, --root-certificates ROOT_CERTIFICATES
File holding PEM-encoded root certificates, default none
-p PRIVATE_KEY, --private-key PRIVATE_KEY
File holding PEM-encoded private key, default is none
-x CERTIFICATE_CHAIN, --certificate-chain CERTIFICATE_CHAIN
File holding PEM-encoded certicate chain default is none
```

View File

@ -0,0 +1,33 @@
class BoundingBox:
def __init__(self, classID, confidence, x1, x2, y1, y2, image_width, image_height):
self.classID = classID
self.confidence = confidence
self.x1 = x1
self.x2 = x2
self.y1 = y1
self.y2 = y2
self.u1 = x1 / image_width
self.u2 = x2 / image_width
self.v1 = y1 / image_height
self.v2 = y2 / image_height
def box(self):
return (self.x1, self.y1, self.x2, self.y2)
def width(self):
return self.x2 - self.x1
def height(self):
return self.y2 - self.y1
def center_absolute(self):
return (0.5 * (self.x1 + self.x2), 0.5 * (self.y1 + self.y2))
def center_normalized(self):
return (0.5 * (self.u1 + self.u2), 0.5 * (self.v1 + self.v2))
def size_absolute(self):
return (self.x2 - self.x1, self.y2 - self.y1)
def size_normalized(self):
return (self.u2 - self.u1, self.v2 - self.v1)

View File

@ -0,0 +1,334 @@
#!/usr/bin/env python
import argparse
import numpy as np
import sys
import cv2
import tritonclient.grpc as grpcclient
from tritonclient.utils import InferenceServerException
from processing import preprocess, postprocess
from render import render_box, render_filled_box, get_text_size, render_text, RAND_COLORS
from labels import COCOLabels
INPUT_NAMES = ["images"]
OUTPUT_NAMES = ["num_dets", "det_boxes", "det_scores", "det_classes"]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('mode',
choices=['dummy', 'image', 'video'],
default='dummy',
help='Run mode. \'dummy\' will send an emtpy buffer to the server to test if inference works. \'image\' will process an image. \'video\' will process a video.')
parser.add_argument('input',
type=str,
nargs='?',
help='Input file to load from in image or video mode')
parser.add_argument('-m',
'--model',
type=str,
required=False,
default='yolov7',
help='Inference model name, default yolov7')
parser.add_argument('--width',
type=int,
required=False,
default=640,
help='Inference model input width, default 640')
parser.add_argument('--height',
type=int,
required=False,
default=640,
help='Inference model input height, default 640')
parser.add_argument('-u',
'--url',
type=str,
required=False,
default='localhost:8001',
help='Inference server URL, default localhost:8001')
parser.add_argument('-o',
'--out',
type=str,
required=False,
default='',
help='Write output into file instead of displaying it')
parser.add_argument('-f',
'--fps',
type=float,
required=False,
default=24.0,
help='Video output fps, default 24.0 FPS')
parser.add_argument('-i',
'--model-info',
action="store_true",
required=False,
default=False,
help='Print model status, configuration and statistics')
parser.add_argument('-v',
'--verbose',
action="store_true",
required=False,
default=False,
help='Enable verbose client output')
parser.add_argument('-t',
'--client-timeout',
type=float,
required=False,
default=None,
help='Client timeout in seconds, default no timeout')
parser.add_argument('-s',
'--ssl',
action="store_true",
required=False,
default=False,
help='Enable SSL encrypted channel to the server')
parser.add_argument('-r',
'--root-certificates',
type=str,
required=False,
default=None,
help='File holding PEM-encoded root certificates, default none')
parser.add_argument('-p',
'--private-key',
type=str,
required=False,
default=None,
help='File holding PEM-encoded private key, default is none')
parser.add_argument('-x',
'--certificate-chain',
type=str,
required=False,
default=None,
help='File holding PEM-encoded certicate chain default is none')
FLAGS = parser.parse_args()
# Create server context
try:
triton_client = grpcclient.InferenceServerClient(
url=FLAGS.url,
verbose=FLAGS.verbose,
ssl=FLAGS.ssl,
root_certificates=FLAGS.root_certificates,
private_key=FLAGS.private_key,
certificate_chain=FLAGS.certificate_chain)
except Exception as e:
print("context creation failed: " + str(e))
sys.exit()
# Health check
if not triton_client.is_server_live():
print("FAILED : is_server_live")
sys.exit(1)
if not triton_client.is_server_ready():
print("FAILED : is_server_ready")
sys.exit(1)
if not triton_client.is_model_ready(FLAGS.model):
print("FAILED : is_model_ready")
sys.exit(1)
if FLAGS.model_info:
# Model metadata
try:
metadata = triton_client.get_model_metadata(FLAGS.model)
print(metadata)
except InferenceServerException as ex:
if "Request for unknown model" not in ex.message():
print("FAILED : get_model_metadata")
print("Got: {}".format(ex.message()))
sys.exit(1)
else:
print("FAILED : get_model_metadata")
sys.exit(1)
# Model configuration
try:
config = triton_client.get_model_config(FLAGS.model)
if not (config.config.name == FLAGS.model):
print("FAILED: get_model_config")
sys.exit(1)
print(config)
except InferenceServerException as ex:
print("FAILED : get_model_config")
print("Got: {}".format(ex.message()))
sys.exit(1)
# DUMMY MODE
if FLAGS.mode == 'dummy':
print("Running in 'dummy' mode")
print("Creating emtpy buffer filled with ones...")
inputs = []
outputs = []
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
inputs[0].set_data_from_numpy(np.ones(shape=(1, 3, FLAGS.width, FLAGS.height), dtype=np.float32))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
print("Invoking inference...")
results = triton_client.infer(model_name=FLAGS.model,
inputs=inputs,
outputs=outputs,
client_timeout=FLAGS.client_timeout)
if FLAGS.model_info:
statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
if len(statistics.model_stats) != 1:
print("FAILED: get_inference_statistics")
sys.exit(1)
print(statistics)
print("Done")
for output in OUTPUT_NAMES:
result = results.as_numpy(output)
print(f"Received result buffer \"{output}\" of size {result.shape}")
print(f"Naive buffer sum: {np.sum(result)}")
# IMAGE MODE
if FLAGS.mode == 'image':
print("Running in 'image' mode")
if not FLAGS.input:
print("FAILED: no input image")
sys.exit(1)
inputs = []
outputs = []
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
print("Creating buffer from image file...")
input_image = cv2.imread(str(FLAGS.input))
if input_image is None:
print(f"FAILED: could not load input image {str(FLAGS.input)}")
sys.exit(1)
input_image_buffer = preprocess(input_image, [FLAGS.width, FLAGS.height])
input_image_buffer = np.expand_dims(input_image_buffer, axis=0)
inputs[0].set_data_from_numpy(input_image_buffer)
print("Invoking inference...")
results = triton_client.infer(model_name=FLAGS.model,
inputs=inputs,
outputs=outputs,
client_timeout=FLAGS.client_timeout)
if FLAGS.model_info:
statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
if len(statistics.model_stats) != 1:
print("FAILED: get_inference_statistics")
sys.exit(1)
print(statistics)
print("Done")
for output in OUTPUT_NAMES:
result = results.as_numpy(output)
print(f"Received result buffer \"{output}\" of size {result.shape}")
print(f"Naive buffer sum: {np.sum(result)}")
num_dets = results.as_numpy(OUTPUT_NAMES[0])
det_boxes = results.as_numpy(OUTPUT_NAMES[1])
det_scores = results.as_numpy(OUTPUT_NAMES[2])
det_classes = results.as_numpy(OUTPUT_NAMES[3])
detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, input_image.shape[1], input_image.shape[0], [FLAGS.width, FLAGS.height])
print(f"Detected objects: {len(detected_objects)}")
for box in detected_objects:
print(f"{COCOLabels(box.classID).name}: {box.confidence}")
input_image = render_box(input_image, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist()))
size = get_text_size(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6)
input_image = render_filled_box(input_image, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220))
input_image = render_text(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5)
if FLAGS.out:
cv2.imwrite(FLAGS.out, input_image)
print(f"Saved result to {FLAGS.out}")
else:
cv2.imshow('image', input_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
# VIDEO MODE
if FLAGS.mode == 'video':
print("Running in 'video' mode")
if not FLAGS.input:
print("FAILED: no input video")
sys.exit(1)
inputs = []
outputs = []
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
print("Opening input video stream...")
cap = cv2.VideoCapture(FLAGS.input)
if not cap.isOpened():
print(f"FAILED: cannot open video {FLAGS.input}")
sys.exit(1)
counter = 0
out = None
print("Invoking inference...")
while True:
ret, frame = cap.read()
if not ret:
print("failed to fetch next frame")
break
if counter == 0 and FLAGS.out:
print("Opening output video stream...")
fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
out = cv2.VideoWriter(FLAGS.out, fourcc, FLAGS.fps, (frame.shape[1], frame.shape[0]))
input_image_buffer = preprocess(frame, [FLAGS.width, FLAGS.height])
input_image_buffer = np.expand_dims(input_image_buffer, axis=0)
inputs[0].set_data_from_numpy(input_image_buffer)
results = triton_client.infer(model_name=FLAGS.model,
inputs=inputs,
outputs=outputs,
client_timeout=FLAGS.client_timeout)
num_dets = results.as_numpy("num_dets")
det_boxes = results.as_numpy("det_boxes")
det_scores = results.as_numpy("det_scores")
det_classes = results.as_numpy("det_classes")
detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, frame.shape[1], frame.shape[0], [FLAGS.width, FLAGS.height])
print(f"Frame {counter}: {len(detected_objects)} objects")
counter += 1
for box in detected_objects:
print(f"{COCOLabels(box.classID).name}: {box.confidence}")
frame = render_box(frame, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist()))
size = get_text_size(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6)
frame = render_filled_box(frame, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220))
frame = render_text(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5)
if FLAGS.out:
out.write(frame)
else:
cv2.imshow('image', frame)
if cv2.waitKey(1) == ord('q'):
break
if FLAGS.model_info:
statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
if len(statistics.model_stats) != 1:
print("FAILED: get_inference_statistics")
sys.exit(1)
print(statistics)
print("Done")
cap.release()
if FLAGS.out:
out.release()
else:
cv2.destroyAllWindows()

Binary file not shown.

After

Width:  |  Height:  |  Size: 160 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 180 KiB

View File

@ -0,0 +1,83 @@
from enum import Enum
class COCOLabels(Enum):
PERSON = 0
BICYCLE = 1
CAR = 2
MOTORBIKE = 3
AEROPLANE = 4
BUS = 5
TRAIN = 6
TRUCK = 7
BOAT = 8
TRAFFIC_LIGHT = 9
FIRE_HYDRANT = 10
STOP_SIGN = 11
PARKING_METER = 12
BENCH = 13
BIRD = 14
CAT = 15
DOG = 16
HORSE = 17
SHEEP = 18
COW = 19
ELEPHANT = 20
BEAR = 21
ZEBRA = 22
GIRAFFE = 23
BACKPACK = 24
UMBRELLA = 25
HANDBAG = 26
TIE = 27
SUITCASE = 28
FRISBEE = 29
SKIS = 30
SNOWBOARD = 31
SPORTS_BALL = 32
KITE = 33
BASEBALL_BAT = 34
BASEBALL_GLOVE = 35
SKATEBOARD = 36
SURFBOARD = 37
TENNIS_RACKET = 38
BOTTLE = 39
WINE_GLASS = 40
CUP = 41
FORK = 42
KNIFE = 43
SPOON = 44
BOWL = 45
BANANA = 46
APPLE = 47
SANDWICH = 48
ORANGE = 49
BROCCOLI = 50
CARROT = 51
HOT_DOG = 52
PIZZA = 53
DONUT = 54
CAKE = 55
CHAIR = 56
SOFA = 57
POTTEDPLANT = 58
BED = 59
DININGTABLE = 60
TOILET = 61
TVMONITOR = 62
LAPTOP = 63
MOUSE = 64
REMOTE = 65
KEYBOARD = 66
CELL_PHONE = 67
MICROWAVE = 68
OVEN = 69
TOASTER = 70
SINK = 71
REFRIGERATOR = 72
BOOK = 73
CLOCK = 74
VASE = 75
SCISSORS = 76
TEDDY_BEAR = 77
HAIR_DRIER = 78
TOOTHBRUSH = 79

View File

@ -0,0 +1,51 @@
from boundingbox import BoundingBox
import cv2
import numpy as np
def preprocess(img, input_shape, letter_box=True):
if letter_box:
img_h, img_w, _ = img.shape
new_h, new_w = input_shape[0], input_shape[1]
offset_h, offset_w = 0, 0
if (new_w / img_w) <= (new_h / img_h):
new_h = int(img_h * new_w / img_w)
offset_h = (input_shape[0] - new_h) // 2
else:
new_w = int(img_w * new_h / img_h)
offset_w = (input_shape[1] - new_w) // 2
resized = cv2.resize(img, (new_w, new_h))
img = np.full((input_shape[0], input_shape[1], 3), 127, dtype=np.uint8)
img[offset_h:(offset_h + new_h), offset_w:(offset_w + new_w), :] = resized
else:
img = cv2.resize(img, (input_shape[1], input_shape[0]))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.transpose((2, 0, 1)).astype(np.float32)
img /= 255.0
return img
def postprocess(num_dets, det_boxes, det_scores, det_classes, img_w, img_h, input_shape, letter_box=True):
boxes = det_boxes[0, :num_dets[0][0]] / np.array([input_shape[0], input_shape[1], input_shape[0], input_shape[1]], dtype=np.float32)
scores = det_scores[0, :num_dets[0][0]]
classes = det_classes[0, :num_dets[0][0]].astype(np.int)
old_h, old_w = img_h, img_w
offset_h, offset_w = 0, 0
if letter_box:
if (img_w / input_shape[1]) >= (img_h / input_shape[0]):
old_h = int(input_shape[0] * img_w / input_shape[1])
offset_h = (old_h - img_h) // 2
else:
old_w = int(input_shape[1] * img_h / input_shape[0])
offset_w = (old_w - img_w) // 2
boxes = boxes * np.array([old_w, old_h, old_w, old_h], dtype=np.float32)
if letter_box:
boxes -= np.array([offset_w, offset_h, offset_w, offset_h], dtype=np.float32)
boxes = boxes.astype(np.int)
detected_objects = []
for box, score, label in zip(boxes, scores, classes):
detected_objects.append(BoundingBox(label, score, box[0], box[2], box[1], box[3], img_w, img_h))
return detected_objects

View File

@ -0,0 +1,110 @@
import numpy as np
import cv2
from math import sqrt
_LINE_THICKNESS_SCALING = 500.0
np.random.seed(0)
RAND_COLORS = np.random.randint(50, 255, (64, 3), "int") # used for class visu
RAND_COLORS[0] = [220, 220, 220]
def render_box(img, box, color=(200, 200, 200)):
"""
Render a box. Calculates scaling and thickness automatically.
:param img: image to render into
:param box: (x1, y1, x2, y2) - box coordinates
:param color: (b, g, r) - box color
:return: updated image
"""
x1, y1, x2, y2 = box
thickness = int(
round(
(img.shape[0] * img.shape[1])
/ (_LINE_THICKNESS_SCALING * _LINE_THICKNESS_SCALING)
)
)
thickness = max(1, thickness)
img = cv2.rectangle(
img,
(int(x1), int(y1)),
(int(x2), int(y2)),
color,
thickness=thickness
)
return img
def render_filled_box(img, box, color=(200, 200, 200)):
"""
Render a box. Calculates scaling and thickness automatically.
:param img: image to render into
:param box: (x1, y1, x2, y2) - box coordinates
:param color: (b, g, r) - box color
:return: updated image
"""
x1, y1, x2, y2 = box
img = cv2.rectangle(
img,
(int(x1), int(y1)),
(int(x2), int(y2)),
color,
thickness=cv2.FILLED
)
return img
_TEXT_THICKNESS_SCALING = 700.0
_TEXT_SCALING = 520.0
def get_text_size(img, text, normalised_scaling=1.0):
"""
Get calculated text size (as box width and height)
:param img: image reference, used to determine appropriate text scaling
:param text: text to display
:param normalised_scaling: additional normalised scaling. Default 1.0.
:return: (width, height) - width and height of text box
"""
thickness = int(
round(
(img.shape[0] * img.shape[1])
/ (_TEXT_THICKNESS_SCALING * _TEXT_THICKNESS_SCALING)
)
* normalised_scaling
)
thickness = max(1, thickness)
scaling = img.shape[0] / _TEXT_SCALING * normalised_scaling
return cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, scaling, thickness)[0]
def render_text(img, text, pos, color=(200, 200, 200), normalised_scaling=1.0):
"""
Render a text into the image. Calculates scaling and thickness automatically.
:param img: image to render into
:param text: text to display
:param pos: (x, y) - upper left coordinates of render position
:param color: (b, g, r) - text color
:param normalised_scaling: additional normalised scaling. Default 1.0.
:return: updated image
"""
x, y = pos
thickness = int(
round(
(img.shape[0] * img.shape[1])
/ (_TEXT_THICKNESS_SCALING * _TEXT_THICKNESS_SCALING)
)
* normalised_scaling
)
thickness = max(1, thickness)
scaling = img.shape[0] / _TEXT_SCALING * normalised_scaling
size = get_text_size(img, text, normalised_scaling)
cv2.putText(
img,
text,
(int(x), int(y + size[1])),
cv2.FONT_HERSHEY_SIMPLEX,
scaling,
color,
thickness=thickness,
)
return img