mirror of https://github.com/WongKinYiu/yolov7.git
Add Triton Inference Server deployment (#346)
* Add client code * Add README.md Co-authored-by: Philipp Schmidt <philipp.schmidt@isarsoft.com>pull/368/head
parent
a7c00297d5
commit
8eee99fcc5
|
@ -0,0 +1,161 @@
|
|||
# YOLOv7 on Triton Inference Server
|
||||
|
||||
Instructions to deploy YOLOv7 as TensorRT engine to [Triton Inference Server](https://github.com/NVIDIA/triton-inference-server).
|
||||
|
||||
Triton Inference Server takes care of model deployment with many out-of-the-box benefits, like a GRPC and HTTP interface, automatic scheduling on multiple GPUs, shared memory (even on GPU), dynamic server-side batching, health metrics and memory resource management.
|
||||
|
||||
There are no additional dependencies needed to run this deployment, except a working docker daemon with GPU support.
|
||||
|
||||
## Export TensorRT
|
||||
|
||||
See https://github.com/WongKinYiu/yolov7#export for more info.
|
||||
|
||||
```bash
|
||||
# Pytorch Yolov7 -> ONNX with grid, EfficientNMS plugin and dynamic batch size
|
||||
python export.py --weights ./yolov7.pt --grid --end2end --dynamic-batch --simplify --topk-all 100 --iou-thres 0.65 --conf-thres 0.35 --img-size 640 640
|
||||
# ONNX -> TensorRT with trtexec and docker
|
||||
docker run -it --rm --gpus=all nvcr.io/nvidia/tensorrt:22.06-py3
|
||||
# Copy onnx -> container: docker cp yolov7.onnx <container-id>:/workspace/
|
||||
# Export with FP16 precision, min batch 1, opt batch 8 and max batch 8
|
||||
./tensorrt/bin/trtexec --onnx=yolov7.onnx --minShapes=images:1x3x640x640 --optShapes=images:8x3x640x640 --maxShapes=images:8x3x640x640 --fp16 --workspace=4096 --saveEngine=yolov7-fp16-1x8x8.engine --timingCacheFile=timing.cache
|
||||
# Test engine
|
||||
./tensorrt/bin/trtexec --loadEngine=yolov7-fp16-1x8x8.engine
|
||||
# Copy engine -> host: docker cp <container-id>:/workspace/yolov7-fp16-1x8x8.engine .
|
||||
```
|
||||
|
||||
Example output of test with RTX 3090.
|
||||
|
||||
```
|
||||
[I] === Performance summary ===
|
||||
[I] Throughput: 73.4985 qps
|
||||
[I] Latency: min = 14.8578 ms, max = 15.8344 ms, mean = 15.07 ms, median = 15.0422 ms, percentile(99%) = 15.7443 ms
|
||||
[I] End-to-End Host Latency: min = 25.8715 ms, max = 28.4102 ms, mean = 26.672 ms, median = 26.6082 ms, percentile(99%) = 27.8314 ms
|
||||
[I] Enqueue Time: min = 0.793701 ms, max = 1.47144 ms, mean = 1.2008 ms, median = 1.28644 ms, percentile(99%) = 1.38965 ms
|
||||
[I] H2D Latency: min = 1.50073 ms, max = 1.52454 ms, mean = 1.51225 ms, median = 1.51404 ms, percentile(99%) = 1.51941 ms
|
||||
[I] GPU Compute Time: min = 13.3386 ms, max = 14.3186 ms, mean = 13.5448 ms, median = 13.5178 ms, percentile(99%) = 14.2151 ms
|
||||
[I] D2H Latency: min = 0.00878906 ms, max = 0.0172729 ms, mean = 0.0128844 ms, median = 0.0125732 ms, percentile(99%) = 0.0166016 ms
|
||||
[I] Total Host Walltime: 3.04768 s
|
||||
[I] Total GPU Compute Time: 3.03404 s
|
||||
[I] Explanations of the performance metrics are printed in the verbose logs.
|
||||
```
|
||||
Note: 73.5 qps x batch 8 = 588 fps @ ~15ms latency.
|
||||
|
||||
## Model Repository
|
||||
|
||||
See [Triton Model Repository Documentation](https://github.com/triton-inference-server/server/blob/main/docs/model_repository.md#model-repository) for more info.
|
||||
|
||||
```bash
|
||||
# Create folder structure
|
||||
mkdir -p triton-deploy/models/yolov7/1/
|
||||
touch triton-deploy/models/yolov7/config.pbtxt
|
||||
# Place model
|
||||
mv yolov7-fp16-1x8x8.engine triton-deploy/models/yolov7/1/model.plan
|
||||
```
|
||||
|
||||
## Model Configuration
|
||||
|
||||
See [Triton Model Configuration Documentation](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md#model-configuration) for more info.
|
||||
|
||||
Minimal configuration for `triton-deploy/models/yolov7/config.pbtxt`:
|
||||
|
||||
```
|
||||
name: "yolov7"
|
||||
platform: "tensorrt_plan"
|
||||
max_batch_size: 8
|
||||
dynamic_batching { }
|
||||
```
|
||||
|
||||
Example repository:
|
||||
|
||||
```bash
|
||||
$ tree triton-deploy/
|
||||
triton-deploy/
|
||||
└── models
|
||||
└── yolov7
|
||||
├── 1
|
||||
│ └── model.plan
|
||||
└── config.pbtxt
|
||||
|
||||
3 directories, 2 files
|
||||
```
|
||||
|
||||
## Start Triton Inference Server
|
||||
|
||||
```
|
||||
docker run --gpus all --rm --ipc=host --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd)/triton-deploy/models:/models nvcr.io/nvidia/tritonserver:22.06-py3 tritonserver --model-repository=/models --strict-model-config=false --log-verbose 1
|
||||
```
|
||||
|
||||
In the log you should see:
|
||||
|
||||
```
|
||||
+--------+---------+--------+
|
||||
| Model | Version | Status |
|
||||
+--------+---------+--------+
|
||||
| yolov7 | 1 | READY |
|
||||
+--------+---------+--------+
|
||||
```
|
||||
|
||||
## Performance with Model Analyzer
|
||||
|
||||
See [Triton Model Analyzer Documentation](https://github.com/triton-inference-server/server/blob/main/docs/model_analyzer.md#model-analyzer) for more info.
|
||||
|
||||
Performance numbers @ RTX 3090 + AMD Ryzen 9 5950X
|
||||
|
||||
Example test for 16 concurrent clients using shared memory, each with batch size 1 requests:
|
||||
|
||||
```bash
|
||||
docker run -it --ipc=host --net=host nvcr.io/nvidia/tritonserver:22.06-py3-sdk /bin/bash
|
||||
|
||||
./install/bin/perf_analyzer -m yolov7 -u 127.0.0.1:8001 -i grpc --shared-memory system --concurrency-range 16
|
||||
|
||||
# Result (truncated)
|
||||
Concurrency: 16, throughput: 590.119 infer/sec, latency 27080 usec
|
||||
```
|
||||
|
||||
Throughput for 16 clients with batch size 1 is the same as for a single thread running the engine at 16 batch size locally thanks to Triton [Dynamic Batching Strategy](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md#dynamic-batcher). Result without dynamic batching (disable in model configuration) considerably worse:
|
||||
|
||||
```bash
|
||||
# Result (truncated)
|
||||
Concurrency: 16, throughput: 335.587 infer/sec, latency 47616 usec
|
||||
```
|
||||
|
||||
## How to run model in your code
|
||||
|
||||
Example client can be found in client.py. It can run dummy input, images and videos.
|
||||
|
||||
```bash
|
||||
pip3 install tritonclient[all] opencv-python
|
||||
python3 client.py image data/dog.jpg
|
||||
```
|
||||
|
||||

|
||||
|
||||
```
|
||||
$ python3 client.py --help
|
||||
usage: client.py [-h] [-m MODEL] [--width WIDTH] [--height HEIGHT] [-u URL] [-o OUT] [-f FPS] [-i] [-v] [-t CLIENT_TIMEOUT] [-s] [-r ROOT_CERTIFICATES] [-p PRIVATE_KEY] [-x CERTIFICATE_CHAIN] {dummy,image,video} [input]
|
||||
|
||||
positional arguments:
|
||||
{dummy,image,video} Run mode. 'dummy' will send an emtpy buffer to the server to test if inference works. 'image' will process an image. 'video' will process a video.
|
||||
input Input file to load from in image or video mode
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
-m MODEL, --model MODEL
|
||||
Inference model name, default yolov7
|
||||
--width WIDTH Inference model input width, default 640
|
||||
--height HEIGHT Inference model input height, default 640
|
||||
-u URL, --url URL Inference server URL, default localhost:8001
|
||||
-o OUT, --out OUT Write output into file instead of displaying it
|
||||
-f FPS, --fps FPS Video output fps, default 24.0 FPS
|
||||
-i, --model-info Print model status, configuration and statistics
|
||||
-v, --verbose Enable verbose client output
|
||||
-t CLIENT_TIMEOUT, --client-timeout CLIENT_TIMEOUT
|
||||
Client timeout in seconds, default no timeout
|
||||
-s, --ssl Enable SSL encrypted channel to the server
|
||||
-r ROOT_CERTIFICATES, --root-certificates ROOT_CERTIFICATES
|
||||
File holding PEM-encoded root certificates, default none
|
||||
-p PRIVATE_KEY, --private-key PRIVATE_KEY
|
||||
File holding PEM-encoded private key, default is none
|
||||
-x CERTIFICATE_CHAIN, --certificate-chain CERTIFICATE_CHAIN
|
||||
File holding PEM-encoded certicate chain default is none
|
||||
```
|
|
@ -0,0 +1,33 @@
|
|||
class BoundingBox:
|
||||
def __init__(self, classID, confidence, x1, x2, y1, y2, image_width, image_height):
|
||||
self.classID = classID
|
||||
self.confidence = confidence
|
||||
self.x1 = x1
|
||||
self.x2 = x2
|
||||
self.y1 = y1
|
||||
self.y2 = y2
|
||||
self.u1 = x1 / image_width
|
||||
self.u2 = x2 / image_width
|
||||
self.v1 = y1 / image_height
|
||||
self.v2 = y2 / image_height
|
||||
|
||||
def box(self):
|
||||
return (self.x1, self.y1, self.x2, self.y2)
|
||||
|
||||
def width(self):
|
||||
return self.x2 - self.x1
|
||||
|
||||
def height(self):
|
||||
return self.y2 - self.y1
|
||||
|
||||
def center_absolute(self):
|
||||
return (0.5 * (self.x1 + self.x2), 0.5 * (self.y1 + self.y2))
|
||||
|
||||
def center_normalized(self):
|
||||
return (0.5 * (self.u1 + self.u2), 0.5 * (self.v1 + self.v2))
|
||||
|
||||
def size_absolute(self):
|
||||
return (self.x2 - self.x1, self.y2 - self.y1)
|
||||
|
||||
def size_normalized(self):
|
||||
return (self.u2 - self.u1, self.v2 - self.v1)
|
|
@ -0,0 +1,334 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import sys
|
||||
import cv2
|
||||
|
||||
import tritonclient.grpc as grpcclient
|
||||
from tritonclient.utils import InferenceServerException
|
||||
|
||||
from processing import preprocess, postprocess
|
||||
from render import render_box, render_filled_box, get_text_size, render_text, RAND_COLORS
|
||||
from labels import COCOLabels
|
||||
|
||||
INPUT_NAMES = ["images"]
|
||||
OUTPUT_NAMES = ["num_dets", "det_boxes", "det_scores", "det_classes"]
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('mode',
|
||||
choices=['dummy', 'image', 'video'],
|
||||
default='dummy',
|
||||
help='Run mode. \'dummy\' will send an emtpy buffer to the server to test if inference works. \'image\' will process an image. \'video\' will process a video.')
|
||||
parser.add_argument('input',
|
||||
type=str,
|
||||
nargs='?',
|
||||
help='Input file to load from in image or video mode')
|
||||
parser.add_argument('-m',
|
||||
'--model',
|
||||
type=str,
|
||||
required=False,
|
||||
default='yolov7',
|
||||
help='Inference model name, default yolov7')
|
||||
parser.add_argument('--width',
|
||||
type=int,
|
||||
required=False,
|
||||
default=640,
|
||||
help='Inference model input width, default 640')
|
||||
parser.add_argument('--height',
|
||||
type=int,
|
||||
required=False,
|
||||
default=640,
|
||||
help='Inference model input height, default 640')
|
||||
parser.add_argument('-u',
|
||||
'--url',
|
||||
type=str,
|
||||
required=False,
|
||||
default='localhost:8001',
|
||||
help='Inference server URL, default localhost:8001')
|
||||
parser.add_argument('-o',
|
||||
'--out',
|
||||
type=str,
|
||||
required=False,
|
||||
default='',
|
||||
help='Write output into file instead of displaying it')
|
||||
parser.add_argument('-f',
|
||||
'--fps',
|
||||
type=float,
|
||||
required=False,
|
||||
default=24.0,
|
||||
help='Video output fps, default 24.0 FPS')
|
||||
parser.add_argument('-i',
|
||||
'--model-info',
|
||||
action="store_true",
|
||||
required=False,
|
||||
default=False,
|
||||
help='Print model status, configuration and statistics')
|
||||
parser.add_argument('-v',
|
||||
'--verbose',
|
||||
action="store_true",
|
||||
required=False,
|
||||
default=False,
|
||||
help='Enable verbose client output')
|
||||
parser.add_argument('-t',
|
||||
'--client-timeout',
|
||||
type=float,
|
||||
required=False,
|
||||
default=None,
|
||||
help='Client timeout in seconds, default no timeout')
|
||||
parser.add_argument('-s',
|
||||
'--ssl',
|
||||
action="store_true",
|
||||
required=False,
|
||||
default=False,
|
||||
help='Enable SSL encrypted channel to the server')
|
||||
parser.add_argument('-r',
|
||||
'--root-certificates',
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help='File holding PEM-encoded root certificates, default none')
|
||||
parser.add_argument('-p',
|
||||
'--private-key',
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help='File holding PEM-encoded private key, default is none')
|
||||
parser.add_argument('-x',
|
||||
'--certificate-chain',
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help='File holding PEM-encoded certicate chain default is none')
|
||||
|
||||
FLAGS = parser.parse_args()
|
||||
|
||||
# Create server context
|
||||
try:
|
||||
triton_client = grpcclient.InferenceServerClient(
|
||||
url=FLAGS.url,
|
||||
verbose=FLAGS.verbose,
|
||||
ssl=FLAGS.ssl,
|
||||
root_certificates=FLAGS.root_certificates,
|
||||
private_key=FLAGS.private_key,
|
||||
certificate_chain=FLAGS.certificate_chain)
|
||||
except Exception as e:
|
||||
print("context creation failed: " + str(e))
|
||||
sys.exit()
|
||||
|
||||
# Health check
|
||||
if not triton_client.is_server_live():
|
||||
print("FAILED : is_server_live")
|
||||
sys.exit(1)
|
||||
|
||||
if not triton_client.is_server_ready():
|
||||
print("FAILED : is_server_ready")
|
||||
sys.exit(1)
|
||||
|
||||
if not triton_client.is_model_ready(FLAGS.model):
|
||||
print("FAILED : is_model_ready")
|
||||
sys.exit(1)
|
||||
|
||||
if FLAGS.model_info:
|
||||
# Model metadata
|
||||
try:
|
||||
metadata = triton_client.get_model_metadata(FLAGS.model)
|
||||
print(metadata)
|
||||
except InferenceServerException as ex:
|
||||
if "Request for unknown model" not in ex.message():
|
||||
print("FAILED : get_model_metadata")
|
||||
print("Got: {}".format(ex.message()))
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("FAILED : get_model_metadata")
|
||||
sys.exit(1)
|
||||
|
||||
# Model configuration
|
||||
try:
|
||||
config = triton_client.get_model_config(FLAGS.model)
|
||||
if not (config.config.name == FLAGS.model):
|
||||
print("FAILED: get_model_config")
|
||||
sys.exit(1)
|
||||
print(config)
|
||||
except InferenceServerException as ex:
|
||||
print("FAILED : get_model_config")
|
||||
print("Got: {}".format(ex.message()))
|
||||
sys.exit(1)
|
||||
|
||||
# DUMMY MODE
|
||||
if FLAGS.mode == 'dummy':
|
||||
print("Running in 'dummy' mode")
|
||||
print("Creating emtpy buffer filled with ones...")
|
||||
inputs = []
|
||||
outputs = []
|
||||
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
|
||||
inputs[0].set_data_from_numpy(np.ones(shape=(1, 3, FLAGS.width, FLAGS.height), dtype=np.float32))
|
||||
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
|
||||
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
|
||||
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
|
||||
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
|
||||
|
||||
print("Invoking inference...")
|
||||
results = triton_client.infer(model_name=FLAGS.model,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
client_timeout=FLAGS.client_timeout)
|
||||
if FLAGS.model_info:
|
||||
statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
|
||||
if len(statistics.model_stats) != 1:
|
||||
print("FAILED: get_inference_statistics")
|
||||
sys.exit(1)
|
||||
print(statistics)
|
||||
print("Done")
|
||||
|
||||
for output in OUTPUT_NAMES:
|
||||
result = results.as_numpy(output)
|
||||
print(f"Received result buffer \"{output}\" of size {result.shape}")
|
||||
print(f"Naive buffer sum: {np.sum(result)}")
|
||||
|
||||
# IMAGE MODE
|
||||
if FLAGS.mode == 'image':
|
||||
print("Running in 'image' mode")
|
||||
if not FLAGS.input:
|
||||
print("FAILED: no input image")
|
||||
sys.exit(1)
|
||||
|
||||
inputs = []
|
||||
outputs = []
|
||||
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
|
||||
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
|
||||
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
|
||||
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
|
||||
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
|
||||
|
||||
print("Creating buffer from image file...")
|
||||
input_image = cv2.imread(str(FLAGS.input))
|
||||
if input_image is None:
|
||||
print(f"FAILED: could not load input image {str(FLAGS.input)}")
|
||||
sys.exit(1)
|
||||
input_image_buffer = preprocess(input_image, [FLAGS.width, FLAGS.height])
|
||||
input_image_buffer = np.expand_dims(input_image_buffer, axis=0)
|
||||
|
||||
inputs[0].set_data_from_numpy(input_image_buffer)
|
||||
|
||||
print("Invoking inference...")
|
||||
results = triton_client.infer(model_name=FLAGS.model,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
client_timeout=FLAGS.client_timeout)
|
||||
if FLAGS.model_info:
|
||||
statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
|
||||
if len(statistics.model_stats) != 1:
|
||||
print("FAILED: get_inference_statistics")
|
||||
sys.exit(1)
|
||||
print(statistics)
|
||||
print("Done")
|
||||
|
||||
for output in OUTPUT_NAMES:
|
||||
result = results.as_numpy(output)
|
||||
print(f"Received result buffer \"{output}\" of size {result.shape}")
|
||||
print(f"Naive buffer sum: {np.sum(result)}")
|
||||
|
||||
num_dets = results.as_numpy(OUTPUT_NAMES[0])
|
||||
det_boxes = results.as_numpy(OUTPUT_NAMES[1])
|
||||
det_scores = results.as_numpy(OUTPUT_NAMES[2])
|
||||
det_classes = results.as_numpy(OUTPUT_NAMES[3])
|
||||
detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, input_image.shape[1], input_image.shape[0], [FLAGS.width, FLAGS.height])
|
||||
print(f"Detected objects: {len(detected_objects)}")
|
||||
|
||||
for box in detected_objects:
|
||||
print(f"{COCOLabels(box.classID).name}: {box.confidence}")
|
||||
input_image = render_box(input_image, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist()))
|
||||
size = get_text_size(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6)
|
||||
input_image = render_filled_box(input_image, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220))
|
||||
input_image = render_text(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5)
|
||||
|
||||
if FLAGS.out:
|
||||
cv2.imwrite(FLAGS.out, input_image)
|
||||
print(f"Saved result to {FLAGS.out}")
|
||||
else:
|
||||
cv2.imshow('image', input_image)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# VIDEO MODE
|
||||
if FLAGS.mode == 'video':
|
||||
print("Running in 'video' mode")
|
||||
if not FLAGS.input:
|
||||
print("FAILED: no input video")
|
||||
sys.exit(1)
|
||||
|
||||
inputs = []
|
||||
outputs = []
|
||||
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
|
||||
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
|
||||
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
|
||||
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
|
||||
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
|
||||
|
||||
print("Opening input video stream...")
|
||||
cap = cv2.VideoCapture(FLAGS.input)
|
||||
if not cap.isOpened():
|
||||
print(f"FAILED: cannot open video {FLAGS.input}")
|
||||
sys.exit(1)
|
||||
|
||||
counter = 0
|
||||
out = None
|
||||
print("Invoking inference...")
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print("failed to fetch next frame")
|
||||
break
|
||||
|
||||
if counter == 0 and FLAGS.out:
|
||||
print("Opening output video stream...")
|
||||
fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
|
||||
out = cv2.VideoWriter(FLAGS.out, fourcc, FLAGS.fps, (frame.shape[1], frame.shape[0]))
|
||||
|
||||
input_image_buffer = preprocess(frame, [FLAGS.width, FLAGS.height])
|
||||
input_image_buffer = np.expand_dims(input_image_buffer, axis=0)
|
||||
|
||||
inputs[0].set_data_from_numpy(input_image_buffer)
|
||||
|
||||
results = triton_client.infer(model_name=FLAGS.model,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
client_timeout=FLAGS.client_timeout)
|
||||
|
||||
num_dets = results.as_numpy("num_dets")
|
||||
det_boxes = results.as_numpy("det_boxes")
|
||||
det_scores = results.as_numpy("det_scores")
|
||||
det_classes = results.as_numpy("det_classes")
|
||||
detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, frame.shape[1], frame.shape[0], [FLAGS.width, FLAGS.height])
|
||||
print(f"Frame {counter}: {len(detected_objects)} objects")
|
||||
counter += 1
|
||||
|
||||
for box in detected_objects:
|
||||
print(f"{COCOLabels(box.classID).name}: {box.confidence}")
|
||||
frame = render_box(frame, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist()))
|
||||
size = get_text_size(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6)
|
||||
frame = render_filled_box(frame, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220))
|
||||
frame = render_text(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5)
|
||||
|
||||
if FLAGS.out:
|
||||
out.write(frame)
|
||||
else:
|
||||
cv2.imshow('image', frame)
|
||||
if cv2.waitKey(1) == ord('q'):
|
||||
break
|
||||
|
||||
if FLAGS.model_info:
|
||||
statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
|
||||
if len(statistics.model_stats) != 1:
|
||||
print("FAILED: get_inference_statistics")
|
||||
sys.exit(1)
|
||||
print(statistics)
|
||||
print("Done")
|
||||
|
||||
cap.release()
|
||||
if FLAGS.out:
|
||||
out.release()
|
||||
else:
|
||||
cv2.destroyAllWindows()
|
Binary file not shown.
After Width: | Height: | Size: 160 KiB |
Binary file not shown.
After Width: | Height: | Size: 180 KiB |
|
@ -0,0 +1,83 @@
|
|||
from enum import Enum
|
||||
|
||||
class COCOLabels(Enum):
|
||||
PERSON = 0
|
||||
BICYCLE = 1
|
||||
CAR = 2
|
||||
MOTORBIKE = 3
|
||||
AEROPLANE = 4
|
||||
BUS = 5
|
||||
TRAIN = 6
|
||||
TRUCK = 7
|
||||
BOAT = 8
|
||||
TRAFFIC_LIGHT = 9
|
||||
FIRE_HYDRANT = 10
|
||||
STOP_SIGN = 11
|
||||
PARKING_METER = 12
|
||||
BENCH = 13
|
||||
BIRD = 14
|
||||
CAT = 15
|
||||
DOG = 16
|
||||
HORSE = 17
|
||||
SHEEP = 18
|
||||
COW = 19
|
||||
ELEPHANT = 20
|
||||
BEAR = 21
|
||||
ZEBRA = 22
|
||||
GIRAFFE = 23
|
||||
BACKPACK = 24
|
||||
UMBRELLA = 25
|
||||
HANDBAG = 26
|
||||
TIE = 27
|
||||
SUITCASE = 28
|
||||
FRISBEE = 29
|
||||
SKIS = 30
|
||||
SNOWBOARD = 31
|
||||
SPORTS_BALL = 32
|
||||
KITE = 33
|
||||
BASEBALL_BAT = 34
|
||||
BASEBALL_GLOVE = 35
|
||||
SKATEBOARD = 36
|
||||
SURFBOARD = 37
|
||||
TENNIS_RACKET = 38
|
||||
BOTTLE = 39
|
||||
WINE_GLASS = 40
|
||||
CUP = 41
|
||||
FORK = 42
|
||||
KNIFE = 43
|
||||
SPOON = 44
|
||||
BOWL = 45
|
||||
BANANA = 46
|
||||
APPLE = 47
|
||||
SANDWICH = 48
|
||||
ORANGE = 49
|
||||
BROCCOLI = 50
|
||||
CARROT = 51
|
||||
HOT_DOG = 52
|
||||
PIZZA = 53
|
||||
DONUT = 54
|
||||
CAKE = 55
|
||||
CHAIR = 56
|
||||
SOFA = 57
|
||||
POTTEDPLANT = 58
|
||||
BED = 59
|
||||
DININGTABLE = 60
|
||||
TOILET = 61
|
||||
TVMONITOR = 62
|
||||
LAPTOP = 63
|
||||
MOUSE = 64
|
||||
REMOTE = 65
|
||||
KEYBOARD = 66
|
||||
CELL_PHONE = 67
|
||||
MICROWAVE = 68
|
||||
OVEN = 69
|
||||
TOASTER = 70
|
||||
SINK = 71
|
||||
REFRIGERATOR = 72
|
||||
BOOK = 73
|
||||
CLOCK = 74
|
||||
VASE = 75
|
||||
SCISSORS = 76
|
||||
TEDDY_BEAR = 77
|
||||
HAIR_DRIER = 78
|
||||
TOOTHBRUSH = 79
|
|
@ -0,0 +1,51 @@
|
|||
from boundingbox import BoundingBox
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def preprocess(img, input_shape, letter_box=True):
|
||||
if letter_box:
|
||||
img_h, img_w, _ = img.shape
|
||||
new_h, new_w = input_shape[0], input_shape[1]
|
||||
offset_h, offset_w = 0, 0
|
||||
if (new_w / img_w) <= (new_h / img_h):
|
||||
new_h = int(img_h * new_w / img_w)
|
||||
offset_h = (input_shape[0] - new_h) // 2
|
||||
else:
|
||||
new_w = int(img_w * new_h / img_h)
|
||||
offset_w = (input_shape[1] - new_w) // 2
|
||||
resized = cv2.resize(img, (new_w, new_h))
|
||||
img = np.full((input_shape[0], input_shape[1], 3), 127, dtype=np.uint8)
|
||||
img[offset_h:(offset_h + new_h), offset_w:(offset_w + new_w), :] = resized
|
||||
else:
|
||||
img = cv2.resize(img, (input_shape[1], input_shape[0]))
|
||||
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = img.transpose((2, 0, 1)).astype(np.float32)
|
||||
img /= 255.0
|
||||
return img
|
||||
|
||||
def postprocess(num_dets, det_boxes, det_scores, det_classes, img_w, img_h, input_shape, letter_box=True):
|
||||
boxes = det_boxes[0, :num_dets[0][0]] / np.array([input_shape[0], input_shape[1], input_shape[0], input_shape[1]], dtype=np.float32)
|
||||
scores = det_scores[0, :num_dets[0][0]]
|
||||
classes = det_classes[0, :num_dets[0][0]].astype(np.int)
|
||||
|
||||
old_h, old_w = img_h, img_w
|
||||
offset_h, offset_w = 0, 0
|
||||
if letter_box:
|
||||
if (img_w / input_shape[1]) >= (img_h / input_shape[0]):
|
||||
old_h = int(input_shape[0] * img_w / input_shape[1])
|
||||
offset_h = (old_h - img_h) // 2
|
||||
else:
|
||||
old_w = int(input_shape[1] * img_h / input_shape[0])
|
||||
offset_w = (old_w - img_w) // 2
|
||||
|
||||
boxes = boxes * np.array([old_w, old_h, old_w, old_h], dtype=np.float32)
|
||||
if letter_box:
|
||||
boxes -= np.array([offset_w, offset_h, offset_w, offset_h], dtype=np.float32)
|
||||
boxes = boxes.astype(np.int)
|
||||
|
||||
detected_objects = []
|
||||
for box, score, label in zip(boxes, scores, classes):
|
||||
detected_objects.append(BoundingBox(label, score, box[0], box[2], box[1], box[3], img_w, img_h))
|
||||
return detected_objects
|
|
@ -0,0 +1,110 @@
|
|||
import numpy as np
|
||||
|
||||
import cv2
|
||||
|
||||
from math import sqrt
|
||||
|
||||
_LINE_THICKNESS_SCALING = 500.0
|
||||
|
||||
np.random.seed(0)
|
||||
RAND_COLORS = np.random.randint(50, 255, (64, 3), "int") # used for class visu
|
||||
RAND_COLORS[0] = [220, 220, 220]
|
||||
|
||||
def render_box(img, box, color=(200, 200, 200)):
|
||||
"""
|
||||
Render a box. Calculates scaling and thickness automatically.
|
||||
:param img: image to render into
|
||||
:param box: (x1, y1, x2, y2) - box coordinates
|
||||
:param color: (b, g, r) - box color
|
||||
:return: updated image
|
||||
"""
|
||||
x1, y1, x2, y2 = box
|
||||
thickness = int(
|
||||
round(
|
||||
(img.shape[0] * img.shape[1])
|
||||
/ (_LINE_THICKNESS_SCALING * _LINE_THICKNESS_SCALING)
|
||||
)
|
||||
)
|
||||
thickness = max(1, thickness)
|
||||
img = cv2.rectangle(
|
||||
img,
|
||||
(int(x1), int(y1)),
|
||||
(int(x2), int(y2)),
|
||||
color,
|
||||
thickness=thickness
|
||||
)
|
||||
return img
|
||||
|
||||
def render_filled_box(img, box, color=(200, 200, 200)):
|
||||
"""
|
||||
Render a box. Calculates scaling and thickness automatically.
|
||||
:param img: image to render into
|
||||
:param box: (x1, y1, x2, y2) - box coordinates
|
||||
:param color: (b, g, r) - box color
|
||||
:return: updated image
|
||||
"""
|
||||
x1, y1, x2, y2 = box
|
||||
img = cv2.rectangle(
|
||||
img,
|
||||
(int(x1), int(y1)),
|
||||
(int(x2), int(y2)),
|
||||
color,
|
||||
thickness=cv2.FILLED
|
||||
)
|
||||
return img
|
||||
|
||||
_TEXT_THICKNESS_SCALING = 700.0
|
||||
_TEXT_SCALING = 520.0
|
||||
|
||||
|
||||
def get_text_size(img, text, normalised_scaling=1.0):
|
||||
"""
|
||||
Get calculated text size (as box width and height)
|
||||
:param img: image reference, used to determine appropriate text scaling
|
||||
:param text: text to display
|
||||
:param normalised_scaling: additional normalised scaling. Default 1.0.
|
||||
:return: (width, height) - width and height of text box
|
||||
"""
|
||||
thickness = int(
|
||||
round(
|
||||
(img.shape[0] * img.shape[1])
|
||||
/ (_TEXT_THICKNESS_SCALING * _TEXT_THICKNESS_SCALING)
|
||||
)
|
||||
* normalised_scaling
|
||||
)
|
||||
thickness = max(1, thickness)
|
||||
scaling = img.shape[0] / _TEXT_SCALING * normalised_scaling
|
||||
return cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, scaling, thickness)[0]
|
||||
|
||||
|
||||
def render_text(img, text, pos, color=(200, 200, 200), normalised_scaling=1.0):
|
||||
"""
|
||||
Render a text into the image. Calculates scaling and thickness automatically.
|
||||
:param img: image to render into
|
||||
:param text: text to display
|
||||
:param pos: (x, y) - upper left coordinates of render position
|
||||
:param color: (b, g, r) - text color
|
||||
:param normalised_scaling: additional normalised scaling. Default 1.0.
|
||||
:return: updated image
|
||||
"""
|
||||
x, y = pos
|
||||
thickness = int(
|
||||
round(
|
||||
(img.shape[0] * img.shape[1])
|
||||
/ (_TEXT_THICKNESS_SCALING * _TEXT_THICKNESS_SCALING)
|
||||
)
|
||||
* normalised_scaling
|
||||
)
|
||||
thickness = max(1, thickness)
|
||||
scaling = img.shape[0] / _TEXT_SCALING * normalised_scaling
|
||||
size = get_text_size(img, text, normalised_scaling)
|
||||
cv2.putText(
|
||||
img,
|
||||
text,
|
||||
(int(x), int(y + size[1])),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
scaling,
|
||||
color,
|
||||
thickness=thickness,
|
||||
)
|
||||
return img
|
Loading…
Reference in New Issue