Merge branch 'develop' into add_table_attribute
commit
a6a08acb5b
227
paddleclas.py
227
paddleclas.py
|
@ -32,6 +32,7 @@ from .ppcls.arch import backbone
|
|||
from .ppcls.utils import logger
|
||||
|
||||
from .deploy.python.predict_cls import ClsPredictor
|
||||
from .deploy.python.predict_system import SystemPredictor
|
||||
from .deploy.utils.get_image_list import get_image_list
|
||||
from .deploy.utils import config
|
||||
|
||||
|
@ -195,6 +196,14 @@ PULC_MODELS = [
|
|||
"table_attribute"
|
||||
]
|
||||
|
||||
SHITU_MODEL_BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/{}_infer.tar"
|
||||
SHITU_MODELS = [
|
||||
# "picodet_PPLCNet_x2_5_mainbody_lite_v1.0", # ShiTuV1(V2)_mainbody_det
|
||||
# "general_PPLCNet_x2_5_lite_v1.0" # ShiTuV1_general_rec
|
||||
# "PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0", # ShiTuV2_general_rec TODO(hesensen): add lite model
|
||||
"PP-ShiTuV2"
|
||||
]
|
||||
|
||||
|
||||
class ImageTypeError(Exception):
|
||||
"""ImageTypeError.
|
||||
|
@ -214,12 +223,24 @@ class InputModelError(Exception):
|
|||
|
||||
def init_config(model_type, model_name, inference_model_dir, **kwargs):
|
||||
|
||||
cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml" if model_type == "pulc" else "deploy/configs/inference_cls.yaml"
|
||||
if model_type == "pulc":
|
||||
cfg_path = f"deploy/configs/PULC/{model_name}/inference_{model_name}.yaml"
|
||||
elif model_type == "shitu":
|
||||
cfg_path = "deploy/configs/inference_general.yaml"
|
||||
else:
|
||||
cfg_path = "deploy/configs/inference_cls.yaml"
|
||||
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
cfg_path = os.path.join(__dir__, cfg_path)
|
||||
cfg = config.get_config(cfg_path, show=False)
|
||||
|
||||
cfg.Global.inference_model_dir = inference_model_dir
|
||||
if cfg.Global.get("inference_model_dir"):
|
||||
cfg.Global.inference_model_dir = inference_model_dir
|
||||
else:
|
||||
cfg.Global.rec_inference_model_dir = os.path.join(
|
||||
inference_model_dir,
|
||||
"PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0")
|
||||
cfg.Global.det_inference_model_dir = os.path.join(
|
||||
inference_model_dir, "picodet_PPLCNet_x2_5_mainbody_lite_v1.0")
|
||||
|
||||
if "batch_size" in kwargs and kwargs["batch_size"]:
|
||||
cfg.Global.batch_size = kwargs["batch_size"]
|
||||
|
@ -233,6 +254,10 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
|
|||
|
||||
if "infer_imgs" in kwargs and kwargs["infer_imgs"]:
|
||||
cfg.Global.infer_imgs = kwargs["infer_imgs"]
|
||||
if "index_dir" in kwargs and kwargs["index_dir"]:
|
||||
cfg.IndexProcess.index_dir = kwargs["index_dir"]
|
||||
if "data_file" in kwargs and kwargs["data_file"]:
|
||||
cfg.IndexProcess.data_file = kwargs["data_file"]
|
||||
if "enable_mkldnn" in kwargs and kwargs["enable_mkldnn"]:
|
||||
cfg.Global.enable_mkldnn = kwargs["enable_mkldnn"]
|
||||
if "cpu_num_threads" in kwargs and kwargs["cpu_num_threads"]:
|
||||
|
@ -254,43 +279,45 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
|
|||
if "thresh" in kwargs and kwargs[
|
||||
"thresh"] and "ThreshOutput" in cfg.PostProcess:
|
||||
cfg.PostProcess.ThreshOutput.thresh = kwargs["thresh"]
|
||||
if "Topk" in cfg.PostProcess:
|
||||
if "topk" in kwargs and kwargs["topk"]:
|
||||
cfg.PostProcess.Topk.topk = kwargs["topk"]
|
||||
if "class_id_map_file" in kwargs and kwargs["class_id_map_file"]:
|
||||
cfg.PostProcess.Topk.class_id_map_file = kwargs[
|
||||
"class_id_map_file"]
|
||||
else:
|
||||
class_id_map_file_path = os.path.relpath(
|
||||
cfg.PostProcess.Topk.class_id_map_file, "../")
|
||||
cfg.PostProcess.Topk.class_id_map_file = os.path.join(
|
||||
__dir__, class_id_map_file_path)
|
||||
if "VehicleAttribute" in cfg.PostProcess:
|
||||
if "color_threshold" in kwargs and kwargs["color_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"color_threshold"]
|
||||
if "type_threshold" in kwargs and kwargs["type_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[
|
||||
"type_threshold"]
|
||||
if "TableAttribute" in cfg.PostProcess:
|
||||
if "source_threshold" in kwargs and kwargs["source_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"source_threshold"]
|
||||
if "number_threshold" in kwargs and kwargs["number_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"number_threshold"]
|
||||
if "color_threshold" in kwargs and kwargs["color_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"color_threshold"]
|
||||
if "clarity_threshold" in kwargs and kwargs["clarity_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"clarity_threshold"]
|
||||
if "obstruction_threshold" in kwargs and kwargs["obstruction_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"obstruction_threshold"]
|
||||
if "angle_threshold" in kwargs and kwargs["angle_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"angle_threshold"]
|
||||
|
||||
if cfg.get("PostProcess"):
|
||||
if "Topk" in cfg.PostProcess:
|
||||
if "topk" in kwargs and kwargs["topk"]:
|
||||
cfg.PostProcess.Topk.topk = kwargs["topk"]
|
||||
if "class_id_map_file" in kwargs and kwargs["class_id_map_file"]:
|
||||
cfg.PostProcess.Topk.class_id_map_file = kwargs[
|
||||
"class_id_map_file"]
|
||||
else:
|
||||
class_id_map_file_path = os.path.relpath(
|
||||
cfg.PostProcess.Topk.class_id_map_file, "../")
|
||||
cfg.PostProcess.Topk.class_id_map_file = os.path.join(
|
||||
__dir__, class_id_map_file_path)
|
||||
if "VehicleAttribute" in cfg.PostProcess:
|
||||
if "color_threshold" in kwargs and kwargs["color_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"color_threshold"]
|
||||
if "type_threshold" in kwargs and kwargs["type_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[
|
||||
"type_threshold"]
|
||||
if "TableAttribute" in cfg.PostProcess:
|
||||
if "source_threshold" in kwargs and kwargs["source_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"source_threshold"]
|
||||
if "number_threshold" in kwargs and kwargs["number_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"number_threshold"]
|
||||
if "color_threshold" in kwargs and kwargs["color_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"color_threshold"]
|
||||
if "clarity_threshold" in kwargs and kwargs["clarity_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"clarity_threshold"]
|
||||
if "obstruction_threshold" in kwargs and kwargs["obstruction_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"obstruction_threshold"]
|
||||
if "angle_threshold" in kwargs and kwargs["angle_threshold"]:
|
||||
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
|
||||
"angle_threshold"]
|
||||
if "save_dir" in kwargs and kwargs["save_dir"]:
|
||||
cfg.PostProcess.SavePreLabel.save_dir = kwargs["save_dir"]
|
||||
|
||||
|
@ -314,6 +341,13 @@ def args_cfg():
|
|||
type=str,
|
||||
help="The directory of model files. Valid when model_name not specifed."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_dir",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The index directory path.")
|
||||
parser.add_argument(
|
||||
"--data_file", type=str, required=False, help="The label file path.")
|
||||
parser.add_argument("--use_gpu", type=str2bool, help="Whether use GPU.")
|
||||
parser.add_argument(
|
||||
"--gpu_mem",
|
||||
|
@ -366,6 +400,7 @@ def print_info():
|
|||
"""
|
||||
imn_table = PrettyTable(["IMN Model Series", "Model Name"])
|
||||
pulc_table = PrettyTable(["PULC Models"])
|
||||
shitu_table = PrettyTable(["PP-ShiTu Models"])
|
||||
try:
|
||||
sz = os.get_terminal_size()
|
||||
total_width = sz.columns
|
||||
|
@ -384,11 +419,16 @@ def print_info():
|
|||
textwrap.fill(
|
||||
" ".join(PULC_MODELS), width=total_width).center(table_width - 4)
|
||||
])
|
||||
shitu_table.add_row([
|
||||
textwrap.fill(
|
||||
" ".join(SHITU_MODELS), width=total_width).center(table_width - 4)
|
||||
])
|
||||
|
||||
print("{}".format("-" * table_width))
|
||||
print("Models supported by PaddleClas".center(table_width))
|
||||
print(imn_table)
|
||||
print(pulc_table)
|
||||
print(shitu_table)
|
||||
print("Powered by PaddlePaddle!".rjust(table_width))
|
||||
print("{}".format("-" * table_width))
|
||||
|
||||
|
@ -444,6 +484,10 @@ def check_model_file(model_type, model_name):
|
|||
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
|
||||
"PULC", model_name)
|
||||
url = PULC_MODEL_BASE_DOWNLOAD_URL.format(model_name)
|
||||
elif model_type == "shitu":
|
||||
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
|
||||
"PP-ShiTu", model_name)
|
||||
url = SHITU_MODEL_BASE_DOWNLOAD_URL.format(model_name)
|
||||
else:
|
||||
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
|
||||
"IMN", model_name)
|
||||
|
@ -504,8 +548,10 @@ class PaddleClas(object):
|
|||
model_name, inference_model_dir)
|
||||
self._config = init_config(self.model_type, model_name,
|
||||
inference_model_dir, **kwargs)
|
||||
|
||||
self.cls_predictor = ClsPredictor(self._config)
|
||||
if self.model_type == "shitu":
|
||||
self.predictor = SystemPredictor(self._config)
|
||||
else:
|
||||
self.predictor = ClsPredictor(self._config)
|
||||
|
||||
def get_config(self):
|
||||
"""Get the config.
|
||||
|
@ -517,6 +563,7 @@ class PaddleClas(object):
|
|||
"""
|
||||
all_imn_model_names = get_imn_model_names()
|
||||
all_pulc_model_names = PULC_MODELS
|
||||
all_shitu_model_names = SHITU_MODELS
|
||||
|
||||
if model_name:
|
||||
if model_name in all_imn_model_names:
|
||||
|
@ -525,6 +572,15 @@ class PaddleClas(object):
|
|||
elif model_name in all_pulc_model_names:
|
||||
inference_model_dir = check_model_file("pulc", model_name)
|
||||
return "pulc", inference_model_dir
|
||||
elif model_name in all_shitu_model_names:
|
||||
inference_model_dir = check_model_file(
|
||||
"shitu",
|
||||
"PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0")
|
||||
inference_model_dir = check_model_file(
|
||||
"shitu", "picodet_PPLCNet_x2_5_mainbody_lite_v1.0")
|
||||
inference_model_dir = os.path.abspath(
|
||||
os.path.dirname(inference_model_dir))
|
||||
return "shitu", inference_model_dir
|
||||
else:
|
||||
similar_imn_names = similar_model_names(model_name,
|
||||
all_imn_model_names)
|
||||
|
@ -545,12 +601,13 @@ class PaddleClas(object):
|
|||
raise InputModelError(err)
|
||||
return "custom", inference_model_dir
|
||||
else:
|
||||
err = f"Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)."
|
||||
err = "Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)."
|
||||
raise InputModelError(err)
|
||||
return None
|
||||
|
||||
def predict(self, input_data: Union[str, np.array],
|
||||
print_pred: bool=False) -> Generator[list, None, None]:
|
||||
def predict_cls(self,
|
||||
input_data: Union[str, np.array],
|
||||
print_pred: bool=False) -> Generator[list, None, None]:
|
||||
"""Predict input_data.
|
||||
|
||||
Args:
|
||||
|
@ -570,7 +627,7 @@ class PaddleClas(object):
|
|||
"""
|
||||
|
||||
if isinstance(input_data, np.ndarray):
|
||||
yield self.cls_predictor.predict(input_data)
|
||||
yield self.predictor.predict(input_data)
|
||||
elif isinstance(input_data, str):
|
||||
if input_data.startswith("http") or input_data.startswith("https"):
|
||||
image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR)
|
||||
|
@ -602,7 +659,7 @@ class PaddleClas(object):
|
|||
cnt += 1
|
||||
|
||||
if cnt % batch_size == 0 or (idx_img + 1) == len(image_list):
|
||||
preds = self.cls_predictor.predict(img_list)
|
||||
preds = self.predictor.predict(img_list)
|
||||
|
||||
if preds:
|
||||
for idx_pred, pred in enumerate(preds):
|
||||
|
@ -619,6 +676,77 @@ class PaddleClas(object):
|
|||
raise ImageTypeError(err)
|
||||
return
|
||||
|
||||
def predict_shitu(self,
|
||||
input_data: Union[str, np.array],
|
||||
print_pred: bool=False) -> Generator[list, None, None]:
|
||||
"""Predict input_data.
|
||||
Args:
|
||||
input_data (Union[str, np.array]):
|
||||
When the type is str, it is the path of image, or the directory containing images, or the URL of image from Internet.
|
||||
When the type is np.array, it is the image data whose channel order is RGB.
|
||||
print_pred (bool, optional): Whether print the prediction result. Defaults to False.
|
||||
|
||||
Raises:
|
||||
ImageTypeError: Illegal input_data.
|
||||
|
||||
Yields:
|
||||
Generator[list, None, None]:
|
||||
The prediction result(s) of input_data by batch_size. For every one image,
|
||||
prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names".
|
||||
The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
|
||||
"""
|
||||
if isinstance(input_data, np.ndarray):
|
||||
yield self.predictor.predict(input_data)
|
||||
elif isinstance(input_data, str):
|
||||
if input_data.startswith("http") or input_data.startswith("https"):
|
||||
image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR)
|
||||
if not os.path.exists(image_storage_dir()):
|
||||
os.makedirs(image_storage_dir())
|
||||
image_save_path = image_storage_dir("tmp.jpg")
|
||||
download_with_progressbar(input_data, image_save_path)
|
||||
logger.info(
|
||||
f"Image to be predicted from Internet: {input_data}, has been saved to: {image_save_path}"
|
||||
)
|
||||
input_data = image_save_path
|
||||
image_list = get_image_list(input_data)
|
||||
|
||||
cnt = 0
|
||||
for idx_img, img_path in enumerate(image_list):
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
logger.warning(
|
||||
f"Image file failed to read and has been skipped. The path: {img_path}"
|
||||
)
|
||||
continue
|
||||
img = img[:, :, ::-1]
|
||||
cnt += 1
|
||||
|
||||
preds = self.predictor.predict(
|
||||
img) # [dict1, dict2, ..., dictn]
|
||||
if preds:
|
||||
if print_pred:
|
||||
logger.info(f"{preds}, filename: {img_path}")
|
||||
|
||||
yield preds
|
||||
else:
|
||||
err = "Please input legal image! The type of image supported by PaddleClas are: NumPy.ndarray and string of local path or Ineternet URL"
|
||||
raise ImageTypeError(err)
|
||||
return
|
||||
|
||||
def predict(self,
|
||||
input_data: Union[str, np.array],
|
||||
print_pred: bool=False,
|
||||
predict_type="cls"):
|
||||
if predict_type == "cls":
|
||||
return self.predict_cls(input_data, print_pred)
|
||||
elif predict_type == "shitu":
|
||||
assert not isinstance(input_data, (
|
||||
list, tuple
|
||||
)), "PP-ShiTu predictor only support single image as input now."
|
||||
return self.predict_shitu(input_data, print_pred)
|
||||
else:
|
||||
raise ModuleNotFoundError
|
||||
|
||||
|
||||
# for CLI
|
||||
def main():
|
||||
|
@ -627,7 +755,10 @@ def main():
|
|||
print_info()
|
||||
cfg = args_cfg()
|
||||
clas_engine = PaddleClas(**cfg)
|
||||
res = clas_engine.predict(cfg["infer_imgs"], print_pred=True)
|
||||
res = clas_engine.predict(
|
||||
cfg["infer_imgs"],
|
||||
print_pred=True,
|
||||
predict_type="cls" if "PP-ShiTu" not in cfg["model_name"] else "shitu")
|
||||
for _ in res:
|
||||
pass
|
||||
logger.info("Predict complete!")
|
||||
|
|
Loading…
Reference in New Issue