yolov5/utils/triton.py
Glenn Jocher 34cf749958
Update LICENSE to AGPL-3.0 (#11359)
* Update LICENSE to AGPL-3.0

This pull request updates the license of the YOLOv5 project from GNU General Public License v3.0 (GPL-3.0) to GNU Affero General Public License v3.0 (AGPL-3.0).

We at Ultralytics have decided to make this change in order to better protect our intellectual property and ensure that any modifications made to the YOLOv5 source code will be shared back with the community when used over a network.

AGPL-3.0 is very similar to GPL-3.0, but with an additional clause to address the use of software over a network. This change ensures that if someone modifies YOLOv5 and provides it as a service over a network (e.g., through a web application or API), they must also make the source code of their modified version available to users of the service.

This update includes the following changes:
- Replace the `LICENSE` file with the AGPL-3.0 license text
- Update the license reference in the `README.md` file
- Update the license headers in source code files

We believe that this change will promote a more collaborative environment and help drive further innovation within the YOLOv5 community.

Please review the changes and let us know if you have any questions or concerns.


Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* Update headers to AGPL-3.0

---------

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
2023-04-14 14:36:16 +02:00

86 lines
3.5 KiB
Python

# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
""" Utils to interact with the Triton Inference Server
"""
import typing
from urllib.parse import urlparse
import torch
class TritonRemoteModel:
""" A wrapper over a model served by the Triton Inference Server. It can
be configured to communicate over GRPC or HTTP. It accepts Torch Tensors
as input and returns them as outputs.
"""
def __init__(self, url: str):
"""
Keyword arguments:
url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000
"""
parsed_url = urlparse(url)
if parsed_url.scheme == 'grpc':
from tritonclient.grpc import InferenceServerClient, InferInput
self.client = InferenceServerClient(parsed_url.netloc) # Triton GRPC client
model_repository = self.client.get_model_repository_index()
self.model_name = model_repository.models[0].name
self.metadata = self.client.get_model_metadata(self.model_name, as_json=True)
def create_input_placeholders() -> typing.List[InferInput]:
return [
InferInput(i['name'], [int(s) for s in i['shape']], i['datatype']) for i in self.metadata['inputs']]
else:
from tritonclient.http import InferenceServerClient, InferInput
self.client = InferenceServerClient(parsed_url.netloc) # Triton HTTP client
model_repository = self.client.get_model_repository_index()
self.model_name = model_repository[0]['name']
self.metadata = self.client.get_model_metadata(self.model_name)
def create_input_placeholders() -> typing.List[InferInput]:
return [
InferInput(i['name'], [int(s) for s in i['shape']], i['datatype']) for i in self.metadata['inputs']]
self._create_input_placeholders_fn = create_input_placeholders
@property
def runtime(self):
"""Returns the model runtime"""
return self.metadata.get('backend', self.metadata.get('platform'))
def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]:
""" Invokes the model. Parameters can be provided via args or kwargs.
args, if provided, are assumed to match the order of inputs of the model.
kwargs are matched with the model input names.
"""
inputs = self._create_inputs(*args, **kwargs)
response = self.client.infer(model_name=self.model_name, inputs=inputs)
result = []
for output in self.metadata['outputs']:
tensor = torch.as_tensor(response.as_numpy(output['name']))
result.append(tensor)
return result[0] if len(result) == 1 else result
def _create_inputs(self, *args, **kwargs):
args_len, kwargs_len = len(args), len(kwargs)
if not args_len and not kwargs_len:
raise RuntimeError('No inputs provided.')
if args_len and kwargs_len:
raise RuntimeError('Cannot specify args and kwargs at the same time')
placeholders = self._create_input_placeholders_fn()
if args_len:
if args_len != len(placeholders):
raise RuntimeError(f'Expected {len(placeholders)} inputs, got {args_len}.')
for input, value in zip(placeholders, args):
input.set_data_from_numpy(value.cpu().numpy())
else:
for input in placeholders:
value = kwargs[input.name]
input.set_data_from_numpy(value.cpu().numpy())
return placeholders