2022-06-07 19:35:53 +08:00
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2021-01-07 17:36:29 +08:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
2021-07-23 14:27:09 +08:00
from typing import Union , Generator
2021-03-26 18:52:50 +08:00
import argparse
import shutil
2021-05-17 14:55:38 +08:00
import textwrap
2021-06-22 11:12:26 +08:00
import tarfile
import requests
from functools import partial
2021-05-17 14:55:38 +08:00
from difflib import SequenceMatcher
2021-01-07 17:36:29 +08:00
import cv2
import numpy as np
from tqdm import tqdm
2021-06-22 11:12:26 +08:00
from prettytable import PrettyTable
2022-06-08 22:01:51 +08:00
import paddle
2021-06-22 11:12:26 +08:00
2022-08-17 16:47:55 +08:00
from . ppcls . arch import backbone
from . ppcls . utils import logger
2022-08-16 21:24:57 +08:00
2022-08-17 16:47:55 +08:00
from . deploy . python . predict_cls import ClsPredictor
from . deploy . utils . get_image_list import get_image_list
from . deploy . utils import config
2021-06-22 11:12:26 +08:00
2022-08-17 16:47:55 +08:00
# for the PaddleClas Project
from . import deploy
from . import ppcls
2021-06-22 11:12:26 +08:00
2021-11-15 14:52:38 +08:00
# for building model with loading pretrained weights from backbone
2022-06-07 18:24:45 +08:00
logger . init_logger ( )
2021-11-15 14:52:38 +08:00
2021-06-22 11:12:26 +08:00
__all__ = [ " PaddleClas " ]
2021-03-26 18:52:50 +08:00
2021-01-07 17:36:29 +08:00
BASE_DIR = os . path . expanduser ( " ~/.paddleclas/ " )
2021-06-22 11:12:26 +08:00
BASE_INFERENCE_MODEL_DIR = os . path . join ( BASE_DIR , " inference_model " )
BASE_IMAGES_DIR = os . path . join ( BASE_DIR , " images " )
2022-06-06 16:51:56 +08:00
IMN_MODEL_BASE_DOWNLOAD_URL = " https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ {} _infer.tar "
IMN_MODEL_SERIES = {
2021-05-17 14:55:38 +08:00
" AlexNet " : [ " AlexNet " ] ,
2022-08-22 19:35:45 +08:00
" CSWinTransformer " : [
" CSWinTransformer_tiny_224 " , " CSWinTransformer_small_224 " ,
" CSWinTransformer_base_224 " , " CSWinTransformer_base_384 " ,
" CSWinTransformer_large_224 " , " CSWinTransformer_large_384 "
] ,
2021-05-17 14:55:38 +08:00
" DarkNet " : [ " DarkNet53 " ] ,
" DeiT " : [
2021-06-22 11:12:26 +08:00
" DeiT_base_distilled_patch16_224 " , " DeiT_base_distilled_patch16_384 " ,
" DeiT_base_patch16_224 " , " DeiT_base_patch16_384 " ,
" DeiT_small_distilled_patch16_224 " , " DeiT_small_patch16_224 " ,
" DeiT_tiny_distilled_patch16_224 " , " DeiT_tiny_patch16_224 "
2021-05-17 14:55:38 +08:00
] ,
" DenseNet " : [
" DenseNet121 " , " DenseNet161 " , " DenseNet169 " , " DenseNet201 " ,
" DenseNet264 "
] ,
2021-11-18 11:25:37 +08:00
" DLA " : [
" DLA46_c " , " DLA60x_c " , " DLA34 " , " DLA60 " , " DLA60x " , " DLA102 " , " DLA102x " ,
" DLA102x2 " , " DLA169 "
] ,
2021-05-17 14:55:38 +08:00
" DPN " : [ " DPN68 " , " DPN92 " , " DPN98 " , " DPN107 " , " DPN131 " ] ,
" EfficientNet " : [
" EfficientNetB0 " , " EfficientNetB0_small " , " EfficientNetB1 " ,
" EfficientNetB2 " , " EfficientNetB3 " , " EfficientNetB4 " , " EfficientNetB5 " ,
" EfficientNetB6 " , " EfficientNetB7 "
] ,
2021-11-15 11:46:09 +08:00
" ESNet " : [ " ESNet_x0_25 " , " ESNet_x0_5 " , " ESNet_x0_75 " , " ESNet_x1_0 " ] ,
2021-05-17 14:55:38 +08:00
" GhostNet " :
[ " GhostNet_x0_5 " , " GhostNet_x1_0 " , " GhostNet_x1_3 " , " GhostNet_x1_3_ssld " ] ,
2021-11-18 11:25:37 +08:00
" HarDNet " : [ " HarDNet39_ds " , " HarDNet68_ds " , " HarDNet68 " , " HarDNet85 " ] ,
2021-05-17 14:55:38 +08:00
" HRNet " : [
" HRNet_W18_C " , " HRNet_W30_C " , " HRNet_W32_C " , " HRNet_W40_C " ,
" HRNet_W44_C " , " HRNet_W48_C " , " HRNet_W64_C " , " HRNet_W18_C_ssld " ,
" HRNet_W48_C_ssld "
] ,
" Inception " : [ " GoogLeNet " , " InceptionV3 " , " InceptionV4 " ] ,
2022-08-22 19:35:45 +08:00
" LeViT " :
[ " LeViT_128S " , " LeViT_128 " , " LeViT_192 " , " LeViT_256 " , " LeViT_384 " ] ,
2021-11-15 10:54:39 +08:00
" MixNet " : [ " MixNet_S " , " MixNet_M " , " MixNet_L " ] ,
2021-05-17 14:55:38 +08:00
" MobileNetV1 " : [
" MobileNetV1_x0_25 " , " MobileNetV1_x0_5 " , " MobileNetV1_x0_75 " ,
" MobileNetV1 " , " MobileNetV1_ssld "
] ,
" MobileNetV2 " : [
" MobileNetV2_x0_25 " , " MobileNetV2_x0_5 " , " MobileNetV2_x0_75 " ,
" MobileNetV2 " , " MobileNetV2_x1_5 " , " MobileNetV2_x2_0 " ,
" MobileNetV2_ssld "
] ,
" MobileNetV3 " : [
" MobileNetV3_small_x0_35 " , " MobileNetV3_small_x0_5 " ,
" MobileNetV3_small_x0_75 " , " MobileNetV3_small_x1_0 " ,
" MobileNetV3_small_x1_25 " , " MobileNetV3_large_x0_35 " ,
" MobileNetV3_large_x0_5 " , " MobileNetV3_large_x0_75 " ,
" MobileNetV3_large_x1_0 " , " MobileNetV3_large_x1_25 " ,
" MobileNetV3_small_x1_0_ssld " , " MobileNetV3_large_x1_0_ssld "
] ,
2022-08-22 19:35:45 +08:00
" MobileViT " : [ " MobileViT_XXS " , " MobileViT_XS " , " MobileViT_S " ] ,
2022-06-08 22:32:56 +08:00
" PPHGNet " : [
" PPHGNet_tiny " ,
" PPHGNet_small " ,
" PPHGNet_tiny_ssld " ,
" PPHGNet_small_ssld " ,
] ,
2021-11-10 14:16:35 +08:00
" PPLCNet " : [
" PPLCNet_x0_25 " , " PPLCNet_x0_35 " , " PPLCNet_x0_5 " , " PPLCNet_x0_75 " ,
" PPLCNet_x1_0 " , " PPLCNet_x1_5 " , " PPLCNet_x2_0 " , " PPLCNet_x2_5 "
] ,
2022-06-08 22:32:56 +08:00
" PPLCNetV2 " : [ " PPLCNetV2_base " ] ,
2022-08-22 19:35:45 +08:00
" PVTV2 " : [
" PVT_V2_B0 " , " PVT_V2_B1 " , " PVT_V2_B2 " , " PVT_V2_B2_Linear " , " PVT_V2_B3 " ,
" PVT_V2_B4 " , " PVT_V2_B5 "
] ,
2021-11-18 11:25:37 +08:00
" RedNet " : [ " RedNet26 " , " RedNet38 " , " RedNet50 " , " RedNet101 " , " RedNet152 " ] ,
2021-05-17 14:55:38 +08:00
" RegNet " : [ " RegNetX_4GF " ] ,
" Res2Net " : [
" Res2Net50_14w_8s " , " Res2Net50_26w_4s " , " Res2Net50_vd_26w_4s " ,
" Res2Net200_vd_26w_4s " , " Res2Net101_vd_26w_4s " ,
" Res2Net50_vd_26w_4s_ssld " , " Res2Net101_vd_26w_4s_ssld " ,
" Res2Net200_vd_26w_4s_ssld "
] ,
" ResNeSt " : [ " ResNeSt50 " , " ResNeSt50_fast_1s1x64d " ] ,
" ResNet " : [
" ResNet18 " , " ResNet18_vd " , " ResNet34 " , " ResNet34_vd " , " ResNet50 " ,
" ResNet50_vc " , " ResNet50_vd " , " ResNet50_vd_v2 " , " ResNet101 " ,
" ResNet101_vd " , " ResNet152 " , " ResNet152_vd " , " ResNet200_vd " ,
" ResNet34_vd_ssld " , " ResNet50_vd_ssld " , " ResNet50_vd_ssld_v2 " ,
" ResNet101_vd_ssld " , " Fix_ResNet50_vd_ssld_v2 " , " ResNet50_ACNet_deploy "
] ,
" ResNeXt " : [
" ResNeXt50_32x4d " , " ResNeXt50_vd_32x4d " , " ResNeXt50_64x4d " ,
" ResNeXt50_vd_64x4d " , " ResNeXt101_32x4d " , " ResNeXt101_vd_32x4d " ,
" ResNeXt101_32x8d_wsl " , " ResNeXt101_32x16d_wsl " ,
" ResNeXt101_32x32d_wsl " , " ResNeXt101_32x48d_wsl " ,
" Fix_ResNeXt101_32x48d_wsl " , " ResNeXt101_64x4d " , " ResNeXt101_vd_64x4d " ,
" ResNeXt152_32x4d " , " ResNeXt152_vd_32x4d " , " ResNeXt152_64x4d " ,
" ResNeXt152_vd_64x4d "
] ,
2021-11-18 11:25:37 +08:00
" ReXNet " :
[ " ReXNet_1_0 " , " ReXNet_1_3 " , " ReXNet_1_5 " , " ReXNet_2_0 " , " ReXNet_3_0 " ] ,
2021-05-17 14:55:38 +08:00
" SENet " : [
" SENet154_vd " , " SE_HRNet_W64_C_ssld " , " SE_ResNet18_vd " ,
" SE_ResNet34_vd " , " SE_ResNet50_vd " , " SE_ResNeXt50_32x4d " ,
" SE_ResNeXt50_vd_32x4d " , " SE_ResNeXt101_32x4d "
] ,
" ShuffleNetV2 " : [
" ShuffleNetV2_swish " , " ShuffleNetV2_x0_25 " , " ShuffleNetV2_x0_33 " ,
" ShuffleNetV2_x0_5 " , " ShuffleNetV2_x1_0 " , " ShuffleNetV2_x1_5 " ,
" ShuffleNetV2_x2_0 "
] ,
" SqueezeNet " : [ " SqueezeNet1_0 " , " SqueezeNet1_1 " ] ,
" SwinTransformer " : [
" SwinTransformer_large_patch4_window7_224_22kto1k " ,
" SwinTransformer_large_patch4_window12_384_22kto1k " ,
" SwinTransformer_base_patch4_window7_224_22kto1k " ,
" SwinTransformer_base_patch4_window12_384_22kto1k " ,
" SwinTransformer_base_patch4_window12_384 " ,
" SwinTransformer_base_patch4_window7_224 " ,
" SwinTransformer_small_patch4_window7_224 " ,
" SwinTransformer_tiny_patch4_window7_224 "
] ,
2021-11-10 14:16:35 +08:00
" Twins " : [
" pcpvt_small " , " pcpvt_base " , " pcpvt_large " , " alt_gvt_small " ,
" alt_gvt_base " , " alt_gvt_large "
] ,
2022-08-22 19:35:45 +08:00
" TNT " : [ " TNT_small " ] ,
2021-05-17 14:55:38 +08:00
" VGG " : [ " VGG11 " , " VGG13 " , " VGG16 " , " VGG19 " ] ,
" VisionTransformer " : [
" ViT_base_patch16_224 " , " ViT_base_patch16_384 " , " ViT_base_patch32_384 " ,
" ViT_large_patch16_224 " , " ViT_large_patch16_384 " ,
" ViT_large_patch32_384 " , " ViT_small_patch16_224 "
] ,
" Xception " : [
" Xception41 " , " Xception41_deeplab " , " Xception65 " , " Xception65_deeplab " ,
" Xception71 "
]
2021-01-07 17:36:29 +08:00
}
2022-06-13 16:49:49 +08:00
PULC_MODEL_BASE_DOWNLOAD_URL = " https://paddleclas.bj.bcebos.com/models/PULC/inference/ {} _infer.tar "
2022-06-06 16:51:56 +08:00
PULC_MODELS = [
2022-06-13 20:13:37 +08:00
" car_exists " , " language_classification " , " person_attribute " ,
" person_exists " , " safety_helmet " , " text_image_orientation " ,
" textline_orientation " , " traffic_sign " , " vehicle_attribute "
2022-06-06 16:51:56 +08:00
]
2021-01-07 17:36:29 +08:00
2021-06-22 11:12:26 +08:00
class ImageTypeError ( Exception ) :
""" ImageTypeError.
2021-05-17 14:55:38 +08:00
"""
2021-06-22 11:12:26 +08:00
def __init__ ( self , message = " " ) :
2021-05-17 14:55:38 +08:00
super ( ) . __init__ ( message )
2021-06-22 11:12:26 +08:00
class InputModelError ( Exception ) :
""" InputModelError.
"""
def __init__ ( self , message = " " ) :
super ( ) . __init__ ( message )
2022-06-06 16:51:56 +08:00
def init_config ( model_type , model_name , inference_model_dir , * * kwargs ) :
cfg_path = f " deploy/configs/PULC/ { model_name } /inference_ { model_name } .yaml " if model_type == " pulc " else " deploy/configs/inference_cls.yaml "
2022-08-17 16:47:55 +08:00
__dir__ = os . path . dirname ( __file__ )
2022-06-08 18:14:38 +08:00
cfg_path = os . path . join ( __dir__ , cfg_path )
2022-06-06 16:51:56 +08:00
cfg = config . get_config ( cfg_path , show = False )
cfg . Global . inference_model_dir = inference_model_dir
if " batch_size " in kwargs and kwargs [ " batch_size " ] :
cfg . Global . batch_size = kwargs [ " batch_size " ]
2022-06-08 22:01:51 +08:00
2022-06-06 16:51:56 +08:00
if " use_gpu " in kwargs and kwargs [ " use_gpu " ] :
cfg . Global . use_gpu = kwargs [ " use_gpu " ]
2022-06-08 22:01:51 +08:00
if cfg . Global . use_gpu and not paddle . device . is_compiled_with_cuda ( ) :
msg = " The current running environment does not support the use of GPU. CPU has been used instead. "
logger . warning ( msg )
cfg . Global . use_gpu = False
2022-06-06 16:51:56 +08:00
if " infer_imgs " in kwargs and kwargs [ " infer_imgs " ] :
cfg . Global . infer_imgs = kwargs [ " infer_imgs " ]
if " enable_mkldnn " in kwargs and kwargs [ " enable_mkldnn " ] :
cfg . Global . enable_mkldnn = kwargs [ " enable_mkldnn " ]
if " cpu_num_threads " in kwargs and kwargs [ " cpu_num_threads " ] :
cfg . Global . cpu_num_threads = kwargs [ " cpu_num_threads " ]
if " use_fp16 " in kwargs and kwargs [ " use_fp16 " ] :
cfg . Global . use_fp16 = kwargs [ " use_fp16 " ]
if " use_tensorrt " in kwargs and kwargs [ " use_tensorrt " ] :
cfg . Global . use_tensorrt = kwargs [ " use_tensorrt " ]
if " gpu_mem " in kwargs and kwargs [ " gpu_mem " ] :
cfg . Global . gpu_mem = kwargs [ " gpu_mem " ]
if " resize_short " in kwargs and kwargs [ " resize_short " ] :
cfg . PreProcess . transform_ops [ 0 ] [ " ResizeImage " ] [
" resize_short " ] = kwargs [ " resize_short " ]
if " crop_size " in kwargs and kwargs [ " crop_size " ] :
cfg . PreProcess . transform_ops [ 1 ] [ " CropImage " ] [ " size " ] = kwargs [
" crop_size " ]
# TODO(gaotingquan): not robust
if " thresh " in kwargs and kwargs [
" thresh " ] and " ThreshOutput " in cfg . PostProcess :
cfg . PostProcess . ThreshOutput . thresh = kwargs [ " thresh " ]
if " Topk " in cfg . PostProcess :
if " topk " in kwargs and kwargs [ " topk " ] :
cfg . PostProcess . Topk . topk = kwargs [ " topk " ]
if " class_id_map_file " in kwargs and kwargs [ " class_id_map_file " ] :
cfg . PostProcess . Topk . class_id_map_file = kwargs [
2021-07-06 23:27:43 +08:00
" class_id_map_file " ]
2022-06-06 16:51:56 +08:00
else :
2022-06-10 18:58:03 +08:00
class_id_map_file_path = os . path . relpath (
2022-06-06 16:51:56 +08:00
cfg . PostProcess . Topk . class_id_map_file , " ../ " )
2022-06-10 18:58:03 +08:00
cfg . PostProcess . Topk . class_id_map_file = os . path . join (
__dir__ , class_id_map_file_path )
2022-06-06 16:51:56 +08:00
if " VehicleAttribute " in cfg . PostProcess :
if " color_threshold " in kwargs and kwargs [ " color_threshold " ] :
cfg . PostProcess . VehicleAttribute . color_threshold = kwargs [
" color_threshold " ]
if " type_threshold " in kwargs and kwargs [ " type_threshold " ] :
cfg . PostProcess . VehicleAttribute . type_threshold = kwargs [
" type_threshold " ]
if " save_dir " in kwargs and kwargs [ " save_dir " ] :
cfg . PostProcess . SavePreLabel . save_dir = kwargs [ " save_dir " ]
2021-07-06 23:27:43 +08:00
return cfg
def args_cfg ( ) :
def str2bool ( v ) :
return v . lower ( ) in ( " true " , " t " , " 1 " )
parser = argparse . ArgumentParser ( )
parser . add_argument (
" --infer_imgs " ,
type = str ,
required = True ,
help = " The image(s) to be predicted. " )
parser . add_argument (
" --model_name " , type = str , help = " The model name to be used. " )
parser . add_argument (
" --inference_model_dir " ,
type = str ,
help = " The directory of model files. Valid when model_name not specifed. "
)
2022-06-08 22:01:51 +08:00
parser . add_argument ( " --use_gpu " , type = str2bool , help = " Whether use GPU. " )
parser . add_argument (
" --gpu_mem " ,
type = int ,
help = " The memory size of GPU allocated to predict. " )
2021-07-06 23:27:43 +08:00
parser . add_argument (
" --enable_mkldnn " ,
type = str2bool ,
help = " Whether use MKLDNN. Valid when use_gpu is False " )
parser . add_argument (
2022-06-08 22:01:51 +08:00
" --cpu_num_threads " ,
type = int ,
help = " The threads number when predicting on CPU. " )
parser . add_argument (
" --use_tensorrt " ,
type = str2bool ,
2022-06-13 22:34:14 +08:00
help = " Whether use TensorRT to accelerate. " )
2022-06-08 22:01:51 +08:00
parser . add_argument (
" --use_fp16 " , type = str2bool , help = " Whether use FP16 to predict. " )
2022-06-06 16:51:56 +08:00
parser . add_argument ( " --batch_size " , type = int , help = " Batch size. " )
2021-07-06 23:27:43 +08:00
parser . add_argument (
" --topk " ,
type = int ,
2022-06-06 16:51:56 +08:00
help = " Return topk score(s) and corresponding results when Topk postprocess is used. "
)
2021-07-06 23:27:43 +08:00
parser . add_argument (
" --class_id_map_file " ,
type = str ,
help = " The path of file that map class_id and label. " )
2022-06-06 16:51:56 +08:00
parser . add_argument (
" --threshold " ,
type = float ,
help = " The threshold of ThreshOutput when postprocess is used. " )
parser . add_argument ( " --color_threshold " , type = float , help = " " )
parser . add_argument ( " --type_threshold " , type = float , help = " " )
2021-07-06 23:27:43 +08:00
parser . add_argument (
" --save_dir " ,
type = str ,
help = " The directory to save prediction results as pre-label. " )
2021-07-19 16:13:28 +08:00
parser . add_argument (
2022-06-06 16:51:56 +08:00
" --resize_short " , type = int , help = " Resize according to short size. " )
parser . add_argument ( " --crop_size " , type = int , help = " Centor crop size. " )
2021-07-06 23:27:43 +08:00
args = parser . parse_args ( )
return vars ( args )
2021-06-22 11:12:26 +08:00
2021-05-17 14:55:38 +08:00
def print_info ( ) :
2021-06-22 11:12:26 +08:00
""" Print list of supported models in formatted.
"""
2022-06-06 16:51:56 +08:00
imn_table = PrettyTable ( [ " IMN Model Series " , " Model Name " ] )
pulc_table = PrettyTable ( [ " PULC Models " ] )
2021-05-28 12:20:51 +08:00
try :
sz = os . get_terminal_size ( )
2022-06-06 16:51:56 +08:00
total_width = sz . columns
first_width = 30
second_width = total_width - first_width if total_width > 50 else 10
2021-05-28 12:20:51 +08:00
except OSError :
2022-06-19 13:31:48 +08:00
total_width = 100
2022-06-06 16:51:56 +08:00
second_width = 100
for series in IMN_MODEL_SERIES :
names = textwrap . fill (
" " . join ( IMN_MODEL_SERIES [ series ] ) , width = second_width )
imn_table . add_row ( [ series , names ] )
table_width = len ( str ( imn_table ) . split ( " \n " ) [ 0 ] )
pulc_table . add_row ( [
textwrap . fill (
" " . join ( PULC_MODELS ) , width = total_width ) . center ( table_width - 4 )
] )
print ( " {} " . format ( " - " * table_width ) )
print ( " Models supported by PaddleClas " . center ( table_width ) )
print ( imn_table )
print ( pulc_table )
print ( " Powered by PaddlePaddle! " . rjust ( table_width ) )
print ( " {} " . format ( " - " * table_width ) )
def get_imn_model_names ( ) :
2021-06-22 11:12:26 +08:00
""" Get the model names list.
"""
2021-05-17 14:55:38 +08:00
model_names = [ ]
2022-06-06 16:51:56 +08:00
for series in IMN_MODEL_SERIES :
model_names + = ( IMN_MODEL_SERIES [ series ] )
2021-05-17 14:55:38 +08:00
return model_names
2022-06-06 16:51:56 +08:00
def similar_model_names ( name = " " , names = [ ] , thresh = 0.1 , topk = 5 ) :
2021-06-22 11:12:26 +08:00
""" Find the most similar topk model names.
2021-05-17 14:55:38 +08:00
"""
scores = [ ]
for idx , n in enumerate ( names ) :
2021-06-22 11:12:26 +08:00
if n . startswith ( " __ " ) :
2021-05-17 14:55:38 +08:00
continue
score = SequenceMatcher ( None , n . lower ( ) , name . lower ( ) ) . quick_ratio ( )
if score > thresh :
scores . append ( ( idx , score ) )
scores . sort ( key = lambda x : x [ 1 ] , reverse = True )
similar_names = [ names [ s [ 0 ] ] for s in scores [ : min ( topk , len ( scores ) ) ] ]
return similar_names
2021-01-07 17:36:29 +08:00
def download_with_progressbar ( url , save_path ) :
2021-06-22 11:12:26 +08:00
""" Download from url with progressbar.
"""
if os . path . isfile ( save_path ) :
os . remove ( save_path )
2021-01-07 17:36:29 +08:00
response = requests . get ( url , stream = True )
2021-06-22 11:12:26 +08:00
total_size_in_bytes = int ( response . headers . get ( " content-length " , 0 ) )
2021-01-07 17:36:29 +08:00
block_size = 1024 # 1 Kibibyte
2021-06-22 11:12:26 +08:00
progress_bar = tqdm ( total = total_size_in_bytes , unit = " iB " , unit_scale = True )
with open ( save_path , " wb " ) as file :
2021-01-07 17:36:29 +08:00
for data in response . iter_content ( block_size ) :
progress_bar . update ( len ( data ) )
file . write ( data )
progress_bar . close ( )
2021-06-22 11:12:26 +08:00
if total_size_in_bytes == 0 or progress_bar . n != total_size_in_bytes or not os . path . isfile (
save_path ) :
2021-03-26 18:52:50 +08:00
raise Exception (
2021-07-06 23:27:43 +08:00
f " Something went wrong while downloading file from { url } " )
2021-01-07 17:36:29 +08:00
2022-06-06 16:51:56 +08:00
def check_model_file ( model_type , model_name ) :
2021-07-23 14:27:09 +08:00
""" Check the model files exist and download and untar when no exist.
2021-06-22 11:12:26 +08:00
"""
2022-06-06 16:51:56 +08:00
if model_type == " pulc " :
storage_directory = partial ( os . path . join , BASE_INFERENCE_MODEL_DIR ,
" PULC " , model_name )
url = PULC_MODEL_BASE_DOWNLOAD_URL . format ( model_name )
else :
storage_directory = partial ( os . path . join , BASE_INFERENCE_MODEL_DIR ,
" IMN " , model_name )
url = IMN_MODEL_BASE_DOWNLOAD_URL . format ( model_name )
2021-06-22 11:12:26 +08:00
2021-01-07 17:36:29 +08:00
tar_file_name_list = [
2021-06-22 11:12:26 +08:00
" inference.pdiparams " , " inference.pdiparams.info " , " inference.pdmodel "
2021-01-07 17:36:29 +08:00
]
2021-06-22 11:12:26 +08:00
model_file_path = storage_directory ( " inference.pdmodel " )
params_file_path = storage_directory ( " inference.pdiparams " )
if not os . path . exists ( model_file_path ) or not os . path . exists (
params_file_path ) :
tmp_path = storage_directory ( url . split ( " / " ) [ - 1 ] )
2022-06-07 18:24:45 +08:00
logger . info ( f " download { url } to { tmp_path } " )
2021-06-22 11:12:26 +08:00
os . makedirs ( storage_directory ( ) , exist_ok = True )
2021-01-07 17:36:29 +08:00
download_with_progressbar ( url , tmp_path )
2021-06-22 11:12:26 +08:00
with tarfile . open ( tmp_path , " r " ) as tarObj :
2021-01-07 17:36:29 +08:00
for member in tarObj . getmembers ( ) :
filename = None
for tar_file_name in tar_file_name_list :
if tar_file_name in member . name :
filename = tar_file_name
if filename is None :
continue
file = tarObj . extractfile ( member )
2021-06-22 11:12:26 +08:00
with open ( storage_directory ( filename ) , " wb " ) as f :
2021-01-07 17:36:29 +08:00
f . write ( file . read ( ) )
os . remove ( tmp_path )
2021-06-22 11:12:26 +08:00
if not os . path . exists ( model_file_path ) or not os . path . exists (
params_file_path ) :
raise Exception (
f " Something went wrong while praparing the model[ { model_name } ] files! "
)
2021-01-07 17:36:29 +08:00
2021-06-22 11:12:26 +08:00
return storage_directory ( )
2021-01-07 17:36:29 +08:00
2021-06-22 11:12:26 +08:00
2021-01-07 17:36:29 +08:00
class PaddleClas ( object ) :
2021-06-22 11:12:26 +08:00
""" PaddleClas.
"""
def __init__ ( self ,
model_name : str = None ,
inference_model_dir : str = None ,
2021-07-06 23:27:43 +08:00
* * kwargs ) :
2021-06-22 11:12:26 +08:00
""" Init PaddleClas with config.
2021-03-26 18:52:50 +08:00
2021-06-22 11:12:26 +08:00
Args :
2021-07-23 14:27:09 +08:00
model_name ( str , optional ) : The model name supported by PaddleClas . If specified , override config . Defaults to None .
inference_model_dir ( str , optional ) : The directory that contained model file and params file to be used . If specified , override config . Defaults to None .
use_gpu ( bool , optional ) : Whether use GPU . If specified , override config . Defaults to True .
batch_size ( int , optional ) : The batch size to pridict . If specified , override config . Defaults to 1.
topk ( int , optional ) : Return the top k prediction results with the highest score . Defaults to 5.
2021-06-22 11:12:26 +08:00
"""
super ( ) . __init__ ( )
2022-08-17 16:47:55 +08:00
2022-06-06 16:51:56 +08:00
self . model_type , inference_model_dir = self . _check_input_model (
model_name , inference_model_dir )
self . _config = init_config ( self . model_type , model_name ,
inference_model_dir , * * kwargs )
2021-06-22 11:12:26 +08:00
self . cls_predictor = ClsPredictor ( self . _config )
def get_config ( self ) :
""" Get the config.
2021-01-07 17:36:29 +08:00
"""
2021-06-22 11:12:26 +08:00
return self . _config
2022-06-06 16:51:56 +08:00
def _check_input_model ( self , model_name , inference_model_dir ) :
2021-06-22 11:12:26 +08:00
""" Check input model name or model files.
"""
2022-06-06 16:51:56 +08:00
all_imn_model_names = get_imn_model_names ( )
all_pulc_model_names = PULC_MODELS
if model_name :
if model_name in all_imn_model_names :
inference_model_dir = check_model_file ( " imn " , model_name )
return " imn " , inference_model_dir
elif model_name in all_pulc_model_names :
inference_model_dir = check_model_file ( " pulc " , model_name )
return " pulc " , inference_model_dir
else :
similar_imn_names = similar_model_names ( model_name ,
all_imn_model_names )
similar_pulc_names = similar_model_names ( model_name ,
all_pulc_model_names )
similar_names_str = " , " . join ( similar_imn_names +
similar_pulc_names )
err = f " { model_name } is not provided by PaddleClas. \n Maybe you want the : [ { similar_names_str } ]. \n If you want to use your own model, please specify inference_model_dir! "
2021-06-22 11:12:26 +08:00
raise InputModelError ( err )
2022-06-06 16:51:56 +08:00
elif inference_model_dir :
2021-06-22 11:12:26 +08:00
model_file_path = os . path . join ( inference_model_dir ,
" inference.pdmodel " )
params_file_path = os . path . join ( inference_model_dir ,
" inference.pdiparams " )
if not os . path . isfile ( model_file_path ) or not os . path . isfile (
params_file_path ) :
err = f " There is no model file or params file in this directory: { inference_model_dir } "
raise InputModelError ( err )
2022-06-06 16:51:56 +08:00
return " custom " , inference_model_dir
2021-06-22 11:12:26 +08:00
else :
2021-07-06 23:27:43 +08:00
err = f " Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams). "
2021-06-22 11:12:26 +08:00
raise InputModelError ( err )
2022-06-06 16:51:56 +08:00
return None
2021-06-22 11:12:26 +08:00
2021-07-23 14:27:09 +08:00
def predict ( self , input_data : Union [ str , np . array ] ,
print_pred : bool = False ) - > Generator [ list , None , None ] :
2021-07-06 23:27:43 +08:00
""" Predict input_data.
2021-01-07 17:36:29 +08:00
Args :
2021-11-10 14:16:35 +08:00
input_data ( Union [ str , np . array ] ) :
2021-07-23 14:27:09 +08:00
When the type is str , it is the path of image , or the directory containing images , or the URL of image from Internet .
When the type is np . array , it is the image data whose channel order is RGB .
2021-11-15 12:19:52 +08:00
print_pred ( bool , optional ) : Whether print the prediction result . Defaults to False .
2021-07-06 23:27:43 +08:00
Raises :
ImageTypeError : Illegal input_data .
Yields :
2021-11-10 14:16:35 +08:00
Generator [ list , None , None ] :
The prediction result ( s ) of input_data by batch_size . For every one image ,
prediction result ( s ) is zipped as a dict , that includs topk " class_ids " , " scores " and " label_names " .
2021-11-15 12:19:52 +08:00
The format of batch prediction result ( s ) is as follow : [ { " class_ids " : [ . . . ] , " scores " : [ . . . ] , " label_names " : [ . . . ] } , . . . ]
2021-01-07 17:36:29 +08:00
"""
2021-07-23 14:27:09 +08:00
2021-03-26 18:52:50 +08:00
if isinstance ( input_data , np . ndarray ) :
2021-11-15 12:19:52 +08:00
yield self . cls_predictor . predict ( input_data )
2021-03-26 18:52:50 +08:00
elif isinstance ( input_data , str ) :
2021-07-06 23:27:43 +08:00
if input_data . startswith ( " http " ) or input_data . startswith ( " https " ) :
2021-06-22 11:12:26 +08:00
image_storage_dir = partial ( os . path . join , BASE_IMAGES_DIR )
if not os . path . exists ( image_storage_dir ( ) ) :
os . makedirs ( image_storage_dir ( ) )
image_save_path = image_storage_dir ( " tmp.jpg " )
download_with_progressbar ( input_data , image_save_path )
2022-06-07 18:24:45 +08:00
logger . info (
2021-06-22 11:12:26 +08:00
f " Image to be predicted from Internet: { input_data } , has been saved to: { image_save_path } "
)
2022-06-06 16:51:56 +08:00
input_data = image_save_path
2021-06-22 11:12:26 +08:00
image_list = get_image_list ( input_data )
batch_size = self . _config . Global . get ( " batch_size " , 1 )
img_list = [ ]
2021-03-26 18:52:50 +08:00
img_path_list = [ ]
cnt = 0
2022-06-06 16:51:56 +08:00
for idx_img , img_path in enumerate ( image_list ) :
2021-03-26 18:52:50 +08:00
img = cv2 . imread ( img_path )
if img is None :
2022-06-07 18:24:45 +08:00
logger . warning (
2021-06-22 11:12:26 +08:00
f " Image file failed to read and has been skipped. The path: { img_path } "
)
2021-03-26 18:52:50 +08:00
continue
2021-07-23 14:27:09 +08:00
img = img [ : , : , : : - 1 ]
2021-06-22 11:12:26 +08:00
img_list . append ( img )
img_path_list . append ( img_path )
cnt + = 1
2022-06-06 16:51:56 +08:00
if cnt % batch_size == 0 or ( idx_img + 1 ) == len ( image_list ) :
2021-11-15 12:19:52 +08:00
preds = self . cls_predictor . predict ( img_list )
2022-06-06 16:51:56 +08:00
if preds :
for idx_pred , pred in enumerate ( preds ) :
pred [ " filename " ] = img_path_list [ idx_pred ]
if print_pred :
2022-06-07 18:24:45 +08:00
logger . info ( " , " . join (
2022-06-06 16:51:56 +08:00
[ f " { k } : { pred [ k ] } " for k in pred ] ) )
2021-07-06 23:27:43 +08:00
2021-06-22 11:12:26 +08:00
img_list = [ ]
2021-03-26 18:52:50 +08:00
img_path_list = [ ]
2021-07-06 23:27:43 +08:00
yield preds
2021-01-07 17:36:29 +08:00
else :
2021-06-22 11:12:26 +08:00
err = " Please input legal image! The type of image supported by PaddleClas are: NumPy.ndarray and string of local path or Ineternet URL "
raise ImageTypeError ( err )
return
2021-01-07 17:36:29 +08:00
2021-06-22 11:12:26 +08:00
# for CLI
2021-01-07 17:36:29 +08:00
def main ( ) :
2021-06-22 11:12:26 +08:00
""" Function API used for commad line.
"""
2022-08-17 16:47:55 +08:00
print_info ( )
2021-06-22 11:12:26 +08:00
cfg = args_cfg ( )
2021-07-06 23:27:43 +08:00
clas_engine = PaddleClas ( * * cfg )
res = clas_engine . predict ( cfg [ " infer_imgs " ] , print_pred = True )
for _ in res :
pass
2022-06-07 18:24:45 +08:00
logger . info ( " Predict complete! " )
2021-06-22 11:12:26 +08:00
return
2021-01-07 17:36:29 +08:00
2021-06-22 11:12:26 +08:00
if __name__ == " __main__ " :
2021-01-07 17:36:29 +08:00
main ( )