mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
* Update LICENSE to AGPL-3.0 This pull request updates the license of the YOLOv5 project from GNU General Public License v3.0 (GPL-3.0) to GNU Affero General Public License v3.0 (AGPL-3.0). We at Ultralytics have decided to make this change in order to better protect our intellectual property and ensure that any modifications made to the YOLOv5 source code will be shared back with the community when used over a network. AGPL-3.0 is very similar to GPL-3.0, but with an additional clause to address the use of software over a network. This change ensures that if someone modifies YOLOv5 and provides it as a service over a network (e.g., through a web application or API), they must also make the source code of their modified version available to users of the service. This update includes the following changes: - Replace the `LICENSE` file with the AGPL-3.0 license text - Update the license reference in the `README.md` file - Update the license headers in source code files We believe that this change will promote a more collaborative environment and help drive further innovation within the YOLOv5 community. Please review the changes and let us know if you have any questions or concerns. Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update headers to AGPL-3.0 --------- Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
86 lines
3.5 KiB
Python
86 lines
3.5 KiB
Python
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
|
|
""" Utils to interact with the Triton Inference Server
|
|
"""
|
|
|
|
import typing
|
|
from urllib.parse import urlparse
|
|
|
|
import torch
|
|
|
|
|
|
class TritonRemoteModel:
|
|
""" A wrapper over a model served by the Triton Inference Server. It can
|
|
be configured to communicate over GRPC or HTTP. It accepts Torch Tensors
|
|
as input and returns them as outputs.
|
|
"""
|
|
|
|
def __init__(self, url: str):
|
|
"""
|
|
Keyword arguments:
|
|
url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000
|
|
"""
|
|
|
|
parsed_url = urlparse(url)
|
|
if parsed_url.scheme == 'grpc':
|
|
from tritonclient.grpc import InferenceServerClient, InferInput
|
|
|
|
self.client = InferenceServerClient(parsed_url.netloc) # Triton GRPC client
|
|
model_repository = self.client.get_model_repository_index()
|
|
self.model_name = model_repository.models[0].name
|
|
self.metadata = self.client.get_model_metadata(self.model_name, as_json=True)
|
|
|
|
def create_input_placeholders() -> typing.List[InferInput]:
|
|
return [
|
|
InferInput(i['name'], [int(s) for s in i['shape']], i['datatype']) for i in self.metadata['inputs']]
|
|
|
|
else:
|
|
from tritonclient.http import InferenceServerClient, InferInput
|
|
|
|
self.client = InferenceServerClient(parsed_url.netloc) # Triton HTTP client
|
|
model_repository = self.client.get_model_repository_index()
|
|
self.model_name = model_repository[0]['name']
|
|
self.metadata = self.client.get_model_metadata(self.model_name)
|
|
|
|
def create_input_placeholders() -> typing.List[InferInput]:
|
|
return [
|
|
InferInput(i['name'], [int(s) for s in i['shape']], i['datatype']) for i in self.metadata['inputs']]
|
|
|
|
self._create_input_placeholders_fn = create_input_placeholders
|
|
|
|
@property
|
|
def runtime(self):
|
|
"""Returns the model runtime"""
|
|
return self.metadata.get('backend', self.metadata.get('platform'))
|
|
|
|
def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]:
|
|
""" Invokes the model. Parameters can be provided via args or kwargs.
|
|
args, if provided, are assumed to match the order of inputs of the model.
|
|
kwargs are matched with the model input names.
|
|
"""
|
|
inputs = self._create_inputs(*args, **kwargs)
|
|
response = self.client.infer(model_name=self.model_name, inputs=inputs)
|
|
result = []
|
|
for output in self.metadata['outputs']:
|
|
tensor = torch.as_tensor(response.as_numpy(output['name']))
|
|
result.append(tensor)
|
|
return result[0] if len(result) == 1 else result
|
|
|
|
def _create_inputs(self, *args, **kwargs):
|
|
args_len, kwargs_len = len(args), len(kwargs)
|
|
if not args_len and not kwargs_len:
|
|
raise RuntimeError('No inputs provided.')
|
|
if args_len and kwargs_len:
|
|
raise RuntimeError('Cannot specify args and kwargs at the same time')
|
|
|
|
placeholders = self._create_input_placeholders_fn()
|
|
if args_len:
|
|
if args_len != len(placeholders):
|
|
raise RuntimeError(f'Expected {len(placeholders)} inputs, got {args_len}.')
|
|
for input, value in zip(placeholders, args):
|
|
input.set_data_from_numpy(value.cpu().numpy())
|
|
else:
|
|
for input in placeholders:
|
|
value = kwargs[input.name]
|
|
input.set_data_from_numpy(value.cpu().numpy())
|
|
return placeholders
|