Require `paddlepaddle>=3.0.0` with new ``*.pdiparams`

pull/13552/head
Glenn Jocher 2025-03-27 15:08:00 +01:00
parent 324bcfd6d7
commit 1ceffbd8c3
2 changed files with 19 additions and 7 deletions

View File

@ -510,7 +510,7 @@ def export_paddle(model, im, file, metadata, prefix=colorstr("PaddlePaddle:")):
$ pip install paddlepaddle x2paddle
```
"""
check_requirements(("paddlepaddle<3.0.0", "x2paddle"))
check_requirements(("paddlepaddle>=3.0.0", "x2paddle"))
import x2paddle
from x2paddle.convert import pytorch2paddle

View File

@ -644,20 +644,32 @@ class DetectMultiBackend(nn.Module):
stride, names = int(meta["stride"]), meta["names"]
elif tfjs: # TF.js
raise NotImplementedError("ERROR: YOLOv5 TF.js inference is not supported")
elif paddle: # PaddlePaddle
# PaddlePaddle
elif paddle:
LOGGER.info(f"Loading {w} for PaddlePaddle inference...")
check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle<3.0.0")
check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle>=3.0.0")
import paddle.inference as pdi
if not Path(w).is_file(): # if not *.pdmodel
w = next(Path(w).rglob("*.pdmodel")) # get *.pdmodel file from *_paddle_model dir
weights = Path(w).with_suffix(".pdiparams")
config = pdi.Config(str(w), str(weights))
w = Path(w)
if w.is_dir():
model_file = next(w.rglob("*.json"), None)
params_file = next(w.rglob("*.pdiparams"), None)
elif w.suffix == ".pdiparams":
model_file = w.with_name("model.json")
params_file = w
else:
raise ValueError(f"Invalid model path {w}. Provide model directory or a .pdiparams file.")
if not (model_file and params_file and model_file.is_file() and params_file.is_file()):
raise FileNotFoundError(f"Model files not found in {w}. Both .json and .pdiparams files are required.")
config = pdi.Config(str(model_file), str(params_file))
if cuda:
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
predictor = pdi.create_predictor(config)
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
output_names = predictor.get_output_names()
elif triton: # NVIDIA Triton Inference Server
LOGGER.info(f"Using {w} as Triton Inference Server...")
check_requirements("tritonclient[all]")