mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
parent
3f533ad62e
commit
befb23c2d5
2
.github/workflows/lint.yaml
vendored
2
.github/workflows/lint.yaml
vendored
@ -8,7 +8,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.6
|
||||
|
2
.github/workflows/publish.yaml
vendored
2
.github/workflows/publish.yaml
vendored
@ -8,7 +8,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build-n-publish:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-20.04
|
||||
if: startsWith(github.event.ref, 'refs/tags')
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
@ -93,7 +93,7 @@ Please refer to [quick_start.md](docs/source/quick_start.md) for quick start. We
|
||||
* [using torchacc](docs/source/tutorials/torchacc.md)
|
||||
* [file io for local and oss files](docs/source/tutorials/file.md)
|
||||
* [using mmdetection model in EasyCV](docs/source/tutorials/mmdet_models_usage_guide.md)
|
||||
|
||||
* [batch prediction tools][docs/source/tutorials/predict.md]
|
||||
|
||||
|
||||
|
||||
|
@ -93,6 +93,7 @@ EasyCV是一个涵盖多个领域的基于Pytorch的计算机视觉工具箱,
|
||||
* [torchacc使用](docs/source/tutorials/torchacc.md)
|
||||
* [本地/oss文件读取](docs/source/tutorials/file.md)
|
||||
* [mmdetection模型使用](docs/source/tutorials/mmdet_models_usage_guide.md)
|
||||
* [批量推理工具][docs/source/tutorials/predict.md]
|
||||
|
||||
## 模型库
|
||||
|
||||
|
186
docs/source/tutorials/predict.md
Normal file
186
docs/source/tutorials/predict.md
Normal file
@ -0,0 +1,186 @@
|
||||
# 批量推理
|
||||
EasyCV提供了工具支持本地大规模图片推理能力,该工具支持读取本地图片、图片http链接、MaxCompute表数据,使用EasyCV提供的各类Predictor进行预测。
|
||||
|
||||
## 依赖安装
|
||||
安装easy_predict, easy_predict把图片预测过程中的数据读取/下载、图片解码、模型推理各个部分抽象成了独立的处理单元,每个处理单元支持多线程并发执行,能够大大加速任务整体的吞吐量。
|
||||
```
|
||||
pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/easy_predict-0.4.2-py2.py3-none-any.whl
|
||||
```
|
||||
|
||||
|
||||
## 数据格式
|
||||
### 输入文件列表
|
||||
当输入为一个文件时,文件每行可以是一个本地图片路径,也可以是图片url地址
|
||||
|
||||
本地文件路径
|
||||
```shell
|
||||
predict/test_data/000000289059.jpg
|
||||
predict/test_data/000000289059.jpg
|
||||
predict/test_data/000000289059.jpg
|
||||
predict/test_data/000000289059.jpg
|
||||
predict/test_data/000000289059.jpg
|
||||
predict/test_data/000000289059.jpg
|
||||
predict/test_data/000000289059.jpg
|
||||
predict/test_data/000000289059.jpg
|
||||
predict/test_data/000000289059.jpg
|
||||
predict/test_data/000000289059.jpg
|
||||
predict/test_data/000000289059.jpg
|
||||
predict/test_data/000000289059.jpg
|
||||
predict/test_data/000000289059.jpg
|
||||
predict/test_data/000000289059.jpg
|
||||
predict/test_data/000000289059.jpg
|
||||
```
|
||||
|
||||
图片url
|
||||
```shell
|
||||
https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/ant%2Bhill_14_33.jpg
|
||||
https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/ant%2Bhill_14_33.jpg
|
||||
https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/ant%2Bhill_14_33.jpg
|
||||
https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/ant%2Bhill_14_33.jpg
|
||||
https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/ant%2Bhill_14_33.jpg
|
||||
https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/ant%2Bhill_14_33.jpg
|
||||
```
|
||||
|
||||
|
||||
### 输入MaxCompute Table
|
||||
输入表可以是一列或者多列, 其中一列需要是图像的url或者图像文件的二进制数据经过base64编码后的字符串(image_base64)
|
||||
|
||||
输入表schema示例如下:
|
||||
url数据
|
||||
```shell
|
||||
+------------------------------------------------------------------------------------+
|
||||
| Field | Type | Label | Comment |
|
||||
+------------------------------------------------------------------------------------+
|
||||
| id | string | | |
|
||||
| url | string | | |
|
||||
+------------------------------------------------------------------------------------+
|
||||
```
|
||||
|
||||
base64数据
|
||||
输入表可以是一列或者多列, 其中一列需要是图像的url或者图像编码后的二进制数据经过base64编码的数据,type为str
|
||||
schema示例如下
|
||||
```shell
|
||||
+------------------------------------------------------------------------------------+
|
||||
| Field | Type | Label | Comment |
|
||||
+------------------------------------------------------------------------------------+
|
||||
| id | string | | |
|
||||
| base64 | string | | |
|
||||
+------------------------------------------------------------------------------------+
|
||||
```
|
||||
|
||||
|
||||
## 运行
|
||||
|
||||
### 读取本地文件
|
||||
|
||||
单卡运行
|
||||
```shell
|
||||
PYTHONPATH=. python tools/predict.py \
|
||||
--input_file predict/test.list \
|
||||
--output_file predict/output.txt \
|
||||
--model_type YoloXPredictor \
|
||||
--model_path predict/test_data/yolox/epoch_300.pt
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>参数说明</summary>
|
||||
|
||||
- `input_file`: 输入文件路径
|
||||
|
||||
- `output_file`: 输出文件路径
|
||||
|
||||
- `model_type`: 模型类型, 对应easycv/predictors/下的不同Predictor类名, 例如YoloXPredictor
|
||||
|
||||
- `model_path`: 模型文件路径/模型目录
|
||||
</details>
|
||||
|
||||
多机多卡运行
|
||||
|
||||
这里多机多卡启动方式复用pytorch DDP方式, 需要在GPU环境下使用
|
||||
```shell
|
||||
PYTHONPATH=. python -m torch.distributed.launch --nproc_per_node=2 --master_port=29527 \
|
||||
tools/predict.py \
|
||||
--input_file predict/test.list \
|
||||
--output_file predict/output.txt \
|
||||
--model_type YoloXPredictor \
|
||||
--model_path predict/test_data/yolox/epoch_300.pt\
|
||||
--launcher pytorch
|
||||
```
|
||||
<details>
|
||||
<summary>参数说明</summary>
|
||||
|
||||
- `nproc_per_node`: 每个节点的gpu数
|
||||
|
||||
- `master_port`: master节点端口
|
||||
|
||||
- `master_addr`: master IP
|
||||
|
||||
- `input_file`: 输入文件路径
|
||||
|
||||
- `output_file`: 输出文件路径
|
||||
|
||||
- `model_type`: 模型类型, 对应easycv/predictors/下的不同Predictor类名, 例如YoloXPredictor
|
||||
|
||||
- `model_path`: 模型文件路径/模型目录
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
### 读取MaxComputeTable
|
||||
|
||||
单卡示例
|
||||
```shell
|
||||
#创建输出表分区
|
||||
odpscmd -e "alter table 表名 add partition (ds=分区名);"
|
||||
|
||||
PYTHONPATH=. python tools/predict.py \
|
||||
--model_type YoloXPredictor \
|
||||
--model_path predict/test_data/yolox/epoch_300.pt \
|
||||
--input_table odps://项目名/tables/表名/ds=分区信息 \
|
||||
--output_table odps://项目名/tables/表名/ds=分区信息\
|
||||
--image_col url\
|
||||
--image_type url\
|
||||
--reserved_columns id\
|
||||
--result_column result \
|
||||
--odps_config /path/to/odps.config
|
||||
```
|
||||
|
||||
<details>
|
||||
- `model_type`: 模型类型, 对应easycv/predictors/下的不同Predictor类名, 例如YoloXPredictor
|
||||
|
||||
- `model_path`: 模型文件路径/模型目录
|
||||
|
||||
- `input_table`: 输入表
|
||||
|
||||
- `output_table`: 输出表
|
||||
|
||||
- `image_col`: 图片数据所在列
|
||||
|
||||
- `image_type`: 图片类型, url or base64
|
||||
- `reserved_columns`: 输入表保留列名,英文逗号分割
|
||||
- `result_column`: 结果列名
|
||||
- `odps_config`: MaxCompute配置文件
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
多卡示例
|
||||
|
||||
这里多机多卡启动方式复用pytorch DDP方式, 需要在GPU环境下使用
|
||||
```shell
|
||||
#创建输出表分区
|
||||
odpscmd -e "alter table 表名 add partition (ds=分区名);"
|
||||
|
||||
PYTHONPATH=. python -m torch.distributed.launch --nproc_per_node=3 --master_port=29527 \
|
||||
tools/predict.py \
|
||||
--model_type YoloXPredictor \
|
||||
--model_path predict/test_data/yolox/epoch_300.pt \
|
||||
--input_table odps://项目名/tables/表名/ds=分区名 \
|
||||
--output_table odps://项目名/tables/表名/ds=分区名\
|
||||
--image_col url\
|
||||
--image_type url\
|
||||
--reserved_columns id\
|
||||
--result_column result \
|
||||
--odps_config /path/to/odps.config \
|
||||
--launcher pytorch
|
||||
```
|
@ -332,7 +332,8 @@ class TorchYoloXPredictor(YoloXPredictor):
|
||||
model_config: config string for model to init, in json format
|
||||
"""
|
||||
if model_config:
|
||||
model_config = json.loads(model_config)
|
||||
if isinstance(model_config, str):
|
||||
model_config = json.loads(model_config)
|
||||
else:
|
||||
model_config = {}
|
||||
|
||||
|
80
tests/tools/test_predict.py
Normal file
80
tests/tools/test_predict.py
Normal file
@ -0,0 +1,80 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import copy
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from mmcv import Config
|
||||
from tests.ut_config import (PRETRAINED_MODEL_SEGFORMER,
|
||||
PRETRAINED_MODEL_YOLOXS_EXPORT, TEST_IMAGES_DIR)
|
||||
|
||||
from easycv.file import io
|
||||
from easycv.utils.test_util import run_in_subprocess
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class PredictTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
def _base_predict(self, model_type, model_path, dist=False):
|
||||
input_file = tempfile.NamedTemporaryFile('w').name
|
||||
input_line_num = 10
|
||||
with open(input_file, 'w') as ofile:
|
||||
for _ in range(input_line_num):
|
||||
ofile.write(
|
||||
os.path.join(TEST_IMAGES_DIR, '000000289059.jpg') + '\n')
|
||||
output_file = tempfile.NamedTemporaryFile('w').name
|
||||
|
||||
if dist:
|
||||
cmd = f'PYTHONPATH=. python -m torch.distributed.launch --nproc_per_node=2 --master_port=29527 \
|
||||
tools/predict.py \
|
||||
--input_file {input_file} \
|
||||
--output_file {output_file} \
|
||||
--model_type {model_type} \
|
||||
--model_path {model_path} \
|
||||
--launcher pytorch'
|
||||
|
||||
else:
|
||||
cmd = f'PYTHONPATH=. python tools/predict.py \
|
||||
--input_file {input_file} \
|
||||
--output_file {output_file} \
|
||||
--model_type {model_type} \
|
||||
--model_path {model_path} '
|
||||
|
||||
logging.info('run command: %s' % cmd)
|
||||
run_in_subprocess(cmd)
|
||||
|
||||
with open(output_file, 'r') as infile:
|
||||
output_line_num = len(infile.readlines())
|
||||
self.assertEqual(input_line_num, output_line_num)
|
||||
|
||||
io.remove(input_file)
|
||||
io.remove(output_file)
|
||||
|
||||
def test_predict(self):
|
||||
model_type = 'YoloXPredictor'
|
||||
model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
|
||||
self._base_predict(model_type, model_path)
|
||||
|
||||
@unittest.skipIf(torch.cuda.device_count() <= 1, 'distributed unittest')
|
||||
def test_predict_dist(self):
|
||||
model_type = 'YoloXPredictor'
|
||||
model_path = PRETRAINED_MODEL_YOLOXS_EXPORT
|
||||
self._base_predict(model_type, model_path, dist=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
9
tools/dist_predict.sh
Executable file
9
tools/dist_predict.sh
Executable file
@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env bash
|
||||
PYTHON=${PYTHON:-"python"}
|
||||
|
||||
GPUS=$1
|
||||
PY_ARGS=${@:2}
|
||||
PORT=${PORT:-29527}
|
||||
|
||||
PYTHONPATH=./ $PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
|
||||
tools/predict.py $PY_ARGS --launcher pytorch \
|
456
tools/predict.py
Normal file
456
tools/predict.py
Normal file
@ -0,0 +1,456 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
# isort:skip_file
|
||||
import argparse
|
||||
import copy
|
||||
import functools
|
||||
import glob
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import easy_predict
|
||||
except ModuleNotFoundError:
|
||||
print('please install easy_predict first using following instruction')
|
||||
print(
|
||||
'pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/easy_predict-0.4.2-py2.py3-none-any.whl'
|
||||
)
|
||||
exit()
|
||||
|
||||
from easy_predict import (Base64DecodeProcess, DataFields,
|
||||
DefaultResultFormatProcess, DownloadProcess,
|
||||
FileReadProcess, FileWriteProcess, Process,
|
||||
ProcessExecutor, ResultGatherProcess,
|
||||
TableReadProcess, TableWriteProcess)
|
||||
from mmcv.runner import init_dist
|
||||
|
||||
from easycv.utils.dist_utils import get_dist_info
|
||||
from easycv.utils.logger import get_root_logger
|
||||
|
||||
|
||||
def define_args():
|
||||
parser = argparse.ArgumentParser('easycv prediction')
|
||||
parser.add_argument(
|
||||
'--model_type',
|
||||
default='',
|
||||
help='model type, classifier/detector/segmentor/yolox')
|
||||
parser.add_argument('--model_path', default='', help='path to model')
|
||||
parser.add_argument(
|
||||
'--model_config',
|
||||
default='',
|
||||
help='model config str, predictor v1 param')
|
||||
|
||||
# oss input output
|
||||
parser.add_argument(
|
||||
'--input_file',
|
||||
default='',
|
||||
help='filelist for images, eash line is a oss path or a local path')
|
||||
parser.add_argument(
|
||||
'--output_file',
|
||||
default='',
|
||||
help='oss file or local file to save predict info')
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
default='',
|
||||
help='output_directory to save image and video results')
|
||||
parser.add_argument(
|
||||
'--oss_prefix',
|
||||
default='',
|
||||
help='oss_prefix will be replaced with local_prefix in input_file')
|
||||
parser.add_argument(
|
||||
'--local_prefix',
|
||||
default='',
|
||||
help='oss_prefix will be replaced with local_prefix in input_file')
|
||||
|
||||
# table input output
|
||||
parser.add_argument('--input_table', default='', help='input table name')
|
||||
parser.add_argument('--output_table', default='', help='output table name')
|
||||
parser.add_argument('--image_col', default='', help='input image column')
|
||||
parser.add_argument(
|
||||
'--reserved_columns',
|
||||
default='',
|
||||
help=
|
||||
'columns from input table to be saved to output table, comma seperated'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--result_column',
|
||||
default='',
|
||||
help='result columns to be saved to output table, comma seperated')
|
||||
parser.add_argument(
|
||||
'--odps_config',
|
||||
default='./odps.config',
|
||||
help='path to your odps config file')
|
||||
parser.add_argument(
|
||||
'--image_type', default='url', help='image data type, url or base64')
|
||||
|
||||
# common args
|
||||
parser.add_argument(
|
||||
'--queue_size',
|
||||
type=int,
|
||||
default=1024,
|
||||
help='length of queues used for each process')
|
||||
parser.add_argument(
|
||||
'--predict_thread_num',
|
||||
type=int,
|
||||
default=1,
|
||||
help='number of threads used for prediction')
|
||||
parser.add_argument(
|
||||
'--preprocess_thread_num',
|
||||
type=int,
|
||||
default=1,
|
||||
help='number of threads used for preprocessing and downloading')
|
||||
parser.add_argument(
|
||||
'--batch_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='batch size used for prediction')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
type=str,
|
||||
choices=[None, 'pytorch'],
|
||||
help='if assigned pytorch, should be used in gpu environment')
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class PredictorProcess(Process):
|
||||
|
||||
def __init__(self,
|
||||
predict_fn,
|
||||
batch_size,
|
||||
thread_num,
|
||||
local_rank=0,
|
||||
input_queue=None,
|
||||
output_queue=None):
|
||||
job_name = 'Predictor'
|
||||
if torch.cuda.is_available():
|
||||
thread_init_fn = functools.partial(torch.cuda.set_device,
|
||||
local_rank)
|
||||
else:
|
||||
thread_init_fn = None
|
||||
super(PredictorProcess, self).__init__(
|
||||
job_name,
|
||||
thread_num,
|
||||
input_queue,
|
||||
output_queue,
|
||||
batch_size=batch_size,
|
||||
thread_init_fn=thread_init_fn)
|
||||
self.predict_fn = predict_fn
|
||||
self.data_lock = threading.Lock()
|
||||
self.all_data_failed = True
|
||||
self.input_empty = True
|
||||
self.local_rank = 0
|
||||
|
||||
def process(self, input_batch):
|
||||
"""
|
||||
Read a batch of image from input_queue and predict
|
||||
|
||||
Args:
|
||||
input_batch: a batch of input data
|
||||
Returns:
|
||||
output_queue: unstak batch, push input data and prediction result into queue
|
||||
"""
|
||||
valid_input = []
|
||||
valid_indices = []
|
||||
valid_frame_ids = []
|
||||
if self.batch_size == 1:
|
||||
input_batch = [input_batch]
|
||||
if self.input_empty and len(input_batch) > 0:
|
||||
self.data_lock.acquire()
|
||||
self.input_empty = False
|
||||
self.data_lock.release()
|
||||
|
||||
output_data_list = input_batch
|
||||
for out in output_data_list:
|
||||
out[DataFields.prediction_result] = None
|
||||
|
||||
for idx, input_data in enumerate(input_batch):
|
||||
if DataFields.image in input_data \
|
||||
and input_data[DataFields.image] is not None:
|
||||
valid_input.append(input_data[DataFields.image])
|
||||
valid_indices.append(idx)
|
||||
|
||||
if len(valid_input) > 0:
|
||||
try:
|
||||
# flatten video_clip to images, use image predictor to predict
|
||||
# then regroup the result to a list for one video_clip
|
||||
output_list = self.predict_fn(valid_input)
|
||||
|
||||
if len(output_list) > 0:
|
||||
assert isinstance(output_list[0], dict), \
|
||||
'the element in predictor output must be a dict'
|
||||
|
||||
if self.all_data_failed:
|
||||
self.data_lock.acquire()
|
||||
self.all_data_failed = False
|
||||
self.data_lock.release()
|
||||
|
||||
except Exception:
|
||||
logging.error(traceback.format_exc())
|
||||
output_list = [None for i in range(len(valid_input))]
|
||||
|
||||
for idx, result_dict in zip(valid_indices, output_list):
|
||||
output_data = output_data_list[idx]
|
||||
output_data[DataFields.prediction_result] = result_dict
|
||||
if result_dict is None:
|
||||
output_data[DataFields.error_msg] = 'prediction error'
|
||||
|
||||
output_data_list[idx] = output_data
|
||||
|
||||
for output_data in output_data_list:
|
||||
self.put(output_data)
|
||||
|
||||
def destroy(self):
|
||||
if not self.input_empty and self.all_data_failed:
|
||||
raise RuntimeError(
|
||||
'failed to predict all the input data, please see exception throwed above in the log'
|
||||
)
|
||||
|
||||
|
||||
def init_predictor(args):
|
||||
model_type = args.model_type
|
||||
model_path = args.model_path
|
||||
batch_size = args.batch_size
|
||||
from easycv.predictors.builder import build_predictor
|
||||
|
||||
ori_model_path = model_path
|
||||
if not ori_model_path.endswith('pth') and not ori_model_path.endswith(
|
||||
'pt'):
|
||||
model_path = glob.glob('%s/**/*.pt*' % ori_model_path, recursive=True)
|
||||
assert len(model_path) > 0, f'model not found in {ori_model_path}'
|
||||
assert len(
|
||||
model_path) == 1, f'more than one model file is found {model_path}'
|
||||
model_path = model_path[0]
|
||||
logging.info(f'model found: {model_path}')
|
||||
|
||||
config_path = glob.glob('%s/**/*.py' % ori_model_path, recursive=True)
|
||||
if len(config_path) == 0:
|
||||
config_path = None
|
||||
else:
|
||||
assert len(
|
||||
config_path
|
||||
) == 1, f'more than one config file is found {config_path}'
|
||||
config_path = config_path[0]
|
||||
logging.info(f'config found: {config_path}')
|
||||
else:
|
||||
model_path = ori_model_path
|
||||
config_path = None
|
||||
|
||||
predictor_cfg = dict(type=model_type, model_path=model_path)
|
||||
if config_path:
|
||||
predictor_cfg['config_file'] = config_path
|
||||
if args.model_config != '':
|
||||
predictor_cfg['model_config'] = args.model_config
|
||||
predictor = build_predictor(predictor_cfg)
|
||||
return predictor
|
||||
|
||||
|
||||
def replace_oss_with_local_path(ori_file, dst_file, bucket_prefix,
|
||||
local_prefix):
|
||||
bucket_prefix = bucket_prefix.rstrip('/') + '/'
|
||||
local_prefix = local_prefix.rstrip('/') + '/'
|
||||
with open(ori_file, 'r') as infile:
|
||||
with open(dst_file, 'w') as ofile:
|
||||
for l in infile:
|
||||
if l.startswith('oss://'):
|
||||
l = l.replace(bucket_prefix, local_prefix)
|
||||
ofile.write(l)
|
||||
|
||||
|
||||
def build_and_run_file_io(args):
|
||||
# distribute info
|
||||
rank, world_size = get_dist_info()
|
||||
worker_id = rank
|
||||
|
||||
input_oss_file_new_host = args.input_file + '.tmp%d' % worker_id
|
||||
replace_oss_with_local_path(args.input_file, input_oss_file_new_host,
|
||||
args.oss_prefix, args.local_prefix)
|
||||
args.input_file = input_oss_file_new_host
|
||||
num_worker = world_size
|
||||
print(f'worker num {num_worker}')
|
||||
print(f'worker_id {worker_id}')
|
||||
batch_size = args.batch_size
|
||||
print(f'Local rank {args.local_rank}')
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
predictor = init_predictor(args)
|
||||
predict_fn = predictor.__call__ if hasattr(
|
||||
predictor, '__call__') else predictor.predict
|
||||
# create proc executor
|
||||
proc_exec = ProcessExecutor(args.queue_size)
|
||||
|
||||
# create oss read process to read file path from filelist
|
||||
proc_exec.add(
|
||||
FileReadProcess(
|
||||
args.input_file,
|
||||
slice_id=worker_id,
|
||||
slice_count=num_worker,
|
||||
output_queue=proc_exec.get_output_queue()))
|
||||
|
||||
# download and decode image data
|
||||
proc_exec.add(
|
||||
DownloadProcess(
|
||||
thread_num=args.predict_thread_num,
|
||||
input_queue=proc_exec.get_input_queue(),
|
||||
output_queue=proc_exec.get_output_queue(),
|
||||
is_video_url=False))
|
||||
|
||||
# transform image data
|
||||
proc_exec.add(
|
||||
PredictorProcess(
|
||||
input_queue=proc_exec.get_input_queue(),
|
||||
output_queue=proc_exec.get_output_queue(),
|
||||
predict_fn=predict_fn,
|
||||
batch_size=batch_size,
|
||||
local_rank=args.local_rank,
|
||||
thread_num=args.predict_thread_num))
|
||||
|
||||
proc_exec.add(
|
||||
DefaultResultFormatProcess(
|
||||
input_queue=proc_exec.get_input_queue(),
|
||||
output_queue=proc_exec.get_output_queue()))
|
||||
|
||||
# Gather result to different dict of different type
|
||||
proc_exec.add(
|
||||
ResultGatherProcess(
|
||||
output_type_dict={},
|
||||
input_queue=proc_exec.get_input_queue(),
|
||||
output_queue=proc_exec.get_output_queue()))
|
||||
|
||||
# Write result
|
||||
proc_exec.add(
|
||||
FileWriteProcess(
|
||||
output_file=args.output_file,
|
||||
output_dir=args.output_dir,
|
||||
slice_id=worker_id,
|
||||
slice_count=num_worker,
|
||||
input_queue=proc_exec.get_input_queue()))
|
||||
|
||||
proc_exec.run()
|
||||
proc_exec.wait()
|
||||
|
||||
|
||||
def build_and_run_table_io(args):
|
||||
os.environ['ODPS_CONFIG_FILE_PATH'] = args.odps_config
|
||||
|
||||
rank, world_size = get_dist_info()
|
||||
worker_id = rank
|
||||
num_worker = world_size
|
||||
print(f'worker num {num_worker}')
|
||||
print(f'worker_id {worker_id}')
|
||||
|
||||
batch_size = args.batch_size
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
predictor = init_predictor(args)
|
||||
predict_fn = predictor.__call__ if hasattr(
|
||||
predictor, '__call__') else predictor.predict
|
||||
# batch size should be less than the total number of data in input table
|
||||
table_read_batch_size = 1
|
||||
table_read_thread_num = 4
|
||||
|
||||
# create proc executor
|
||||
proc_exec = ProcessExecutor(args.queue_size)
|
||||
|
||||
# create oss read process to read file path from filelist
|
||||
selected_cols = list(
|
||||
set(args.image_col.split(',') + args.reserved_columns.split(',')))
|
||||
if args.image_col not in selected_cols:
|
||||
selected_cols.append(args.image_col)
|
||||
image_col_idx = selected_cols.index(args.image_col)
|
||||
proc_exec.add(
|
||||
TableReadProcess(
|
||||
args.input_table,
|
||||
selected_cols=selected_cols,
|
||||
slice_id=worker_id,
|
||||
slice_count=num_worker,
|
||||
output_queue=proc_exec.get_output_queue(),
|
||||
image_col_idx=image_col_idx,
|
||||
image_type=args.image_type,
|
||||
batch_size=table_read_batch_size,
|
||||
num_threads=table_read_thread_num))
|
||||
|
||||
if args.image_type == 'base64':
|
||||
base64_thread_num = args.preprocess_thread_num
|
||||
proc_exec.add(
|
||||
Base64DecodeProcess(
|
||||
thread_num=base64_thread_num,
|
||||
input_queue=proc_exec.get_input_queue(),
|
||||
output_queue=proc_exec.get_output_queue()))
|
||||
elif args.image_type == 'url':
|
||||
download_thread_num = args.preprocess_thread_num
|
||||
proc_exec.add(
|
||||
DownloadProcess(
|
||||
thread_num=download_thread_num,
|
||||
input_queue=proc_exec.get_input_queue(),
|
||||
output_queue=proc_exec.get_output_queue(),
|
||||
use_pil_decode=False))
|
||||
|
||||
# transform image data
|
||||
proc_exec.add(
|
||||
PredictorProcess(
|
||||
input_queue=proc_exec.get_input_queue(),
|
||||
output_queue=proc_exec.get_output_queue(),
|
||||
predict_fn=predict_fn,
|
||||
batch_size=batch_size,
|
||||
local_rank=args.local_rank,
|
||||
thread_num=args.predict_thread_num))
|
||||
|
||||
proc_exec.add(
|
||||
DefaultResultFormatProcess(
|
||||
input_queue=proc_exec.get_input_queue(),
|
||||
output_queue=proc_exec.get_output_queue(),
|
||||
reserved_col_names=args.reserved_columns.split(','),
|
||||
output_col_names=args.result_column.split(',')))
|
||||
|
||||
# Write result
|
||||
output_cols = args.reserved_columns.split(',') + args.result_column.split(
|
||||
',')
|
||||
proc_exec.add(
|
||||
TableWriteProcess(
|
||||
args.output_table,
|
||||
output_col_names=output_cols,
|
||||
slice_id=worker_id,
|
||||
input_queue=proc_exec.get_input_queue()))
|
||||
|
||||
proc_exec.run()
|
||||
proc_exec.wait()
|
||||
|
||||
|
||||
def check_args(args, arg_name, default_value=''):
|
||||
assert getattr(args, arg_name) != '', f'{arg_name} should not be empty'
|
||||
|
||||
|
||||
def patch_logging():
|
||||
# after get_root_logger, logging will not take effect because
|
||||
# it sets all other handler to level logging.INFO
|
||||
logger = get_root_logger()
|
||||
for handler in logger.root.handlers:
|
||||
if type(handler) is logging.StreamHandler:
|
||||
handler.setLevel(logging.INFO)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = define_args()
|
||||
patch_logging()
|
||||
if args.launcher:
|
||||
init_dist(args.launcher, backend='nccl')
|
||||
if args.input_file != '':
|
||||
check_args(args, 'output_file')
|
||||
build_and_run_file_io(args)
|
||||
else:
|
||||
check_args(args, 'input_table')
|
||||
check_args(args, 'output_table')
|
||||
check_args(args, 'image_col')
|
||||
check_args(args, 'reserved_columns')
|
||||
check_args(args, 'result_column')
|
||||
build_and_run_table_io(args)
|
Loading…
x
Reference in New Issue
Block a user