diff --git a/deploy/triton-inference-server/README.md b/deploy/triton-inference-server/README.md
new file mode 100644
index 0000000..22b01e3
--- /dev/null
+++ b/deploy/triton-inference-server/README.md
@@ -0,0 +1,161 @@
+# YOLOv7 on Triton Inference Server
+
+Instructions to deploy YOLOv7 as TensorRT engine to [Triton Inference Server](https://github.com/NVIDIA/triton-inference-server).
+
+Triton Inference Server takes care of model deployment with many out-of-the-box benefits, like a GRPC and HTTP interface, automatic scheduling on multiple GPUs, shared memory (even on GPU), dynamic server-side batching, health metrics and memory resource management.
+
+There are no additional dependencies needed to run this deployment, except a working docker daemon with GPU support.
+
+## Export TensorRT
+
+See https://github.com/WongKinYiu/yolov7#export for more info.
+
+```bash
+# Pytorch Yolov7 -> ONNX with grid, EfficientNMS plugin and dynamic batch size
+python export.py --weights ./yolov7.pt --grid --end2end --dynamic-batch --simplify --topk-all 100 --iou-thres 0.65 --conf-thres 0.35 --img-size 640 640
+# ONNX -> TensorRT with trtexec and docker
+docker run -it --rm --gpus=all nvcr.io/nvidia/tensorrt:22.06-py3
+# Copy onnx -> container: docker cp yolov7.onnx <container-id>:/workspace/
+# Export with FP16 precision, min batch 1, opt batch 8 and max batch 8
+./tensorrt/bin/trtexec --onnx=yolov7.onnx --minShapes=images:1x3x640x640 --optShapes=images:8x3x640x640 --maxShapes=images:8x3x640x640 --fp16 --workspace=4096 --saveEngine=yolov7-fp16-1x8x8.engine --timingCacheFile=timing.cache
+# Test engine
+./tensorrt/bin/trtexec --loadEngine=yolov7-fp16-1x8x8.engine
+# Copy engine -> host: docker cp <container-id>:/workspace/yolov7-fp16-1x8x8.engine .
+```
+
+Example output of test with RTX 3090.
+
+```
+[I] === Performance summary ===
+[I] Throughput: 73.4985 qps
+[I] Latency: min = 14.8578 ms, max = 15.8344 ms, mean = 15.07 ms, median = 15.0422 ms, percentile(99%) = 15.7443 ms
+[I] End-to-End Host Latency: min = 25.8715 ms, max = 28.4102 ms, mean = 26.672 ms, median = 26.6082 ms, percentile(99%) = 27.8314 ms
+[I] Enqueue Time: min = 0.793701 ms, max = 1.47144 ms, mean = 1.2008 ms, median = 1.28644 ms, percentile(99%) = 1.38965 ms
+[I] H2D Latency: min = 1.50073 ms, max = 1.52454 ms, mean = 1.51225 ms, median = 1.51404 ms, percentile(99%) = 1.51941 ms
+[I] GPU Compute Time: min = 13.3386 ms, max = 14.3186 ms, mean = 13.5448 ms, median = 13.5178 ms, percentile(99%) = 14.2151 ms
+[I] D2H Latency: min = 0.00878906 ms, max = 0.0172729 ms, mean = 0.0128844 ms, median = 0.0125732 ms, percentile(99%) = 0.0166016 ms
+[I] Total Host Walltime: 3.04768 s
+[I] Total GPU Compute Time: 3.03404 s
+[I] Explanations of the performance metrics are printed in the verbose logs.
+```
+Note: 73.5 qps x batch 8 = 588 fps @ ~15ms latency.
+
+## Model Repository
+
+See [Triton Model Repository Documentation](https://github.com/triton-inference-server/server/blob/main/docs/model_repository.md#model-repository) for more info.
+
+```bash
+# Create folder structure
+mkdir -p triton-deploy/models/yolov7/1/
+touch triton-deploy/models/yolov7/config.pbtxt
+# Place model
+mv yolov7-fp16-1x8x8.engine triton-deploy/models/yolov7/1/model.plan
+```
+
+## Model Configuration
+
+See [Triton Model Configuration Documentation](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md#model-configuration) for more info.
+
+Minimal configuration for `triton-deploy/models/yolov7/config.pbtxt`:
+
+```
+name: "yolov7"
+platform: "tensorrt_plan"
+max_batch_size: 8
+dynamic_batching { }
+```
+
+Example repository:
+
+```bash
+$ tree triton-deploy/
+triton-deploy/
+└── models
+    └── yolov7
+        ├── 1
+        │   └── model.plan
+        └── config.pbtxt
+
+3 directories, 2 files
+```
+
+## Start Triton Inference Server
+
+```
+docker run --gpus all --rm --ipc=host --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd)/triton-deploy/models:/models nvcr.io/nvidia/tritonserver:22.06-py3 tritonserver --model-repository=/models --strict-model-config=false --log-verbose 1
+```
+
+In the log you should see:
+
+```
++--------+---------+--------+
+| Model  | Version | Status |
++--------+---------+--------+
+| yolov7 | 1       | READY  |
++--------+---------+--------+
+```
+
+## Performance with Model Analyzer
+
+See [Triton Model Analyzer Documentation](https://github.com/triton-inference-server/server/blob/main/docs/model_analyzer.md#model-analyzer) for more info.
+
+Performance numbers @ RTX 3090 + AMD Ryzen 9 5950X
+
+Example test for 16 concurrent clients using shared memory, each with batch size 1 requests:
+
+```bash
+docker run -it --ipc=host --net=host nvcr.io/nvidia/tritonserver:22.06-py3-sdk /bin/bash
+
+./install/bin/perf_analyzer -m yolov7 -u 127.0.0.1:8001 -i grpc --shared-memory system --concurrency-range 16
+
+# Result (truncated)
+Concurrency: 16, throughput: 590.119 infer/sec, latency 27080 usec
+```
+
+Throughput for 16 clients with batch size 1 is the same as for a single thread running the engine at 16 batch size locally thanks to Triton [Dynamic Batching Strategy](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md#dynamic-batcher). Result without dynamic batching (disable in model configuration) considerably worse:
+
+```bash
+# Result (truncated)
+Concurrency: 16, throughput: 335.587 infer/sec, latency 47616 usec
+```
+
+## How to run model in your code
+
+Example client can be found in client.py. It can run dummy input, images and videos.
+
+```bash
+pip3 install tritonclient[all] opencv-python
+python3 client.py image data/dog.jpg
+```
+
+![exemplary output result](data/dog_result.jpg)
+
+```
+$ python3 client.py --help
+usage: client.py [-h] [-m MODEL] [--width WIDTH] [--height HEIGHT] [-u URL] [-o OUT] [-f FPS] [-i] [-v] [-t CLIENT_TIMEOUT] [-s] [-r ROOT_CERTIFICATES] [-p PRIVATE_KEY] [-x CERTIFICATE_CHAIN] {dummy,image,video} [input]
+
+positional arguments:
+  {dummy,image,video}   Run mode. 'dummy' will send an emtpy buffer to the server to test if inference works. 'image' will process an image. 'video' will process a video.
+  input                 Input file to load from in image or video mode
+
+optional arguments:
+  -h, --help            show this help message and exit
+  -m MODEL, --model MODEL
+                        Inference model name, default yolov7
+  --width WIDTH         Inference model input width, default 640
+  --height HEIGHT       Inference model input height, default 640
+  -u URL, --url URL     Inference server URL, default localhost:8001
+  -o OUT, --out OUT     Write output into file instead of displaying it
+  -f FPS, --fps FPS     Video output fps, default 24.0 FPS
+  -i, --model-info      Print model status, configuration and statistics
+  -v, --verbose         Enable verbose client output
+  -t CLIENT_TIMEOUT, --client-timeout CLIENT_TIMEOUT
+                        Client timeout in seconds, default no timeout
+  -s, --ssl             Enable SSL encrypted channel to the server
+  -r ROOT_CERTIFICATES, --root-certificates ROOT_CERTIFICATES
+                        File holding PEM-encoded root certificates, default none
+  -p PRIVATE_KEY, --private-key PRIVATE_KEY
+                        File holding PEM-encoded private key, default is none
+  -x CERTIFICATE_CHAIN, --certificate-chain CERTIFICATE_CHAIN
+                        File holding PEM-encoded certicate chain default is none
+```
diff --git a/deploy/triton-inference-server/boundingbox.py b/deploy/triton-inference-server/boundingbox.py
new file mode 100644
index 0000000..8b95330
--- /dev/null
+++ b/deploy/triton-inference-server/boundingbox.py
@@ -0,0 +1,33 @@
+class BoundingBox:
+    def __init__(self, classID, confidence, x1, x2, y1, y2, image_width, image_height):
+        self.classID = classID
+        self.confidence = confidence
+        self.x1 = x1
+        self.x2 = x2
+        self.y1 = y1
+        self.y2 = y2
+        self.u1 = x1 / image_width
+        self.u2 = x2 / image_width
+        self.v1 = y1 / image_height
+        self.v2 = y2 / image_height
+
+    def box(self):
+        return (self.x1, self.y1, self.x2, self.y2)
+
+    def width(self):
+        return self.x2 - self.x1
+
+    def height(self):
+        return self.y2 - self.y1
+
+    def center_absolute(self):
+        return (0.5 * (self.x1 + self.x2), 0.5 * (self.y1 + self.y2))
+
+    def center_normalized(self):
+        return (0.5 * (self.u1 + self.u2), 0.5 * (self.v1 + self.v2))
+
+    def size_absolute(self):
+        return (self.x2 - self.x1, self.y2 - self.y1)
+
+    def size_normalized(self):
+        return (self.u2 - self.u1, self.v2 - self.v1)
diff --git a/deploy/triton-inference-server/client.py b/deploy/triton-inference-server/client.py
new file mode 100644
index 0000000..aedca11
--- /dev/null
+++ b/deploy/triton-inference-server/client.py
@@ -0,0 +1,334 @@
+#!/usr/bin/env python
+
+import argparse
+import numpy as np
+import sys
+import cv2
+
+import tritonclient.grpc as grpcclient
+from tritonclient.utils import InferenceServerException
+
+from processing import preprocess, postprocess
+from render import render_box, render_filled_box, get_text_size, render_text, RAND_COLORS
+from labels import COCOLabels
+
+INPUT_NAMES = ["images"]
+OUTPUT_NAMES = ["num_dets", "det_boxes", "det_scores", "det_classes"]
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('mode',
+                        choices=['dummy', 'image', 'video'],
+                        default='dummy',
+                        help='Run mode. \'dummy\' will send an emtpy buffer to the server to test if inference works. \'image\' will process an image. \'video\' will process a video.')
+    parser.add_argument('input',
+                        type=str,
+                        nargs='?',
+                        help='Input file to load from in image or video mode')
+    parser.add_argument('-m',
+                        '--model',
+                        type=str,
+                        required=False,
+                        default='yolov7',
+                        help='Inference model name, default yolov7')
+    parser.add_argument('--width',
+                        type=int,
+                        required=False,
+                        default=640,
+                        help='Inference model input width, default 640')
+    parser.add_argument('--height',
+                        type=int,
+                        required=False,
+                        default=640,
+                        help='Inference model input height, default 640')
+    parser.add_argument('-u',
+                        '--url',
+                        type=str,
+                        required=False,
+                        default='localhost:8001',
+                        help='Inference server URL, default localhost:8001')
+    parser.add_argument('-o',
+                        '--out',
+                        type=str,
+                        required=False,
+                        default='',
+                        help='Write output into file instead of displaying it')
+    parser.add_argument('-f',
+                        '--fps',
+                        type=float,
+                        required=False,
+                        default=24.0,
+                        help='Video output fps, default 24.0 FPS')
+    parser.add_argument('-i',
+                        '--model-info',
+                        action="store_true",
+                        required=False,
+                        default=False,
+                        help='Print model status, configuration and statistics')
+    parser.add_argument('-v',
+                        '--verbose',
+                        action="store_true",
+                        required=False,
+                        default=False,
+                        help='Enable verbose client output')
+    parser.add_argument('-t',
+                        '--client-timeout',
+                        type=float,
+                        required=False,
+                        default=None,
+                        help='Client timeout in seconds, default no timeout')
+    parser.add_argument('-s',
+                        '--ssl',
+                        action="store_true",
+                        required=False,
+                        default=False,
+                        help='Enable SSL encrypted channel to the server')
+    parser.add_argument('-r',
+                        '--root-certificates',
+                        type=str,
+                        required=False,
+                        default=None,
+                        help='File holding PEM-encoded root certificates, default none')
+    parser.add_argument('-p',
+                        '--private-key',
+                        type=str,
+                        required=False,
+                        default=None,
+                        help='File holding PEM-encoded private key, default is none')
+    parser.add_argument('-x',
+                        '--certificate-chain',
+                        type=str,
+                        required=False,
+                        default=None,
+                        help='File holding PEM-encoded certicate chain default is none')
+
+    FLAGS = parser.parse_args()
+
+    # Create server context
+    try:
+        triton_client = grpcclient.InferenceServerClient(
+            url=FLAGS.url,
+            verbose=FLAGS.verbose,
+            ssl=FLAGS.ssl,
+            root_certificates=FLAGS.root_certificates,
+            private_key=FLAGS.private_key,
+            certificate_chain=FLAGS.certificate_chain)
+    except Exception as e:
+        print("context creation failed: " + str(e))
+        sys.exit()
+
+    # Health check
+    if not triton_client.is_server_live():
+        print("FAILED : is_server_live")
+        sys.exit(1)
+
+    if not triton_client.is_server_ready():
+        print("FAILED : is_server_ready")
+        sys.exit(1)
+
+    if not triton_client.is_model_ready(FLAGS.model):
+        print("FAILED : is_model_ready")
+        sys.exit(1)
+
+    if FLAGS.model_info:
+        # Model metadata
+        try:
+            metadata = triton_client.get_model_metadata(FLAGS.model)
+            print(metadata)
+        except InferenceServerException as ex:
+            if "Request for unknown model" not in ex.message():
+                print("FAILED : get_model_metadata")
+                print("Got: {}".format(ex.message()))
+                sys.exit(1)
+            else:
+                print("FAILED : get_model_metadata")
+                sys.exit(1)
+
+        # Model configuration
+        try:
+            config = triton_client.get_model_config(FLAGS.model)
+            if not (config.config.name == FLAGS.model):
+                print("FAILED: get_model_config")
+                sys.exit(1)
+            print(config)
+        except InferenceServerException as ex:
+            print("FAILED : get_model_config")
+            print("Got: {}".format(ex.message()))
+            sys.exit(1)
+
+    # DUMMY MODE
+    if FLAGS.mode == 'dummy':
+        print("Running in 'dummy' mode")
+        print("Creating emtpy buffer filled with ones...")
+        inputs = []
+        outputs = []
+        inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
+        inputs[0].set_data_from_numpy(np.ones(shape=(1, 3, FLAGS.width, FLAGS.height), dtype=np.float32))
+        outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
+        outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
+        outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
+        outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
+
+        print("Invoking inference...")
+        results = triton_client.infer(model_name=FLAGS.model,
+                                      inputs=inputs,
+                                      outputs=outputs,
+                                      client_timeout=FLAGS.client_timeout)
+        if FLAGS.model_info:
+            statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
+            if len(statistics.model_stats) != 1:
+                print("FAILED: get_inference_statistics")
+                sys.exit(1)
+            print(statistics)
+        print("Done")
+
+        for output in OUTPUT_NAMES:
+            result = results.as_numpy(output)
+            print(f"Received result buffer \"{output}\" of size {result.shape}")
+            print(f"Naive buffer sum: {np.sum(result)}")
+
+    # IMAGE MODE
+    if FLAGS.mode == 'image':
+        print("Running in 'image' mode")
+        if not FLAGS.input:
+            print("FAILED: no input image")
+            sys.exit(1)
+
+        inputs = []
+        outputs = []
+        inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
+        outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
+        outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
+        outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
+        outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
+
+        print("Creating buffer from image file...")
+        input_image = cv2.imread(str(FLAGS.input))
+        if input_image is None:
+            print(f"FAILED: could not load input image {str(FLAGS.input)}")
+            sys.exit(1)
+        input_image_buffer = preprocess(input_image, [FLAGS.width, FLAGS.height])
+        input_image_buffer = np.expand_dims(input_image_buffer, axis=0)
+
+        inputs[0].set_data_from_numpy(input_image_buffer)
+
+        print("Invoking inference...")
+        results = triton_client.infer(model_name=FLAGS.model,
+                                      inputs=inputs,
+                                      outputs=outputs,
+                                      client_timeout=FLAGS.client_timeout)
+        if FLAGS.model_info:
+            statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
+            if len(statistics.model_stats) != 1:
+                print("FAILED: get_inference_statistics")
+                sys.exit(1)
+            print(statistics)
+        print("Done")
+
+        for output in OUTPUT_NAMES:
+            result = results.as_numpy(output)
+            print(f"Received result buffer \"{output}\" of size {result.shape}")
+            print(f"Naive buffer sum: {np.sum(result)}")
+
+        num_dets = results.as_numpy(OUTPUT_NAMES[0])
+        det_boxes = results.as_numpy(OUTPUT_NAMES[1])
+        det_scores = results.as_numpy(OUTPUT_NAMES[2])
+        det_classes = results.as_numpy(OUTPUT_NAMES[3])
+        detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, input_image.shape[1], input_image.shape[0], [FLAGS.width, FLAGS.height])
+        print(f"Detected objects: {len(detected_objects)}")
+
+        for box in detected_objects:
+            print(f"{COCOLabels(box.classID).name}: {box.confidence}")
+            input_image = render_box(input_image, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist()))
+            size = get_text_size(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6)
+            input_image = render_filled_box(input_image, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220))
+            input_image = render_text(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5)
+
+        if FLAGS.out:
+            cv2.imwrite(FLAGS.out, input_image)
+            print(f"Saved result to {FLAGS.out}")
+        else:
+            cv2.imshow('image', input_image)
+            cv2.waitKey(0)
+            cv2.destroyAllWindows()
+
+    # VIDEO MODE
+    if FLAGS.mode == 'video':
+        print("Running in 'video' mode")
+        if not FLAGS.input:
+            print("FAILED: no input video")
+            sys.exit(1)
+
+        inputs = []
+        outputs = []
+        inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32"))
+        outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0]))
+        outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1]))
+        outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2]))
+        outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3]))
+
+        print("Opening input video stream...")
+        cap = cv2.VideoCapture(FLAGS.input)
+        if not cap.isOpened():
+            print(f"FAILED: cannot open video {FLAGS.input}")
+            sys.exit(1)
+
+        counter = 0
+        out = None
+        print("Invoking inference...")
+        while True:
+            ret, frame = cap.read()
+            if not ret:
+                print("failed to fetch next frame")
+                break
+
+            if counter == 0 and FLAGS.out:
+                print("Opening output video stream...")
+                fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
+                out = cv2.VideoWriter(FLAGS.out, fourcc, FLAGS.fps, (frame.shape[1], frame.shape[0]))
+
+            input_image_buffer = preprocess(frame, [FLAGS.width, FLAGS.height])
+            input_image_buffer = np.expand_dims(input_image_buffer, axis=0)
+
+            inputs[0].set_data_from_numpy(input_image_buffer)
+
+            results = triton_client.infer(model_name=FLAGS.model,
+                                          inputs=inputs,
+                                          outputs=outputs,
+                                          client_timeout=FLAGS.client_timeout)
+
+            num_dets = results.as_numpy("num_dets")
+            det_boxes = results.as_numpy("det_boxes")
+            det_scores = results.as_numpy("det_scores")
+            det_classes = results.as_numpy("det_classes")
+            detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, frame.shape[1], frame.shape[0], [FLAGS.width, FLAGS.height])
+            print(f"Frame {counter}: {len(detected_objects)} objects")
+            counter += 1
+
+            for box in detected_objects:
+                print(f"{COCOLabels(box.classID).name}: {box.confidence}")
+                frame = render_box(frame, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist()))
+                size = get_text_size(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6)
+                frame = render_filled_box(frame, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220))
+                frame = render_text(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5)
+
+            if FLAGS.out:
+                out.write(frame)
+            else:
+                cv2.imshow('image', frame)
+                if cv2.waitKey(1) == ord('q'):
+                    break
+
+        if FLAGS.model_info:
+            statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
+            if len(statistics.model_stats) != 1:
+                print("FAILED: get_inference_statistics")
+                sys.exit(1)
+            print(statistics)
+        print("Done")
+
+        cap.release()
+        if FLAGS.out:
+            out.release()
+        else:
+            cv2.destroyAllWindows()
diff --git a/deploy/triton-inference-server/data/dog.jpg b/deploy/triton-inference-server/data/dog.jpg
new file mode 100644
index 0000000..77b0381
Binary files /dev/null and b/deploy/triton-inference-server/data/dog.jpg differ
diff --git a/deploy/triton-inference-server/data/dog_result.jpg b/deploy/triton-inference-server/data/dog_result.jpg
new file mode 100644
index 0000000..6f380ef
Binary files /dev/null and b/deploy/triton-inference-server/data/dog_result.jpg differ
diff --git a/deploy/triton-inference-server/labels.py b/deploy/triton-inference-server/labels.py
new file mode 100644
index 0000000..ba6c5c5
--- /dev/null
+++ b/deploy/triton-inference-server/labels.py
@@ -0,0 +1,83 @@
+from enum import Enum
+
+class COCOLabels(Enum):
+    PERSON = 0
+    BICYCLE = 1
+    CAR = 2
+    MOTORBIKE = 3
+    AEROPLANE = 4
+    BUS = 5
+    TRAIN = 6
+    TRUCK = 7
+    BOAT = 8
+    TRAFFIC_LIGHT = 9
+    FIRE_HYDRANT = 10
+    STOP_SIGN = 11
+    PARKING_METER = 12
+    BENCH = 13
+    BIRD = 14
+    CAT = 15
+    DOG = 16
+    HORSE = 17
+    SHEEP = 18
+    COW = 19
+    ELEPHANT = 20
+    BEAR = 21
+    ZEBRA = 22
+    GIRAFFE = 23
+    BACKPACK = 24
+    UMBRELLA = 25
+    HANDBAG = 26
+    TIE = 27
+    SUITCASE = 28
+    FRISBEE = 29
+    SKIS = 30
+    SNOWBOARD = 31
+    SPORTS_BALL = 32
+    KITE = 33
+    BASEBALL_BAT = 34
+    BASEBALL_GLOVE = 35
+    SKATEBOARD = 36
+    SURFBOARD = 37
+    TENNIS_RACKET = 38
+    BOTTLE = 39
+    WINE_GLASS = 40
+    CUP = 41
+    FORK = 42
+    KNIFE = 43
+    SPOON = 44
+    BOWL = 45
+    BANANA = 46
+    APPLE = 47
+    SANDWICH = 48
+    ORANGE = 49
+    BROCCOLI = 50
+    CARROT = 51
+    HOT_DOG = 52
+    PIZZA = 53
+    DONUT = 54
+    CAKE = 55
+    CHAIR = 56
+    SOFA = 57
+    POTTEDPLANT = 58
+    BED = 59
+    DININGTABLE = 60
+    TOILET = 61
+    TVMONITOR = 62
+    LAPTOP = 63
+    MOUSE = 64
+    REMOTE = 65
+    KEYBOARD = 66
+    CELL_PHONE = 67
+    MICROWAVE = 68
+    OVEN = 69
+    TOASTER = 70
+    SINK = 71
+    REFRIGERATOR = 72
+    BOOK = 73
+    CLOCK = 74
+    VASE = 75
+    SCISSORS = 76
+    TEDDY_BEAR = 77
+    HAIR_DRIER = 78
+    TOOTHBRUSH = 79
diff --git a/deploy/triton-inference-server/processing.py b/deploy/triton-inference-server/processing.py
new file mode 100644
index 0000000..3d51c50
--- /dev/null
+++ b/deploy/triton-inference-server/processing.py
@@ -0,0 +1,51 @@
+from boundingbox import BoundingBox
+
+import cv2
+import numpy as np
+
+def preprocess(img, input_shape, letter_box=True):
+    if letter_box:
+        img_h, img_w, _ = img.shape
+        new_h, new_w = input_shape[0], input_shape[1]
+        offset_h, offset_w = 0, 0
+        if (new_w / img_w) <= (new_h / img_h):
+            new_h = int(img_h * new_w / img_w)
+            offset_h = (input_shape[0] - new_h) // 2
+        else:
+            new_w = int(img_w * new_h / img_h)
+            offset_w = (input_shape[1] - new_w) // 2
+        resized = cv2.resize(img, (new_w, new_h))
+        img = np.full((input_shape[0], input_shape[1], 3), 127, dtype=np.uint8)
+        img[offset_h:(offset_h + new_h), offset_w:(offset_w + new_w), :] = resized
+    else:
+        img = cv2.resize(img, (input_shape[1], input_shape[0]))
+
+    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+    img = img.transpose((2, 0, 1)).astype(np.float32)
+    img /= 255.0
+    return img
+
+def postprocess(num_dets, det_boxes, det_scores, det_classes, img_w, img_h, input_shape, letter_box=True):
+    boxes = det_boxes[0, :num_dets[0][0]] / np.array([input_shape[0], input_shape[1], input_shape[0], input_shape[1]], dtype=np.float32)
+    scores = det_scores[0, :num_dets[0][0]]
+    classes = det_classes[0, :num_dets[0][0]].astype(np.int)
+
+    old_h, old_w = img_h, img_w
+    offset_h, offset_w = 0, 0
+    if letter_box:
+        if (img_w / input_shape[1]) >= (img_h / input_shape[0]):
+            old_h = int(input_shape[0] * img_w / input_shape[1])
+            offset_h = (old_h - img_h) // 2
+        else:
+            old_w = int(input_shape[1] * img_h / input_shape[0])
+            offset_w = (old_w - img_w) // 2
+
+    boxes = boxes * np.array([old_w, old_h, old_w, old_h], dtype=np.float32)
+    if letter_box:
+        boxes -= np.array([offset_w, offset_h, offset_w, offset_h], dtype=np.float32)
+    boxes = boxes.astype(np.int)
+
+    detected_objects = []
+    for box, score, label in zip(boxes, scores, classes):
+        detected_objects.append(BoundingBox(label, score, box[0], box[2], box[1], box[3], img_w, img_h))
+    return detected_objects
diff --git a/deploy/triton-inference-server/render.py b/deploy/triton-inference-server/render.py
new file mode 100644
index 0000000..dea0401
--- /dev/null
+++ b/deploy/triton-inference-server/render.py
@@ -0,0 +1,110 @@
+import numpy as np
+
+import cv2
+
+from math import sqrt
+
+_LINE_THICKNESS_SCALING = 500.0
+
+np.random.seed(0)
+RAND_COLORS = np.random.randint(50, 255, (64, 3), "int")  # used for class visu
+RAND_COLORS[0] = [220, 220, 220]
+
+def render_box(img, box, color=(200, 200, 200)):
+    """
+    Render a box. Calculates scaling and thickness automatically.
+    :param img: image to render into
+    :param box: (x1, y1, x2, y2) - box coordinates
+    :param color: (b, g, r) - box color
+    :return: updated image
+    """
+    x1, y1, x2, y2 = box
+    thickness = int(
+        round(
+            (img.shape[0] * img.shape[1])
+            / (_LINE_THICKNESS_SCALING * _LINE_THICKNESS_SCALING)
+        )
+    )
+    thickness = max(1, thickness)
+    img = cv2.rectangle(
+        img,
+        (int(x1), int(y1)),
+        (int(x2), int(y2)),
+        color,
+        thickness=thickness
+    )
+    return img
+
+def render_filled_box(img, box, color=(200, 200, 200)):
+    """
+    Render a box. Calculates scaling and thickness automatically.
+    :param img: image to render into
+    :param box: (x1, y1, x2, y2) - box coordinates
+    :param color: (b, g, r) - box color
+    :return: updated image
+    """
+    x1, y1, x2, y2 = box
+    img = cv2.rectangle(
+        img,
+        (int(x1), int(y1)),
+        (int(x2), int(y2)),
+        color,
+        thickness=cv2.FILLED
+    )
+    return img
+
+_TEXT_THICKNESS_SCALING = 700.0
+_TEXT_SCALING = 520.0
+
+
+def get_text_size(img, text, normalised_scaling=1.0):
+    """
+    Get calculated text size (as box width and height)
+    :param img: image reference, used to determine appropriate text scaling
+    :param text: text to display
+    :param normalised_scaling: additional normalised scaling. Default 1.0.
+    :return: (width, height) - width and height of text box
+    """
+    thickness = int(
+        round(
+            (img.shape[0] * img.shape[1])
+            / (_TEXT_THICKNESS_SCALING * _TEXT_THICKNESS_SCALING)
+        )
+        * normalised_scaling
+    )
+    thickness = max(1, thickness)
+    scaling = img.shape[0] / _TEXT_SCALING * normalised_scaling
+    return cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, scaling, thickness)[0]
+
+
+def render_text(img, text, pos, color=(200, 200, 200), normalised_scaling=1.0):
+    """
+    Render a text into the image. Calculates scaling and thickness automatically.
+    :param img: image to render into
+    :param text: text to display
+    :param pos: (x, y) - upper left coordinates of render position
+    :param color: (b, g, r) - text color
+    :param normalised_scaling: additional normalised scaling. Default 1.0.
+    :return: updated image
+    """
+    x, y = pos
+    thickness = int(
+        round(
+            (img.shape[0] * img.shape[1])
+            / (_TEXT_THICKNESS_SCALING * _TEXT_THICKNESS_SCALING)
+        )
+        * normalised_scaling
+    )
+    thickness = max(1, thickness)
+    scaling = img.shape[0] / _TEXT_SCALING * normalised_scaling
+    size = get_text_size(img, text, normalised_scaling)
+    cv2.putText(
+        img,
+        text,
+        (int(x), int(y + size[1])),
+        cv2.FONT_HERSHEY_SIMPLEX,
+        scaling,
+        color,
+        thickness=thickness,
+    )
+    return img