Merge branch 'PaddlePaddle:develop_reg' into develop_reg
|
@ -0,0 +1 @@
|
|||
from . import utils
|
|
@ -0,0 +1,30 @@
|
|||
|
||||
Global:
|
||||
infer_imgs: "../docs/images/whl/demo.jpg"
|
||||
inference_model_dir: "./MobileNetV1_infer/"
|
||||
batch_size: 1
|
||||
use_gpu: True
|
||||
enable_mkldnn: True
|
||||
cpu_num_threads: 100
|
||||
enable_benchmark: True
|
||||
use_fp16: False
|
||||
ir_optim: True
|
||||
use_tensorrt: False
|
||||
gpu_mem: 8000
|
||||
enable_profile: False
|
||||
PreProcess:
|
||||
transform_ops:
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
||||
name: Topk
|
||||
topk: 5
|
||||
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
|
|
@ -0,0 +1,69 @@
|
|||
Global:
|
||||
infer_imgs: "images/coco_000000570688.jpg"
|
||||
# infer_imgs: "../docs/images/whl/demo.jpg"
|
||||
det_inference_model_dir: "./ppyolov2_r50vd_dcn_365e_mainbody_infer/"
|
||||
rec_inference_model_dir: "./MobileNetV1_infer/"
|
||||
batch_size: 1
|
||||
image_shape: [3, 640, 640]
|
||||
threshold: 0.5
|
||||
max_det_results: 1
|
||||
labe_list:
|
||||
- foreground
|
||||
|
||||
# inference engine config
|
||||
use_gpu: False
|
||||
enable_mkldnn: True
|
||||
cpu_num_threads: 100
|
||||
enable_benchmark: True
|
||||
use_fp16: False
|
||||
ir_optim: True
|
||||
use_tensorrt: False
|
||||
gpu_mem: 8000
|
||||
enable_profile: False
|
||||
|
||||
DetPreProcess:
|
||||
transform_ops:
|
||||
- DetResize:
|
||||
interp: 2
|
||||
keep_ratio: false
|
||||
target_size: [640, 640]
|
||||
- DetNormalizeImage:
|
||||
is_scale: true
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
- DetPermute: {}
|
||||
|
||||
DetPostProcess: {}
|
||||
|
||||
|
||||
RecPreProcess:
|
||||
transform_ops:
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
|
||||
RecPostProcess: null
|
||||
|
||||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
build:
|
||||
enable: True
|
||||
index_path: "./logo_index/"
|
||||
image_root: "dataset/LogoDet-3K-crop/train"
|
||||
data_file: "dataset/LogoDet-3K-crop/LogoDet-3K+train.txt"
|
||||
spacer: " "
|
||||
dist_type: "IP"
|
||||
pq_size: 100
|
||||
embedding_size: 1000
|
||||
infer:
|
||||
index_path: "./logo_index/"
|
||||
search_budget: 100
|
||||
return_k: 10
|
Before Width: | Height: | Size: 126 KiB After Width: | Height: | Size: 126 KiB |
Before Width: | Height: | Size: 42 KiB After Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 75 KiB After Width: | Height: | Size: 75 KiB |
Before Width: | Height: | Size: 84 KiB After Width: | Height: | Size: 84 KiB |
Before Width: | Height: | Size: 57 KiB After Width: | Height: | Size: 57 KiB |
Before Width: | Height: | Size: 45 KiB After Width: | Height: | Size: 45 KiB |
Before Width: | Height: | Size: 62 KiB After Width: | Height: | Size: 62 KiB |
Before Width: | Height: | Size: 83 KiB After Width: | Height: | Size: 83 KiB |
After Width: | Height: | Size: 135 KiB |
|
@ -0,0 +1,205 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def decode_image(im_file, im_info):
|
||||
"""read rgb image
|
||||
Args:
|
||||
im_file (str|np.ndarray): input can be image path or np.ndarray
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
if isinstance(im_file, str):
|
||||
with open(im_file, 'rb') as f:
|
||||
im_read = f.read()
|
||||
data = np.frombuffer(im_read, dtype='uint8')
|
||||
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
|
||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||||
else:
|
||||
im = im_file
|
||||
im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
|
||||
im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
|
||||
return im, im_info
|
||||
|
||||
|
||||
class DetResize(object):
|
||||
"""resize image by target_size and max_size
|
||||
Args:
|
||||
target_size (int): the target size of image
|
||||
keep_ratio (bool): whether keep_ratio or not, default true
|
||||
interp (int): method of resize
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_size,
|
||||
keep_ratio=True,
|
||||
interp=cv2.INTER_LINEAR, ):
|
||||
if isinstance(target_size, int):
|
||||
target_size = [target_size, target_size]
|
||||
self.target_size = target_size
|
||||
self.keep_ratio = keep_ratio
|
||||
self.interp = interp
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
assert len(self.target_size) == 2
|
||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||
im_channel = im.shape[2]
|
||||
im_scale_y, im_scale_x = self.generate_scale(im)
|
||||
# set image_shape
|
||||
im_info['input_shape'][1] = int(im_scale_y * im.shape[0])
|
||||
im_info['input_shape'][2] = int(im_scale_x * im.shape[1])
|
||||
im = cv2.resize(
|
||||
im,
|
||||
None,
|
||||
None,
|
||||
fx=im_scale_x,
|
||||
fy=im_scale_y,
|
||||
interpolation=self.interp)
|
||||
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
|
||||
im_info['scale_factor'] = np.array(
|
||||
[im_scale_y, im_scale_x]).astype('float32')
|
||||
return im, im_info
|
||||
|
||||
def generate_scale(self, im):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
Returns:
|
||||
im_scale_x: the resize ratio of X
|
||||
im_scale_y: the resize ratio of Y
|
||||
"""
|
||||
origin_shape = im.shape[:2]
|
||||
im_c = im.shape[2]
|
||||
if self.keep_ratio:
|
||||
im_size_min = np.min(origin_shape)
|
||||
im_size_max = np.max(origin_shape)
|
||||
target_size_min = np.min(self.target_size)
|
||||
target_size_max = np.max(self.target_size)
|
||||
im_scale = float(target_size_min) / float(im_size_min)
|
||||
if np.round(im_scale * im_size_max) > target_size_max:
|
||||
im_scale = float(target_size_max) / float(im_size_max)
|
||||
im_scale_x = im_scale
|
||||
im_scale_y = im_scale
|
||||
else:
|
||||
resize_h, resize_w = self.target_size
|
||||
im_scale_y = resize_h / float(origin_shape[0])
|
||||
im_scale_x = resize_w / float(origin_shape[1])
|
||||
return im_scale_y, im_scale_x
|
||||
|
||||
|
||||
class DetNormalizeImage(object):
|
||||
"""normalize image
|
||||
Args:
|
||||
mean (list): im - mean
|
||||
std (list): im / std
|
||||
is_scale (bool): whether need im / 255
|
||||
is_channel_first (bool): if True: image shape is CHW, else: HWC
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std, is_scale=True):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.is_scale = is_scale
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.astype(np.float32, copy=False)
|
||||
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
|
||||
std = np.array(self.std)[np.newaxis, np.newaxis, :]
|
||||
|
||||
if self.is_scale:
|
||||
im = im / 255.0
|
||||
|
||||
im -= mean
|
||||
im /= std
|
||||
return im, im_info
|
||||
|
||||
|
||||
class DetPermute(object):
|
||||
"""permute image
|
||||
Args:
|
||||
to_bgr (bool): whether convert RGB to BGR
|
||||
channel_first (bool): whether convert HWC to CHW
|
||||
"""
|
||||
|
||||
def __init__(self, ):
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.transpose((2, 0, 1)).copy()
|
||||
return im, im_info
|
||||
|
||||
|
||||
class DetPadStride(object):
|
||||
""" padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config
|
||||
Args:
|
||||
stride (bool): model with FPN need image shape % stride == 0
|
||||
"""
|
||||
|
||||
def __init__(self, stride=0):
|
||||
self.coarsest_stride = stride
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
coarsest_stride = self.coarsest_stride
|
||||
if coarsest_stride <= 0:
|
||||
return im, im_info
|
||||
im_c, im_h, im_w = im.shape
|
||||
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
|
||||
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
|
||||
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
|
||||
padding_im[:, :im_h, :im_w] = im
|
||||
return padding_im, im_info
|
||||
|
||||
|
||||
def det_preprocess(im, im_info, preprocess_ops):
|
||||
for operator in preprocess_ops:
|
||||
im, im_info = operator(im, im_info)
|
||||
return im, im_info
|
|
@ -0,0 +1,84 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import copy
|
||||
import importlib
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
def build_postprocess(config):
|
||||
if config is None:
|
||||
return None
|
||||
config = copy.deepcopy(config)
|
||||
model_name = config.pop("name")
|
||||
mod = importlib.import_module(__name__)
|
||||
postprocess_func = getattr(mod, model_name)(**config)
|
||||
return postprocess_func
|
||||
|
||||
|
||||
class Topk(object):
|
||||
def __init__(self, topk=1, class_id_map_file=None):
|
||||
assert isinstance(topk, (int, ))
|
||||
self.class_id_map = self.parse_class_id_map(class_id_map_file)
|
||||
self.topk = topk
|
||||
|
||||
def parse_class_id_map(self, class_id_map_file):
|
||||
if class_id_map_file is None:
|
||||
return None
|
||||
if not os.path.exists(class_id_map_file):
|
||||
print(
|
||||
"Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
class_id_map = {}
|
||||
with open(class_id_map_file, "r") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
partition = line.split("\n")[0].partition(" ")
|
||||
class_id_map[int(partition[0])] = str(partition[-1])
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
class_id_map = None
|
||||
return class_id_map
|
||||
|
||||
def __call__(self, x, file_names=None):
|
||||
if file_names is not None:
|
||||
assert x.shape[0] == len(file_names)
|
||||
y = []
|
||||
for idx, probs in enumerate(x):
|
||||
index = probs.argsort(axis=0)[-self.topk:][::-1].astype("int32")
|
||||
clas_id_list = []
|
||||
score_list = []
|
||||
label_name_list = []
|
||||
for i in index:
|
||||
clas_id_list.append(i.item())
|
||||
score_list.append(probs[i].item())
|
||||
if self.class_id_map is not None:
|
||||
label_name_list.append(self.class_id_map[i.item()])
|
||||
result = {
|
||||
"class_ids": clas_id_list,
|
||||
"scores": np.around(
|
||||
score_list, decimals=5).tolist(),
|
||||
}
|
||||
if file_names is not None:
|
||||
result["file_name"] = file_names[idx]
|
||||
if label_name_list is not None:
|
||||
result["label_names"] = label_name_list
|
||||
y.append(result)
|
||||
return y
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from utils import logger
|
||||
from utils import config
|
||||
from utils.predictor import Predictor
|
||||
from utils.get_image_list import get_image_list
|
||||
from preprocess import create_operators
|
||||
from postprocess import build_postprocess
|
||||
|
||||
|
||||
class ClsPredictor(Predictor):
|
||||
def __init__(self, config):
|
||||
super().__init__(config["Global"])
|
||||
self.preprocess_ops = create_operators(config["PreProcess"][
|
||||
"transform_ops"])
|
||||
self.postprocess = build_postprocess(config["PostProcess"])
|
||||
|
||||
def predict(self, images):
|
||||
input_names = self.paddle_predictor.get_input_names()
|
||||
input_tensor = self.paddle_predictor.get_input_handle(input_names[0])
|
||||
|
||||
output_names = self.paddle_predictor.get_output_names()
|
||||
output_tensor = self.paddle_predictor.get_output_handle(output_names[
|
||||
0])
|
||||
|
||||
if not isinstance(images, (list, )):
|
||||
images = [images]
|
||||
for idx in range(len(images)):
|
||||
for ops in self.preprocess_ops:
|
||||
images[idx] = ops(images[idx])
|
||||
image = np.array(images)
|
||||
|
||||
input_tensor.copy_from_cpu(image)
|
||||
self.paddle_predictor.run()
|
||||
batch_output = output_tensor.copy_to_cpu()
|
||||
return batch_output
|
||||
|
||||
|
||||
def main(config):
|
||||
cls_predictor = ClsPredictor(config)
|
||||
image_list = get_image_list(config["Global"]["infer_imgs"])
|
||||
|
||||
assert config["Global"]["batch_size"] == 1
|
||||
for idx, image_file in enumerate(image_list):
|
||||
img = cv2.imread(image_file)[:, :, ::-1]
|
||||
output = cls_predictor.predict(img)
|
||||
output = cls_predictor.postprocess(output)
|
||||
print(output)
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = config.parse_args()
|
||||
config = config.get_config(args.config, overrides=args.override, show=True)
|
||||
main(config)
|
|
@ -0,0 +1,158 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
|
||||
|
||||
from utils import logger
|
||||
from utils import config
|
||||
from utils.predictor import Predictor
|
||||
from utils.get_image_list import get_image_list
|
||||
from det_preprocess import det_preprocess
|
||||
from preprocess import create_operators
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import yaml
|
||||
import ast
|
||||
from functools import reduce
|
||||
import cv2
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
|
||||
class DetPredictor(Predictor):
|
||||
def __init__(self, config):
|
||||
super().__init__(config["Global"],
|
||||
config["Global"]["det_inference_model_dir"])
|
||||
|
||||
self.preprocess_ops = create_operators(config["DetPreProcess"][
|
||||
"transform_ops"])
|
||||
self.config = config
|
||||
|
||||
def preprocess(self, img):
|
||||
im_info = {
|
||||
'scale_factor': np.array(
|
||||
[1., 1.], dtype=np.float32),
|
||||
'im_shape': np.array(
|
||||
img.shape[:2], dtype=np.float32),
|
||||
'input_shape': self.config["Global"]["image_shape"],
|
||||
"scale_factor": np.array(
|
||||
[1., 1.], dtype=np.float32)
|
||||
}
|
||||
im, im_info = det_preprocess(img, im_info, self.preprocess_ops)
|
||||
inputs = self.create_inputs(im, im_info)
|
||||
return inputs
|
||||
|
||||
def create_inputs(self, im, im_info):
|
||||
"""generate input for different model type
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
model_arch (str): model type
|
||||
Returns:
|
||||
inputs (dict): input of model
|
||||
"""
|
||||
inputs = {}
|
||||
inputs['image'] = np.array((im, )).astype('float32')
|
||||
inputs['im_shape'] = np.array(
|
||||
(im_info['im_shape'], )).astype('float32')
|
||||
inputs['scale_factor'] = np.array(
|
||||
(im_info['scale_factor'], )).astype('float32')
|
||||
|
||||
return inputs
|
||||
|
||||
def parse_det_results(self, pred, threshold, label_list):
|
||||
max_det_results = self.config["Global"]["max_det_results"]
|
||||
keep_indexes = pred[:, 1].argsort()[::-1][:max_det_results]
|
||||
results = []
|
||||
for idx in keep_indexes:
|
||||
single_res = pred[idx]
|
||||
class_id = int(single_res[0])
|
||||
score = single_res[1]
|
||||
bbox = single_res[2:]
|
||||
if score < threshold:
|
||||
continue
|
||||
label_name = label_list[class_id]
|
||||
results.append({
|
||||
"class_id": class_id,
|
||||
"score": score,
|
||||
"bbox": bbox,
|
||||
"label_name": label_name,
|
||||
})
|
||||
return results
|
||||
|
||||
def predict(self, image, threshold=0.5, run_benchmark=False):
|
||||
'''
|
||||
Args:
|
||||
image (str/np.ndarray): path of image/ np.ndarray read by cv2
|
||||
threshold (float): threshold of predicted box' score
|
||||
Returns:
|
||||
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
|
||||
matix element:[class, score, x_min, y_min, x_max, y_max]
|
||||
MaskRCNN's results include 'masks': np.ndarray:
|
||||
shape: [N, im_h, im_w]
|
||||
'''
|
||||
inputs = self.preprocess(image)
|
||||
np_boxes = None
|
||||
input_names = self.paddle_predictor.get_input_names()
|
||||
|
||||
for i in range(len(input_names)):
|
||||
input_tensor = self.paddle_predictor.get_input_handle(input_names[
|
||||
i])
|
||||
input_tensor.copy_from_cpu(inputs[input_names[i]])
|
||||
|
||||
t1 = time.time()
|
||||
self.paddle_predictor.run()
|
||||
output_names = self.paddle_predictor.get_output_names()
|
||||
boxes_tensor = self.paddle_predictor.get_output_handle(output_names[0])
|
||||
np_boxes = boxes_tensor.copy_to_cpu()
|
||||
t2 = time.time()
|
||||
|
||||
print("Inference: {} ms per batch image".format((t2 - t1) * 1000.0))
|
||||
|
||||
# do not perform postprocess in benchmark mode
|
||||
results = []
|
||||
if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
|
||||
print('[WARNNING] No object detected.')
|
||||
results = np.array([])
|
||||
else:
|
||||
results = np_boxes
|
||||
|
||||
results = self.parse_det_results(results,
|
||||
self.config["Global"]["threshold"],
|
||||
self.config["Global"]["labe_list"])
|
||||
return results
|
||||
|
||||
|
||||
def main(config):
|
||||
det_predictor = DetPredictor(config)
|
||||
image_list = get_image_list(config["Global"]["infer_imgs"])
|
||||
|
||||
assert config["Global"]["batch_size"] == 1
|
||||
for idx, image_file in enumerate(image_list):
|
||||
img = cv2.imread(image_file)[:, :, ::-1]
|
||||
output = det_predictor.predict(img)
|
||||
print(output)
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = config.parse_args()
|
||||
config = config.get_config(args.config, overrides=args.override, show=True)
|
||||
main(config)
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from utils import logger
|
||||
from utils import config
|
||||
from utils.predictor import Predictor
|
||||
from utils.get_image_list import get_image_list
|
||||
from preprocess import create_operators
|
||||
from postprocess import build_postprocess
|
||||
|
||||
|
||||
class RecPredictor(Predictor):
|
||||
def __init__(self, config):
|
||||
super().__init__(config["Global"],
|
||||
config["Global"]["rec_inference_model_dir"])
|
||||
self.preprocess_ops = create_operators(config["RecPreProcess"][
|
||||
"transform_ops"])
|
||||
self.postprocess = build_postprocess(config["RecPostProcess"])
|
||||
|
||||
def predict(self, images):
|
||||
input_names = self.paddle_predictor.get_input_names()
|
||||
input_tensor = self.paddle_predictor.get_input_handle(input_names[0])
|
||||
|
||||
output_names = self.paddle_predictor.get_output_names()
|
||||
output_tensor = self.paddle_predictor.get_output_handle(output_names[
|
||||
0])
|
||||
|
||||
if not isinstance(images, (list, )):
|
||||
images = [images]
|
||||
for idx in range(len(images)):
|
||||
for ops in self.preprocess_ops:
|
||||
images[idx] = ops(images[idx])
|
||||
image = np.array(images)
|
||||
|
||||
input_tensor.copy_from_cpu(image)
|
||||
self.paddle_predictor.run()
|
||||
batch_output = output_tensor.copy_to_cpu()
|
||||
return batch_output
|
||||
|
||||
|
||||
def main(config):
|
||||
rec_predictor = RecPredictor(config)
|
||||
image_list = get_image_list(config["Global"]["infer_imgs"])
|
||||
|
||||
assert config["Global"]["batch_size"] == 1
|
||||
for idx, image_file in enumerate(image_list):
|
||||
batch_input = []
|
||||
img = cv2.imread(image_file)[:, :, ::-1]
|
||||
output = rec_predictor.predict(img)
|
||||
if rec_predictor.postprocess is not None:
|
||||
output = rec_predictor.postprocess(output)
|
||||
print(output.shape)
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = config.parse_args()
|
||||
config = config.get_config(args.config, overrides=args.override, show=True)
|
||||
main(config)
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
|
||||
|
||||
import copy
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from python.predict_rec import RecPredictor
|
||||
from python.predict_det import DetPredictor
|
||||
from vector_search import Graph_Index
|
||||
|
||||
from utils import logger
|
||||
from utils import config
|
||||
from utils.get_image_list import get_image_list
|
||||
|
||||
|
||||
def split_datafile(data_file, image_root):
|
||||
gallery_images = []
|
||||
gallery_docs = []
|
||||
with open(datafile) as f:
|
||||
lines = f.readlines()
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip().split("\t")
|
||||
if line[0] == 'image_id':
|
||||
continue
|
||||
image_file = os.path.join(image_root, line[3])
|
||||
image_doc = line[1]
|
||||
gallery_images.append(image_file)
|
||||
gallery_docs.append(image_doc)
|
||||
return gallery_images, gallery_docs
|
||||
|
||||
|
||||
class SystemPredictor(object):
|
||||
def __init__(self, config):
|
||||
|
||||
self.config = config
|
||||
self.rec_predictor = RecPredictor(config)
|
||||
self.det_predictor = DetPredictor(config)
|
||||
|
||||
assert 'IndexProcess' in config.keys(), "Index config not found ... "
|
||||
self.indexer(config['IndexProcess'])
|
||||
self.return_k = self.config['IndexProcess']['infer']['return_k']
|
||||
self.search_budget = self.config['IndexProcess']['infer']['search_budget']
|
||||
|
||||
def indexer(self, config):
|
||||
if 'build' in config.keys() and config['build']['enable']: # build the index from scratch
|
||||
with open(config['build']['datafile']) as f:
|
||||
lines = f.readlines()
|
||||
gallery_images, gallery_docs = split_datafile(config['build']['data_file'], config['build']['image_root'])
|
||||
# extract gallery features
|
||||
gallery_features = np.zeros([len(gallery_images), config['build']['embedding_size']], dtype=np.float32)
|
||||
for i, image_file in enumerate(gallery_images):
|
||||
img = cv2.imread(image_file)[:, :, ::-1]
|
||||
rec_feat = self.rec_predictor.predict(img)
|
||||
gallery_features[i,:] = rec_feat
|
||||
# train index
|
||||
self.Searcher = Graph_Index(dist_type=config['build']['dist_type'])
|
||||
self.Searcher.build(gallery_vectors=gallery_features, gallery_docs=gallery_docs,
|
||||
pq_size=config['build']['pq_size'], index_path=config['build']['index_path'])
|
||||
|
||||
else: # load local index
|
||||
self.Searcher = Graph_Index(dist_type=config['build']['dist_type'])
|
||||
self.Searcher.load(config['infer']['index_path'])
|
||||
|
||||
def predict(self, img):
|
||||
output = []
|
||||
results = self.det_predictor.predict(img)
|
||||
for result in results:
|
||||
xmin, ymin, xmax, ymax = result["bbox"].astype("int")
|
||||
crop_img = img[xmin:xmax, ymin:ymax, :].copy()
|
||||
rec_results = self.rec_predictor.predict(crop_img)
|
||||
result["featrue"] = rec_results
|
||||
|
||||
scores, docs = self.Searcher.search(query=rec_results, return_k=self.return_k, search_budget=self.search_budget)
|
||||
result["ret_docs"] = docs
|
||||
result["ret_scores"] = scores
|
||||
|
||||
output.append(result)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
def main(config):
|
||||
system_predictor = SystemPredictor(config)
|
||||
image_list = get_image_list(config["Global"]["infer_imgs"])
|
||||
|
||||
assert config["Global"]["batch_size"] == 1
|
||||
for idx, image_file in enumerate(image_list):
|
||||
img = cv2.imread(image_file)[:, :, ::-1]
|
||||
output = system_predictor.predict(img)
|
||||
#print(output)
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = config.parse_args()
|
||||
config = config.get_config(args.config, overrides=args.override, show=True)
|
||||
main(config)
|
|
@ -0,0 +1,278 @@
|
|||
"""
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import six
|
||||
import math
|
||||
import random
|
||||
import cv2
|
||||
import numpy as np
|
||||
import importlib
|
||||
|
||||
from det_preprocess import DetNormalizeImage, DetPadStride, DetPermute, DetResize
|
||||
|
||||
|
||||
def create_operators(params):
|
||||
"""
|
||||
create operators based on the config
|
||||
|
||||
Args:
|
||||
params(list): a dict list, used to create some operators
|
||||
"""
|
||||
assert isinstance(params, list), ('operator config should be a list')
|
||||
mod = importlib.import_module(__name__)
|
||||
ops = []
|
||||
for operator in params:
|
||||
assert isinstance(operator,
|
||||
dict) and len(operator) == 1, "yaml format error"
|
||||
op_name = list(operator)[0]
|
||||
param = {} if operator[op_name] is None else operator[op_name]
|
||||
op = getattr(mod, op_name)(**param)
|
||||
ops.append(op)
|
||||
|
||||
return ops
|
||||
|
||||
|
||||
class OperatorParamError(ValueError):
|
||||
""" OperatorParamError
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DecodeImage(object):
|
||||
""" decode image """
|
||||
|
||||
def __init__(self, to_rgb=True, to_np=False, channel_first=False):
|
||||
self.to_rgb = to_rgb
|
||||
self.to_np = to_np # to numpy
|
||||
self.channel_first = channel_first # only enabled when to_np is True
|
||||
|
||||
def __call__(self, img):
|
||||
if six.PY2:
|
||||
assert type(img) is str and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
else:
|
||||
assert type(img) is bytes and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
data = np.frombuffer(img, dtype='uint8')
|
||||
img = cv2.imdecode(data, 1)
|
||||
if self.to_rgb:
|
||||
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
|
||||
img.shape)
|
||||
img = img[:, :, ::-1]
|
||||
|
||||
if self.channel_first:
|
||||
img = img.transpose((2, 0, 1))
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class ResizeImage(object):
|
||||
""" resize image """
|
||||
|
||||
def __init__(self, size=None, resize_short=None, interpolation=-1):
|
||||
self.interpolation = interpolation if interpolation >= 0 else None
|
||||
if resize_short is not None and resize_short > 0:
|
||||
self.resize_short = resize_short
|
||||
self.w = None
|
||||
self.h = None
|
||||
elif size is not None:
|
||||
self.resize_short = None
|
||||
self.w = size if type(size) is int else size[0]
|
||||
self.h = size if type(size) is int else size[1]
|
||||
else:
|
||||
raise OperatorParamError("invalid params for ReisizeImage for '\
|
||||
'both 'size' and 'resize_short' are None")
|
||||
|
||||
def __call__(self, img):
|
||||
img_h, img_w = img.shape[:2]
|
||||
if self.resize_short is not None:
|
||||
percent = float(self.resize_short) / min(img_w, img_h)
|
||||
w = int(round(img_w * percent))
|
||||
h = int(round(img_h * percent))
|
||||
else:
|
||||
w = self.w
|
||||
h = self.h
|
||||
if self.interpolation is None:
|
||||
return cv2.resize(img, (w, h))
|
||||
else:
|
||||
return cv2.resize(img, (w, h), interpolation=self.interpolation)
|
||||
|
||||
|
||||
class CropImage(object):
|
||||
""" crop image """
|
||||
|
||||
def __init__(self, size):
|
||||
if type(size) is int:
|
||||
self.size = (size, size)
|
||||
else:
|
||||
self.size = size # (h, w)
|
||||
|
||||
def __call__(self, img):
|
||||
w, h = self.size
|
||||
img_h, img_w = img.shape[:2]
|
||||
w_start = (img_w - w) // 2
|
||||
h_start = (img_h - h) // 2
|
||||
|
||||
w_end = w_start + w
|
||||
h_end = h_start + h
|
||||
return img[h_start:h_end, w_start:w_end, :]
|
||||
|
||||
|
||||
class RandCropImage(object):
|
||||
""" random crop image """
|
||||
|
||||
def __init__(self, size, scale=None, ratio=None, interpolation=-1):
|
||||
|
||||
self.interpolation = interpolation if interpolation >= 0 else None
|
||||
if type(size) is int:
|
||||
self.size = (size, size) # (h, w)
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
self.scale = [0.08, 1.0] if scale is None else scale
|
||||
self.ratio = [3. / 4., 4. / 3.] if ratio is None else ratio
|
||||
|
||||
def __call__(self, img):
|
||||
size = self.size
|
||||
scale = self.scale
|
||||
ratio = self.ratio
|
||||
|
||||
aspect_ratio = math.sqrt(random.uniform(*ratio))
|
||||
w = 1. * aspect_ratio
|
||||
h = 1. / aspect_ratio
|
||||
|
||||
img_h, img_w = img.shape[:2]
|
||||
|
||||
bound = min((float(img_w) / img_h) / (w**2),
|
||||
(float(img_h) / img_w) / (h**2))
|
||||
scale_max = min(scale[1], bound)
|
||||
scale_min = min(scale[0], bound)
|
||||
|
||||
target_area = img_w * img_h * random.uniform(scale_min, scale_max)
|
||||
target_size = math.sqrt(target_area)
|
||||
w = int(target_size * w)
|
||||
h = int(target_size * h)
|
||||
|
||||
i = random.randint(0, img_w - w)
|
||||
j = random.randint(0, img_h - h)
|
||||
|
||||
img = img[j:j + h, i:i + w, :]
|
||||
if self.interpolation is None:
|
||||
return cv2.resize(img, size)
|
||||
else:
|
||||
return cv2.resize(img, size, interpolation=self.interpolation)
|
||||
|
||||
|
||||
class RandFlipImage(object):
|
||||
""" random flip image
|
||||
flip_code:
|
||||
1: Flipped Horizontally
|
||||
0: Flipped Vertically
|
||||
-1: Flipped Horizontally & Vertically
|
||||
"""
|
||||
|
||||
def __init__(self, flip_code=1):
|
||||
assert flip_code in [-1, 0, 1
|
||||
], "flip_code should be a value in [-1, 0, 1]"
|
||||
self.flip_code = flip_code
|
||||
|
||||
def __call__(self, img):
|
||||
if random.randint(0, 1) == 1:
|
||||
return cv2.flip(img, self.flip_code)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
class AutoAugment(object):
|
||||
def __init__(self):
|
||||
self.policy = ImageNetPolicy()
|
||||
|
||||
def __call__(self, img):
|
||||
from PIL import Image
|
||||
img = np.ascontiguousarray(img)
|
||||
img = Image.fromarray(img)
|
||||
img = self.policy(img)
|
||||
img = np.asarray(img)
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
""" normalize image such as substract mean, divide std
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
scale=None,
|
||||
mean=None,
|
||||
std=None,
|
||||
order='chw',
|
||||
output_fp16=False,
|
||||
channel_num=3):
|
||||
if isinstance(scale, str):
|
||||
scale = eval(scale)
|
||||
assert channel_num in [
|
||||
3, 4
|
||||
], "channel number of input image should be set to 3 or 4."
|
||||
self.channel_num = channel_num
|
||||
self.output_dtype = 'float16' if output_fp16 else 'float32'
|
||||
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
||||
self.order = order
|
||||
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
||||
std = std if std is not None else [0.229, 0.224, 0.225]
|
||||
|
||||
shape = (3, 1, 1) if self.order == 'chw' else (1, 1, 3)
|
||||
self.mean = np.array(mean).reshape(shape).astype('float32')
|
||||
self.std = np.array(std).reshape(shape).astype('float32')
|
||||
|
||||
def __call__(self, img):
|
||||
from PIL import Image
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
|
||||
assert isinstance(img,
|
||||
np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||
|
||||
img = (img.astype('float32') * self.scale - self.mean) / self.std
|
||||
|
||||
if self.channel_num == 4:
|
||||
img_h = img.shape[1] if self.order == 'chw' else img.shape[0]
|
||||
img_w = img.shape[2] if self.order == 'chw' else img.shape[1]
|
||||
pad_zeros = np.zeros(
|
||||
(1, img_h, img_w)) if self.order == 'chw' else np.zeros(
|
||||
(img_h, img_w, 1))
|
||||
img = (np.concatenate(
|
||||
(img, pad_zeros), axis=0)
|
||||
if self.order == 'chw' else np.concatenate(
|
||||
(img, pad_zeros), axis=2))
|
||||
return img.astype(self.output_dtype)
|
||||
|
||||
|
||||
class ToCHWImage(object):
|
||||
""" convert hwc image to chw image
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, img):
|
||||
from PIL import Image
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
|
||||
return img.transpose((2, 0, 1))
|
|
@ -0,0 +1,11 @@
|
|||
# classification
|
||||
python3.7 python/predict_cls.py -c configs/inference_cls.yaml
|
||||
|
||||
# feature extractor
|
||||
# python3.7 python/predict_rec.py -c configs/inference_rec.yaml
|
||||
|
||||
# detection
|
||||
# python3.7 python/predict_det.py -c configs/inference_rec.yaml
|
||||
|
||||
# mainbody detection + feature extractor + retrieval
|
||||
# python3.7 python/predict_system.py -c configs/inference_rec.yaml
|
|
@ -0,0 +1,4 @@
|
|||
from . import logger
|
||||
from . import config
|
||||
from . import get_image_list
|
||||
from . import predictor
|
|
@ -0,0 +1,186 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import copy
|
||||
import argparse
|
||||
import yaml
|
||||
from utils import logger
|
||||
__all__ = ['get_config']
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __getattr__(self, key):
|
||||
return self[key]
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in self.__dict__:
|
||||
self.__dict__[key] = value
|
||||
else:
|
||||
self[key] = value
|
||||
|
||||
def __deepcopy__(self, content):
|
||||
return copy.deepcopy(dict(self))
|
||||
|
||||
|
||||
def create_attr_dict(yaml_config):
|
||||
from ast import literal_eval
|
||||
for key, value in yaml_config.items():
|
||||
if type(value) is dict:
|
||||
yaml_config[key] = value = AttrDict(value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
value = literal_eval(value)
|
||||
except BaseException:
|
||||
pass
|
||||
if isinstance(value, AttrDict):
|
||||
create_attr_dict(yaml_config[key])
|
||||
else:
|
||||
yaml_config[key] = value
|
||||
|
||||
|
||||
def parse_config(cfg_file):
|
||||
"""Load a config file into AttrDict"""
|
||||
with open(cfg_file, 'r') as fopen:
|
||||
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader))
|
||||
create_attr_dict(yaml_config)
|
||||
return yaml_config
|
||||
|
||||
|
||||
def print_dict(d, delimiter=0):
|
||||
"""
|
||||
Recursively visualize a dict and
|
||||
indenting acrrording by the relationship of keys.
|
||||
"""
|
||||
placeholder = "-" * 60
|
||||
for k, v in sorted(d.items()):
|
||||
if isinstance(v, dict):
|
||||
logger.info("{}{} : ".format(delimiter * " ",
|
||||
logger.coloring(k, "HEADER")))
|
||||
print_dict(v, delimiter + 4)
|
||||
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
|
||||
logger.info("{}{} : ".format(delimiter * " ",
|
||||
logger.coloring(str(k), "HEADER")))
|
||||
for value in v:
|
||||
print_dict(value, delimiter + 4)
|
||||
else:
|
||||
logger.info("{}{} : {}".format(delimiter * " ",
|
||||
logger.coloring(k, "HEADER"),
|
||||
logger.coloring(v, "OKGREEN")))
|
||||
if k.isupper():
|
||||
logger.info(placeholder)
|
||||
|
||||
|
||||
def print_config(config):
|
||||
"""
|
||||
visualize configs
|
||||
Arguments:
|
||||
config: configs
|
||||
"""
|
||||
logger.advertise()
|
||||
print_dict(config)
|
||||
|
||||
|
||||
def override(dl, ks, v):
|
||||
"""
|
||||
Recursively replace dict of list
|
||||
Args:
|
||||
dl(dict or list): dict or list to be replaced
|
||||
ks(list): list of keys
|
||||
v(str): value to be replaced
|
||||
"""
|
||||
|
||||
def str2num(v):
|
||||
try:
|
||||
return eval(v)
|
||||
except Exception:
|
||||
return v
|
||||
|
||||
assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
|
||||
assert len(ks) > 0, ('lenght of keys should larger than 0')
|
||||
if isinstance(dl, list):
|
||||
k = str2num(ks[0])
|
||||
if len(ks) == 1:
|
||||
assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
|
||||
dl[k] = str2num(v)
|
||||
else:
|
||||
override(dl[k], ks[1:], v)
|
||||
else:
|
||||
if len(ks) == 1:
|
||||
# assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
|
||||
if not ks[0] in dl:
|
||||
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
|
||||
dl[ks[0]] = str2num(v)
|
||||
else:
|
||||
override(dl[ks[0]], ks[1:], v)
|
||||
|
||||
|
||||
def override_config(config, options=None):
|
||||
"""
|
||||
Recursively override the config
|
||||
Args:
|
||||
config(dict): dict to be replaced
|
||||
options(list): list of pairs(key0.key1.idx.key2=value)
|
||||
such as: [
|
||||
'topk=2',
|
||||
'VALID.transforms.1.ResizeImage.resize_short=300'
|
||||
]
|
||||
Returns:
|
||||
config(dict): replaced config
|
||||
"""
|
||||
if options is not None:
|
||||
for opt in options:
|
||||
assert isinstance(opt, str), (
|
||||
"option({}) should be a str".format(opt))
|
||||
assert "=" in opt, (
|
||||
"option({}) should contain a ="
|
||||
"to distinguish between key and value".format(opt))
|
||||
pair = opt.split('=')
|
||||
assert len(pair) == 2, ("there can be only a = in the option")
|
||||
key, value = pair
|
||||
keys = key.split('.')
|
||||
override(config, keys, value)
|
||||
return config
|
||||
|
||||
|
||||
def get_config(fname, overrides=None, show=True):
|
||||
"""
|
||||
Read config from file
|
||||
"""
|
||||
assert os.path.exists(fname), (
|
||||
'config file({}) is not exist'.format(fname))
|
||||
config = parse_config(fname)
|
||||
override_config(config, overrides)
|
||||
if show:
|
||||
print_config(config)
|
||||
# check_config(config)
|
||||
return config
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("generic-image-rec train script")
|
||||
parser.add_argument(
|
||||
'-c',
|
||||
'--config',
|
||||
type=str,
|
||||
default='configs/config.yaml',
|
||||
help='config file path')
|
||||
parser.add_argument(
|
||||
'-o',
|
||||
'--override',
|
||||
action='append',
|
||||
default=[],
|
||||
help='config options to be overridden')
|
||||
args = parser.parse_args()
|
||||
return args
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import base64
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_image_list(img_file):
|
||||
imgs_lists = []
|
||||
if img_file is None or not os.path.exists(img_file):
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
|
||||
img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp']
|
||||
if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end:
|
||||
imgs_lists.append(img_file)
|
||||
elif os.path.isdir(img_file):
|
||||
for single_file in os.listdir(img_file):
|
||||
if single_file.split('.')[-1] in img_end:
|
||||
imgs_lists.append(os.path.join(img_file, single_file))
|
||||
if len(imgs_lists) == 0:
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
imgs_lists = sorted(imgs_lists)
|
||||
return imgs_lists
|
||||
|
||||
|
||||
def get_image_list_from_label_file(image_path, label_file_path):
|
||||
imgs_lists = []
|
||||
gt_labels = []
|
||||
with open(label_file_path, "r") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
image_name, label = line.strip("\n").split()
|
||||
label = int(label)
|
||||
imgs_lists.append(os.path.join(image_path, image_name))
|
||||
gt_labels.append(int(label))
|
||||
return imgs_lists, gt_labels
|
|
@ -0,0 +1,120 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import datetime
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def time_zone(sec, fmt):
|
||||
real_time = datetime.datetime.now()
|
||||
return real_time.timetuple()
|
||||
|
||||
|
||||
logging.Formatter.converter = time_zone
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
Color = {
|
||||
'RED': '\033[31m',
|
||||
'HEADER': '\033[35m', # deep purple
|
||||
'PURPLE': '\033[95m', # purple
|
||||
'OKBLUE': '\033[94m',
|
||||
'OKGREEN': '\033[92m',
|
||||
'WARNING': '\033[93m',
|
||||
'FAIL': '\033[91m',
|
||||
'ENDC': '\033[0m'
|
||||
}
|
||||
|
||||
|
||||
def coloring(message, color="OKGREEN"):
|
||||
assert color in Color.keys()
|
||||
if os.environ.get('PADDLECLAS_COLORING', False):
|
||||
return Color[color] + str(message) + Color["ENDC"]
|
||||
else:
|
||||
return message
|
||||
|
||||
|
||||
def anti_fleet(log):
|
||||
"""
|
||||
logs will print multi-times when calling Fleet API.
|
||||
Only display single log and ignore the others.
|
||||
"""
|
||||
|
||||
def wrapper(fmt, *args):
|
||||
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
|
||||
log(fmt, *args)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@anti_fleet
|
||||
def info(fmt, *args):
|
||||
_logger.info(fmt, *args)
|
||||
|
||||
|
||||
@anti_fleet
|
||||
def warning(fmt, *args):
|
||||
_logger.warning(coloring(fmt, "RED"), *args)
|
||||
|
||||
|
||||
@anti_fleet
|
||||
def error(fmt, *args):
|
||||
_logger.error(coloring(fmt, "FAIL"), *args)
|
||||
|
||||
|
||||
def scaler(name, value, step, writer):
|
||||
"""
|
||||
This function will draw a scalar curve generated by the visualdl.
|
||||
Usage: Install visualdl: pip3 install visualdl==2.0.0b4
|
||||
and then:
|
||||
visualdl --logdir ./scalar --host 0.0.0.0 --port 8830
|
||||
to preview loss corve in real time.
|
||||
"""
|
||||
writer.add_scalar(tag=name, step=step, value=value)
|
||||
|
||||
|
||||
def advertise():
|
||||
"""
|
||||
Show the advertising message like the following:
|
||||
|
||||
===========================================================
|
||||
== PaddleClas is powered by PaddlePaddle ! ==
|
||||
===========================================================
|
||||
== ==
|
||||
== For more info please go to the following website. ==
|
||||
== ==
|
||||
== https://github.com/PaddlePaddle/PaddleClas ==
|
||||
===========================================================
|
||||
|
||||
"""
|
||||
copyright = "PaddleClas is powered by PaddlePaddle !"
|
||||
ad = "For more info please go to the following website."
|
||||
website = "https://github.com/PaddlePaddle/PaddleClas"
|
||||
AD_LEN = 6 + len(max([copyright, ad, website], key=len))
|
||||
|
||||
info(
|
||||
coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
|
||||
"=" * (AD_LEN + 4),
|
||||
"=={}==".format(copyright.center(AD_LEN)),
|
||||
"=" * (AD_LEN + 4),
|
||||
"=={}==".format(' ' * AD_LEN),
|
||||
"=={}==".format(ad.center(AD_LEN)),
|
||||
"=={}==".format(' ' * AD_LEN),
|
||||
"=={}==".format(website.center(AD_LEN)),
|
||||
"=" * (AD_LEN + 4), ), "RED"))
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import argparse
|
||||
import base64
|
||||
import shutil
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from paddle.inference import Config
|
||||
from paddle.inference import create_predictor
|
||||
|
||||
|
||||
class Predictor(object):
|
||||
def __init__(self, args, inference_model_dir=None):
|
||||
# HALF precission predict only work when using tensorrt
|
||||
if args.use_fp16 is True:
|
||||
assert args.use_tensorrt is True
|
||||
self.args = args
|
||||
self.paddle_predictor = self.create_paddle_predictor(
|
||||
args, inference_model_dir)
|
||||
|
||||
def predict(self, image):
|
||||
raise NotImplementedError
|
||||
|
||||
def create_paddle_predictor(self, args, inference_model_dir=None):
|
||||
if inference_model_dir is None:
|
||||
inference_model_dir = args.inference_model_dir
|
||||
params_file = os.path.join(inference_model_dir, "inference.pdiparams")
|
||||
model_file = os.path.join(inference_model_dir, "inference.pdmodel")
|
||||
config = Config(model_file, params_file)
|
||||
|
||||
if args.use_gpu:
|
||||
config.enable_use_gpu(args.gpu_mem, 0)
|
||||
else:
|
||||
config.disable_gpu()
|
||||
if args.enable_mkldnn:
|
||||
# cache 10 different shapes for mkldnn to avoid memory leak
|
||||
config.set_mkldnn_cache_capacity(10)
|
||||
config.enable_mkldnn()
|
||||
config.set_cpu_math_library_num_threads(args.cpu_num_threads)
|
||||
|
||||
if args.enable_profile:
|
||||
config.enable_profile()
|
||||
config.disable_glog_info()
|
||||
config.switch_ir_optim(args.ir_optim) # default true
|
||||
if args.use_tensorrt:
|
||||
config.enable_tensorrt_engine(
|
||||
precision_mode=Config.Precision.Half
|
||||
if args.use_fp16 else Config.Precision.Float32,
|
||||
max_batch_size=args.batch_size)
|
||||
|
||||
config.enable_memory_optim()
|
||||
# use zero copy
|
||||
config.switch_use_feed_fetch_ops(False)
|
||||
predictor = create_predictor(config)
|
||||
|
||||
return predictor
|
|
@ -22,7 +22,9 @@ import json
|
|||
from ctypes import *
|
||||
from numpy.ctypeslib import ndpointer
|
||||
|
||||
lib = ctypes.cdll.LoadLibrary("./index.so")
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
so_path = os.path.join(__dir__, "index.so")
|
||||
lib = ctypes.cdll.LoadLibrary(so_path)
|
||||
|
||||
class IndexContext(Structure):
|
||||
_fields_=[("graph",c_void_p),
|
||||
|
|
|
@ -1,140 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import time
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, ".")
|
||||
from ppcls.utils import logger
|
||||
from tools.infer.utils import parse_args, create_paddle_predictor, preprocess, postprocess
|
||||
from tools.infer.utils import get_image_list, get_image_list_from_label_file, calc_topk_acc
|
||||
|
||||
|
||||
class Predictor(object):
|
||||
def __init__(self, args):
|
||||
# HALF precission predict only work when using tensorrt
|
||||
if args.use_fp16 is True:
|
||||
assert args.use_tensorrt is True
|
||||
self.args = args
|
||||
|
||||
self.paddle_predictor = create_paddle_predictor(args)
|
||||
input_names = self.paddle_predictor.get_input_names()
|
||||
self.input_tensor = self.paddle_predictor.get_input_handle(input_names[
|
||||
0])
|
||||
|
||||
output_names = self.paddle_predictor.get_output_names()
|
||||
self.output_tensor = self.paddle_predictor.get_output_handle(
|
||||
output_names[0])
|
||||
|
||||
def predict(self, batch_input):
|
||||
self.input_tensor.copy_from_cpu(batch_input)
|
||||
self.paddle_predictor.run()
|
||||
batch_output = self.output_tensor.copy_to_cpu()
|
||||
return batch_output
|
||||
|
||||
def normal_predict(self):
|
||||
if self.args.enable_calc_topk:
|
||||
assert self.args.gt_label_path is not None and os.path.exists(self.args.gt_label_path), \
|
||||
"gt_label_path shoule not be None and must exist, please check its path."
|
||||
image_list, gt_labels = get_image_list_from_label_file(
|
||||
self.args.image_file, self.args.gt_label_path)
|
||||
predicts_map = {
|
||||
"prediction": [],
|
||||
"gt_label": [],
|
||||
}
|
||||
else:
|
||||
image_list = get_image_list(self.args.image_file)
|
||||
gt_labels = None
|
||||
|
||||
batch_input_list = []
|
||||
img_name_list = []
|
||||
cnt = 0
|
||||
for idx, img_path in enumerate(image_list):
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
logger.warning(
|
||||
"Image file failed to read and has been skipped. The path: {}".
|
||||
format(img_path))
|
||||
continue
|
||||
else:
|
||||
img = img[:, :, ::-1]
|
||||
img = preprocess(img, args)
|
||||
batch_input_list.append(img)
|
||||
img_name = img_path.split("/")[-1]
|
||||
img_name_list.append(img_name)
|
||||
cnt += 1
|
||||
if self.args.enable_calc_topk:
|
||||
predicts_map["gt_label"].append(gt_labels[idx])
|
||||
|
||||
if cnt % args.batch_size == 0 or (idx + 1) == len(image_list):
|
||||
batch_outputs = self.predict(np.array(batch_input_list))
|
||||
batch_result_list = postprocess(batch_outputs, self.args.top_k)
|
||||
|
||||
for number, result_dict in enumerate(batch_result_list):
|
||||
filename = img_name_list[number]
|
||||
clas_ids = result_dict["clas_ids"]
|
||||
scores_str = "[{}]".format(", ".join("{:.2f}".format(
|
||||
r) for r in result_dict["scores"]))
|
||||
logger.info(
|
||||
"File:{}, Top-{} result: class id(s): {}, score(s): {}".
|
||||
format(filename, self.args.top_k, clas_ids,
|
||||
scores_str))
|
||||
|
||||
if self.args.enable_calc_topk:
|
||||
predicts_map["prediction"].append(clas_ids)
|
||||
|
||||
batch_input_list = []
|
||||
img_name_list = []
|
||||
if self.args.enable_calc_topk:
|
||||
topk_acc = calc_topk_acc(predicts_map)
|
||||
for idx, acc in enumerate(topk_acc):
|
||||
logger.info("Top-{} acc: {:.5f}".format(idx + 1, acc))
|
||||
|
||||
def benchmark_predict(self):
|
||||
test_num = 500
|
||||
test_time = 0.0
|
||||
for i in range(0, test_num + 10):
|
||||
inputs = np.random.rand(args.batch_size, 3, 224,
|
||||
224).astype(np.float32)
|
||||
start_time = time.time()
|
||||
batch_output = self.predict(inputs).flatten()
|
||||
if i >= 10:
|
||||
test_time += time.time() - start_time
|
||||
time.sleep(0.01) # sleep for T4 GPU
|
||||
|
||||
fp_message = "FP16" if args.use_fp16 else "FP32"
|
||||
trt_msg = "using tensorrt" if args.use_tensorrt else "not using tensorrt"
|
||||
print("{0}\t{1}\t{2}\tbatch size: {3}\ttime(ms): {4}".format(
|
||||
args.model, trt_msg, fp_message, args.batch_size, 1000 * test_time
|
||||
/ test_num))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
assert os.path.exists(
|
||||
args.model_file), "The path of 'model_file' does not exist: {}".format(
|
||||
args.model_file)
|
||||
assert os.path.exists(
|
||||
args.params_file
|
||||
), "The path of 'params_file' does not exist: {}".format(args.params_file)
|
||||
|
||||
predictor = Predictor(args)
|
||||
if not args.enable_benchmark:
|
||||
predictor.normal_predict()
|
||||
else:
|
||||
assert args.model is not None
|
||||
predictor.benchmark_predict()
|
|
@ -1,274 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import base64
|
||||
import shutil
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from paddle.inference import Config
|
||||
from paddle.inference import create_predictor
|
||||
|
||||
|
||||
def parse_args():
|
||||
def str2bool(v):
|
||||
return v.lower() in ("true", "t", "1")
|
||||
|
||||
# general params
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--image_file", type=str)
|
||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
parser.add_argument("--multilabel", type=str2bool, default=False)
|
||||
|
||||
# params for preprocess
|
||||
parser.add_argument("--resize_short", type=int, default=256)
|
||||
parser.add_argument("--resize", type=int, default=224)
|
||||
parser.add_argument("--normalize", type=str2bool, default=True)
|
||||
|
||||
# params for predict
|
||||
parser.add_argument("--model_file", type=str)
|
||||
parser.add_argument("--params_file", type=str)
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=1)
|
||||
parser.add_argument("--use_fp16", type=str2bool, default=False)
|
||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||
parser.add_argument("--gpu_mem", type=int, default=8000)
|
||||
parser.add_argument("--enable_profile", type=str2bool, default=False)
|
||||
parser.add_argument("--enable_benchmark", type=str2bool, default=False)
|
||||
parser.add_argument("--top_k", type=int, default=1)
|
||||
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
|
||||
parser.add_argument("--cpu_num_threads", type=int, default=10)
|
||||
parser.add_argument("--hubserving", type=str2bool, default=False)
|
||||
|
||||
# params for infer
|
||||
parser.add_argument("--model", type=str)
|
||||
parser.add_argument("--pretrained_model", type=str)
|
||||
parser.add_argument("--class_num", type=int, default=1000)
|
||||
parser.add_argument(
|
||||
"--load_static_weights",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help='Whether to load the pretrained weights saved in static mode')
|
||||
|
||||
# parameters for pre-label the images
|
||||
parser.add_argument(
|
||||
"--pre_label_image",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to pre-label the images using the loaded weights")
|
||||
parser.add_argument("--pre_label_out_idr", type=str, default=None)
|
||||
|
||||
# parameters for test hubserving
|
||||
parser.add_argument("--server_url", type=str)
|
||||
|
||||
# enable_calc_metric, when set as true, topk acc will be calculated
|
||||
parser.add_argument("--enable_calc_topk", type=str2bool, default=False)
|
||||
# groudtruth label path
|
||||
# data format for each line: $image_name $class_id
|
||||
parser.add_argument("--gt_label_path", type=str, default=None)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def create_paddle_predictor(args):
|
||||
config = Config(args.model_file, args.params_file)
|
||||
|
||||
if args.use_gpu:
|
||||
config.enable_use_gpu(args.gpu_mem, 0)
|
||||
else:
|
||||
config.disable_gpu()
|
||||
if args.enable_mkldnn:
|
||||
# cache 10 different shapes for mkldnn to avoid memory leak
|
||||
config.set_mkldnn_cache_capacity(10)
|
||||
config.enable_mkldnn()
|
||||
config.set_cpu_math_library_num_threads(args.cpu_num_threads)
|
||||
|
||||
if args.enable_profile:
|
||||
config.enable_profile()
|
||||
config.disable_glog_info()
|
||||
config.switch_ir_optim(args.ir_optim) # default true
|
||||
if args.use_tensorrt:
|
||||
config.enable_tensorrt_engine(
|
||||
precision_mode=Config.Precision.Half
|
||||
if args.use_fp16 else Config.Precision.Float32,
|
||||
max_batch_size=args.batch_size)
|
||||
|
||||
config.enable_memory_optim()
|
||||
# use zero copy
|
||||
config.switch_use_feed_fetch_ops(False)
|
||||
predictor = create_predictor(config)
|
||||
|
||||
return predictor
|
||||
|
||||
|
||||
def preprocess(img, args):
|
||||
resize_op = ResizeImage(resize_short=args.resize_short)
|
||||
img = resize_op(img)
|
||||
crop_op = CropImage(size=(args.resize, args.resize))
|
||||
img = crop_op(img)
|
||||
if args.normalize:
|
||||
img_mean = [0.485, 0.456, 0.406]
|
||||
img_std = [0.229, 0.224, 0.225]
|
||||
img_scale = 1.0 / 255.0
|
||||
normalize_op = NormalizeImage(
|
||||
scale=img_scale, mean=img_mean, std=img_std)
|
||||
img = normalize_op(img)
|
||||
tensor_op = ToTensor()
|
||||
img = tensor_op(img)
|
||||
return img
|
||||
|
||||
|
||||
def postprocess(batch_outputs, topk=5, multilabel=False):
|
||||
batch_results = []
|
||||
for probs in batch_outputs:
|
||||
if multilabel:
|
||||
index = np.where(probs >= 0.5)[0].astype('int32')
|
||||
else:
|
||||
index = probs.argsort(axis=0)[-topk:][::-1].astype("int32")
|
||||
clas_id_list = []
|
||||
score_list = []
|
||||
for i in index:
|
||||
clas_id_list.append(i.item())
|
||||
score_list.append(probs[i].item())
|
||||
batch_results.append({"clas_ids": clas_id_list, "scores": score_list})
|
||||
return batch_results
|
||||
|
||||
|
||||
def get_image_list(img_file):
|
||||
imgs_lists = []
|
||||
if img_file is None or not os.path.exists(img_file):
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
|
||||
img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp']
|
||||
if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end:
|
||||
imgs_lists.append(img_file)
|
||||
elif os.path.isdir(img_file):
|
||||
for single_file in os.listdir(img_file):
|
||||
if single_file.split('.')[-1] in img_end:
|
||||
imgs_lists.append(os.path.join(img_file, single_file))
|
||||
if len(imgs_lists) == 0:
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
return imgs_lists
|
||||
|
||||
|
||||
def get_image_list_from_label_file(image_path, label_file_path):
|
||||
imgs_lists = []
|
||||
gt_labels = []
|
||||
with open(label_file_path, "r") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
image_name, label = line.strip("\n").split()
|
||||
label = int(label)
|
||||
imgs_lists.append(os.path.join(image_path, image_name))
|
||||
gt_labels.append(int(label))
|
||||
return imgs_lists, gt_labels
|
||||
|
||||
|
||||
def calc_topk_acc(info_map):
|
||||
'''
|
||||
calc_topk_acc
|
||||
input:
|
||||
info_map(dict): keys are prediction and gt_label
|
||||
output:
|
||||
topk_acc(list): top-k accuracy list
|
||||
'''
|
||||
gt_label = np.array(info_map["gt_label"])
|
||||
prediction = np.array(info_map["prediction"])
|
||||
|
||||
gt_label = np.reshape(gt_label, (-1, 1)).repeat(
|
||||
prediction.shape[1], axis=1)
|
||||
correct = np.equal(prediction, gt_label)
|
||||
topk_acc = []
|
||||
for idx in range(prediction.shape[1]):
|
||||
if idx > 0:
|
||||
correct[:, idx] = np.logical_or(correct[:, idx],
|
||||
correct[:, idx - 1])
|
||||
topk_acc.append(1.0 * np.sum(correct[:, idx]) / correct.shape[0])
|
||||
return topk_acc
|
||||
|
||||
|
||||
def save_prelabel_results(class_id, input_file_path, output_dir):
|
||||
output_dir = os.path.join(output_dir, str(class_id))
|
||||
if not os.path.isdir(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
shutil.copy(input_file_path, output_dir)
|
||||
|
||||
|
||||
class ResizeImage(object):
|
||||
def __init__(self, resize_short=None):
|
||||
self.resize_short = resize_short
|
||||
|
||||
def __call__(self, img):
|
||||
img_h, img_w = img.shape[:2]
|
||||
percent = float(self.resize_short) / min(img_w, img_h)
|
||||
w = int(round(img_w * percent))
|
||||
h = int(round(img_h * percent))
|
||||
return cv2.resize(img, (w, h))
|
||||
|
||||
|
||||
class CropImage(object):
|
||||
def __init__(self, size):
|
||||
if type(size) is int:
|
||||
self.size = (size, size)
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
w, h = self.size
|
||||
img_h, img_w = img.shape[:2]
|
||||
w_start = (img_w - w) // 2
|
||||
h_start = (img_h - h) // 2
|
||||
|
||||
w_end = w_start + w
|
||||
h_end = h_start + h
|
||||
return img[h_start:h_end, w_start:w_end, :]
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
def __init__(self, scale=None, mean=None, std=None):
|
||||
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
||||
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
||||
std = std if std is not None else [0.229, 0.224, 0.225]
|
||||
|
||||
shape = (1, 1, 3)
|
||||
self.mean = np.array(mean).reshape(shape).astype('float32')
|
||||
self.std = np.array(std).reshape(shape).astype('float32')
|
||||
|
||||
def __call__(self, img):
|
||||
return (img.astype('float32') * self.scale - self.mean) / self.std
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, img):
|
||||
img = img.transpose((2, 0, 1))
|
||||
return img
|
||||
|
||||
|
||||
def b64_to_np(b64str, revert_params):
|
||||
shape = revert_params["shape"]
|
||||
dtype = revert_params["dtype"]
|
||||
dtype = getattr(np, dtype) if isinstance(str, type(dtype)) else dtype
|
||||
data = base64.b64decode(b64str.encode('utf8'))
|
||||
data = np.fromstring(data, dtype).reshape(shape)
|
||||
return data
|
||||
|
||||
|
||||
def np_to_b64(images):
|
||||
img_str = base64.b64encode(images).decode('utf8')
|
||||
return img_str, images.shape
|