# Copyright (c) Alibaba, Inc. and its affiliates. # isort:skip_file import argparse import copy import functools import glob import inspect import logging import os import threading import traceback import torch try: import easy_predict except ModuleNotFoundError: print('please install easy_predict first using following instruction') print( 'pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/easy_predict-0.4.2-py2.py3-none-any.whl' ) exit() from easy_predict import (Base64DecodeProcess, DataFields, DefaultResultFormatProcess, DownloadProcess, FileReadProcess, FileWriteProcess, Process, ProcessExecutor, ResultGatherProcess, TableReadProcess, TableWriteProcess) from mmcv.runner import init_dist from easycv.utils.dist_utils import get_dist_info from easycv.utils.logger import get_root_logger def define_args(): parser = argparse.ArgumentParser('easycv prediction') parser.add_argument( '--model_type', default='', help='model type, classifier/detector/segmentor/yolox') parser.add_argument('--model_path', default='', help='path to model') parser.add_argument( '--model_config', default='', help='model config str, predictor v1 param') # oss input output parser.add_argument( '--input_file', default='', help='filelist for images, eash line is a oss path or a local path') parser.add_argument( '--output_file', default='', help='oss file or local file to save predict info') parser.add_argument( '--output_dir', default='', help='output_directory to save image and video results') parser.add_argument( '--oss_prefix', default='', help='oss_prefix will be replaced with local_prefix in input_file') parser.add_argument( '--local_prefix', default='', help='oss_prefix will be replaced with local_prefix in input_file') # table input output parser.add_argument('--input_table', default='', help='input table name') parser.add_argument('--output_table', default='', help='output table name') parser.add_argument('--image_col', default='', help='input image column') parser.add_argument( '--reserved_columns', default='', help= 'columns from input table to be saved to output table, comma seperated' ) parser.add_argument( '--result_column', default='', help='result columns to be saved to output table, comma seperated') parser.add_argument( '--odps_config', default='./odps.config', help='path to your odps config file') parser.add_argument( '--image_type', default='url', help='image data type, url or base64') # common args parser.add_argument( '--queue_size', type=int, default=1024, help='length of queues used for each process') parser.add_argument( '--predict_thread_num', type=int, default=1, help='number of threads used for prediction') parser.add_argument( '--preprocess_thread_num', type=int, default=1, help='number of threads used for preprocessing and downloading') parser.add_argument( '--batch_size', type=int, default=1, help='batch size used for prediction') parser.add_argument('--local_rank', type=int, default=0) parser.add_argument( '--launcher', type=str, choices=[None, 'pytorch'], help='if assigned pytorch, should be used in gpu environment') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) return args class PredictorProcess(Process): def __init__(self, predict_fn, batch_size, thread_num, local_rank=0, input_queue=None, output_queue=None): job_name = 'Predictor' if torch.cuda.is_available(): thread_init_fn = functools.partial(torch.cuda.set_device, local_rank) else: thread_init_fn = None super(PredictorProcess, self).__init__( job_name, thread_num, input_queue, output_queue, batch_size=batch_size, thread_init_fn=thread_init_fn) self.predict_fn = predict_fn self.data_lock = threading.Lock() self.all_data_failed = True self.input_empty = True self.local_rank = 0 def process(self, input_batch): """ Read a batch of image from input_queue and predict Args: input_batch: a batch of input data Returns: output_queue: unstak batch, push input data and prediction result into queue """ valid_input = [] valid_indices = [] valid_frame_ids = [] if self.batch_size == 1: input_batch = [input_batch] if self.input_empty and len(input_batch) > 0: self.data_lock.acquire() self.input_empty = False self.data_lock.release() output_data_list = input_batch for out in output_data_list: out[DataFields.prediction_result] = None for idx, input_data in enumerate(input_batch): if DataFields.image in input_data \ and input_data[DataFields.image] is not None: valid_input.append(input_data[DataFields.image]) valid_indices.append(idx) if len(valid_input) > 0: try: # flatten video_clip to images, use image predictor to predict # then regroup the result to a list for one video_clip output_list = self.predict_fn(valid_input) if len(output_list) > 0: assert isinstance(output_list[0], dict), \ 'the element in predictor output must be a dict' if self.all_data_failed: self.data_lock.acquire() self.all_data_failed = False self.data_lock.release() except Exception: logging.error(traceback.format_exc()) output_list = [None for i in range(len(valid_input))] for idx, result_dict in zip(valid_indices, output_list): output_data = output_data_list[idx] output_data[DataFields.prediction_result] = result_dict if result_dict is None: output_data[DataFields.error_msg] = 'prediction error' output_data_list[idx] = output_data for output_data in output_data_list: self.put(output_data) def destroy(self): if not self.input_empty and self.all_data_failed: raise RuntimeError( 'failed to predict all the input data, please see exception throwed above in the log' ) def create_yolox_predictor_kwargs(model_dir): jit_models = glob.glob('%s/**/*.jit' % model_dir, recursive=True) raw_models = glob.glob('%s/**/*.pt' % model_dir, recursive=True) if len(jit_models) > 0: assert len( jit_models ) == 1, f'more than one jit script model files is found in {model_dir}' config_path = jit_models[0] + '.config.json' if not os.path.exists(config_path): raise ValueError( f'Not find config json file {config_path} for inference with jit script model' ) return {'model_path': jit_models[0], 'config_file': config_path} else: assert len(raw_models) > 0, f'export model not found in {model_dir}' assert len(raw_models ) == 1, f'more than one model files is found in {model_dir}' return {'model_path': raw_models[0]} def create_default_predictor_kwargs(model_dir): model_path = glob.glob('%s/**/*.pt*' % model_dir, recursive=True) assert len(model_path) > 0, f'model not found in {model_dir}' assert len( model_path) == 1, f'more than one model file is found {model_path}' model_path = model_path[0] logging.info(f'model found: {model_path}') config_path = glob.glob('%s/**/*.py' % model_dir, recursive=True) if len(config_path) == 0: config_path = None else: assert len(config_path ) == 1, f'more than one config file is found {config_path}' config_path = config_path[0] logging.info(f'config found: {config_path}') if config_path: return {'model_path': model_path, 'config_file': config_path} else: return {'model_path': model_path} def create_predictor_kwargs(model_type, model_dir): if model_type == 'YoloXPredictor': return create_yolox_predictor_kwargs(model_dir) else: return create_default_predictor_kwargs(model_dir) def init_predictor(args): model_type = args.model_type model_path = args.model_path batch_size = args.batch_size from easycv.predictors.builder import build_predictor ori_model_path = model_path if os.path.isdir(ori_model_path): predictor_kwargs = create_predictor_kwargs(model_type, ori_model_path) else: predictor_kwargs = {'model_path': ori_model_path} predictor_cfg = dict(type=model_type, **predictor_kwargs) if args.model_config != '': predictor_cfg['model_config'] = args.model_config predictor = build_predictor(predictor_cfg) return predictor def replace_oss_with_local_path(ori_file, dst_file, bucket_prefix, local_prefix): bucket_prefix = bucket_prefix.rstrip('/') + '/' local_prefix = local_prefix.rstrip('/') + '/' with open(ori_file, 'r') as infile: with open(dst_file, 'w') as ofile: for l in infile: if l.startswith('oss://'): l = l.replace(bucket_prefix, local_prefix) ofile.write(l) def build_and_run_file_io(args): # distribute info rank, world_size = get_dist_info() worker_id = rank input_oss_file_new_host = args.input_file + '.tmp%d' % worker_id replace_oss_with_local_path(args.input_file, input_oss_file_new_host, args.oss_prefix, args.local_prefix) args.input_file = input_oss_file_new_host num_worker = world_size print(f'worker num {num_worker}') print(f'worker_id {worker_id}') batch_size = args.batch_size print(f'Local rank {args.local_rank}') if torch.cuda.is_available(): torch.cuda.set_device(args.local_rank) predictor = init_predictor(args) predict_fn = predictor.__call__ if hasattr( predictor, '__call__') else predictor.predict # create proc executor proc_exec = ProcessExecutor(args.queue_size) # create oss read process to read file path from filelist proc_exec.add( FileReadProcess( args.input_file, slice_id=worker_id, slice_count=num_worker, output_queue=proc_exec.get_output_queue())) # download and decode image data proc_exec.add( DownloadProcess( thread_num=args.predict_thread_num, input_queue=proc_exec.get_input_queue(), output_queue=proc_exec.get_output_queue(), is_video_url=False)) # transform image data proc_exec.add( PredictorProcess( input_queue=proc_exec.get_input_queue(), output_queue=proc_exec.get_output_queue(), predict_fn=predict_fn, batch_size=batch_size, local_rank=args.local_rank, thread_num=args.predict_thread_num)) proc_exec.add( DefaultResultFormatProcess( input_queue=proc_exec.get_input_queue(), output_queue=proc_exec.get_output_queue())) # Gather result to different dict of different type proc_exec.add( ResultGatherProcess( output_type_dict={}, input_queue=proc_exec.get_input_queue(), output_queue=proc_exec.get_output_queue())) # Write result proc_exec.add( FileWriteProcess( output_file=args.output_file, output_dir=args.output_dir, slice_id=worker_id, slice_count=num_worker, input_queue=proc_exec.get_input_queue())) proc_exec.run() proc_exec.wait() def build_and_run_table_io(args): os.environ['ODPS_CONFIG_FILE_PATH'] = args.odps_config rank, world_size = get_dist_info() worker_id = rank num_worker = world_size print(f'worker num {num_worker}') print(f'worker_id {worker_id}') batch_size = args.batch_size if torch.cuda.is_available(): torch.cuda.set_device(args.local_rank) predictor = init_predictor(args) predict_fn = predictor.__call__ if hasattr( predictor, '__call__') else predictor.predict # batch size should be less than the total number of data in input table table_read_batch_size = 1 table_read_thread_num = 4 # create proc executor proc_exec = ProcessExecutor(args.queue_size) # create oss read process to read file path from filelist selected_cols = list( set(args.image_col.split(',') + args.reserved_columns.split(','))) if args.image_col not in selected_cols: selected_cols.append(args.image_col) image_col_idx = selected_cols.index(args.image_col) proc_exec.add( TableReadProcess( args.input_table, selected_cols=selected_cols, slice_id=worker_id, slice_count=num_worker, output_queue=proc_exec.get_output_queue(), image_col_idx=image_col_idx, image_type=args.image_type, batch_size=table_read_batch_size, num_threads=table_read_thread_num)) if args.image_type == 'base64': base64_thread_num = args.preprocess_thread_num proc_exec.add( Base64DecodeProcess( thread_num=base64_thread_num, input_queue=proc_exec.get_input_queue(), output_queue=proc_exec.get_output_queue())) elif args.image_type == 'url': download_thread_num = args.preprocess_thread_num proc_exec.add( DownloadProcess( thread_num=download_thread_num, input_queue=proc_exec.get_input_queue(), output_queue=proc_exec.get_output_queue(), use_pil_decode=False)) # transform image data proc_exec.add( PredictorProcess( input_queue=proc_exec.get_input_queue(), output_queue=proc_exec.get_output_queue(), predict_fn=predict_fn, batch_size=batch_size, local_rank=args.local_rank, thread_num=args.predict_thread_num)) proc_exec.add( DefaultResultFormatProcess( input_queue=proc_exec.get_input_queue(), output_queue=proc_exec.get_output_queue(), reserved_col_names=args.reserved_columns.split(','), output_col_names=args.result_column.split(','))) # Write result output_cols = args.reserved_columns.split(',') + args.result_column.split( ',') proc_exec.add( TableWriteProcess( args.output_table, output_col_names=output_cols, slice_id=worker_id, input_queue=proc_exec.get_input_queue())) proc_exec.run() proc_exec.wait() def check_args(args, arg_name, default_value=''): assert getattr(args, arg_name) != '', f'{arg_name} should not be empty' def patch_logging(): # after get_root_logger, logging will not take effect because # it sets all other handler to level logging.INFO logger = get_root_logger() for handler in logger.root.handlers: if type(handler) is logging.StreamHandler: handler.setLevel(logging.INFO) if __name__ == '__main__': args = define_args() patch_logging() if args.launcher: init_dist(args.launcher, backend='nccl') if args.input_file != '': check_args(args, 'output_file') build_and_run_file_io(args) else: check_args(args, 'input_table') check_args(args, 'output_table') check_args(args, 'image_col') check_args(args, 'reserved_columns') check_args(args, 'result_column') build_and_run_table_io(args)