Merge pull request #1987 from TingquanGao/dev/pulc_whl_deploy
[WIP] feat: support PULC to deploy with whlpull/2007/head
commit
d5173bf1a7
deploy/configs/PULC/traffic_sign
docs
en/inference_deployment
zh_CN/inference_deployment
ppcls/utils
cls_demo
|
@ -1,7 +1,8 @@
|
|||
include LICENSE.txt
|
||||
include README.md
|
||||
include docs/en/whl_en.md
|
||||
recursive-include deploy/python predict_cls.py preprocess.py postprocess.py det_preprocess.py
|
||||
recursive-include deploy/python *.py
|
||||
recursive-include deploy/configs *.yaml
|
||||
recursive-include deploy/utils get_image_list.py config.py logger.py predictor.py
|
||||
|
||||
recursive-include ppcls/ *.py *.txt
|
|
@ -30,6 +30,6 @@ PostProcess:
|
|||
main_indicator: Topk
|
||||
Topk:
|
||||
topk: 5
|
||||
class_id_map_file: "../dataset/traffic_sign/label_name_id.txt"
|
||||
class_id_map_file: "../ppcls/utils/PULC_label_list/traffic_sign_label_list.txt"
|
||||
SavePreLabel:
|
||||
save_dir: ./pre_label/
|
||||
|
|
|
@ -212,14 +212,14 @@ You can save the prediction result(s) as pre-label, only need to use `pre_label_
|
|||
```python
|
||||
from paddleclas import PaddleClas
|
||||
clas = PaddleClas(model_name='ResNet50', save_dir='./output_pre_label/')
|
||||
infer_imgs = 'docs/images/inference_deployment/whl_' # it can be infer_imgs folder path which contains all of images you want to predict.
|
||||
infer_imgs = 'docs/images/' # it can be infer_imgs folder path which contains all of images you want to predict.
|
||||
result=clas.predict(infer_imgs)
|
||||
print(next(result))
|
||||
```
|
||||
|
||||
* CLI
|
||||
```bash
|
||||
paddleclas --model_name='ResNet50' --infer_imgs='docs/images/inference_deployment/whl_' --save_dir='./output_pre_label/'
|
||||
paddleclas --model_name='ResNet50' --infer_imgs='docs/images/' --save_dir='./output_pre_label/'
|
||||
```
|
||||
|
||||
<a name="4.8"></a>
|
||||
|
|
|
@ -18,7 +18,7 @@ PaddleClas 支持 Python Whl 包方式进行预测,目前 Whl 包方式仅支
|
|||
- [4.6 对 `NumPy.ndarray` 格式数据进行预测](#4.6)
|
||||
- [4.7 保存预测结果](#4.7)
|
||||
- [4.8 指定 label name](#4.8)
|
||||
|
||||
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. 安装 paddleclas
|
||||
|
@ -212,14 +212,14 @@ print(next(result))
|
|||
```python
|
||||
from paddleclas import PaddleClas
|
||||
clas = PaddleClas(model_name='ResNet50', save_dir='./output_pre_label/')
|
||||
infer_imgs = 'docs/images/whl/' # it can be infer_imgs folder path which contains all of images you want to predict.
|
||||
infer_imgs = 'docs/images/' # it can be infer_imgs folder path which contains all of images you want to predict.
|
||||
result=clas.predict(infer_imgs)
|
||||
print(next(result))
|
||||
```
|
||||
|
||||
* CLI
|
||||
```bash
|
||||
paddleclas --model_name='ResNet50' --infer_imgs='docs/images/whl/' --save_dir='./output_pre_label/'
|
||||
paddleclas --model_name='ResNet50' --infer_imgs='docs/images/' --save_dir='./output_pre_label/'
|
||||
```
|
||||
|
||||
<a name="4.8"></a>
|
||||
|
|
329
paddleclas.py
329
paddleclas.py
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -24,7 +24,6 @@ import shutil
|
|||
import textwrap
|
||||
import tarfile
|
||||
import requests
|
||||
import warnings
|
||||
from functools import partial
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
|
@ -32,24 +31,25 @@ import cv2
|
|||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from prettytable import PrettyTable
|
||||
import paddle
|
||||
|
||||
from deploy.python.predict_cls import ClsPredictor
|
||||
from deploy.utils.get_image_list import get_image_list
|
||||
from deploy.utils import config
|
||||
|
||||
from ppcls.arch.backbone import *
|
||||
from ppcls.utils.logger import init_logger
|
||||
import ppcls.arch.backbone as backbone
|
||||
from ppcls.utils import logger
|
||||
|
||||
# for building model with loading pretrained weights from backbone
|
||||
init_logger()
|
||||
logger.init_logger()
|
||||
|
||||
__all__ = ["PaddleClas"]
|
||||
|
||||
BASE_DIR = os.path.expanduser("~/.paddleclas/")
|
||||
BASE_INFERENCE_MODEL_DIR = os.path.join(BASE_DIR, "inference_model")
|
||||
BASE_IMAGES_DIR = os.path.join(BASE_DIR, "images")
|
||||
BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar"
|
||||
MODEL_SERIES = {
|
||||
IMN_MODEL_BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar"
|
||||
IMN_MODEL_SERIES = {
|
||||
"AlexNet": ["AlexNet"],
|
||||
"DarkNet": ["DarkNet53"],
|
||||
"DeiT": [
|
||||
|
@ -100,10 +100,17 @@ MODEL_SERIES = {
|
|||
"MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25",
|
||||
"MobileNetV3_small_x1_0_ssld", "MobileNetV3_large_x1_0_ssld"
|
||||
],
|
||||
"PPHGNet": [
|
||||
"PPHGNet_tiny",
|
||||
"PPHGNet_small",
|
||||
"PPHGNet_tiny_ssld",
|
||||
"PPHGNet_small_ssld",
|
||||
],
|
||||
"PPLCNet": [
|
||||
"PPLCNet_x0_25", "PPLCNet_x0_35", "PPLCNet_x0_5", "PPLCNet_x0_75",
|
||||
"PPLCNet_x1_0", "PPLCNet_x1_5", "PPLCNet_x2_0", "PPLCNet_x2_5"
|
||||
],
|
||||
"PPLCNetV2": ["PPLCNetV2_base"],
|
||||
"RedNet": ["RedNet26", "RedNet38", "RedNet50", "RedNet101", "RedNet152"],
|
||||
"RegNet": ["RegNetX_4GF"],
|
||||
"Res2Net": [
|
||||
|
@ -168,6 +175,13 @@ MODEL_SERIES = {
|
|||
]
|
||||
}
|
||||
|
||||
PULC_MODEL_BASE_DOWNLOAD_URL = "https://paddleclas.bj.bcebos.com/models/PULC/{}_infer.tar"
|
||||
PULC_MODELS = [
|
||||
"person_exists", "person_attribute", "safety_helmet", "traffic_sign",
|
||||
"vehicle_exists", "vehicle_attr", "textline_orientation",
|
||||
"text_image_orientation", "language_classification"
|
||||
]
|
||||
|
||||
|
||||
class ImageTypeError(Exception):
|
||||
"""ImageTypeError.
|
||||
|
@ -185,76 +199,67 @@ class InputModelError(Exception):
|
|||
super().__init__(message)
|
||||
|
||||
|
||||
def init_config(model_name,
|
||||
inference_model_dir,
|
||||
use_gpu=True,
|
||||
batch_size=1,
|
||||
topk=5,
|
||||
**kwargs):
|
||||
imagenet1k_map_path = os.path.join(
|
||||
os.path.abspath(__dir__), "ppcls/utils/imagenet1k_label_list.txt")
|
||||
cfg = {
|
||||
"Global": {
|
||||
"infer_imgs": kwargs["infer_imgs"]
|
||||
if "infer_imgs" in kwargs else False,
|
||||
"model_name": model_name,
|
||||
"inference_model_dir": inference_model_dir,
|
||||
"batch_size": batch_size,
|
||||
"use_gpu": use_gpu,
|
||||
"enable_mkldnn": kwargs["enable_mkldnn"]
|
||||
if "enable_mkldnn" in kwargs else False,
|
||||
"cpu_num_threads": kwargs["cpu_num_threads"]
|
||||
if "cpu_num_threads" in kwargs else 1,
|
||||
"enable_benchmark": False,
|
||||
"use_fp16": kwargs["use_fp16"] if "use_fp16" in kwargs else False,
|
||||
"ir_optim": True,
|
||||
"use_tensorrt": kwargs["use_tensorrt"]
|
||||
if "use_tensorrt" in kwargs else False,
|
||||
"gpu_mem": kwargs["gpu_mem"] if "gpu_mem" in kwargs else 8000,
|
||||
"enable_profile": False
|
||||
},
|
||||
"PreProcess": {
|
||||
"transform_ops": [{
|
||||
"ResizeImage": {
|
||||
"resize_short": kwargs["resize_short"]
|
||||
if "resize_short" in kwargs else 256
|
||||
}
|
||||
}, {
|
||||
"CropImage": {
|
||||
"size": kwargs["crop_size"]
|
||||
if "crop_size" in kwargs else 224
|
||||
}
|
||||
}, {
|
||||
"NormalizeImage": {
|
||||
"scale": 0.00392157,
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
"order": ''
|
||||
}
|
||||
}, {
|
||||
"ToCHWImage": None
|
||||
}]
|
||||
},
|
||||
"PostProcess": {
|
||||
"main_indicator": "Topk",
|
||||
"Topk": {
|
||||
"topk": topk,
|
||||
"class_id_map_file": imagenet1k_map_path
|
||||
}
|
||||
}
|
||||
}
|
||||
if "save_dir" in kwargs:
|
||||
if kwargs["save_dir"] is not None:
|
||||
cfg["PostProcess"]["SavePreLabel"] = {
|
||||
"save_dir": kwargs["save_dir"]
|
||||
}
|
||||
if "class_id_map_file" in kwargs:
|
||||
if kwargs["class_id_map_file"] is not None:
|
||||
cfg["PostProcess"]["Topk"]["class_id_map_file"] = kwargs[
|
||||
"class_id_map_file"]
|
||||
def init_config(model_type, model_name, inference_model_dir, **kwargs):
|
||||
|
||||
cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml" if model_type == "pulc" else "deploy/configs/inference_cls.yaml"
|
||||
cfg_path = os.path.join(__dir__, cfg_path)
|
||||
cfg = config.get_config(cfg_path, show=False)
|
||||
|
||||
cfg.Global.inference_model_dir = inference_model_dir
|
||||
|
||||
if "batch_size" in kwargs and kwargs["batch_size"]:
|
||||
cfg.Global.batch_size = kwargs["batch_size"]
|
||||
|
||||
if "use_gpu" in kwargs and kwargs["use_gpu"]:
|
||||
cfg.Global.use_gpu = kwargs["use_gpu"]
|
||||
if cfg.Global.use_gpu and not paddle.device.is_compiled_with_cuda():
|
||||
msg = "The current running environment does not support the use of GPU. CPU has been used instead."
|
||||
logger.warning(msg)
|
||||
cfg.Global.use_gpu = False
|
||||
|
||||
if "infer_imgs" in kwargs and kwargs["infer_imgs"]:
|
||||
cfg.Global.infer_imgs = kwargs["infer_imgs"]
|
||||
if "enable_mkldnn" in kwargs and kwargs["enable_mkldnn"]:
|
||||
cfg.Global.enable_mkldnn = kwargs["enable_mkldnn"]
|
||||
if "cpu_num_threads" in kwargs and kwargs["cpu_num_threads"]:
|
||||
cfg.Global.cpu_num_threads = kwargs["cpu_num_threads"]
|
||||
if "use_fp16" in kwargs and kwargs["use_fp16"]:
|
||||
cfg.Global.use_fp16 = kwargs["use_fp16"]
|
||||
if "use_tensorrt" in kwargs and kwargs["use_tensorrt"]:
|
||||
cfg.Global.use_tensorrt = kwargs["use_tensorrt"]
|
||||
if "gpu_mem" in kwargs and kwargs["gpu_mem"]:
|
||||
cfg.Global.gpu_mem = kwargs["gpu_mem"]
|
||||
if "resize_short" in kwargs and kwargs["resize_short"]:
|
||||
cfg.PreProcess.transform_ops[0]["ResizeImage"][
|
||||
"resize_short"] = kwargs["resize_short"]
|
||||
if "crop_size" in kwargs and kwargs["crop_size"]:
|
||||
cfg.PreProcess.transform_ops[1]["CropImage"]["size"] = kwargs[
|
||||
"crop_size"]
|
||||
|
||||
# TODO(gaotingquan): not robust
|
||||
if "thresh" in kwargs and kwargs[
|
||||
"thresh"] and "ThreshOutput" in cfg.PostProcess:
|
||||
cfg.PostProcess.ThreshOutput.thresh = kwargs["thresh"]
|
||||
if "Topk" in cfg.PostProcess:
|
||||
if "topk" in kwargs and kwargs["topk"]:
|
||||
cfg.PostProcess.Topk.topk = kwargs["topk"]
|
||||
if "class_id_map_file" in kwargs and kwargs["class_id_map_file"]:
|
||||
cfg.PostProcess.Topk.class_id_map_file = kwargs[
|
||||
"class_id_map_file"]
|
||||
else:
|
||||
cfg.PostProcess.Topk.class_id_map_file = os.path.relpath(
|
||||
cfg.PostProcess.Topk.class_id_map_file, "../")
|
||||
if "VehicleAttribute" in cfg.PostProcess:
|
||||
if "color_threshold" in kwargs and kwargs["color_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"color_threshold"]
|
||||
if "type_threshold" in kwargs and kwargs["type_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[
|
||||
"type_threshold"]
|
||||
|
||||
if "save_dir" in kwargs and kwargs["save_dir"]:
|
||||
cfg.PostProcess.SavePreLabel.save_dir = kwargs["save_dir"]
|
||||
|
||||
cfg = config.AttrDict(cfg)
|
||||
config.create_attr_dict(cfg)
|
||||
return cfg
|
||||
|
||||
|
||||
|
@ -275,40 +280,48 @@ def args_cfg():
|
|||
type=str,
|
||||
help="The directory of model files. Valid when model_name not specifed."
|
||||
)
|
||||
parser.add_argument("--use_gpu", type=str2bool, help="Whether use GPU.")
|
||||
parser.add_argument(
|
||||
"--use_gpu", type=str, default=True, help="Whether use GPU.")
|
||||
parser.add_argument("--gpu_mem", type=int, default=8000, help="")
|
||||
"--gpu_mem",
|
||||
type=int,
|
||||
help="The memory size of GPU allocated to predict.")
|
||||
parser.add_argument(
|
||||
"--enable_mkldnn",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether use MKLDNN. Valid when use_gpu is False")
|
||||
parser.add_argument("--cpu_num_threads", type=int, default=1, help="")
|
||||
parser.add_argument(
|
||||
"--use_tensorrt", type=str2bool, default=False, help="")
|
||||
parser.add_argument("--use_fp16", type=str2bool, default=False, help="")
|
||||
"--cpu_num_threads",
|
||||
type=int,
|
||||
help="The threads number when predicting on CPU.")
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="Batch size. Default by 1.")
|
||||
"--use_tensorrt",
|
||||
type=str2bool,
|
||||
help="Whether use TensorRT to accelerate. ")
|
||||
parser.add_argument(
|
||||
"--use_fp16", type=str2bool, help="Whether use FP16 to predict.")
|
||||
parser.add_argument("--batch_size", type=int, help="Batch size.")
|
||||
parser.add_argument(
|
||||
"--topk",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Return topk score(s) and corresponding results. Default by 5.")
|
||||
help="Return topk score(s) and corresponding results when Topk postprocess is used."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_id_map_file",
|
||||
type=str,
|
||||
help="The path of file that map class_id and label.")
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
help="The threshold of ThreshOutput when postprocess is used.")
|
||||
parser.add_argument("--color_threshold", type=float, help="")
|
||||
parser.add_argument("--type_threshold", type=float, help="")
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
help="The directory to save prediction results as pre-label.")
|
||||
parser.add_argument(
|
||||
"--resize_short",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Resize according to short size.")
|
||||
parser.add_argument(
|
||||
"--crop_size", type=int, default=224, help="Centor crop size.")
|
||||
"--resize_short", type=int, help="Resize according to short size.")
|
||||
parser.add_argument("--crop_size", type=int, help="Centor crop size.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return vars(args)
|
||||
|
@ -317,33 +330,44 @@ def args_cfg():
|
|||
def print_info():
|
||||
"""Print list of supported models in formatted.
|
||||
"""
|
||||
table = PrettyTable(["Series", "Name"])
|
||||
imn_table = PrettyTable(["IMN Model Series", "Model Name"])
|
||||
pulc_table = PrettyTable(["PULC Models"])
|
||||
try:
|
||||
sz = os.get_terminal_size()
|
||||
width = sz.columns - 30 if sz.columns > 50 else 10
|
||||
total_width = sz.columns
|
||||
first_width = 30
|
||||
second_width = total_width - first_width if total_width > 50 else 10
|
||||
except OSError:
|
||||
width = 100
|
||||
for series in MODEL_SERIES:
|
||||
names = textwrap.fill(" ".join(MODEL_SERIES[series]), width=width)
|
||||
table.add_row([series, names])
|
||||
width = len(str(table).split("\n")[0])
|
||||
print("{}".format("-" * width))
|
||||
print("Models supported by PaddleClas".center(width))
|
||||
print(table)
|
||||
print("Powered by PaddlePaddle!".rjust(width))
|
||||
print("{}".format("-" * width))
|
||||
second_width = 100
|
||||
for series in IMN_MODEL_SERIES:
|
||||
names = textwrap.fill(
|
||||
" ".join(IMN_MODEL_SERIES[series]), width=second_width)
|
||||
imn_table.add_row([series, names])
|
||||
|
||||
table_width = len(str(imn_table).split("\n")[0])
|
||||
pulc_table.add_row([
|
||||
textwrap.fill(
|
||||
" ".join(PULC_MODELS), width=total_width).center(table_width - 4)
|
||||
])
|
||||
|
||||
print("{}".format("-" * table_width))
|
||||
print("Models supported by PaddleClas".center(table_width))
|
||||
print(imn_table)
|
||||
print(pulc_table)
|
||||
print("Powered by PaddlePaddle!".rjust(table_width))
|
||||
print("{}".format("-" * table_width))
|
||||
|
||||
|
||||
def get_model_names():
|
||||
def get_imn_model_names():
|
||||
"""Get the model names list.
|
||||
"""
|
||||
model_names = []
|
||||
for series in MODEL_SERIES:
|
||||
model_names += (MODEL_SERIES[series])
|
||||
for series in IMN_MODEL_SERIES:
|
||||
model_names += (IMN_MODEL_SERIES[series])
|
||||
return model_names
|
||||
|
||||
|
||||
def similar_architectures(name="", names=[], thresh=0.1, topk=10):
|
||||
def similar_model_names(name="", names=[], thresh=0.1, topk=5):
|
||||
"""Find the most similar topk model names.
|
||||
"""
|
||||
scores = []
|
||||
|
@ -378,12 +402,17 @@ def download_with_progressbar(url, save_path):
|
|||
f"Something went wrong while downloading file from {url}")
|
||||
|
||||
|
||||
def check_model_file(model_name):
|
||||
def check_model_file(model_type, model_name):
|
||||
"""Check the model files exist and download and untar when no exist.
|
||||
"""
|
||||
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
|
||||
model_name)
|
||||
url = BASE_DOWNLOAD_URL.format(model_name)
|
||||
if model_type == "pulc":
|
||||
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
|
||||
"PULC", model_name)
|
||||
url = PULC_MODEL_BASE_DOWNLOAD_URL.format(model_name)
|
||||
else:
|
||||
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
|
||||
"IMN", model_name)
|
||||
url = IMN_MODEL_BASE_DOWNLOAD_URL.format(model_name)
|
||||
|
||||
tar_file_name_list = [
|
||||
"inference.pdiparams", "inference.pdiparams.info", "inference.pdmodel"
|
||||
|
@ -393,7 +422,7 @@ def check_model_file(model_name):
|
|||
if not os.path.exists(model_file_path) or not os.path.exists(
|
||||
params_file_path):
|
||||
tmp_path = storage_directory(url.split("/")[-1])
|
||||
print(f"download {url} to {tmp_path}")
|
||||
logger.info(f"download {url} to {tmp_path}")
|
||||
os.makedirs(storage_directory(), exist_ok=True)
|
||||
download_with_progressbar(url, tmp_path)
|
||||
with tarfile.open(tmp_path, "r") as tarObj:
|
||||
|
@ -426,9 +455,6 @@ class PaddleClas(object):
|
|||
def __init__(self,
|
||||
model_name: str=None,
|
||||
inference_model_dir: str=None,
|
||||
use_gpu: bool=True,
|
||||
batch_size: int=1,
|
||||
topk: int=5,
|
||||
**kwargs):
|
||||
"""Init PaddleClas with config.
|
||||
|
||||
|
@ -440,9 +466,11 @@ class PaddleClas(object):
|
|||
topk (int, optional): Return the top k prediction results with the highest score. Defaults to 5.
|
||||
"""
|
||||
super().__init__()
|
||||
self._config = init_config(model_name, inference_model_dir, use_gpu,
|
||||
batch_size, topk, **kwargs)
|
||||
self._check_input_model()
|
||||
self.model_type, inference_model_dir = self._check_input_model(
|
||||
model_name, inference_model_dir)
|
||||
self._config = init_config(self.model_type, model_name,
|
||||
inference_model_dir, **kwargs)
|
||||
|
||||
self.cls_predictor = ClsPredictor(self._config)
|
||||
|
||||
def get_config(self):
|
||||
|
@ -450,24 +478,29 @@ class PaddleClas(object):
|
|||
"""
|
||||
return self._config
|
||||
|
||||
def _check_input_model(self):
|
||||
def _check_input_model(self, model_name, inference_model_dir):
|
||||
"""Check input model name or model files.
|
||||
"""
|
||||
candidate_model_names = get_model_names()
|
||||
input_model_name = self._config.Global.get("model_name", None)
|
||||
inference_model_dir = self._config.Global.get("inference_model_dir",
|
||||
None)
|
||||
if input_model_name is not None:
|
||||
similar_names = similar_architectures(input_model_name,
|
||||
candidate_model_names)
|
||||
similar_names_str = ", ".join(similar_names)
|
||||
if input_model_name not in candidate_model_names:
|
||||
err = f"{input_model_name} is not provided by PaddleClas. \nMaybe you want: [{similar_names_str}]. \nIf you want to use your own model, please specify inference_model_dir!"
|
||||
all_imn_model_names = get_imn_model_names()
|
||||
all_pulc_model_names = PULC_MODELS
|
||||
|
||||
if model_name:
|
||||
if model_name in all_imn_model_names:
|
||||
inference_model_dir = check_model_file("imn", model_name)
|
||||
return "imn", inference_model_dir
|
||||
elif model_name in all_pulc_model_names:
|
||||
inference_model_dir = check_model_file("pulc", model_name)
|
||||
return "pulc", inference_model_dir
|
||||
else:
|
||||
similar_imn_names = similar_model_names(model_name,
|
||||
all_imn_model_names)
|
||||
similar_pulc_names = similar_model_names(model_name,
|
||||
all_pulc_model_names)
|
||||
similar_names_str = ", ".join(similar_imn_names +
|
||||
similar_pulc_names)
|
||||
err = f"{model_name} is not provided by PaddleClas. \nMaybe you want the : [{similar_names_str}]. \nIf you want to use your own model, please specify inference_model_dir!"
|
||||
raise InputModelError(err)
|
||||
self._config.Global.inference_model_dir = check_model_file(
|
||||
input_model_name)
|
||||
return
|
||||
elif inference_model_dir is not None:
|
||||
elif inference_model_dir:
|
||||
model_file_path = os.path.join(inference_model_dir,
|
||||
"inference.pdmodel")
|
||||
params_file_path = os.path.join(inference_model_dir,
|
||||
|
@ -476,11 +509,11 @@ class PaddleClas(object):
|
|||
params_file_path):
|
||||
err = f"There is no model file or params file in this directory: {inference_model_dir}"
|
||||
raise InputModelError(err)
|
||||
return
|
||||
return "custom", inference_model_dir
|
||||
else:
|
||||
err = f"Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)."
|
||||
raise InputModelError(err)
|
||||
return
|
||||
return None
|
||||
|
||||
def predict(self, input_data: Union[str, np.array],
|
||||
print_pred: bool=False) -> Generator[list, None, None]:
|
||||
|
@ -511,22 +544,21 @@ class PaddleClas(object):
|
|||
os.makedirs(image_storage_dir())
|
||||
image_save_path = image_storage_dir("tmp.jpg")
|
||||
download_with_progressbar(input_data, image_save_path)
|
||||
input_data = image_save_path
|
||||
warnings.warn(
|
||||
logger.info(
|
||||
f"Image to be predicted from Internet: {input_data}, has been saved to: {image_save_path}"
|
||||
)
|
||||
input_data = image_save_path
|
||||
image_list = get_image_list(input_data)
|
||||
|
||||
batch_size = self._config.Global.get("batch_size", 1)
|
||||
topk = self._config.PostProcess.Topk.get('topk', 1)
|
||||
|
||||
img_list = []
|
||||
img_path_list = []
|
||||
cnt = 0
|
||||
for idx, img_path in enumerate(image_list):
|
||||
for idx_img, img_path in enumerate(image_list):
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
warnings.warn(
|
||||
logger.warning(
|
||||
f"Image file failed to read and has been skipped. The path: {img_path}"
|
||||
)
|
||||
continue
|
||||
|
@ -535,16 +567,15 @@ class PaddleClas(object):
|
|||
img_path_list.append(img_path)
|
||||
cnt += 1
|
||||
|
||||
if cnt % batch_size == 0 or (idx + 1) == len(image_list):
|
||||
if cnt % batch_size == 0 or (idx_img + 1) == len(image_list):
|
||||
preds = self.cls_predictor.predict(img_list)
|
||||
|
||||
if print_pred and preds:
|
||||
for idx, pred in enumerate(preds):
|
||||
pred_str = ", ".join(
|
||||
[f"{k}: {pred[k]}" for k in pred])
|
||||
print(
|
||||
f"filename: {img_path_list[idx]}, top-{topk}, {pred_str}"
|
||||
)
|
||||
if preds:
|
||||
for idx_pred, pred in enumerate(preds):
|
||||
pred["filename"] = img_path_list[idx_pred]
|
||||
if print_pred:
|
||||
logger.info(", ".join(
|
||||
[f"{k}: {pred[k]}" for k in pred]))
|
||||
|
||||
img_list = []
|
||||
img_path_list = []
|
||||
|
@ -564,7 +595,7 @@ def main():
|
|||
res = clas_engine.predict(cfg["infer_imgs"], print_pred=True)
|
||||
for _ in res:
|
||||
pass
|
||||
print("Predict complete!")
|
||||
logger.info("Predict complete!")
|
||||
return
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
0 0
|
||||
1 90
|
||||
2 180
|
||||
3 270
|
|
@ -0,0 +1,232 @@
|
|||
0 pl80
|
||||
1 w9
|
||||
2 p6
|
||||
3 ph4.2
|
||||
4 i8
|
||||
5 w14
|
||||
6 w33
|
||||
7 pa13
|
||||
8 im
|
||||
9 w58
|
||||
10 pl90
|
||||
11 il70
|
||||
12 p5
|
||||
13 pm55
|
||||
14 pl60
|
||||
15 ip
|
||||
16 p11
|
||||
17 pdd
|
||||
18 wc
|
||||
19 i2r
|
||||
20 w30
|
||||
21 pmr
|
||||
22 p23
|
||||
23 pl15
|
||||
24 pm10
|
||||
25 pss
|
||||
26 w1
|
||||
27 p4
|
||||
28 w38
|
||||
29 w50
|
||||
30 w34
|
||||
31 pw3.5
|
||||
32 iz
|
||||
33 w39
|
||||
34 w11
|
||||
35 p1n
|
||||
36 pr70
|
||||
37 pd
|
||||
38 pnl
|
||||
39 pg
|
||||
40 ph5.3
|
||||
41 w66
|
||||
42 il80
|
||||
43 pb
|
||||
44 pbm
|
||||
45 pm5
|
||||
46 w24
|
||||
47 w67
|
||||
48 w49
|
||||
49 pm40
|
||||
50 ph4
|
||||
51 w45
|
||||
52 i4
|
||||
53 w37
|
||||
54 ph2.6
|
||||
55 pl70
|
||||
56 ph5.5
|
||||
57 i14
|
||||
58 i11
|
||||
59 p7
|
||||
60 p29
|
||||
61 pne
|
||||
62 pr60
|
||||
63 pm13
|
||||
64 ph4.5
|
||||
65 p12
|
||||
66 p3
|
||||
67 w40
|
||||
68 pl5
|
||||
69 w13
|
||||
70 pr10
|
||||
71 p14
|
||||
72 i4l
|
||||
73 pr30
|
||||
74 pw4.2
|
||||
75 w16
|
||||
76 p17
|
||||
77 ph3
|
||||
78 i9
|
||||
79 w15
|
||||
80 w35
|
||||
81 pa8
|
||||
82 pt
|
||||
83 pr45
|
||||
84 w17
|
||||
85 pl30
|
||||
86 pcs
|
||||
87 pctl
|
||||
88 pr50
|
||||
89 ph4.4
|
||||
90 pm46
|
||||
91 pm35
|
||||
92 i15
|
||||
93 pa12
|
||||
94 pclr
|
||||
95 i1
|
||||
96 pcd
|
||||
97 pbp
|
||||
98 pcr
|
||||
99 w28
|
||||
100 ps
|
||||
101 pm8
|
||||
102 w18
|
||||
103 w2
|
||||
104 w52
|
||||
105 ph2.9
|
||||
106 ph1.8
|
||||
107 pe
|
||||
108 p20
|
||||
109 w36
|
||||
110 p10
|
||||
111 pn
|
||||
112 pa14
|
||||
113 w54
|
||||
114 ph3.2
|
||||
115 p2
|
||||
116 ph2.5
|
||||
117 w62
|
||||
118 w55
|
||||
119 pw3
|
||||
120 pw4.5
|
||||
121 i12
|
||||
122 ph4.3
|
||||
123 phclr
|
||||
124 i10
|
||||
125 pr5
|
||||
126 i13
|
||||
127 w10
|
||||
128 p26
|
||||
129 w26
|
||||
130 p8
|
||||
131 w5
|
||||
132 w42
|
||||
133 il50
|
||||
134 p13
|
||||
135 pr40
|
||||
136 p25
|
||||
137 w41
|
||||
138 pl20
|
||||
139 ph4.8
|
||||
140 pnlc
|
||||
141 ph3.3
|
||||
142 w29
|
||||
143 ph2.1
|
||||
144 w53
|
||||
145 pm30
|
||||
146 p24
|
||||
147 p21
|
||||
148 pl40
|
||||
149 w27
|
||||
150 pmb
|
||||
151 pc
|
||||
152 i6
|
||||
153 pr20
|
||||
154 p18
|
||||
155 ph3.8
|
||||
156 pm50
|
||||
157 pm25
|
||||
158 i2
|
||||
159 w22
|
||||
160 w47
|
||||
161 w56
|
||||
162 pl120
|
||||
163 ph2.8
|
||||
164 i7
|
||||
165 w12
|
||||
166 pm1.5
|
||||
167 pm2.5
|
||||
168 w32
|
||||
169 pm15
|
||||
170 ph5
|
||||
171 w19
|
||||
172 pw3.2
|
||||
173 pw2.5
|
||||
174 pl10
|
||||
175 il60
|
||||
176 w57
|
||||
177 w48
|
||||
178 w60
|
||||
179 pl100
|
||||
180 pr80
|
||||
181 p16
|
||||
182 pl110
|
||||
183 w59
|
||||
184 w64
|
||||
185 w20
|
||||
186 ph2
|
||||
187 p9
|
||||
188 il100
|
||||
189 w31
|
||||
190 w65
|
||||
191 ph2.4
|
||||
192 pr100
|
||||
193 p19
|
||||
194 ph3.5
|
||||
195 pa10
|
||||
196 pcl
|
||||
197 pl35
|
||||
198 p15
|
||||
199 w7
|
||||
200 pa6
|
||||
201 phcs
|
||||
202 w43
|
||||
203 p28
|
||||
204 w6
|
||||
205 w3
|
||||
206 w25
|
||||
207 pl25
|
||||
208 il110
|
||||
209 p1
|
||||
210 w46
|
||||
211 pn-2
|
||||
212 w51
|
||||
213 w44
|
||||
214 w63
|
||||
215 w23
|
||||
216 pm20
|
||||
217 w8
|
||||
218 pmblr
|
||||
219 w4
|
||||
220 i5
|
||||
221 il90
|
||||
222 w21
|
||||
223 p27
|
||||
224 pl50
|
||||
225 pl65
|
||||
226 w61
|
||||
227 ph2.2
|
||||
228 pm2
|
||||
229 i3
|
||||
230 pa18
|
||||
231 pw4
|
|
@ -1,2 +0,0 @@
|
|||
0 nobody
|
||||
1 someone
|
Loading…
Reference in New Issue