2021-06-22 11:12:26 +08:00
|
|
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2021-01-07 17:36:29 +08:00
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
__dir__ = os.path.dirname(__file__)
|
2021-06-22 11:12:26 +08:00
|
|
|
|
sys.path.append(os.path.join(__dir__, ""))
|
|
|
|
|
sys.path.append(os.path.join(__dir__, "deploy"))
|
|
|
|
|
|
2021-03-26 18:52:50 +08:00
|
|
|
|
import argparse
|
|
|
|
|
import shutil
|
2021-05-17 14:55:38 +08:00
|
|
|
|
import textwrap
|
2021-06-22 11:12:26 +08:00
|
|
|
|
import tarfile
|
|
|
|
|
import requests
|
|
|
|
|
import warnings
|
|
|
|
|
from functools import partial
|
2021-05-17 14:55:38 +08:00
|
|
|
|
from difflib import SequenceMatcher
|
2021-01-07 17:36:29 +08:00
|
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
|
import numpy as np
|
|
|
|
|
from tqdm import tqdm
|
2021-06-22 11:12:26 +08:00
|
|
|
|
from prettytable import PrettyTable
|
|
|
|
|
|
|
|
|
|
from deploy.python.predict_cls import ClsPredictor
|
|
|
|
|
from deploy.utils.get_image_list import get_image_list
|
|
|
|
|
from deploy.utils import config
|
|
|
|
|
|
|
|
|
|
from ppcls.arch.backbone import *
|
|
|
|
|
|
|
|
|
|
__all__ = ["PaddleClas"]
|
2021-03-26 18:52:50 +08:00
|
|
|
|
|
2021-01-07 17:36:29 +08:00
|
|
|
|
BASE_DIR = os.path.expanduser("~/.paddleclas/")
|
2021-06-22 11:12:26 +08:00
|
|
|
|
BASE_INFERENCE_MODEL_DIR = os.path.join(BASE_DIR, "inference_model")
|
|
|
|
|
BASE_IMAGES_DIR = os.path.join(BASE_DIR, "images")
|
|
|
|
|
BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/{}_infer.tar"
|
|
|
|
|
MODEL_SERIES = {
|
2021-05-17 14:55:38 +08:00
|
|
|
|
"AlexNet": ["AlexNet"],
|
|
|
|
|
"DarkNet": ["DarkNet53"],
|
|
|
|
|
"DeiT": [
|
2021-06-22 11:12:26 +08:00
|
|
|
|
"DeiT_base_distilled_patch16_224", "DeiT_base_distilled_patch16_384",
|
|
|
|
|
"DeiT_base_patch16_224", "DeiT_base_patch16_384",
|
|
|
|
|
"DeiT_small_distilled_patch16_224", "DeiT_small_patch16_224",
|
|
|
|
|
"DeiT_tiny_distilled_patch16_224", "DeiT_tiny_patch16_224"
|
2021-05-17 14:55:38 +08:00
|
|
|
|
],
|
|
|
|
|
"DenseNet": [
|
|
|
|
|
"DenseNet121", "DenseNet161", "DenseNet169", "DenseNet201",
|
|
|
|
|
"DenseNet264"
|
|
|
|
|
],
|
|
|
|
|
"DPN": ["DPN68", "DPN92", "DPN98", "DPN107", "DPN131"],
|
|
|
|
|
"EfficientNet": [
|
|
|
|
|
"EfficientNetB0", "EfficientNetB0_small", "EfficientNetB1",
|
|
|
|
|
"EfficientNetB2", "EfficientNetB3", "EfficientNetB4", "EfficientNetB5",
|
|
|
|
|
"EfficientNetB6", "EfficientNetB7"
|
|
|
|
|
],
|
|
|
|
|
"GhostNet":
|
|
|
|
|
["GhostNet_x0_5", "GhostNet_x1_0", "GhostNet_x1_3", "GhostNet_x1_3_ssld"],
|
|
|
|
|
"HRNet": [
|
|
|
|
|
"HRNet_W18_C", "HRNet_W30_C", "HRNet_W32_C", "HRNet_W40_C",
|
|
|
|
|
"HRNet_W44_C", "HRNet_W48_C", "HRNet_W64_C", "HRNet_W18_C_ssld",
|
|
|
|
|
"HRNet_W48_C_ssld"
|
|
|
|
|
],
|
|
|
|
|
"Inception": ["GoogLeNet", "InceptionV3", "InceptionV4"],
|
|
|
|
|
"MobileNetV1": [
|
|
|
|
|
"MobileNetV1_x0_25", "MobileNetV1_x0_5", "MobileNetV1_x0_75",
|
|
|
|
|
"MobileNetV1", "MobileNetV1_ssld"
|
|
|
|
|
],
|
|
|
|
|
"MobileNetV2": [
|
|
|
|
|
"MobileNetV2_x0_25", "MobileNetV2_x0_5", "MobileNetV2_x0_75",
|
|
|
|
|
"MobileNetV2", "MobileNetV2_x1_5", "MobileNetV2_x2_0",
|
|
|
|
|
"MobileNetV2_ssld"
|
|
|
|
|
],
|
|
|
|
|
"MobileNetV3": [
|
|
|
|
|
"MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5",
|
|
|
|
|
"MobileNetV3_small_x0_75", "MobileNetV3_small_x1_0",
|
|
|
|
|
"MobileNetV3_small_x1_25", "MobileNetV3_large_x0_35",
|
|
|
|
|
"MobileNetV3_large_x0_5", "MobileNetV3_large_x0_75",
|
|
|
|
|
"MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25",
|
|
|
|
|
"MobileNetV3_small_x1_0_ssld", "MobileNetV3_large_x1_0_ssld"
|
|
|
|
|
],
|
|
|
|
|
"RegNet": ["RegNetX_4GF"],
|
|
|
|
|
"Res2Net": [
|
|
|
|
|
"Res2Net50_14w_8s", "Res2Net50_26w_4s", "Res2Net50_vd_26w_4s",
|
|
|
|
|
"Res2Net200_vd_26w_4s", "Res2Net101_vd_26w_4s",
|
|
|
|
|
"Res2Net50_vd_26w_4s_ssld", "Res2Net101_vd_26w_4s_ssld",
|
|
|
|
|
"Res2Net200_vd_26w_4s_ssld"
|
|
|
|
|
],
|
|
|
|
|
"ResNeSt": ["ResNeSt50", "ResNeSt50_fast_1s1x64d"],
|
|
|
|
|
"ResNet": [
|
|
|
|
|
"ResNet18", "ResNet18_vd", "ResNet34", "ResNet34_vd", "ResNet50",
|
|
|
|
|
"ResNet50_vc", "ResNet50_vd", "ResNet50_vd_v2", "ResNet101",
|
|
|
|
|
"ResNet101_vd", "ResNet152", "ResNet152_vd", "ResNet200_vd",
|
|
|
|
|
"ResNet34_vd_ssld", "ResNet50_vd_ssld", "ResNet50_vd_ssld_v2",
|
|
|
|
|
"ResNet101_vd_ssld", "Fix_ResNet50_vd_ssld_v2", "ResNet50_ACNet_deploy"
|
|
|
|
|
],
|
|
|
|
|
"ResNeXt": [
|
|
|
|
|
"ResNeXt50_32x4d", "ResNeXt50_vd_32x4d", "ResNeXt50_64x4d",
|
|
|
|
|
"ResNeXt50_vd_64x4d", "ResNeXt101_32x4d", "ResNeXt101_vd_32x4d",
|
|
|
|
|
"ResNeXt101_32x8d_wsl", "ResNeXt101_32x16d_wsl",
|
|
|
|
|
"ResNeXt101_32x32d_wsl", "ResNeXt101_32x48d_wsl",
|
|
|
|
|
"Fix_ResNeXt101_32x48d_wsl", "ResNeXt101_64x4d", "ResNeXt101_vd_64x4d",
|
|
|
|
|
"ResNeXt152_32x4d", "ResNeXt152_vd_32x4d", "ResNeXt152_64x4d",
|
|
|
|
|
"ResNeXt152_vd_64x4d"
|
|
|
|
|
],
|
|
|
|
|
"SENet": [
|
|
|
|
|
"SENet154_vd", "SE_HRNet_W64_C_ssld", "SE_ResNet18_vd",
|
|
|
|
|
"SE_ResNet34_vd", "SE_ResNet50_vd", "SE_ResNeXt50_32x4d",
|
|
|
|
|
"SE_ResNeXt50_vd_32x4d", "SE_ResNeXt101_32x4d"
|
|
|
|
|
],
|
|
|
|
|
"ShuffleNetV2": [
|
|
|
|
|
"ShuffleNetV2_swish", "ShuffleNetV2_x0_25", "ShuffleNetV2_x0_33",
|
|
|
|
|
"ShuffleNetV2_x0_5", "ShuffleNetV2_x1_0", "ShuffleNetV2_x1_5",
|
|
|
|
|
"ShuffleNetV2_x2_0"
|
|
|
|
|
],
|
|
|
|
|
"SqueezeNet": ["SqueezeNet1_0", "SqueezeNet1_1"],
|
|
|
|
|
"SwinTransformer": [
|
|
|
|
|
"SwinTransformer_large_patch4_window7_224_22kto1k",
|
|
|
|
|
"SwinTransformer_large_patch4_window12_384_22kto1k",
|
|
|
|
|
"SwinTransformer_base_patch4_window7_224_22kto1k",
|
|
|
|
|
"SwinTransformer_base_patch4_window12_384_22kto1k",
|
|
|
|
|
"SwinTransformer_base_patch4_window12_384",
|
|
|
|
|
"SwinTransformer_base_patch4_window7_224",
|
|
|
|
|
"SwinTransformer_small_patch4_window7_224",
|
|
|
|
|
"SwinTransformer_tiny_patch4_window7_224"
|
|
|
|
|
],
|
|
|
|
|
"VGG": ["VGG11", "VGG13", "VGG16", "VGG19"],
|
|
|
|
|
"VisionTransformer": [
|
|
|
|
|
"ViT_base_patch16_224", "ViT_base_patch16_384", "ViT_base_patch32_384",
|
|
|
|
|
"ViT_large_patch16_224", "ViT_large_patch16_384",
|
|
|
|
|
"ViT_large_patch32_384", "ViT_small_patch16_224"
|
|
|
|
|
],
|
|
|
|
|
"Xception": [
|
|
|
|
|
"Xception41", "Xception41_deeplab", "Xception65", "Xception65_deeplab",
|
|
|
|
|
"Xception71"
|
|
|
|
|
]
|
2021-01-07 17:36:29 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2021-06-22 11:12:26 +08:00
|
|
|
|
class ImageTypeError(Exception):
|
|
|
|
|
"""ImageTypeError.
|
2021-05-17 14:55:38 +08:00
|
|
|
|
"""
|
|
|
|
|
|
2021-06-22 11:12:26 +08:00
|
|
|
|
def __init__(self, message=""):
|
2021-05-17 14:55:38 +08:00
|
|
|
|
super().__init__(message)
|
|
|
|
|
|
|
|
|
|
|
2021-06-22 11:12:26 +08:00
|
|
|
|
class InputModelError(Exception):
|
|
|
|
|
"""InputModelError.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, message=""):
|
|
|
|
|
super().__init__(message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def args_cfg():
|
|
|
|
|
parser = config.parser()
|
|
|
|
|
other_options = [
|
|
|
|
|
("infer_imgs", str, None, "The image(s) to be predicted."),
|
|
|
|
|
("model_name", str, None, "The model name to be used."),
|
|
|
|
|
("inference_model_dir", str, None, "The directory of model files."),
|
|
|
|
|
("use_gpu", bool, True, "Whether use GPU. Default by True."), (
|
|
|
|
|
"enable_mkldnn", bool, False,
|
|
|
|
|
"Whether use MKLDNN. Default by False."),
|
|
|
|
|
("batch_size", int, 1, "Batch size. Default by 1.")
|
|
|
|
|
]
|
|
|
|
|
for name, opt_type, default, description in other_options:
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--" + name, type=opt_type, default=default, help=description)
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
for name, opt_type, default, description in other_options:
|
|
|
|
|
val = eval("args." + name)
|
|
|
|
|
full_name = "Global." + name
|
|
|
|
|
args.override.append(
|
|
|
|
|
f"{full_name}={val}") if val is not default else None
|
|
|
|
|
|
|
|
|
|
cfg = config.get_config(
|
|
|
|
|
args.config, overrides=args.override, show=args.verbose)
|
|
|
|
|
|
|
|
|
|
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_default_confg():
|
|
|
|
|
return {
|
|
|
|
|
"Global": {
|
|
|
|
|
"model_name": "MobileNetV3_small_x0_35",
|
|
|
|
|
"use_gpu": False,
|
|
|
|
|
"use_fp16": False,
|
|
|
|
|
"enable_mkldnn": False,
|
|
|
|
|
"cpu_num_threads": 1,
|
|
|
|
|
"use_tensorrt": False,
|
|
|
|
|
"ir_optim": False,
|
|
|
|
|
"enable_profile": False
|
|
|
|
|
},
|
|
|
|
|
"PreProcess": {
|
|
|
|
|
"transform_ops": [{
|
|
|
|
|
"ResizeImage": {
|
|
|
|
|
"resize_short": 256
|
|
|
|
|
}
|
|
|
|
|
}, {
|
|
|
|
|
"CropImage": {
|
|
|
|
|
"size": 224
|
|
|
|
|
}
|
|
|
|
|
}, {
|
|
|
|
|
"NormalizeImage": {
|
|
|
|
|
"scale": 0.00392157,
|
|
|
|
|
"mean": [0.485, 0.456, 0.406],
|
|
|
|
|
"std": [0.229, 0.224, 0.225],
|
|
|
|
|
"order": ""
|
|
|
|
|
}
|
|
|
|
|
}, {
|
|
|
|
|
"ToCHWImage": None
|
|
|
|
|
}]
|
|
|
|
|
},
|
|
|
|
|
"PostProcess": {
|
|
|
|
|
"name": "Topk",
|
|
|
|
|
"topk": 5,
|
|
|
|
|
"class_id_map_file": "./ppcls/utils/imagenet1k_label_list.txt"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2021-05-17 14:55:38 +08:00
|
|
|
|
def print_info():
|
2021-06-22 11:12:26 +08:00
|
|
|
|
"""Print list of supported models in formatted.
|
|
|
|
|
"""
|
|
|
|
|
table = PrettyTable(["Series", "Name"])
|
2021-05-28 12:20:51 +08:00
|
|
|
|
try:
|
|
|
|
|
sz = os.get_terminal_size()
|
|
|
|
|
width = sz.columns - 30 if sz.columns > 50 else 10
|
|
|
|
|
except OSError:
|
|
|
|
|
width = 100
|
2021-06-22 11:12:26 +08:00
|
|
|
|
for series in MODEL_SERIES:
|
|
|
|
|
names = textwrap.fill(" ".join(MODEL_SERIES[series]), width=width)
|
2021-05-17 14:55:38 +08:00
|
|
|
|
table.add_row([series, names])
|
2021-06-22 11:12:26 +08:00
|
|
|
|
width = len(str(table).split("\n")[0])
|
|
|
|
|
print("{}".format("-" * width))
|
|
|
|
|
print("Models supported by PaddleClas".center(width))
|
2021-05-17 14:55:38 +08:00
|
|
|
|
print(table)
|
2021-06-22 11:12:26 +08:00
|
|
|
|
print("Powered by PaddlePaddle!".rjust(width))
|
|
|
|
|
print("{}".format("-" * width))
|
2021-05-17 14:55:38 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_names():
|
2021-06-22 11:12:26 +08:00
|
|
|
|
"""Get the model names list.
|
|
|
|
|
"""
|
2021-05-17 14:55:38 +08:00
|
|
|
|
model_names = []
|
2021-06-22 11:12:26 +08:00
|
|
|
|
for series in MODEL_SERIES:
|
|
|
|
|
model_names += (MODEL_SERIES[series])
|
2021-05-17 14:55:38 +08:00
|
|
|
|
return model_names
|
|
|
|
|
|
|
|
|
|
|
2021-06-22 11:12:26 +08:00
|
|
|
|
def similar_architectures(name="", names=[], thresh=0.1, topk=10):
|
|
|
|
|
"""Find the most similar topk model names.
|
2021-05-17 14:55:38 +08:00
|
|
|
|
"""
|
|
|
|
|
scores = []
|
|
|
|
|
for idx, n in enumerate(names):
|
2021-06-22 11:12:26 +08:00
|
|
|
|
if n.startswith("__"):
|
2021-05-17 14:55:38 +08:00
|
|
|
|
continue
|
|
|
|
|
score = SequenceMatcher(None, n.lower(), name.lower()).quick_ratio()
|
|
|
|
|
if score > thresh:
|
|
|
|
|
scores.append((idx, score))
|
|
|
|
|
scores.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
|
similar_names = [names[s[0]] for s in scores[:min(topk, len(scores))]]
|
|
|
|
|
return similar_names
|
|
|
|
|
|
|
|
|
|
|
2021-01-07 17:36:29 +08:00
|
|
|
|
def download_with_progressbar(url, save_path):
|
2021-06-22 11:12:26 +08:00
|
|
|
|
"""Download from url with progressbar.
|
|
|
|
|
"""
|
|
|
|
|
if os.path.isfile(save_path):
|
|
|
|
|
os.remove(save_path)
|
2021-01-07 17:36:29 +08:00
|
|
|
|
response = requests.get(url, stream=True)
|
2021-06-22 11:12:26 +08:00
|
|
|
|
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
2021-01-07 17:36:29 +08:00
|
|
|
|
block_size = 1024 # 1 Kibibyte
|
2021-06-22 11:12:26 +08:00
|
|
|
|
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
|
|
|
|
with open(save_path, "wb") as file:
|
2021-01-07 17:36:29 +08:00
|
|
|
|
for data in response.iter_content(block_size):
|
|
|
|
|
progress_bar.update(len(data))
|
|
|
|
|
file.write(data)
|
|
|
|
|
progress_bar.close()
|
2021-06-22 11:12:26 +08:00
|
|
|
|
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes or not os.path.isfile(
|
|
|
|
|
save_path):
|
2021-03-26 18:52:50 +08:00
|
|
|
|
raise Exception(
|
2021-06-22 11:12:26 +08:00
|
|
|
|
f"Something went wrong while downloading model/image from {url}")
|
2021-01-07 17:36:29 +08:00
|
|
|
|
|
|
|
|
|
|
2021-06-22 11:12:26 +08:00
|
|
|
|
def check_model_file(model_name):
|
|
|
|
|
"""Check the model files exist and download and untar when no exist.
|
|
|
|
|
"""
|
|
|
|
|
storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR,
|
|
|
|
|
model_name)
|
|
|
|
|
url = BASE_DOWNLOAD_URL.format(model_name)
|
|
|
|
|
|
2021-01-07 17:36:29 +08:00
|
|
|
|
tar_file_name_list = [
|
2021-06-22 11:12:26 +08:00
|
|
|
|
"inference.pdiparams", "inference.pdiparams.info", "inference.pdmodel"
|
2021-01-07 17:36:29 +08:00
|
|
|
|
]
|
2021-06-22 11:12:26 +08:00
|
|
|
|
model_file_path = storage_directory("inference.pdmodel")
|
|
|
|
|
params_file_path = storage_directory("inference.pdiparams")
|
|
|
|
|
if not os.path.exists(model_file_path) or not os.path.exists(
|
|
|
|
|
params_file_path):
|
|
|
|
|
tmp_path = storage_directory(url.split("/")[-1])
|
|
|
|
|
print(f"download {url} to {tmp_path}")
|
|
|
|
|
os.makedirs(storage_directory(), exist_ok=True)
|
2021-01-07 17:36:29 +08:00
|
|
|
|
download_with_progressbar(url, tmp_path)
|
2021-06-22 11:12:26 +08:00
|
|
|
|
with tarfile.open(tmp_path, "r") as tarObj:
|
2021-01-07 17:36:29 +08:00
|
|
|
|
for member in tarObj.getmembers():
|
|
|
|
|
filename = None
|
|
|
|
|
for tar_file_name in tar_file_name_list:
|
|
|
|
|
if tar_file_name in member.name:
|
|
|
|
|
filename = tar_file_name
|
|
|
|
|
if filename is None:
|
|
|
|
|
continue
|
|
|
|
|
file = tarObj.extractfile(member)
|
2021-06-22 11:12:26 +08:00
|
|
|
|
with open(storage_directory(filename), "wb") as f:
|
2021-01-07 17:36:29 +08:00
|
|
|
|
f.write(file.read())
|
|
|
|
|
os.remove(tmp_path)
|
2021-06-22 11:12:26 +08:00
|
|
|
|
if not os.path.exists(model_file_path) or not os.path.exists(
|
|
|
|
|
params_file_path):
|
|
|
|
|
raise Exception(
|
|
|
|
|
f"Something went wrong while praparing the model[{model_name}] files!"
|
|
|
|
|
)
|
2021-01-07 17:36:29 +08:00
|
|
|
|
|
2021-06-22 11:12:26 +08:00
|
|
|
|
return storage_directory()
|
2021-01-07 17:36:29 +08:00
|
|
|
|
|
2021-06-22 11:12:26 +08:00
|
|
|
|
|
|
|
|
|
def save_prelabel_results(class_id, input_file_path, output_dir):
|
|
|
|
|
"""Save the predicted image according to the prediction.
|
|
|
|
|
"""
|
|
|
|
|
output_dir = os.path.join(output_dir, str(class_id))
|
|
|
|
|
if not os.path.isdir(output_dir):
|
|
|
|
|
os.makedirs(output_dir)
|
|
|
|
|
shutil.copy(input_file_path, output_dir)
|
2021-01-07 17:36:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PaddleClas(object):
|
2021-06-22 11:12:26 +08:00
|
|
|
|
"""PaddleClas.
|
|
|
|
|
"""
|
|
|
|
|
|
2021-05-17 14:55:38 +08:00
|
|
|
|
print_info()
|
2021-01-07 17:36:29 +08:00
|
|
|
|
|
2021-06-22 11:12:26 +08:00
|
|
|
|
def __init__(self,
|
|
|
|
|
config: dict=None,
|
|
|
|
|
model_name: str=None,
|
|
|
|
|
inference_model_dir: str=None,
|
|
|
|
|
use_gpu: bool=None,
|
|
|
|
|
batch_size: int=None):
|
|
|
|
|
"""Init PaddleClas with config.
|
2021-03-26 18:52:50 +08:00
|
|
|
|
|
2021-06-22 11:12:26 +08:00
|
|
|
|
Args:
|
|
|
|
|
config: The config of PaddleClas's predictor, default by None. If default, the default configuration is used. Please refer doc for more information.
|
|
|
|
|
model_name: The model name supported by PaddleClas, default by None. If specified, override config.
|
|
|
|
|
inference_model_dir: The directory that contained model file and params file to be used, default by None. If specified, override config.
|
|
|
|
|
use_gpu: Wheather use GPU, default by None. If specified, override config.
|
|
|
|
|
batch_size: The batch size to pridict, default by None. If specified, override config.
|
|
|
|
|
"""
|
|
|
|
|
super().__init__()
|
|
|
|
|
self._config = config
|
|
|
|
|
self._check_config(model_name, inference_model_dir, use_gpu,
|
|
|
|
|
batch_size)
|
|
|
|
|
self._check_input_model()
|
|
|
|
|
self.cls_predictor = ClsPredictor(self._config)
|
|
|
|
|
|
|
|
|
|
def get_config(self):
|
|
|
|
|
"""Get the config.
|
2021-01-07 17:36:29 +08:00
|
|
|
|
"""
|
2021-06-22 11:12:26 +08:00
|
|
|
|
return self._config
|
|
|
|
|
|
|
|
|
|
def _check_config(self,
|
|
|
|
|
model_name=None,
|
|
|
|
|
inference_model_dir=None,
|
|
|
|
|
use_gpu=None,
|
|
|
|
|
batch_size=None):
|
|
|
|
|
if self._config is None:
|
|
|
|
|
self._config = get_default_confg()
|
|
|
|
|
warnings.warn("config is not provided, use default!")
|
|
|
|
|
self._config = config.AttrDict(self._config)
|
|
|
|
|
config.create_attr_dict(self._config)
|
|
|
|
|
|
|
|
|
|
if model_name is not None:
|
|
|
|
|
self._config.Global["model_name"] = model_name
|
|
|
|
|
if inference_model_dir is not None:
|
|
|
|
|
self._config.Global["inference_model_dir"] = inference_model_dir
|
|
|
|
|
if use_gpu is not None:
|
|
|
|
|
self._config.Global["use_gpu"] = use_gpu
|
|
|
|
|
if batch_size is not None:
|
|
|
|
|
self._config.Global["batch_size"] = batch_size
|
|
|
|
|
|
|
|
|
|
def _check_input_model(self):
|
|
|
|
|
"""Check input model name or model files.
|
|
|
|
|
"""
|
|
|
|
|
candidate_model_names = get_model_names()
|
|
|
|
|
input_model_name = self._config.Global.get("model_name", None)
|
|
|
|
|
inference_model_dir = self._config.Global.get("inference_model_dir",
|
|
|
|
|
None)
|
|
|
|
|
if input_model_name is not None:
|
|
|
|
|
similar_names = similar_architectures(input_model_name,
|
|
|
|
|
candidate_model_names)
|
|
|
|
|
similar_names_str = ", ".join(similar_names)
|
|
|
|
|
if input_model_name not in similar_names_str:
|
|
|
|
|
err = f"{input_model_name} is not exist! Maybe you want: [{similar_names_str}]"
|
|
|
|
|
raise InputModelError(err)
|
|
|
|
|
if input_model_name not in candidate_model_names:
|
|
|
|
|
err = f"{input_model_name} is not provided by PaddleClas. If you want to use your own model, please input model_file as model path!"
|
|
|
|
|
raise InputModelError(err)
|
|
|
|
|
self._config.Global.inference_model_dir = check_model_file(
|
|
|
|
|
input_model_name)
|
|
|
|
|
return
|
|
|
|
|
elif inference_model_dir is not None:
|
|
|
|
|
model_file_path = os.path.join(inference_model_dir,
|
|
|
|
|
"inference.pdmodel")
|
|
|
|
|
params_file_path = os.path.join(inference_model_dir,
|
|
|
|
|
"inference.pdiparams")
|
|
|
|
|
if not os.path.isfile(model_file_path) or not os.path.isfile(
|
|
|
|
|
params_file_path):
|
|
|
|
|
err = f"There is no model file or params file in this directory: {inference_model_dir}"
|
|
|
|
|
raise InputModelError(err)
|
|
|
|
|
return
|
|
|
|
|
else:
|
|
|
|
|
err = f"Please specify the model name supported by PaddleClas or directory contained model file and params file."
|
|
|
|
|
raise InputModelError(err)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
def predict(self, input_data, print_pred=True):
|
|
|
|
|
"""Predict label of img with paddleclas.
|
2021-01-07 17:36:29 +08:00
|
|
|
|
Args:
|
2021-06-22 11:12:26 +08:00
|
|
|
|
input_data(str, NumPy.ndarray):
|
|
|
|
|
image to be classified, support: str(local path of image file, internet URL, directory containing series of images) and NumPy.ndarray(preprocessed image data that has 3 channels and accords with [C, H, W], or raw image data that has 3 channels and accords with [H, W, C]).
|
2021-01-07 17:36:29 +08:00
|
|
|
|
Returns:
|
2021-03-26 18:52:50 +08:00
|
|
|
|
dict: {image_name: "", class_id: [], scores: [], label_names: []},if label name path == None,label_names will be empty.
|
2021-01-07 17:36:29 +08:00
|
|
|
|
"""
|
2021-03-26 18:52:50 +08:00
|
|
|
|
if isinstance(input_data, np.ndarray):
|
2021-06-22 11:12:26 +08:00
|
|
|
|
return self.cls_predictor.predict(input_data)
|
2021-03-26 18:52:50 +08:00
|
|
|
|
elif isinstance(input_data, str):
|
2021-06-22 11:12:26 +08:00
|
|
|
|
if input_data.startswith("http"):
|
|
|
|
|
image_storage_dir = partial(os.path.join, BASE_IMAGES_DIR)
|
|
|
|
|
if not os.path.exists(image_storage_dir()):
|
|
|
|
|
os.makedirs(image_storage_dir())
|
|
|
|
|
image_save_path = image_storage_dir("tmp.jpg")
|
|
|
|
|
download_with_progressbar(input_data, image_save_path)
|
|
|
|
|
input_data = image_save_path
|
|
|
|
|
warnings.warn(
|
|
|
|
|
f"Image to be predicted from Internet: {input_data}, has been saved to: {image_save_path}"
|
|
|
|
|
)
|
|
|
|
|
image_list = get_image_list(input_data)
|
|
|
|
|
|
|
|
|
|
batch_size = self._config.Global.get("batch_size", 1)
|
|
|
|
|
pre_label_out_idr = self._config.Global.get("pre_label_out_idr",
|
|
|
|
|
False)
|
|
|
|
|
|
|
|
|
|
img_list = []
|
2021-03-26 18:52:50 +08:00
|
|
|
|
img_path_list = []
|
2021-06-22 11:12:26 +08:00
|
|
|
|
output_list = []
|
2021-03-26 18:52:50 +08:00
|
|
|
|
cnt = 0
|
|
|
|
|
for idx, img_path in enumerate(image_list):
|
|
|
|
|
img = cv2.imread(img_path)
|
|
|
|
|
if img is None:
|
2021-06-22 11:12:26 +08:00
|
|
|
|
warnings.warn(
|
|
|
|
|
f"Image file failed to read and has been skipped. The path: {img_path}"
|
|
|
|
|
)
|
2021-03-26 18:52:50 +08:00
|
|
|
|
continue
|
2021-06-22 11:12:26 +08:00
|
|
|
|
img_list.append(img)
|
|
|
|
|
img_path_list.append(img_path)
|
|
|
|
|
cnt += 1
|
|
|
|
|
|
|
|
|
|
if cnt % batch_size == 0 or (idx + 1) == len(image_list):
|
|
|
|
|
outputs = self.cls_predictor.predict(img_list)
|
|
|
|
|
output_list.append(outputs[0])
|
|
|
|
|
preds = self.cls_predictor.postprocess(outputs)
|
|
|
|
|
for nu, pred in enumerate(preds):
|
|
|
|
|
if pre_label_out_idr:
|
|
|
|
|
save_prelabel_results(pred["class_ids"][0],
|
|
|
|
|
img_path_list[nu],
|
|
|
|
|
pre_label_out_idr)
|
|
|
|
|
if print_pred:
|
|
|
|
|
pred_str_list = [
|
|
|
|
|
f"filename: {img_path_list[nu]}",
|
|
|
|
|
f"top-{self._config.PostProcess.get('topk', 1)}"
|
|
|
|
|
]
|
|
|
|
|
for k in pred:
|
|
|
|
|
pred_str_list.append(f"{k}: {pred[k]}")
|
|
|
|
|
print(", ".join(pred_str_list))
|
|
|
|
|
img_list = []
|
2021-03-26 18:52:50 +08:00
|
|
|
|
img_path_list = []
|
2021-06-22 11:12:26 +08:00
|
|
|
|
return output_list
|
2021-01-07 17:36:29 +08:00
|
|
|
|
else:
|
2021-06-22 11:12:26 +08:00
|
|
|
|
err = "Please input legal image! The type of image supported by PaddleClas are: NumPy.ndarray and string of local path or Ineternet URL"
|
|
|
|
|
raise ImageTypeError(err)
|
|
|
|
|
return
|
2021-01-07 17:36:29 +08:00
|
|
|
|
|
|
|
|
|
|
2021-06-22 11:12:26 +08:00
|
|
|
|
# for CLI
|
2021-01-07 17:36:29 +08:00
|
|
|
|
def main():
|
2021-06-22 11:12:26 +08:00
|
|
|
|
"""Function API used for commad line.
|
|
|
|
|
"""
|
|
|
|
|
cfg = args_cfg()
|
|
|
|
|
clas_engine = PaddleClas(cfg)
|
|
|
|
|
clas_engine.predict(cfg["Global"]["infer_imgs"], print_pred=True)
|
|
|
|
|
return
|
2021-01-07 17:36:29 +08:00
|
|
|
|
|
|
|
|
|
|
2021-06-22 11:12:26 +08:00
|
|
|
|
if __name__ == "__main__":
|
2021-01-07 17:36:29 +08:00
|
|
|
|
main()
|