mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
* Add license line to .github/ISSUE_TEMPLATE/bug-report.yml * Add license line to .github/ISSUE_TEMPLATE/config.yml * Add license line to .github/ISSUE_TEMPLATE/feature-request.yml * Add license line to .github/ISSUE_TEMPLATE/question.yml * Add license line to .github/dependabot.yml * Add license line to .github/workflows/ci-testing.yml * Add license line to .github/workflows/cla.yml * Add license line to .github/workflows/codeql-analysis.yml * Add license line to .github/workflows/docker.yml * Add license line to .github/workflows/format.yml * Add license line to .github/workflows/greetings.yml * Add license line to .github/workflows/links.yml * Add license line to .github/workflows/merge-main-into-prs.yml * Add license line to .github/workflows/stale.yml * Add license line to benchmarks.py * Add license line to classify/predict.py * Add license line to classify/train.py * Add license line to classify/val.py * Add license line to data/Argoverse.yaml * Add license line to data/GlobalWheat2020.yaml * Add license line to data/ImageNet.yaml * Add license line to data/ImageNet10.yaml * Add license line to data/ImageNet100.yaml * Add license line to data/ImageNet1000.yaml * Add license line to data/Objects365.yaml * Add license line to data/SKU-110K.yaml * Add license line to data/VOC.yaml * Add license line to data/VisDrone.yaml * Add license line to data/coco.yaml * Add license line to data/coco128-seg.yaml * Add license line to data/coco128.yaml * Add license line to data/hyps/hyp.Objects365.yaml * Add license line to data/hyps/hyp.VOC.yaml * Add license line to data/hyps/hyp.no-augmentation.yaml * Add license line to data/hyps/hyp.scratch-high.yaml * Add license line to data/hyps/hyp.scratch-low.yaml * Add license line to data/hyps/hyp.scratch-med.yaml * Add license line to data/xView.yaml * Add license line to detect.py * Add license line to export.py * Add license line to hubconf.py * Add license line to models/common.py * Add license line to models/experimental.py * Add license line to models/hub/anchors.yaml * Add license line to models/hub/yolov3-spp.yaml * Add license line to models/hub/yolov3-tiny.yaml * Add license line to models/hub/yolov3.yaml * Add license line to models/hub/yolov5-bifpn.yaml * Add license line to models/hub/yolov5-fpn.yaml * Add license line to models/hub/yolov5-p2.yaml * Add license line to models/hub/yolov5-p34.yaml * Add license line to models/hub/yolov5-p6.yaml * Add license line to models/hub/yolov5-p7.yaml * Add license line to models/hub/yolov5-panet.yaml * Add license line to models/hub/yolov5l6.yaml * Add license line to models/hub/yolov5m6.yaml * Add license line to models/hub/yolov5n6.yaml * Add license line to models/hub/yolov5s-LeakyReLU.yaml * Add license line to models/hub/yolov5s-ghost.yaml * Add license line to models/hub/yolov5s-transformer.yaml * Add license line to models/hub/yolov5s6.yaml * Add license line to models/hub/yolov5x6.yaml * Add license line to models/segment/yolov5l-seg.yaml * Add license line to models/segment/yolov5m-seg.yaml * Add license line to models/segment/yolov5n-seg.yaml * Add license line to models/segment/yolov5s-seg.yaml * Add license line to models/segment/yolov5x-seg.yaml * Add license line to models/tf.py * Add license line to models/yolo.py * Add license line to models/yolov5l.yaml * Add license line to models/yolov5m.yaml * Add license line to models/yolov5n.yaml * Add license line to models/yolov5s.yaml * Add license line to models/yolov5x.yaml * Add license line to pyproject.toml * Add license line to segment/predict.py * Add license line to segment/train.py * Add license line to segment/val.py * Add license line to train.py * Add license line to utils/__init__.py * Add license line to utils/activations.py * Add license line to utils/augmentations.py * Add license line to utils/autoanchor.py * Add license line to utils/autobatch.py * Add license line to utils/aws/resume.py * Add license line to utils/callbacks.py * Add license line to utils/dataloaders.py * Add license line to utils/downloads.py * Add license line to utils/flask_rest_api/example_request.py * Add license line to utils/flask_rest_api/restapi.py * Add license line to utils/general.py * Add license line to utils/google_app_engine/app.yaml * Add license line to utils/loggers/__init__.py * Add license line to utils/loggers/clearml/clearml_utils.py * Add license line to utils/loggers/clearml/hpo.py * Add license line to utils/loggers/comet/__init__.py * Add license line to utils/loggers/comet/comet_utils.py * Add license line to utils/loggers/comet/hpo.py * Add license line to utils/loggers/wandb/wandb_utils.py * Add license line to utils/loss.py * Add license line to utils/metrics.py * Add license line to utils/plots.py * Add license line to utils/segment/augmentations.py * Add license line to utils/segment/dataloaders.py * Add license line to utils/segment/general.py * Add license line to utils/segment/loss.py * Add license line to utils/segment/metrics.py * Add license line to utils/segment/plots.py * Add license line to utils/torch_utils.py * Add license line to utils/triton.py * Add license line to val.py * Auto-format by https://ultralytics.com/actions * Update ImageNet1000.yaml Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Auto-format by https://ultralytics.com/actions --------- Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
92 lines
3.7 KiB
Python
92 lines
3.7 KiB
Python
# Ultralytics YOLOv5 🚀, AGPL-3.0 license
|
|
"""Utils to interact with the Triton Inference Server."""
|
|
|
|
import typing
|
|
from urllib.parse import urlparse
|
|
|
|
import torch
|
|
|
|
|
|
class TritonRemoteModel:
|
|
"""
|
|
A wrapper over a model served by the Triton Inference Server.
|
|
|
|
It can be configured to communicate over GRPC or HTTP. It accepts Torch Tensors as input and returns them as
|
|
outputs.
|
|
"""
|
|
|
|
def __init__(self, url: str):
|
|
"""
|
|
Keyword arguments:
|
|
url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000
|
|
"""
|
|
|
|
parsed_url = urlparse(url)
|
|
if parsed_url.scheme == "grpc":
|
|
from tritonclient.grpc import InferenceServerClient, InferInput
|
|
|
|
self.client = InferenceServerClient(parsed_url.netloc) # Triton GRPC client
|
|
model_repository = self.client.get_model_repository_index()
|
|
self.model_name = model_repository.models[0].name
|
|
self.metadata = self.client.get_model_metadata(self.model_name, as_json=True)
|
|
|
|
def create_input_placeholders() -> typing.List[InferInput]:
|
|
return [
|
|
InferInput(i["name"], [int(s) for s in i["shape"]], i["datatype"]) for i in self.metadata["inputs"]
|
|
]
|
|
|
|
else:
|
|
from tritonclient.http import InferenceServerClient, InferInput
|
|
|
|
self.client = InferenceServerClient(parsed_url.netloc) # Triton HTTP client
|
|
model_repository = self.client.get_model_repository_index()
|
|
self.model_name = model_repository[0]["name"]
|
|
self.metadata = self.client.get_model_metadata(self.model_name)
|
|
|
|
def create_input_placeholders() -> typing.List[InferInput]:
|
|
return [
|
|
InferInput(i["name"], [int(s) for s in i["shape"]], i["datatype"]) for i in self.metadata["inputs"]
|
|
]
|
|
|
|
self._create_input_placeholders_fn = create_input_placeholders
|
|
|
|
@property
|
|
def runtime(self):
|
|
"""Returns the model runtime."""
|
|
return self.metadata.get("backend", self.metadata.get("platform"))
|
|
|
|
def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]:
|
|
"""
|
|
Invokes the model.
|
|
|
|
Parameters can be provided via args or kwargs. args, if provided, are assumed to match the order of inputs of
|
|
the model. kwargs are matched with the model input names.
|
|
"""
|
|
inputs = self._create_inputs(*args, **kwargs)
|
|
response = self.client.infer(model_name=self.model_name, inputs=inputs)
|
|
result = []
|
|
for output in self.metadata["outputs"]:
|
|
tensor = torch.as_tensor(response.as_numpy(output["name"]))
|
|
result.append(tensor)
|
|
return result[0] if len(result) == 1 else result
|
|
|
|
def _create_inputs(self, *args, **kwargs):
|
|
"""Creates input tensors from args or kwargs, not both; raises error if none or both are provided."""
|
|
args_len, kwargs_len = len(args), len(kwargs)
|
|
if not args_len and not kwargs_len:
|
|
raise RuntimeError("No inputs provided.")
|
|
if args_len and kwargs_len:
|
|
raise RuntimeError("Cannot specify args and kwargs at the same time")
|
|
|
|
placeholders = self._create_input_placeholders_fn()
|
|
if args_len:
|
|
if args_len != len(placeholders):
|
|
raise RuntimeError(f"Expected {len(placeholders)} inputs, got {args_len}.")
|
|
for input, value in zip(placeholders, args):
|
|
input.set_data_from_numpy(value.cpu().numpy())
|
|
else:
|
|
for input in placeholders:
|
|
value = kwargs[input.name]
|
|
input.set_data_from_numpy(value.cpu().numpy())
|
|
return placeholders
|