add prediction script (#239)

* add prediction script

* update doc
This commit is contained in:
wenmeng zhou 2022-12-05 18:06:00 +08:00 committed by GitHub
parent 3f533ad62e
commit befb23c2d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 737 additions and 4 deletions

View File

@ -8,7 +8,7 @@ concurrency:
jobs:
lint:
runs-on: ubuntu-latest
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.6

View File

@ -8,7 +8,7 @@ concurrency:
jobs:
build-n-publish:
runs-on: ubuntu-latest
runs-on: ubuntu-20.04
if: startsWith(github.event.ref, 'refs/tags')
steps:
- uses: actions/checkout@v2

View File

@ -93,7 +93,7 @@ Please refer to [quick_start.md](docs/source/quick_start.md) for quick start. We
* [using torchacc](docs/source/tutorials/torchacc.md)
* [file io for local and oss files](docs/source/tutorials/file.md)
* [using mmdetection model in EasyCV](docs/source/tutorials/mmdet_models_usage_guide.md)
* [batch prediction tools][docs/source/tutorials/predict.md]

View File

@ -93,6 +93,7 @@ EasyCV是一个涵盖多个领域的基于Pytorch的计算机视觉工具箱
* [torchacc使用](docs/source/tutorials/torchacc.md)
* [本地/oss文件读取](docs/source/tutorials/file.md)
* [mmdetection模型使用](docs/source/tutorials/mmdet_models_usage_guide.md)
* [批量推理工具][docs/source/tutorials/predict.md]
## 模型库

View File

@ -0,0 +1,186 @@
# 批量推理
EasyCV提供了工具支持本地大规模图片推理能力该工具支持读取本地图片、图片http链接、MaxCompute表数据使用EasyCV提供的各类Predictor进行预测。
## 依赖安装
安装easy_predict, easy_predict把图片预测过程中的数据读取/下载、图片解码、模型推理各个部分抽象成了独立的处理单元,每个处理单元支持多线程并发执行,能够大大加速任务整体的吞吐量。
```
pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/easy_predict-0.4.2-py2.py3-none-any.whl
```
## 数据格式
### 输入文件列表
当输入为一个文件时文件每行可以是一个本地图片路径也可以是图片url地址
本地文件路径
```shell
predict/test_data/000000289059.jpg
predict/test_data/000000289059.jpg
predict/test_data/000000289059.jpg
predict/test_data/000000289059.jpg
predict/test_data/000000289059.jpg
predict/test_data/000000289059.jpg
predict/test_data/000000289059.jpg
predict/test_data/000000289059.jpg
predict/test_data/000000289059.jpg
predict/test_data/000000289059.jpg
predict/test_data/000000289059.jpg
predict/test_data/000000289059.jpg
predict/test_data/000000289059.jpg
predict/test_data/000000289059.jpg
predict/test_data/000000289059.jpg
```
图片url
```shell
https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/ant%2Bhill_14_33.jpg
https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/ant%2Bhill_14_33.jpg
https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/ant%2Bhill_14_33.jpg
https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/ant%2Bhill_14_33.jpg
https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/ant%2Bhill_14_33.jpg
https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/ant%2Bhill_14_33.jpg
```
### 输入MaxCompute Table
输入表可以是一列或者多列, 其中一列需要是图像的url或者图像文件的二进制数据经过base64编码后的字符串(image_base64)
输入表schema示例如下
url数据
```shell
+------------------------------------------------------------------------------------+
| Field | Type | Label | Comment |
+------------------------------------------------------------------------------------+
| id | string | | |
| url | string | | |
+------------------------------------------------------------------------------------+
```
base64数据
输入表可以是一列或者多列, 其中一列需要是图像的url或者图像编码后的二进制数据经过base64编码的数据type为str
schema示例如下
```shell
+------------------------------------------------------------------------------------+
| Field | Type | Label | Comment |
+------------------------------------------------------------------------------------+
| id | string | | |
| base64 | string | | |
+------------------------------------------------------------------------------------+
```
## 运行
### 读取本地文件
单卡运行
```shell
PYTHONPATH=. python tools/predict.py \
--input_file predict/test.list \
--output_file predict/output.txt \
--model_type YoloXPredictor \
--model_path predict/test_data/yolox/epoch_300.pt
```
<details>
<summary>参数说明</summary>
- `input_file`: 输入文件路径
- `output_file`: 输出文件路径
- `model_type`: 模型类型, 对应easycv/predictors/下的不同Predictor类名 例如YoloXPredictor
- `model_path`: 模型文件路径/模型目录
</details>
多机多卡运行
这里多机多卡启动方式复用pytorch DDP方式 需要在GPU环境下使用
```shell
PYTHONPATH=. python -m torch.distributed.launch --nproc_per_node=2 --master_port=29527 \
tools/predict.py \
--input_file predict/test.list \
--output_file predict/output.txt \
--model_type YoloXPredictor \
--model_path predict/test_data/yolox/epoch_300.pt\
--launcher pytorch
```
<details>
<summary>参数说明</summary>
- `nproc_per_node`: 每个节点的gpu数
- `master_port`: master节点端口
- `master_addr`: master IP
- `input_file`: 输入文件路径
- `output_file`: 输出文件路径
- `model_type`: 模型类型, 对应easycv/predictors/下的不同Predictor类名 例如YoloXPredictor
- `model_path`: 模型文件路径/模型目录
</details>
### 读取MaxComputeTable
单卡示例
```shell
#创建输出表分区
odpscmd -e "alter table 表名 add partition (ds=分区名);"
PYTHONPATH=. python tools/predict.py \
--model_type YoloXPredictor \
--model_path predict/test_data/yolox/epoch_300.pt \
--input_table odps://项目名/tables/表名/ds=分区信息 \
--output_table odps://项目名/tables/表名/ds=分区信息\
--image_col url\
--image_type url\
--reserved_columns id\
--result_column result \
--odps_config /path/to/odps.config
```
<details>
- `model_type`: 模型类型, 对应easycv/predictors/下的不同Predictor类名 例如YoloXPredictor
- `model_path`: 模型文件路径/模型目录
- `input_table`: 输入表
- `output_table`: 输出表
- `image_col`: 图片数据所在列
- `image_type`: 图片类型, url or base64
- `reserved_columns`: 输入表保留列名,英文逗号分割
- `result_column`: 结果列名
- `odps_config`: MaxCompute配置文件
</details>
多卡示例
这里多机多卡启动方式复用pytorch DDP方式 需要在GPU环境下使用
```shell
#创建输出表分区
odpscmd -e "alter table 表名 add partition (ds=分区名);"
PYTHONPATH=. python -m torch.distributed.launch --nproc_per_node=3 --master_port=29527 \
tools/predict.py \
--model_type YoloXPredictor \
--model_path predict/test_data/yolox/epoch_300.pt \
--input_table odps://项目名/tables/表名/ds=分区名 \
--output_table odps://项目名/tables/表名/ds=分区名\
--image_col url\
--image_type url\
--reserved_columns id\
--result_column result \
--odps_config /path/to/odps.config \
--launcher pytorch
```

View File

@ -332,7 +332,8 @@ class TorchYoloXPredictor(YoloXPredictor):
model_config: config string for model to init, in json format
"""
if model_config:
model_config = json.loads(model_config)
if isinstance(model_config, str):
model_config = json.loads(model_config)
else:
model_config = {}

View File

@ -0,0 +1,80 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import glob
import json
import logging
import os
import sys
import tempfile
import unittest
import torch
from mmcv import Config
from tests.ut_config import (PRETRAINED_MODEL_SEGFORMER,
PRETRAINED_MODEL_YOLOXS_EXPORT, TEST_IMAGES_DIR)
from easycv.file import io
from easycv.utils.test_util import run_in_subprocess
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
logging.basicConfig(level=logging.INFO)
class PredictTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def tearDown(self):
super().tearDown()
def _base_predict(self, model_type, model_path, dist=False):
input_file = tempfile.NamedTemporaryFile('w').name
input_line_num = 10
with open(input_file, 'w') as ofile:
for _ in range(input_line_num):
ofile.write(
os.path.join(TEST_IMAGES_DIR, '000000289059.jpg') + '\n')
output_file = tempfile.NamedTemporaryFile('w').name
if dist:
cmd = f'PYTHONPATH=. python -m torch.distributed.launch --nproc_per_node=2 --master_port=29527 \
tools/predict.py \
--input_file {input_file} \
--output_file {output_file} \
--model_type {model_type} \
--model_path {model_path} \
--launcher pytorch'
else:
cmd = f'PYTHONPATH=. python tools/predict.py \
--input_file {input_file} \
--output_file {output_file} \
--model_type {model_type} \
--model_path {model_path} '
logging.info('run command: %s' % cmd)
run_in_subprocess(cmd)
with open(output_file, 'r') as infile:
output_line_num = len(infile.readlines())
self.assertEqual(input_line_num, output_line_num)
io.remove(input_file)
io.remove(output_file)
def test_predict(self):
model_type = 'YoloXPredictor'
model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
self._base_predict(model_type, model_path)
@unittest.skipIf(torch.cuda.device_count() <= 1, 'distributed unittest')
def test_predict_dist(self):
model_type = 'YoloXPredictor'
model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
self._base_predict(model_type, model_path, dist=True)
if __name__ == '__main__':
unittest.main()

9
tools/dist_predict.sh Executable file
View File

@ -0,0 +1,9 @@
#!/usr/bin/env bash
PYTHON=${PYTHON:-"python"}
GPUS=$1
PY_ARGS=${@:2}
PORT=${PORT:-29527}
PYTHONPATH=./ $PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
tools/predict.py $PY_ARGS --launcher pytorch \

456
tools/predict.py Normal file
View File

@ -0,0 +1,456 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# isort:skip_file
import argparse
import copy
import functools
import glob
import inspect
import logging
import os
import threading
import traceback
import torch
try:
import easy_predict
except ModuleNotFoundError:
print('please install easy_predict first using following instruction')
print(
'pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/easy_predict-0.4.2-py2.py3-none-any.whl'
)
exit()
from easy_predict import (Base64DecodeProcess, DataFields,
DefaultResultFormatProcess, DownloadProcess,
FileReadProcess, FileWriteProcess, Process,
ProcessExecutor, ResultGatherProcess,
TableReadProcess, TableWriteProcess)
from mmcv.runner import init_dist
from easycv.utils.dist_utils import get_dist_info
from easycv.utils.logger import get_root_logger
def define_args():
parser = argparse.ArgumentParser('easycv prediction')
parser.add_argument(
'--model_type',
default='',
help='model type, classifier/detector/segmentor/yolox')
parser.add_argument('--model_path', default='', help='path to model')
parser.add_argument(
'--model_config',
default='',
help='model config str, predictor v1 param')
# oss input output
parser.add_argument(
'--input_file',
default='',
help='filelist for images, eash line is a oss path or a local path')
parser.add_argument(
'--output_file',
default='',
help='oss file or local file to save predict info')
parser.add_argument(
'--output_dir',
default='',
help='output_directory to save image and video results')
parser.add_argument(
'--oss_prefix',
default='',
help='oss_prefix will be replaced with local_prefix in input_file')
parser.add_argument(
'--local_prefix',
default='',
help='oss_prefix will be replaced with local_prefix in input_file')
# table input output
parser.add_argument('--input_table', default='', help='input table name')
parser.add_argument('--output_table', default='', help='output table name')
parser.add_argument('--image_col', default='', help='input image column')
parser.add_argument(
'--reserved_columns',
default='',
help=
'columns from input table to be saved to output table, comma seperated'
)
parser.add_argument(
'--result_column',
default='',
help='result columns to be saved to output table, comma seperated')
parser.add_argument(
'--odps_config',
default='./odps.config',
help='path to your odps config file')
parser.add_argument(
'--image_type', default='url', help='image data type, url or base64')
# common args
parser.add_argument(
'--queue_size',
type=int,
default=1024,
help='length of queues used for each process')
parser.add_argument(
'--predict_thread_num',
type=int,
default=1,
help='number of threads used for prediction')
parser.add_argument(
'--preprocess_thread_num',
type=int,
default=1,
help='number of threads used for preprocessing and downloading')
parser.add_argument(
'--batch_size',
type=int,
default=1,
help='batch size used for prediction')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--launcher',
type=str,
choices=[None, 'pytorch'],
help='if assigned pytorch, should be used in gpu environment')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
class PredictorProcess(Process):
def __init__(self,
predict_fn,
batch_size,
thread_num,
local_rank=0,
input_queue=None,
output_queue=None):
job_name = 'Predictor'
if torch.cuda.is_available():
thread_init_fn = functools.partial(torch.cuda.set_device,
local_rank)
else:
thread_init_fn = None
super(PredictorProcess, self).__init__(
job_name,
thread_num,
input_queue,
output_queue,
batch_size=batch_size,
thread_init_fn=thread_init_fn)
self.predict_fn = predict_fn
self.data_lock = threading.Lock()
self.all_data_failed = True
self.input_empty = True
self.local_rank = 0
def process(self, input_batch):
"""
Read a batch of image from input_queue and predict
Args:
input_batch: a batch of input data
Returns:
output_queue: unstak batch, push input data and prediction result into queue
"""
valid_input = []
valid_indices = []
valid_frame_ids = []
if self.batch_size == 1:
input_batch = [input_batch]
if self.input_empty and len(input_batch) > 0:
self.data_lock.acquire()
self.input_empty = False
self.data_lock.release()
output_data_list = input_batch
for out in output_data_list:
out[DataFields.prediction_result] = None
for idx, input_data in enumerate(input_batch):
if DataFields.image in input_data \
and input_data[DataFields.image] is not None:
valid_input.append(input_data[DataFields.image])
valid_indices.append(idx)
if len(valid_input) > 0:
try:
# flatten video_clip to images, use image predictor to predict
# then regroup the result to a list for one video_clip
output_list = self.predict_fn(valid_input)
if len(output_list) > 0:
assert isinstance(output_list[0], dict), \
'the element in predictor output must be a dict'
if self.all_data_failed:
self.data_lock.acquire()
self.all_data_failed = False
self.data_lock.release()
except Exception:
logging.error(traceback.format_exc())
output_list = [None for i in range(len(valid_input))]
for idx, result_dict in zip(valid_indices, output_list):
output_data = output_data_list[idx]
output_data[DataFields.prediction_result] = result_dict
if result_dict is None:
output_data[DataFields.error_msg] = 'prediction error'
output_data_list[idx] = output_data
for output_data in output_data_list:
self.put(output_data)
def destroy(self):
if not self.input_empty and self.all_data_failed:
raise RuntimeError(
'failed to predict all the input data, please see exception throwed above in the log'
)
def init_predictor(args):
model_type = args.model_type
model_path = args.model_path
batch_size = args.batch_size
from easycv.predictors.builder import build_predictor
ori_model_path = model_path
if not ori_model_path.endswith('pth') and not ori_model_path.endswith(
'pt'):
model_path = glob.glob('%s/**/*.pt*' % ori_model_path, recursive=True)
assert len(model_path) > 0, f'model not found in {ori_model_path}'
assert len(
model_path) == 1, f'more than one model file is found {model_path}'
model_path = model_path[0]
logging.info(f'model found: {model_path}')
config_path = glob.glob('%s/**/*.py' % ori_model_path, recursive=True)
if len(config_path) == 0:
config_path = None
else:
assert len(
config_path
) == 1, f'more than one config file is found {config_path}'
config_path = config_path[0]
logging.info(f'config found: {config_path}')
else:
model_path = ori_model_path
config_path = None
predictor_cfg = dict(type=model_type, model_path=model_path)
if config_path:
predictor_cfg['config_file'] = config_path
if args.model_config != '':
predictor_cfg['model_config'] = args.model_config
predictor = build_predictor(predictor_cfg)
return predictor
def replace_oss_with_local_path(ori_file, dst_file, bucket_prefix,
local_prefix):
bucket_prefix = bucket_prefix.rstrip('/') + '/'
local_prefix = local_prefix.rstrip('/') + '/'
with open(ori_file, 'r') as infile:
with open(dst_file, 'w') as ofile:
for l in infile:
if l.startswith('oss://'):
l = l.replace(bucket_prefix, local_prefix)
ofile.write(l)
def build_and_run_file_io(args):
# distribute info
rank, world_size = get_dist_info()
worker_id = rank
input_oss_file_new_host = args.input_file + '.tmp%d' % worker_id
replace_oss_with_local_path(args.input_file, input_oss_file_new_host,
args.oss_prefix, args.local_prefix)
args.input_file = input_oss_file_new_host
num_worker = world_size
print(f'worker num {num_worker}')
print(f'worker_id {worker_id}')
batch_size = args.batch_size
print(f'Local rank {args.local_rank}')
if torch.cuda.is_available():
torch.cuda.set_device(args.local_rank)
predictor = init_predictor(args)
predict_fn = predictor.__call__ if hasattr(
predictor, '__call__') else predictor.predict
# create proc executor
proc_exec = ProcessExecutor(args.queue_size)
# create oss read process to read file path from filelist
proc_exec.add(
FileReadProcess(
args.input_file,
slice_id=worker_id,
slice_count=num_worker,
output_queue=proc_exec.get_output_queue()))
# download and decode image data
proc_exec.add(
DownloadProcess(
thread_num=args.predict_thread_num,
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue(),
is_video_url=False))
# transform image data
proc_exec.add(
PredictorProcess(
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue(),
predict_fn=predict_fn,
batch_size=batch_size,
local_rank=args.local_rank,
thread_num=args.predict_thread_num))
proc_exec.add(
DefaultResultFormatProcess(
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue()))
# Gather result to different dict of different type
proc_exec.add(
ResultGatherProcess(
output_type_dict={},
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue()))
# Write result
proc_exec.add(
FileWriteProcess(
output_file=args.output_file,
output_dir=args.output_dir,
slice_id=worker_id,
slice_count=num_worker,
input_queue=proc_exec.get_input_queue()))
proc_exec.run()
proc_exec.wait()
def build_and_run_table_io(args):
os.environ['ODPS_CONFIG_FILE_PATH'] = args.odps_config
rank, world_size = get_dist_info()
worker_id = rank
num_worker = world_size
print(f'worker num {num_worker}')
print(f'worker_id {worker_id}')
batch_size = args.batch_size
if torch.cuda.is_available():
torch.cuda.set_device(args.local_rank)
predictor = init_predictor(args)
predict_fn = predictor.__call__ if hasattr(
predictor, '__call__') else predictor.predict
# batch size should be less than the total number of data in input table
table_read_batch_size = 1
table_read_thread_num = 4
# create proc executor
proc_exec = ProcessExecutor(args.queue_size)
# create oss read process to read file path from filelist
selected_cols = list(
set(args.image_col.split(',') + args.reserved_columns.split(',')))
if args.image_col not in selected_cols:
selected_cols.append(args.image_col)
image_col_idx = selected_cols.index(args.image_col)
proc_exec.add(
TableReadProcess(
args.input_table,
selected_cols=selected_cols,
slice_id=worker_id,
slice_count=num_worker,
output_queue=proc_exec.get_output_queue(),
image_col_idx=image_col_idx,
image_type=args.image_type,
batch_size=table_read_batch_size,
num_threads=table_read_thread_num))
if args.image_type == 'base64':
base64_thread_num = args.preprocess_thread_num
proc_exec.add(
Base64DecodeProcess(
thread_num=base64_thread_num,
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue()))
elif args.image_type == 'url':
download_thread_num = args.preprocess_thread_num
proc_exec.add(
DownloadProcess(
thread_num=download_thread_num,
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue(),
use_pil_decode=False))
# transform image data
proc_exec.add(
PredictorProcess(
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue(),
predict_fn=predict_fn,
batch_size=batch_size,
local_rank=args.local_rank,
thread_num=args.predict_thread_num))
proc_exec.add(
DefaultResultFormatProcess(
input_queue=proc_exec.get_input_queue(),
output_queue=proc_exec.get_output_queue(),
reserved_col_names=args.reserved_columns.split(','),
output_col_names=args.result_column.split(',')))
# Write result
output_cols = args.reserved_columns.split(',') + args.result_column.split(
',')
proc_exec.add(
TableWriteProcess(
args.output_table,
output_col_names=output_cols,
slice_id=worker_id,
input_queue=proc_exec.get_input_queue()))
proc_exec.run()
proc_exec.wait()
def check_args(args, arg_name, default_value=''):
assert getattr(args, arg_name) != '', f'{arg_name} should not be empty'
def patch_logging():
# after get_root_logger, logging will not take effect because
# it sets all other handler to level logging.INFO
logger = get_root_logger()
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.INFO)
if __name__ == '__main__':
args = define_args()
patch_logging()
if args.launcher:
init_dist(args.launcher, backend='nccl')
if args.input_file != '':
check_args(args, 'output_file')
build_and_run_file_io(args)
else:
check_args(args, 'input_table')
check_args(args, 'output_table')
check_args(args, 'image_col')
check_args(args, 'reserved_columns')
check_args(args, 'result_column')
build_and_run_table_io(args)