yolov5/utils/triton.py

86 lines
3.5 KiB
Python
Raw Normal View History

Detect.py supports running against a Triton container (#9228) * update coco128-seg comments * Enables detect.py to use Triton for inference Triton Inference Server is an open source inference serving software that streamlines AI inferencing. https://github.com/triton-inference-server/server The user can now provide a "--triton-url" argument to detect.py to use a local or remote Triton server for inference. For e.g., http://localhost:8000 will use http over port 8000 and grpc://localhost:8001 will use grpc over port 8001. Note, it is not necessary to specify a weights file to use Triton. A Triton container can be created by first exporting the Yolov5 model to a Triton supported runtime. Onnx, Torchscript, TensorRT are supported by both Triton and the export.py script. The exported model can then be containerized via the OctoML CLI. See https://github.com/octoml/octo-cli#getting-started for a guide. * added triton client to requirements * fixed support for TFSavedModels in Triton * reverted change * Test CoreML update Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update ci-testing.yml Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Use pathlib Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Refacto DetectMultiBackend to directly accept triton url as --weights http://... Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Deploy category Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update detect.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update common.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update common.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update predict.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update predict.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update predict.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update triton.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update triton.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add printout and requirements check * Cleanup Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * triton fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed triton model query over grpc * Update check_requirements('tritonclient[all]') * group imports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix likely remote URL bug * update comment * Update is_url() * Fix 2x download attempt on http://path/to/model.pt Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: glennjocher <glenn.jocher@ultralytics.com> Co-authored-by: Gaz Iqbal <giqbal@octoml.ai> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2022-09-24 06:56:42 +08:00
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
""" Utils to interact with the Triton Inference Server
"""
import typing
from urllib.parse import urlparse
import torch
class TritonRemoteModel:
""" A wrapper over a model served by the Triton Inference Server. It can
be configured to communicate over GRPC or HTTP. It accepts Torch Tensors
as input and returns them as outputs.
"""
def __init__(self, url: str):
"""
Keyword arguments:
url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000
"""
parsed_url = urlparse(url)
if parsed_url.scheme == "grpc":
from tritonclient.grpc import InferenceServerClient, InferInput
self.client = InferenceServerClient(parsed_url.netloc) # Triton GRPC client
model_repository = self.client.get_model_repository_index()
self.model_name = model_repository.models[0].name
self.metadata = self.client.get_model_metadata(self.model_name, as_json=True)
def create_input_placeholders() -> typing.List[InferInput]:
return [
InferInput(i['name'], [int(s) for s in i["shape"]], i['datatype']) for i in self.metadata['inputs']]
else:
from tritonclient.http import InferenceServerClient, InferInput
self.client = InferenceServerClient(parsed_url.netloc) # Triton HTTP client
model_repository = self.client.get_model_repository_index()
self.model_name = model_repository[0]['name']
self.metadata = self.client.get_model_metadata(self.model_name)
def create_input_placeholders() -> typing.List[InferInput]:
return [
InferInput(i['name'], [int(s) for s in i["shape"]], i['datatype']) for i in self.metadata['inputs']]
self._create_input_placeholders_fn = create_input_placeholders
@property
def runtime(self):
"""Returns the model runtime"""
return self.metadata.get("backend", self.metadata.get("platform"))
def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]:
""" Invokes the model. Parameters can be provided via args or kwargs.
args, if provided, are assumed to match the order of inputs of the model.
kwargs are matched with the model input names.
"""
inputs = self._create_inputs(*args, **kwargs)
response = self.client.infer(model_name=self.model_name, inputs=inputs)
result = []
for output in self.metadata['outputs']:
tensor = torch.as_tensor(response.as_numpy(output['name']))
result.append(tensor)
return result[0] if len(result) == 1 else result
def _create_inputs(self, *args, **kwargs):
args_len, kwargs_len = len(args), len(kwargs)
if not args_len and not kwargs_len:
raise RuntimeError("No inputs provided.")
if args_len and kwargs_len:
raise RuntimeError("Cannot specify args and kwargs at the same time")
placeholders = self._create_input_placeholders_fn()
if args_len:
if args_len != len(placeholders):
raise RuntimeError(f"Expected {len(placeholders)} inputs, got {args_len}.")
for input, value in zip(placeholders, args):
input.set_data_from_numpy(value.cpu().numpy())
else:
for input in placeholders:
value = kwargs[input.name]
input.set_data_from_numpy(value.cpu().numpy())
return placeholders