EasyCV/tools/predict.py

487 lines
16 KiB
Python
Raw Normal View History

# Copyright (c) Alibaba, Inc. and its affiliates.
# isort:skip_file
import argparse
import copy
import functools
import glob
import inspect
import logging
import os
import threading
import traceback
import torch
try:
import easy_predict
except ModuleNotFoundError:
print('please install easy_predict first using following instruction')
print(
'pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/easy_predict-0.4.2-py2.py3-none-any.whl'
)
exit()
from easy_predict import (Base64DecodeProcess, DataFields,
DefaultResultFormatProcess, DownloadProcess,
FileReadProcess, FileWriteProcess, Process,
ProcessExecutor, ResultGatherProcess,
TableReadProcess, TableWriteProcess)
from mmcv.runner import init_dist
from easycv.utils.dist_utils import get_dist_info
from easycv.utils.logger import get_root_logger
def define_args():
parser = argparse.ArgumentParser('easycv prediction')
parser.add_argument(
'--model_type',
default='',
help='model type, classifier/detector/segmentor/yolox')
parser.add_argument('--model_path', default='', help='path to model')
parser.add_argument(
'--model_config',
default='',
help='model config str, predictor v1 param')
# oss input output
parser.add_argument(
'--input_file',
default='',
help='filelist for images, eash line is a oss path or a local path')
parser.add_argument(
'--output_file',
default='',
help='oss file or local file to save predict info')
parser.add_argument(
'--output_dir',
default='',
help='output_directory to save image and video results')
parser.add_argument(
'--oss_prefix',
default='',
help='oss_prefix will be replaced with local_prefix in input_file')
parser.add_argument(
'--local_prefix',
default='',
help='oss_prefix will be replaced with local_prefix in input_file')
# table input output
parser.add_argument('--input_table', default='', help='input table name')
parser.add_argument('--output_table', default='', help='output table name')
parser.add_argument('--image_col', default='', help='input image column')
parser.add_argument(
'--reserved_columns',
default='',
help=
'columns from input table to be saved to output table, comma seperated'
)
parser.add_argument(
'--result_column',
default='',
help='result columns to be saved to output table, comma seperated')
parser.add_argument(
'--odps_config',
default='./odps.config',
help='path to your odps config file')
parser.add_argument(
'--image_type', default='url', help='image data type, url or base64')
# common args
parser.add_argument(
'--queue_size',
type=int,
default=1024,
help='length of queues used for each process')
parser.add_argument(
'--predict_thread_num',
type=int,
default=1,
help='number of threads used for prediction')
parser.add_argument(
'--preprocess_thread_num',
type=int,
default=1,
help='number of threads used for preprocessing and downloading')
parser.add_argument(
'--batch_size',
type=int,
default=1,
help='batch size used for prediction')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--launcher',
type=str,
choices=[None, 'pytorch'],
help='if assigned pytorch, should be used in gpu environment')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
class PredictorProcess(Process):
def __init__(self,
predict_fn,
batch_size,
thread_num,
local_rank=0,
input_queue=None,
output_queue=None):
job_name = 'Predictor'
if torch.cuda.is_available():
thread_init_fn = functools.partial(torch.cuda.set_device,
local_rank)
else:
thread_init_fn = None
super(PredictorProcess, self).__init__(
job_name,
thread_num,
input_queue,
output_queue,
batch_size=batch_size,
thread_init_fn=thread_init_fn)
self.predict_fn = predict_fn
self.data_lock = threading.Lock()
self.all_data_failed = True
self.input_empty = True
self.local_rank = 0
def process(self, input_batch):
"""
Read a batch of image from input_queue and predict
Args:
input_batch: a batch of input data
Returns:
output_queue: unstak batch, push input data and prediction result into queue
"""
valid_input = []
valid_indices = []
valid_frame_ids = []
if self.batch_size == 1:
input_batch = [input_batch]
if self.input_empty and len(input_batch) > 0:
self.data_lock.acquire()
self.input_empty = False
self.data_lock.release()
output_data_list = input_batch
for out in output_data_list:
out[DataFields.prediction_result] = None
for idx, input_data in enumerate(input_batch):
if DataFields.image in input_data \
and input_data[DataFields.image] is not None:
valid_input.append(input_data[DataFields.image])
valid_indices.append(idx)
if len(valid_input) > 0:
try:
# flatten video_clip to images, use image predictor to predict
# then regroup the result to a list for one video_clip
output_list = self.predict_fn(valid_input)
if len(output_list) > 0:
assert isinstance(output_list[0], dict), \
'the element in predictor output must be a dict'
if self.all_data_failed:
self.data_lock.acquire()
self.all_data_failed = False
self.data_lock.release()
except Exception:
logging.error(traceback.format_exc())
output_list = [None for i in range(len(valid_input))]
for idx, result_dict in zip(valid_indices, output_list):
output_data = output_data_list[idx]
output_data[DataFields.prediction_result] = result_dict
if result_dict is None:
output_data[DataFields.error_msg] = 'prediction error'
output_data_list[idx] = output_data
for output_data in output_data_list:
self.put(output_data)
def destroy(self):
if not self.input_empty and self.all_data_failed:
raise RuntimeError(
'failed to predict all the input data, please see exception throwed above in the log'
)
def create_yolox_predictor_kwargs(model_dir):
jit_models = glob.glob('%s/**/*.jit' % model_dir, recursive=True)
raw_models = glob.glob('%s/**/*.pt' % model_dir, recursive=True)
if len(jit_models) > 0:
assert len(
jit_models
) == 1, f'more than one jit script model files is found in {model_dir}'
config_path = jit_models[0] + '.config.json'
if not os.path.exists(config_path):
raise ValueError(
f'Not find config json file {config_path} for inference with jit script model'
)
return {'model_path': jit_models[0], 'config_file': config_path}
else:
assert len(raw_models) > 0, f'export model not found in {model_dir}'
assert len(raw_models
) == 1, f'more than one model files is found in {model_dir}'
return {'model_path': raw_models[0]}
def create_default_predictor_kwargs(model_dir):
model_path = glob.glob('%s/**/*.pt*' % model_dir, recursive=True)
assert len(model_path) > 0, f'model not found in {model_dir}'
assert len(
model_path) == 1, f'more than one model file is found {model_path}'
model_path = model_path[0]
logging.info(f'model found: {model_path}')
config_path = glob.glob('%s/**/*.py' % model_dir, recursive=True)
if len(config_path) == 0:
config_path = None
else:
assert len(config_path
) == 1, f'more than one config file is found {config_path}'
config_path = config_path[0]
logging.info(f'config found: {config_path}')
if config_path:
return {'model_path': model_path, 'config_file': config_path}
else:
return {'model_path': model_path}
def create_predictor_kwargs(model_type, model_dir):
if model_type == 'YoloXPredictor':
return create_yolox_predictor_kwargs(model_dir)
else:
return create_default_predictor_kwargs(model_dir)
def init_predictor(args):
model_type = args.model_type
model_path = args.model_path
batch_size = args.batch_size
from easycv.predictors.builder import build_predictor
ori_model_path = model_path
if os.path.isdir(ori_model_path):
predictor_kwargs = create_predictor_kwargs(model_type, ori_model_path)
else:
predictor_kwargs = {'model_path': ori_model_path}
predictor_cfg = dict(type=model_type, **predictor_kwargs)
if args.model_config != '':
predictor_cfg['model_config'] = args.model_config
predictor = build_predictor(predictor_cfg)
return predictor
def replace_oss_with_local_path(ori_file, dst_file, bucket_prefix,
local_prefix):
bucket_prefix = bucket_prefix.rstrip('/') + '/'
local_prefix = local_prefix.rstrip('/') + '/'
with open(ori_file, 'r') as infile:
with open(dst_file, 'w') as ofile:
for l in infile:
if l.startswith('oss://'):
l = l.replace(bucket_prefix, local_prefix)
ofile.write(l)
def build_and_run_file_io(args):
# distribute info
rank, world_size = get_dist_info()
worker_id = rank
input_oss_file_new_host = args.input_file + '.tmp%d' % worker_id
replace_oss_with_local_path(args.input_file, input_oss_file_new_host,
args.oss_prefix, args.local_prefix)
args.input_file = input_oss_file_new_host
num_worker = world_size
print(f'worker num {num_worker}')
print(f'worker_id {worker_id}')
batch_size = args.batch_size
print(f'Local rank {args.local_rank}')
if torch.cuda.is_available():
torch.cuda.set_device(args.local_rank)
predictor = init_predictor(args)
predict_fn = predictor.__call__ if hasattr(
predictor, '__call__') else predictor.predict
# create proc executor
proc_exec = ProcessExecutor(args.queue_size)
# create oss read process to read file path from filelist
proc_exec.add(
FileReadProcess(
args.input_file,
slice_id=worker_id,
slice_count=num_worker,
output_queue=proc_exec.get_output_queue()))
# download and decode image data
proc_exec.add(
DownloadProcess(
thread_num=args.predict_thread_num,
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue(),
is_video_url=False))
# transform image data
proc_exec.add(
PredictorProcess(
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue(),
predict_fn=predict_fn,
batch_size=batch_size,
local_rank=args.local_rank,
thread_num=args.predict_thread_num))
proc_exec.add(
DefaultResultFormatProcess(
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue()))
# Gather result to different dict of different type
proc_exec.add(
ResultGatherProcess(
output_type_dict={},
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue()))
# Write result
proc_exec.add(
FileWriteProcess(
output_file=args.output_file,
output_dir=args.output_dir,
slice_id=worker_id,
slice_count=num_worker,
input_queue=proc_exec.get_input_queue()))
proc_exec.run()
proc_exec.wait()
def build_and_run_table_io(args):
os.environ['ODPS_CONFIG_FILE_PATH'] = args.odps_config
rank, world_size = get_dist_info()
worker_id = rank
num_worker = world_size
print(f'worker num {num_worker}')
print(f'worker_id {worker_id}')
batch_size = args.batch_size
if torch.cuda.is_available():
torch.cuda.set_device(args.local_rank)
predictor = init_predictor(args)
predict_fn = predictor.__call__ if hasattr(
predictor, '__call__') else predictor.predict
# batch size should be less than the total number of data in input table
table_read_batch_size = 1
table_read_thread_num = 4
# create proc executor
proc_exec = ProcessExecutor(args.queue_size)
# create oss read process to read file path from filelist
selected_cols = list(
set(args.image_col.split(',') + args.reserved_columns.split(',')))
if args.image_col not in selected_cols:
selected_cols.append(args.image_col)
image_col_idx = selected_cols.index(args.image_col)
proc_exec.add(
TableReadProcess(
args.input_table,
selected_cols=selected_cols,
slice_id=worker_id,
slice_count=num_worker,
output_queue=proc_exec.get_output_queue(),
image_col_idx=image_col_idx,
image_type=args.image_type,
batch_size=table_read_batch_size,
num_threads=table_read_thread_num))
if args.image_type == 'base64':
base64_thread_num = args.preprocess_thread_num
proc_exec.add(
Base64DecodeProcess(
thread_num=base64_thread_num,
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue()))
elif args.image_type == 'url':
download_thread_num = args.preprocess_thread_num
proc_exec.add(
DownloadProcess(
thread_num=download_thread_num,
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue(),
use_pil_decode=False))
# transform image data
proc_exec.add(
PredictorProcess(
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue(),
predict_fn=predict_fn,
batch_size=batch_size,
local_rank=args.local_rank,
thread_num=args.predict_thread_num))
proc_exec.add(
DefaultResultFormatProcess(
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue(),
reserved_col_names=args.reserved_columns.split(','),
output_col_names=args.result_column.split(',')))
# Write result
output_cols = args.reserved_columns.split(',') + args.result_column.split(
',')
proc_exec.add(
TableWriteProcess(
args.output_table,
output_col_names=output_cols,
slice_id=worker_id,
input_queue=proc_exec.get_input_queue()))
proc_exec.run()
proc_exec.wait()
def check_args(args, arg_name, default_value=''):
assert getattr(args, arg_name) != '', f'{arg_name} should not be empty'
def patch_logging():
# after get_root_logger, logging will not take effect because
# it sets all other handler to level logging.INFO
logger = get_root_logger()
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.INFO)
if __name__ == '__main__':
args = define_args()
patch_logging()
if args.launcher:
init_dist(args.launcher, backend='nccl')
if args.input_file != '':
check_args(args, 'output_file')
build_and_run_file_io(args)
else:
check_args(args, 'input_table')
check_args(args, 'output_table')
check_args(args, 'image_col')
check_args(args, 'reserved_columns')
check_args(args, 'result_column')
build_and_run_table_io(args)