Require `paddlepaddle>=3.0.0` with new ``*.pdiparams`
parent
324bcfd6d7
commit
1ceffbd8c3
|
@ -510,7 +510,7 @@ def export_paddle(model, im, file, metadata, prefix=colorstr("PaddlePaddle:")):
|
|||
$ pip install paddlepaddle x2paddle
|
||||
```
|
||||
"""
|
||||
check_requirements(("paddlepaddle<3.0.0", "x2paddle"))
|
||||
check_requirements(("paddlepaddle>=3.0.0", "x2paddle"))
|
||||
import x2paddle
|
||||
from x2paddle.convert import pytorch2paddle
|
||||
|
||||
|
|
|
@ -644,20 +644,32 @@ class DetectMultiBackend(nn.Module):
|
|||
stride, names = int(meta["stride"]), meta["names"]
|
||||
elif tfjs: # TF.js
|
||||
raise NotImplementedError("ERROR: YOLOv5 TF.js inference is not supported")
|
||||
elif paddle: # PaddlePaddle
|
||||
# PaddlePaddle
|
||||
elif paddle:
|
||||
LOGGER.info(f"Loading {w} for PaddlePaddle inference...")
|
||||
check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle<3.0.0")
|
||||
check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle>=3.0.0")
|
||||
import paddle.inference as pdi
|
||||
|
||||
if not Path(w).is_file(): # if not *.pdmodel
|
||||
w = next(Path(w).rglob("*.pdmodel")) # get *.pdmodel file from *_paddle_model dir
|
||||
weights = Path(w).with_suffix(".pdiparams")
|
||||
config = pdi.Config(str(w), str(weights))
|
||||
w = Path(w)
|
||||
if w.is_dir():
|
||||
model_file = next(w.rglob("*.json"), None)
|
||||
params_file = next(w.rglob("*.pdiparams"), None)
|
||||
elif w.suffix == ".pdiparams":
|
||||
model_file = w.with_name("model.json")
|
||||
params_file = w
|
||||
else:
|
||||
raise ValueError(f"Invalid model path {w}. Provide model directory or a .pdiparams file.")
|
||||
|
||||
if not (model_file and params_file and model_file.is_file() and params_file.is_file()):
|
||||
raise FileNotFoundError(f"Model files not found in {w}. Both .json and .pdiparams files are required.")
|
||||
|
||||
config = pdi.Config(str(model_file), str(params_file))
|
||||
if cuda:
|
||||
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
|
||||
predictor = pdi.create_predictor(config)
|
||||
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
|
||||
output_names = predictor.get_output_names()
|
||||
|
||||
elif triton: # NVIDIA Triton Inference Server
|
||||
LOGGER.info(f"Using {w} as Triton Inference Server...")
|
||||
check_requirements("tritonclient[all]")
|
||||
|
|
Loading…
Reference in New Issue