2022-12-05 18:06:00 +08:00
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
|
|
|
|
# isort:skip_file
|
|
|
|
import argparse
|
|
|
|
import copy
|
|
|
|
import functools
|
|
|
|
import glob
|
|
|
|
import inspect
|
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
import threading
|
|
|
|
import traceback
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
try:
|
|
|
|
import easy_predict
|
|
|
|
except ModuleNotFoundError:
|
|
|
|
print('please install easy_predict first using following instruction')
|
|
|
|
print(
|
|
|
|
'pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/easy_predict-0.4.2-py2.py3-none-any.whl'
|
|
|
|
)
|
|
|
|
exit()
|
|
|
|
|
|
|
|
from easy_predict import (Base64DecodeProcess, DataFields,
|
|
|
|
DefaultResultFormatProcess, DownloadProcess,
|
|
|
|
FileReadProcess, FileWriteProcess, Process,
|
|
|
|
ProcessExecutor, ResultGatherProcess,
|
|
|
|
TableReadProcess, TableWriteProcess)
|
|
|
|
from mmcv.runner import init_dist
|
|
|
|
|
|
|
|
from easycv.utils.dist_utils import get_dist_info
|
|
|
|
from easycv.utils.logger import get_root_logger
|
|
|
|
|
|
|
|
|
|
|
|
def define_args():
|
|
|
|
parser = argparse.ArgumentParser('easycv prediction')
|
|
|
|
parser.add_argument(
|
|
|
|
'--model_type',
|
|
|
|
default='',
|
|
|
|
help='model type, classifier/detector/segmentor/yolox')
|
|
|
|
parser.add_argument('--model_path', default='', help='path to model')
|
|
|
|
parser.add_argument(
|
|
|
|
'--model_config',
|
|
|
|
default='',
|
|
|
|
help='model config str, predictor v1 param')
|
|
|
|
|
|
|
|
# oss input output
|
|
|
|
parser.add_argument(
|
|
|
|
'--input_file',
|
|
|
|
default='',
|
|
|
|
help='filelist for images, eash line is a oss path or a local path')
|
|
|
|
parser.add_argument(
|
|
|
|
'--output_file',
|
|
|
|
default='',
|
|
|
|
help='oss file or local file to save predict info')
|
|
|
|
parser.add_argument(
|
|
|
|
'--output_dir',
|
|
|
|
default='',
|
|
|
|
help='output_directory to save image and video results')
|
|
|
|
parser.add_argument(
|
|
|
|
'--oss_prefix',
|
|
|
|
default='',
|
|
|
|
help='oss_prefix will be replaced with local_prefix in input_file')
|
|
|
|
parser.add_argument(
|
|
|
|
'--local_prefix',
|
|
|
|
default='',
|
|
|
|
help='oss_prefix will be replaced with local_prefix in input_file')
|
|
|
|
|
|
|
|
# table input output
|
|
|
|
parser.add_argument('--input_table', default='', help='input table name')
|
|
|
|
parser.add_argument('--output_table', default='', help='output table name')
|
|
|
|
parser.add_argument('--image_col', default='', help='input image column')
|
|
|
|
parser.add_argument(
|
|
|
|
'--reserved_columns',
|
|
|
|
default='',
|
|
|
|
help=
|
|
|
|
'columns from input table to be saved to output table, comma seperated'
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
'--result_column',
|
|
|
|
default='',
|
|
|
|
help='result columns to be saved to output table, comma seperated')
|
|
|
|
parser.add_argument(
|
|
|
|
'--odps_config',
|
|
|
|
default='./odps.config',
|
|
|
|
help='path to your odps config file')
|
|
|
|
parser.add_argument(
|
|
|
|
'--image_type', default='url', help='image data type, url or base64')
|
|
|
|
|
|
|
|
# common args
|
|
|
|
parser.add_argument(
|
|
|
|
'--queue_size',
|
|
|
|
type=int,
|
|
|
|
default=1024,
|
|
|
|
help='length of queues used for each process')
|
|
|
|
parser.add_argument(
|
|
|
|
'--predict_thread_num',
|
|
|
|
type=int,
|
|
|
|
default=1,
|
|
|
|
help='number of threads used for prediction')
|
|
|
|
parser.add_argument(
|
|
|
|
'--preprocess_thread_num',
|
|
|
|
type=int,
|
|
|
|
default=1,
|
|
|
|
help='number of threads used for preprocessing and downloading')
|
|
|
|
parser.add_argument(
|
|
|
|
'--batch_size',
|
|
|
|
type=int,
|
|
|
|
default=1,
|
|
|
|
help='batch size used for prediction')
|
|
|
|
parser.add_argument('--local_rank', type=int, default=0)
|
|
|
|
parser.add_argument(
|
|
|
|
'--launcher',
|
|
|
|
type=str,
|
|
|
|
choices=[None, 'pytorch'],
|
|
|
|
help='if assigned pytorch, should be used in gpu environment')
|
|
|
|
args = parser.parse_args()
|
|
|
|
if 'LOCAL_RANK' not in os.environ:
|
|
|
|
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
|
|
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
class PredictorProcess(Process):
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
predict_fn,
|
|
|
|
batch_size,
|
|
|
|
thread_num,
|
|
|
|
local_rank=0,
|
|
|
|
input_queue=None,
|
|
|
|
output_queue=None):
|
|
|
|
job_name = 'Predictor'
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
thread_init_fn = functools.partial(torch.cuda.set_device,
|
|
|
|
local_rank)
|
|
|
|
else:
|
|
|
|
thread_init_fn = None
|
|
|
|
super(PredictorProcess, self).__init__(
|
|
|
|
job_name,
|
|
|
|
thread_num,
|
|
|
|
input_queue,
|
|
|
|
output_queue,
|
|
|
|
batch_size=batch_size,
|
|
|
|
thread_init_fn=thread_init_fn)
|
|
|
|
self.predict_fn = predict_fn
|
|
|
|
self.data_lock = threading.Lock()
|
|
|
|
self.all_data_failed = True
|
|
|
|
self.input_empty = True
|
|
|
|
self.local_rank = 0
|
|
|
|
|
|
|
|
def process(self, input_batch):
|
|
|
|
"""
|
|
|
|
Read a batch of image from input_queue and predict
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_batch: a batch of input data
|
|
|
|
Returns:
|
|
|
|
output_queue: unstak batch, push input data and prediction result into queue
|
|
|
|
"""
|
|
|
|
valid_input = []
|
|
|
|
valid_indices = []
|
|
|
|
valid_frame_ids = []
|
|
|
|
if self.batch_size == 1:
|
|
|
|
input_batch = [input_batch]
|
|
|
|
if self.input_empty and len(input_batch) > 0:
|
|
|
|
self.data_lock.acquire()
|
|
|
|
self.input_empty = False
|
|
|
|
self.data_lock.release()
|
|
|
|
|
|
|
|
output_data_list = input_batch
|
|
|
|
for out in output_data_list:
|
|
|
|
out[DataFields.prediction_result] = None
|
|
|
|
|
|
|
|
for idx, input_data in enumerate(input_batch):
|
|
|
|
if DataFields.image in input_data \
|
|
|
|
and input_data[DataFields.image] is not None:
|
|
|
|
valid_input.append(input_data[DataFields.image])
|
|
|
|
valid_indices.append(idx)
|
|
|
|
|
|
|
|
if len(valid_input) > 0:
|
|
|
|
try:
|
|
|
|
# flatten video_clip to images, use image predictor to predict
|
|
|
|
# then regroup the result to a list for one video_clip
|
|
|
|
output_list = self.predict_fn(valid_input)
|
|
|
|
|
|
|
|
if len(output_list) > 0:
|
|
|
|
assert isinstance(output_list[0], dict), \
|
|
|
|
'the element in predictor output must be a dict'
|
|
|
|
|
|
|
|
if self.all_data_failed:
|
|
|
|
self.data_lock.acquire()
|
|
|
|
self.all_data_failed = False
|
|
|
|
self.data_lock.release()
|
|
|
|
|
|
|
|
except Exception:
|
|
|
|
logging.error(traceback.format_exc())
|
|
|
|
output_list = [None for i in range(len(valid_input))]
|
|
|
|
|
|
|
|
for idx, result_dict in zip(valid_indices, output_list):
|
|
|
|
output_data = output_data_list[idx]
|
|
|
|
output_data[DataFields.prediction_result] = result_dict
|
|
|
|
if result_dict is None:
|
|
|
|
output_data[DataFields.error_msg] = 'prediction error'
|
|
|
|
|
|
|
|
output_data_list[idx] = output_data
|
|
|
|
|
|
|
|
for output_data in output_data_list:
|
|
|
|
self.put(output_data)
|
|
|
|
|
|
|
|
def destroy(self):
|
|
|
|
if not self.input_empty and self.all_data_failed:
|
|
|
|
raise RuntimeError(
|
|
|
|
'failed to predict all the input data, please see exception throwed above in the log'
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2022-12-06 18:22:50 +08:00
|
|
|
def create_yolox_predictor_kwargs(model_dir):
|
|
|
|
jit_models = glob.glob('%s/**/*.jit' % model_dir, recursive=True)
|
|
|
|
raw_models = glob.glob('%s/**/*.pt' % model_dir, recursive=True)
|
|
|
|
if len(jit_models) > 0:
|
|
|
|
assert len(
|
|
|
|
jit_models
|
|
|
|
) == 1, f'more than one jit script model files is found in {model_dir}'
|
|
|
|
config_path = jit_models[0] + '.config.json'
|
|
|
|
if not os.path.exists(config_path):
|
|
|
|
raise ValueError(
|
|
|
|
f'Not find config json file {config_path} for inference with jit script model'
|
|
|
|
)
|
|
|
|
return {'model_path': jit_models[0], 'config_file': config_path}
|
|
|
|
else:
|
|
|
|
assert len(raw_models) > 0, f'export model not found in {model_dir}'
|
|
|
|
assert len(raw_models
|
|
|
|
) == 1, f'more than one model files is found in {model_dir}'
|
|
|
|
return {'model_path': raw_models[0]}
|
|
|
|
|
|
|
|
|
|
|
|
def create_default_predictor_kwargs(model_dir):
|
|
|
|
model_path = glob.glob('%s/**/*.pt*' % model_dir, recursive=True)
|
|
|
|
assert len(model_path) > 0, f'model not found in {model_dir}'
|
|
|
|
assert len(
|
|
|
|
model_path) == 1, f'more than one model file is found {model_path}'
|
|
|
|
model_path = model_path[0]
|
|
|
|
logging.info(f'model found: {model_path}')
|
|
|
|
|
|
|
|
config_path = glob.glob('%s/**/*.py' % model_dir, recursive=True)
|
|
|
|
if len(config_path) == 0:
|
|
|
|
config_path = None
|
|
|
|
else:
|
|
|
|
assert len(config_path
|
|
|
|
) == 1, f'more than one config file is found {config_path}'
|
|
|
|
config_path = config_path[0]
|
|
|
|
logging.info(f'config found: {config_path}')
|
|
|
|
if config_path:
|
|
|
|
return {'model_path': model_path, 'config_file': config_path}
|
|
|
|
else:
|
|
|
|
return {'model_path': model_path}
|
|
|
|
|
|
|
|
|
|
|
|
def create_predictor_kwargs(model_type, model_dir):
|
|
|
|
if model_type == 'YoloXPredictor':
|
|
|
|
return create_yolox_predictor_kwargs(model_dir)
|
|
|
|
else:
|
|
|
|
return create_default_predictor_kwargs(model_dir)
|
|
|
|
|
|
|
|
|
2022-12-05 18:06:00 +08:00
|
|
|
def init_predictor(args):
|
|
|
|
model_type = args.model_type
|
|
|
|
model_path = args.model_path
|
|
|
|
batch_size = args.batch_size
|
|
|
|
from easycv.predictors.builder import build_predictor
|
|
|
|
|
|
|
|
ori_model_path = model_path
|
2022-12-06 18:22:50 +08:00
|
|
|
if os.path.isdir(ori_model_path):
|
|
|
|
predictor_kwargs = create_predictor_kwargs(model_type, ori_model_path)
|
2022-12-05 18:06:00 +08:00
|
|
|
else:
|
2022-12-06 18:22:50 +08:00
|
|
|
predictor_kwargs = {'model_path': ori_model_path}
|
2022-12-05 18:06:00 +08:00
|
|
|
|
2022-12-06 18:22:50 +08:00
|
|
|
predictor_cfg = dict(type=model_type, **predictor_kwargs)
|
2022-12-05 18:06:00 +08:00
|
|
|
if args.model_config != '':
|
|
|
|
predictor_cfg['model_config'] = args.model_config
|
|
|
|
predictor = build_predictor(predictor_cfg)
|
|
|
|
return predictor
|
|
|
|
|
|
|
|
|
|
|
|
def replace_oss_with_local_path(ori_file, dst_file, bucket_prefix,
|
|
|
|
local_prefix):
|
|
|
|
bucket_prefix = bucket_prefix.rstrip('/') + '/'
|
|
|
|
local_prefix = local_prefix.rstrip('/') + '/'
|
|
|
|
with open(ori_file, 'r') as infile:
|
|
|
|
with open(dst_file, 'w') as ofile:
|
|
|
|
for l in infile:
|
|
|
|
if l.startswith('oss://'):
|
|
|
|
l = l.replace(bucket_prefix, local_prefix)
|
|
|
|
ofile.write(l)
|
|
|
|
|
|
|
|
|
|
|
|
def build_and_run_file_io(args):
|
|
|
|
# distribute info
|
|
|
|
rank, world_size = get_dist_info()
|
|
|
|
worker_id = rank
|
|
|
|
|
|
|
|
input_oss_file_new_host = args.input_file + '.tmp%d' % worker_id
|
|
|
|
replace_oss_with_local_path(args.input_file, input_oss_file_new_host,
|
|
|
|
args.oss_prefix, args.local_prefix)
|
|
|
|
args.input_file = input_oss_file_new_host
|
|
|
|
num_worker = world_size
|
|
|
|
print(f'worker num {num_worker}')
|
|
|
|
print(f'worker_id {worker_id}')
|
|
|
|
batch_size = args.batch_size
|
|
|
|
print(f'Local rank {args.local_rank}')
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
torch.cuda.set_device(args.local_rank)
|
|
|
|
predictor = init_predictor(args)
|
|
|
|
predict_fn = predictor.__call__ if hasattr(
|
|
|
|
predictor, '__call__') else predictor.predict
|
|
|
|
# create proc executor
|
|
|
|
proc_exec = ProcessExecutor(args.queue_size)
|
|
|
|
|
|
|
|
# create oss read process to read file path from filelist
|
|
|
|
proc_exec.add(
|
|
|
|
FileReadProcess(
|
|
|
|
args.input_file,
|
|
|
|
slice_id=worker_id,
|
|
|
|
slice_count=num_worker,
|
|
|
|
output_queue=proc_exec.get_output_queue()))
|
|
|
|
|
|
|
|
# download and decode image data
|
|
|
|
proc_exec.add(
|
|
|
|
DownloadProcess(
|
|
|
|
thread_num=args.predict_thread_num,
|
|
|
|
input_queue=proc_exec.get_input_queue(),
|
|
|
|
output_queue=proc_exec.get_output_queue(),
|
|
|
|
is_video_url=False))
|
|
|
|
|
|
|
|
# transform image data
|
|
|
|
proc_exec.add(
|
|
|
|
PredictorProcess(
|
|
|
|
input_queue=proc_exec.get_input_queue(),
|
|
|
|
output_queue=proc_exec.get_output_queue(),
|
|
|
|
predict_fn=predict_fn,
|
|
|
|
batch_size=batch_size,
|
|
|
|
local_rank=args.local_rank,
|
|
|
|
thread_num=args.predict_thread_num))
|
|
|
|
|
|
|
|
proc_exec.add(
|
|
|
|
DefaultResultFormatProcess(
|
|
|
|
input_queue=proc_exec.get_input_queue(),
|
|
|
|
output_queue=proc_exec.get_output_queue()))
|
|
|
|
|
|
|
|
# Gather result to different dict of different type
|
|
|
|
proc_exec.add(
|
|
|
|
ResultGatherProcess(
|
|
|
|
output_type_dict={},
|
|
|
|
input_queue=proc_exec.get_input_queue(),
|
|
|
|
output_queue=proc_exec.get_output_queue()))
|
|
|
|
|
|
|
|
# Write result
|
|
|
|
proc_exec.add(
|
|
|
|
FileWriteProcess(
|
|
|
|
output_file=args.output_file,
|
|
|
|
output_dir=args.output_dir,
|
|
|
|
slice_id=worker_id,
|
|
|
|
slice_count=num_worker,
|
|
|
|
input_queue=proc_exec.get_input_queue()))
|
|
|
|
|
|
|
|
proc_exec.run()
|
|
|
|
proc_exec.wait()
|
|
|
|
|
|
|
|
|
|
|
|
def build_and_run_table_io(args):
|
|
|
|
os.environ['ODPS_CONFIG_FILE_PATH'] = args.odps_config
|
|
|
|
|
|
|
|
rank, world_size = get_dist_info()
|
|
|
|
worker_id = rank
|
|
|
|
num_worker = world_size
|
|
|
|
print(f'worker num {num_worker}')
|
|
|
|
print(f'worker_id {worker_id}')
|
|
|
|
|
|
|
|
batch_size = args.batch_size
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
torch.cuda.set_device(args.local_rank)
|
|
|
|
predictor = init_predictor(args)
|
|
|
|
predict_fn = predictor.__call__ if hasattr(
|
|
|
|
predictor, '__call__') else predictor.predict
|
|
|
|
# batch size should be less than the total number of data in input table
|
|
|
|
table_read_batch_size = 1
|
|
|
|
table_read_thread_num = 4
|
|
|
|
|
|
|
|
# create proc executor
|
|
|
|
proc_exec = ProcessExecutor(args.queue_size)
|
|
|
|
|
|
|
|
# create oss read process to read file path from filelist
|
|
|
|
selected_cols = list(
|
|
|
|
set(args.image_col.split(',') + args.reserved_columns.split(',')))
|
|
|
|
if args.image_col not in selected_cols:
|
|
|
|
selected_cols.append(args.image_col)
|
|
|
|
image_col_idx = selected_cols.index(args.image_col)
|
|
|
|
proc_exec.add(
|
|
|
|
TableReadProcess(
|
|
|
|
args.input_table,
|
|
|
|
selected_cols=selected_cols,
|
|
|
|
slice_id=worker_id,
|
|
|
|
slice_count=num_worker,
|
|
|
|
output_queue=proc_exec.get_output_queue(),
|
|
|
|
image_col_idx=image_col_idx,
|
|
|
|
image_type=args.image_type,
|
|
|
|
batch_size=table_read_batch_size,
|
|
|
|
num_threads=table_read_thread_num))
|
|
|
|
|
|
|
|
if args.image_type == 'base64':
|
|
|
|
base64_thread_num = args.preprocess_thread_num
|
|
|
|
proc_exec.add(
|
|
|
|
Base64DecodeProcess(
|
|
|
|
thread_num=base64_thread_num,
|
|
|
|
input_queue=proc_exec.get_input_queue(),
|
|
|
|
output_queue=proc_exec.get_output_queue()))
|
|
|
|
elif args.image_type == 'url':
|
|
|
|
download_thread_num = args.preprocess_thread_num
|
|
|
|
proc_exec.add(
|
|
|
|
DownloadProcess(
|
|
|
|
thread_num=download_thread_num,
|
|
|
|
input_queue=proc_exec.get_input_queue(),
|
|
|
|
output_queue=proc_exec.get_output_queue(),
|
|
|
|
use_pil_decode=False))
|
|
|
|
|
|
|
|
# transform image data
|
|
|
|
proc_exec.add(
|
|
|
|
PredictorProcess(
|
|
|
|
input_queue=proc_exec.get_input_queue(),
|
|
|
|
output_queue=proc_exec.get_output_queue(),
|
|
|
|
predict_fn=predict_fn,
|
|
|
|
batch_size=batch_size,
|
|
|
|
local_rank=args.local_rank,
|
|
|
|
thread_num=args.predict_thread_num))
|
|
|
|
|
|
|
|
proc_exec.add(
|
|
|
|
DefaultResultFormatProcess(
|
|
|
|
input_queue=proc_exec.get_input_queue(),
|
|
|
|
output_queue=proc_exec.get_output_queue(),
|
|
|
|
reserved_col_names=args.reserved_columns.split(','),
|
|
|
|
output_col_names=args.result_column.split(',')))
|
|
|
|
|
|
|
|
# Write result
|
|
|
|
output_cols = args.reserved_columns.split(',') + args.result_column.split(
|
|
|
|
',')
|
|
|
|
proc_exec.add(
|
|
|
|
TableWriteProcess(
|
|
|
|
args.output_table,
|
|
|
|
output_col_names=output_cols,
|
|
|
|
slice_id=worker_id,
|
|
|
|
input_queue=proc_exec.get_input_queue()))
|
|
|
|
|
|
|
|
proc_exec.run()
|
|
|
|
proc_exec.wait()
|
|
|
|
|
|
|
|
|
|
|
|
def check_args(args, arg_name, default_value=''):
|
|
|
|
assert getattr(args, arg_name) != '', f'{arg_name} should not be empty'
|
|
|
|
|
|
|
|
|
|
|
|
def patch_logging():
|
|
|
|
# after get_root_logger, logging will not take effect because
|
|
|
|
# it sets all other handler to level logging.INFO
|
|
|
|
logger = get_root_logger()
|
|
|
|
for handler in logger.root.handlers:
|
|
|
|
if type(handler) is logging.StreamHandler:
|
|
|
|
handler.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
args = define_args()
|
|
|
|
patch_logging()
|
|
|
|
if args.launcher:
|
|
|
|
init_dist(args.launcher, backend='nccl')
|
|
|
|
if args.input_file != '':
|
|
|
|
check_args(args, 'output_file')
|
|
|
|
build_and_run_file_io(args)
|
|
|
|
else:
|
|
|
|
check_args(args, 'input_table')
|
|
|
|
check_args(args, 'output_table')
|
|
|
|
check_args(args, 'image_col')
|
|
|
|
check_args(args, 'reserved_columns')
|
|
|
|
check_args(args, 'result_column')
|
|
|
|
build_and_run_table_io(args)
|