mirror of https://github.com/JDAI-CV/fast-reid.git
update deployment toolchain (#428)
Summary: Remove tiny-tensorrt dependency and rewrite a new tensorrt inference api. In the new version of trt infer, it can pad the input to fixed batch automatically, so you don't need to worry about dynamic batch size.pull/429/head
parent
d7c1294d9e
commit
cb7a1cb3e1
|
@ -18,32 +18,14 @@ This is a tiny example for converting fastreid-baseline in `meta_arch` to Caffe
|
|||
1. Run `caffe_export.py` to get the converted Caffe model,
|
||||
|
||||
```bash
|
||||
python caffe_export.py --config-file root-path/market1501/bagtricks_R50/config.yml --name "baseline_R50" --output outputs/caffe_model --opts MODEL.WEIGHTS root-path/logs/market1501/bagtricks_R50/model_final.pth
|
||||
python caffe_export.py --config-file root-path/market1501/bagtricks_R50/config.yml --name baseline_R50 --output outputs/caffe_model --opts MODEL.WEIGHTS root-path/logs/market1501/bagtricks_R50/model_final.pth
|
||||
```
|
||||
|
||||
then you can check the Caffe model and prototxt in `outputs/caffe_model`.
|
||||
|
||||
2. Change `prototxt` following next three steps:
|
||||
2. Change `prototxt` following next two steps:
|
||||
|
||||
1) Edit `max_pooling` in `baseline_R50.prototxt` like this
|
||||
|
||||
```prototxt
|
||||
layer {
|
||||
name: "max_pool1"
|
||||
type: "Pooling"
|
||||
bottom: "relu_blob1"
|
||||
top: "max_pool_blob1"
|
||||
pooling_param {
|
||||
pool: MAX
|
||||
kernel_size: 3
|
||||
stride: 2
|
||||
pad: 0 # 1
|
||||
# ceil_mode: false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2) Add `avg_pooling` right place in `baseline_R50.prototxt`
|
||||
1) Modify `avg_pooling` in `baseline_R50.prototxt`
|
||||
|
||||
```prototxt
|
||||
layer {
|
||||
|
@ -58,7 +40,7 @@ This is a tiny example for converting fastreid-baseline in `meta_arch` to Caffe
|
|||
}
|
||||
```
|
||||
|
||||
3) Change the last layer `top` name to `output`
|
||||
2) Change the last layer `top` name to `output`
|
||||
|
||||
```prototxt
|
||||
layer {
|
||||
|
@ -100,7 +82,7 @@ This is a tiny example for converting fastreid-baseline in `meta_arch` to ONNX m
|
|||
1. Run `onnx_export.py` to get the converted ONNX model,
|
||||
|
||||
```bash
|
||||
python onnx_export.py --config-file root-path/bagtricks_R50/config.yml --name "baseline_R50" --output outputs/onnx_model --opts MODEL.WEIGHTS root-path/logs/market1501/bagtricks_R50/model_final.pth
|
||||
python onnx_export.py --config-file root-path/bagtricks_R50/config.yml --name baseline_R50 --output outputs/onnx_model --opts MODEL.WEIGHTS root-path/logs/market1501/bagtricks_R50/model_final.pth
|
||||
```
|
||||
|
||||
then you can check the ONNX model in `outputs/onnx_model`.
|
||||
|
@ -127,15 +109,16 @@ This is a tiny example for converting fastreid-baseline in `meta_arch` to ONNX m
|
|||
<details>
|
||||
<summary>step-to-step pipeline for trt convert</summary>
|
||||
|
||||
This is a tiny example for converting fastreid-baseline in `meta_arch` to TRT model. We use [tiny-tensorrt](https://github.com/zerollzeng/tiny-tensorrt) which is a simple and easy-to-use nvidia TensorRt warpper, to get the model converted to tensorRT.
|
||||
This is a tiny example for converting fastreid-baseline in `meta_arch` to TRT model.
|
||||
|
||||
First you need to convert the pytorch model to ONNX format following [ONNX Convert](https://github.com/JDAI-CV/fast-reid#fastreid), and you need to remember your `output` name. Then you can convert ONNX model to TensorRT following instructions below.
|
||||
|
||||
1. Run command line below to get the converted TRT model from ONNX model,
|
||||
|
||||
```bash
|
||||
|
||||
python trt_export.py --name "baseline_R50" --output outputs/trt_model --onnx-model outputs/onnx_model/baseline.onnx --heighi 256 --width 128
|
||||
python trt_export.py --name baseline_R50 --output outputs/trt_model \
|
||||
--mode fp32 --batch-size 8 --height 256 --width 128 \
|
||||
--onnx-model outputs/onnx_model/baseline.onnx
|
||||
```
|
||||
|
||||
then you can check the TRT model in `outputs/trt_model`.
|
||||
|
@ -143,16 +126,18 @@ First you need to convert the pytorch model to ONNX format following [ONNX Conve
|
|||
2. Run `trt_inference.py` to save TRT model features with input images
|
||||
|
||||
```bash
|
||||
python onnx_inference.py --model-path outputs/trt_model/baseline.engine \
|
||||
--input test_data/*.jpg --output trt_output --output-name trt_model_outputname
|
||||
python3 trt_inference.py --model-path outputs/trt_model/baseline.engine \
|
||||
--input test_data/*.jpg --batch-size 8 --height 256 --width 128 --output trt_output
|
||||
```
|
||||
|
||||
3. Run `demo/demo.py` to get fastreid model features with the same input images, then verify that TensorRT and PyTorch are computing the same value for the network.
|
||||
|
||||
```python
|
||||
np.testing.assert_allclose(torch_out, ort_out, rtol=1e-3, atol=1e-6)
|
||||
np.testing.assert_allclose(torch_out, trt_out, rtol=1e-3, atol=1e-6)
|
||||
```
|
||||
|
||||
Notice: The int8 mode in tensorRT runtime is not supported now and there are some bugs in calibrator. Need help!
|
||||
|
||||
</details>
|
||||
|
||||
## Acknowledgements
|
||||
|
|
|
@ -19,12 +19,17 @@ from fastreid.utils.file_io import PathManager
|
|||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.utils.logger import setup_logger
|
||||
|
||||
# import some modules added in project like this below
|
||||
# sys.path.append('../projects/FastCls')
|
||||
# from fastcls import *
|
||||
|
||||
setup_logger(name='fastreid')
|
||||
logger = logging.getLogger("fastreid.caffe_export")
|
||||
|
||||
|
||||
def setup_cfg(args):
|
||||
cfg = get_cfg()
|
||||
# add_cls_config(cfg)
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
|
@ -64,7 +69,8 @@ if __name__ == '__main__':
|
|||
|
||||
cfg.defrost()
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
cfg.MODEL.HEADS.POOL_LAYER = "identity"
|
||||
if cfg.MODEL.HEADS.POOL_LAYER == 'fastavgpool':
|
||||
cfg.MODEL.HEADS.POOL_LAYER = 'avgpool'
|
||||
cfg.MODEL.BACKBONE.WITH_NL = False
|
||||
|
||||
model = build_model(cfg)
|
||||
|
|
|
@ -21,6 +21,10 @@ from fastreid.utils.file_io import PathManager
|
|||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.utils.logger import setup_logger
|
||||
|
||||
# import some modules added in project like this below
|
||||
# sys.path.append('../../projects/FastDistill')
|
||||
# from fastdistill import *
|
||||
|
||||
logger = setup_logger(name='onnx_export')
|
||||
|
||||
|
||||
|
@ -50,6 +54,12 @@ def get_parser():
|
|||
default='onnx_model',
|
||||
help='path to save converted onnx model'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
default=1,
|
||||
type=int,
|
||||
help="the maximum batch size of onnx runtime"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opts",
|
||||
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
||||
|
@ -130,7 +140,7 @@ if __name__ == '__main__':
|
|||
model.eval()
|
||||
logger.info(model)
|
||||
|
||||
inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1])
|
||||
inputs = torch.randn(args.batch_size, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1])
|
||||
onnx_model = export_onnx_model(model, inputs)
|
||||
|
||||
model_simp, check = simplify(onnx_model)
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
python3 caffe_export.py --config-file /export/home/lxy/cvpalgo-fast-reid/projects/bjzProject/configs/r34.yml \
|
||||
--name "r34-0603" \
|
||||
--output logs/caffe_r34-0603 \
|
||||
--opts MODEL.WEIGHTS /export/home/lxy/cvpalgo-fast-reid/logs/bjz/sbs_R34_bjz_0603_8x32/model_final.pth
|
|
@ -0,0 +1,102 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
|
||||
Create custom calibrator, use to calibrate int8 TensorRT model.
|
||||
Need to override some methods of trt.IInt8EntropyCalibrator2, such as get_batch_size, get_batch,
|
||||
read_calibration_cache, write_calibration_cache.
|
||||
"""
|
||||
|
||||
# based on:
|
||||
# https://github.com/qq995431104/Pytorch2TensorRT/blob/master/myCalibrator.py
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import tensorrt as trt
|
||||
import pycuda.driver as cuda
|
||||
import pycuda.autoinit
|
||||
|
||||
import numpy as np
|
||||
import torchvision.transforms as T
|
||||
|
||||
sys.path.append('../..')
|
||||
|
||||
from fastreid.data.build import _root
|
||||
from fastreid.data.data_utils import read_image
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
import logging
|
||||
|
||||
from fastreid.data.transforms import ToTensor
|
||||
|
||||
|
||||
logger = logging.getLogger('trt_export.calibrator')
|
||||
|
||||
|
||||
class FeatEntropyCalibrator(trt.IInt8EntropyCalibrator2):
|
||||
|
||||
def __init__(self, args):
|
||||
trt.IInt8EntropyCalibrator2.__init__(self)
|
||||
|
||||
self.cache_file = 'reid_feat.cache'
|
||||
|
||||
self.batch_size = args.batch_size
|
||||
self.channel = args.channel
|
||||
self.height = args.height
|
||||
self.width = args.width
|
||||
self.transform = T.Compose([
|
||||
T.Resize((self.height, self.width), interpolation=3), # [h,w]
|
||||
ToTensor(),
|
||||
])
|
||||
|
||||
dataset = DATASET_REGISTRY.get(args.calib_data)(root=_root)
|
||||
self._data_items = dataset.train + dataset.query + dataset.gallery
|
||||
np.random.shuffle(self._data_items)
|
||||
self.imgs = [item[0] for item in self._data_items]
|
||||
|
||||
self.batch_idx = 0
|
||||
self.max_batch_idx = len(self.imgs) // self.batch_size
|
||||
|
||||
self.data_size = self.batch_size * self.channel * self.height * self.width * trt.float32.itemsize
|
||||
self.device_input = cuda.mem_alloc(self.data_size)
|
||||
|
||||
def next_batch(self):
|
||||
if self.batch_idx < self.max_batch_idx:
|
||||
batch_files = self.imgs[self.batch_idx * self.batch_size:(self.batch_idx + 1) * self.batch_size]
|
||||
batch_imgs = np.zeros((self.batch_size, self.channel, self.height, self.width),
|
||||
dtype=np.float32)
|
||||
for i, f in enumerate(batch_files):
|
||||
img = read_image(f)
|
||||
img = self.transform(img).numpy()
|
||||
assert (img.nbytes == self.data_size // self.batch_size), 'not valid img!' + f
|
||||
batch_imgs[i] = img
|
||||
self.batch_idx += 1
|
||||
logger.info("batch:[{}/{}]".format(self.batch_idx, self.max_batch_idx))
|
||||
return np.ascontiguousarray(batch_imgs)
|
||||
else:
|
||||
return np.array([])
|
||||
|
||||
def get_batch_size(self):
|
||||
return self.batch_size
|
||||
|
||||
def get_batch(self, names, p_str=None):
|
||||
try:
|
||||
batch_imgs = self.next_batch()
|
||||
batch_imgs = batch_imgs.ravel()
|
||||
if batch_imgs.size == 0 or batch_imgs.size != self.batch_size * self.channel * self.height * self.width:
|
||||
return None
|
||||
cuda.memcpy_htod(self.device_input, batch_imgs.astype(np.float32))
|
||||
return [int(self.device_input)]
|
||||
except:
|
||||
return None
|
||||
|
||||
def read_calibration_cache(self):
|
||||
# If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
|
||||
if os.path.exists(self.cache_file):
|
||||
with open(self.cache_file, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def write_calibration_cache(self, cache):
|
||||
with open(self.cache_file, "wb") as f:
|
||||
f.write(cache)
|
|
@ -6,18 +6,17 @@
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
import tensorrt as trt
|
||||
|
||||
sys.path.append('../../')
|
||||
sys.path.append("/export/home/lxy/runtimelib-tensorrt-tiny/build")
|
||||
from trt_calibrator import FeatEntropyCalibrator
|
||||
|
||||
sys.path.append('../../')
|
||||
|
||||
import pytrt
|
||||
from fastreid.utils.logger import setup_logger
|
||||
from fastreid.utils.file_io import PathManager
|
||||
|
||||
|
||||
logger = setup_logger(name='trt_export')
|
||||
|
||||
|
||||
|
@ -25,79 +24,103 @@ def get_parser():
|
|||
parser = argparse.ArgumentParser(description="Convert ONNX to TRT model")
|
||||
|
||||
parser.add_argument(
|
||||
"--name",
|
||||
default="baseline",
|
||||
'--name',
|
||||
default='baseline',
|
||||
help="name for converted model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
'--output',
|
||||
default='outputs/trt_model',
|
||||
help='path to save converted trt model'
|
||||
help="path to save converted trt model"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--mode',
|
||||
default='fp32',
|
||||
help="which mode is used in tensorRT engine, mode can be ['fp32', 'fp16' 'int8']"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
default=1,
|
||||
type=int,
|
||||
help="the maximum batch size of trt module"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--height',
|
||||
default=256,
|
||||
type=int,
|
||||
help="input image height"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--width',
|
||||
default=128,
|
||||
type=int,
|
||||
help="input image width"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--channel',
|
||||
default=3,
|
||||
type=int,
|
||||
help="input image channel"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--calib-data',
|
||||
default='Market1501',
|
||||
help="int8 calibrator dataset name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--onnx-model",
|
||||
default='outputs/onnx_model/baseline.onnx',
|
||||
help='path to onnx model'
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=256,
|
||||
help="height of image"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=128,
|
||||
help="width of image"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def onnx2trt(
|
||||
model,
|
||||
onnx_file_path,
|
||||
save_path,
|
||||
mode,
|
||||
log_level='ERROR',
|
||||
max_batch_size=1,
|
||||
max_workspace_size=1,
|
||||
fp16_mode=False,
|
||||
strict_type_constraints=False,
|
||||
int8_mode=False,
|
||||
int8_calibrator=None,
|
||||
):
|
||||
"""build TensorRT model from onnx model.
|
||||
Args:
|
||||
model (string or io object): onnx model name
|
||||
onnx_file_path (string or io object): onnx model name
|
||||
save_path (string): tensortRT serialization save path
|
||||
mode (string): Whether or not FP16 or Int8 kernels are permitted during engine build.
|
||||
log_level (string, default is ERROR): tensorrt logger level, now
|
||||
INTERNAL_ERROR, ERROR, WARNING, INFO, VERBOSE are support.
|
||||
max_batch_size (int, default=1): The maximum batch size which can be used at execution time, and also the
|
||||
batch size for which the ICudaEngine will be optimized.
|
||||
max_workspace_size (int, default is 1): The maximum GPU temporary memory which the ICudaEngine can use at
|
||||
execution time. default is 1GB.
|
||||
fp16_mode (bool, default is False): Whether or not 16-bit kernels are permitted. During engine build
|
||||
fp16 kernels will also be tried when this mode is enabled.
|
||||
strict_type_constraints (bool, default is False): When strict type constraints is set, TensorRT will choose
|
||||
the type constraints that conforms to type constraints. If the flag is not enabled higher precision
|
||||
implementation may be chosen if it results in higher performance.
|
||||
int8_mode (bool, default is False): Whether Int8 mode is used.
|
||||
int8_calibrator (volksdep.calibrators.base.BaseCalibrator, default is None): calibrator for int8 mode,
|
||||
if None, default calibrator will be used as calibration data.
|
||||
"""
|
||||
mode = mode.lower()
|
||||
assert mode in ['fp32', 'fp16', 'int8'], "mode should be in ['fp32', 'fp16', 'int8'], " \
|
||||
"but got {}".format(mode)
|
||||
|
||||
logger = trt.Logger(getattr(trt.Logger, log_level))
|
||||
builder = trt.Builder(logger)
|
||||
trt_logger = trt.Logger(getattr(trt.Logger, log_level))
|
||||
builder = trt.Builder(trt_logger)
|
||||
|
||||
network = builder.create_network(1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
if isinstance(model, str):
|
||||
with open(model, 'rb') as f:
|
||||
logger.info("Loading ONNX file from path {}...".format(onnx_file_path))
|
||||
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
network = builder.create_network(EXPLICIT_BATCH)
|
||||
parser = trt.OnnxParser(network, trt_logger)
|
||||
if isinstance(onnx_file_path, str):
|
||||
with open(onnx_file_path, 'rb') as f:
|
||||
logger.info("Beginning ONNX file parsing")
|
||||
flag = parser.parse(f.read())
|
||||
else:
|
||||
flag = parser.parse(model.read())
|
||||
flag = parser.parse(onnx_file_path.read())
|
||||
if not flag:
|
||||
for error in range(parser.num_errors):
|
||||
print(parser.get_error(error))
|
||||
logger.info(parser.get_error(error))
|
||||
|
||||
logger.info("Completed parsing of ONNX file.")
|
||||
# re-order output tensor
|
||||
output_tensors = [network.get_output(i) for i in range(network.num_outputs)]
|
||||
[network.unmark_output(tensor) for tensor in output_tensors]
|
||||
|
@ -106,68 +129,39 @@ def onnx2trt(
|
|||
identity_out_tensor.name = 'identity_{}'.format(tensor.name)
|
||||
network.mark_output(tensor=identity_out_tensor)
|
||||
|
||||
builder.max_batch_size = max_batch_size
|
||||
|
||||
config = builder.create_builder_config()
|
||||
config.max_workspace_size = max_workspace_size * (1 << 25)
|
||||
if fp16_mode:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
if mode == 'fp16':
|
||||
assert builder.platform_has_fast_fp16, "not support fp16"
|
||||
builder.fp16_mode = True
|
||||
if mode == 'int8':
|
||||
assert builder.platform_has_fast_int8, "not support int8"
|
||||
builder.int8_mode = True
|
||||
builder.int8_calibrator = int8_calibrator
|
||||
|
||||
if strict_type_constraints:
|
||||
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
|
||||
# if int8_mode:
|
||||
# config.set_flag(trt.BuilderFlag.INT8)
|
||||
# if int8_calibrator is None:
|
||||
# shapes = [(1,) + network.get_input(i).shape[1:] for i in range(network.num_inputs)]
|
||||
# dummy_data = utils.gen_ones_data(shapes)
|
||||
# int8_calibrator = EntropyCalibrator2(CustomDataset(dummy_data))
|
||||
# config.int8_calibrator = int8_calibrator
|
||||
|
||||
# set dynamic batch size profile
|
||||
profile = builder.create_optimization_profile()
|
||||
for i in range(network.num_inputs):
|
||||
tensor = network.get_input(i)
|
||||
name = tensor.name
|
||||
shape = tensor.shape[1:]
|
||||
min_shape = (1,) + shape
|
||||
opt_shape = ((1 + max_batch_size) // 2,) + shape
|
||||
max_shape = (max_batch_size,) + shape
|
||||
profile.set_shape(name, min_shape, opt_shape, max_shape)
|
||||
config.add_optimization_profile(profile)
|
||||
|
||||
engine = builder.build_engine(network, config)
|
||||
logger.info("Building an engine from file {}; this may take a while...".format(onnx_file_path))
|
||||
engine = builder.build_cuda_engine(network)
|
||||
logger.info("Create engine successfully!")
|
||||
|
||||
logger.info("Saving TRT engine file to path {}".format(save_path))
|
||||
with open(save_path, 'wb') as f:
|
||||
f.write(engine.serialize())
|
||||
# trt_model = TRTModel(engine)
|
||||
|
||||
# return trt_model
|
||||
|
||||
|
||||
def export_trt_model(onnxModel, engineFile, input_numpy_array):
|
||||
r"""
|
||||
Export a model to trt format.
|
||||
"""
|
||||
|
||||
trt = pytrt.Trt()
|
||||
|
||||
customOutput = []
|
||||
maxBatchSize = 8
|
||||
calibratorData = []
|
||||
mode = 0
|
||||
trt.CreateEngine(onnxModel, engineFile, customOutput, maxBatchSize, mode, calibratorData)
|
||||
trt.DoInference(input_numpy_array) # slightly different from c++
|
||||
return 0
|
||||
logger.info("Engine file has already saved to {}!".format(save_path))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_parser().parse_args()
|
||||
|
||||
inputs = np.zeros(shape=(1, args.height, args.width, 3))
|
||||
onnxModel = args.onnx_model
|
||||
engineFile = os.path.join(args.output, args.name+'.engine')
|
||||
onnx_file_path = args.onnx_model
|
||||
engineFile = os.path.join(args.output, args.name + '.engine')
|
||||
|
||||
if args.mode.lower() == 'int8':
|
||||
int8_calib = FeatEntropyCalibrator(args)
|
||||
else:
|
||||
int8_calib = None
|
||||
|
||||
PathManager.mkdirs(args.output)
|
||||
onnx2trt(onnxModel, engineFile)
|
||||
# export_trt_model(onnxModel, engineFile, inputs)
|
||||
|
||||
logger.info(f"Export trt model in {args.output} successfully!")
|
||||
onnx2trt(onnx_file_path, engineFile, args.mode, int8_calibrator=int8_calib)
|
||||
|
|
|
@ -6,15 +6,14 @@
|
|||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pycuda.driver as cuda
|
||||
import tensorrt as trt
|
||||
import tqdm
|
||||
|
||||
sys.path.append("/export/home/lxy/runtimelib-tensorrt-tiny/build")
|
||||
|
||||
import pytrt
|
||||
TRT_LOGGER = trt.Logger()
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -37,8 +36,10 @@ def get_parser():
|
|||
help="path to save trt model inference results"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-name",
|
||||
help="tensorRT model output name"
|
||||
'--batch-size',
|
||||
default=1,
|
||||
type=int,
|
||||
help='the maximum batch size of trt module'
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
|
@ -55,35 +56,134 @@ def get_parser():
|
|||
return parser
|
||||
|
||||
|
||||
def preprocess(image_path, image_height, image_width):
|
||||
original_image = cv2.imread(image_path)
|
||||
# the model expects RGB inputs
|
||||
original_image = original_image[:, :, ::-1]
|
||||
class HostDeviceMem(object):
|
||||
""" Host and Device Memory Package """
|
||||
|
||||
# Apply pre-processing to image.
|
||||
img = cv2.resize(original_image, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
|
||||
img = img.astype("float32").transpose(2, 0, 1)[np.newaxis] # (1, 3, h, w)
|
||||
return img
|
||||
def __init__(self, host_mem, device_mem):
|
||||
self.host = host_mem
|
||||
self.device = device_mem
|
||||
|
||||
def __str__(self):
|
||||
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def normalize(nparray, order=2, axis=-1):
|
||||
"""Normalize a N-D numpy array along the specified axis."""
|
||||
norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
|
||||
return nparray / (norm + np.finfo(np.float32).eps)
|
||||
class TrtEngine:
|
||||
|
||||
def __init__(self, trt_file=None, gpu_idx=0, batch_size=1):
|
||||
cuda.init()
|
||||
self._batch_size = batch_size
|
||||
self._device_ctx = cuda.Device(gpu_idx).make_context()
|
||||
self._engine = self._load_engine(trt_file)
|
||||
self._context = self._engine.create_execution_context()
|
||||
self._input, self._output, self._bindings, self._stream = self._allocate_buffers(self._context)
|
||||
|
||||
def _load_engine(self, trt_file):
|
||||
"""
|
||||
Load tensorrt engine.
|
||||
:param trt_file: tensorrt file.
|
||||
:return:
|
||||
ICudaEngine
|
||||
"""
|
||||
with open(trt_file, "rb") as f, \
|
||||
trt.Runtime(TRT_LOGGER) as runtime:
|
||||
engine = runtime.deserialize_cuda_engine(f.read())
|
||||
return engine
|
||||
|
||||
def _allocate_buffers(self, context):
|
||||
"""
|
||||
Allocate device memory space for data.
|
||||
:param context:
|
||||
:return:
|
||||
"""
|
||||
inputs = []
|
||||
outputs = []
|
||||
bindings = []
|
||||
stream = cuda.Stream()
|
||||
for binding in self._engine:
|
||||
size = trt.volume(self._engine.get_binding_shape(binding)) * self._engine.max_batch_size
|
||||
dtype = trt.nptype(self._engine.get_binding_dtype(binding))
|
||||
# Allocate host and device buffers
|
||||
host_mem = cuda.pagelocked_empty(size, dtype)
|
||||
device_mem = cuda.mem_alloc(host_mem.nbytes)
|
||||
# Append the device buffer to device bindings.
|
||||
bindings.append(int(device_mem))
|
||||
# Append to the appropriate list.
|
||||
if self._engine.binding_is_input(binding):
|
||||
inputs.append(HostDeviceMem(host_mem, device_mem))
|
||||
else:
|
||||
outputs.append(HostDeviceMem(host_mem, device_mem))
|
||||
return inputs, outputs, bindings, stream
|
||||
|
||||
def infer(self, data):
|
||||
"""
|
||||
Real inference process.
|
||||
:param model: Model objects
|
||||
:param data: Preprocessed data
|
||||
:return:
|
||||
output
|
||||
"""
|
||||
# Copy data to input memory buffer
|
||||
[np.copyto(_inp.host, data.ravel()) for _inp in self._input]
|
||||
# Push to device
|
||||
self._device_ctx.push()
|
||||
# Transfer input data to the GPU.
|
||||
# cuda.memcpy_htod_async(self._input.device, self._input.host, self._stream)
|
||||
[cuda.memcpy_htod_async(inp.device, inp.host, self._stream) for inp in self._input]
|
||||
# Run inference.
|
||||
self._context.execute_async_v2(bindings=self._bindings, stream_handle=self._stream.handle)
|
||||
# Transfer predictions back from the GPU.
|
||||
# cuda.memcpy_dtoh_async(self._output.host, self._output.device, self._stream)
|
||||
[cuda.memcpy_dtoh_async(out.host, out.device, self._stream) for out in self._output]
|
||||
# Synchronize the stream
|
||||
self._stream.synchronize()
|
||||
# Pop the device
|
||||
self._device_ctx.pop()
|
||||
|
||||
return [out.host.reshape(self._batch_size, -1) for out in self._output[::-1]]
|
||||
|
||||
def inference_on_images(self, imgs, new_size=(256, 128)):
|
||||
trt_inputs = []
|
||||
for img in imgs:
|
||||
input_ndarray = self.preprocess(img, *new_size)
|
||||
trt_inputs.append(input_ndarray)
|
||||
trt_inputs = np.vstack(trt_inputs)
|
||||
|
||||
valid_bsz = trt_inputs.shape[0]
|
||||
if valid_bsz < self._batch_size:
|
||||
trt_inputs = np.vstack([trt_inputs, np.zeros((self._batch_size - valid_bsz, 3, *new_size))])
|
||||
|
||||
result, = self.infer(trt_inputs)
|
||||
result = result[:valid_bsz]
|
||||
feat = self.postprocess(result, axis=1)
|
||||
return feat
|
||||
|
||||
@classmethod
|
||||
def preprocess(cls, img, img_height, img_width):
|
||||
# Apply pre-processing to image.
|
||||
resize_img = cv2.resize(img, (img_width, img_height), interpolation=cv2.INTER_CUBIC)
|
||||
type_img = resize_img.astype("float32").transpose(2, 0, 1)[np.newaxis] # (1, 3, h, w)
|
||||
return type_img
|
||||
|
||||
@classmethod
|
||||
def postprocess(cls, nparray, order=2, axis=-1):
|
||||
"""Normalize a N-D numpy array along the specified axis."""
|
||||
norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
|
||||
return nparray / (norm + np.finfo(np.float32).eps)
|
||||
|
||||
def __del__(self):
|
||||
del self._input
|
||||
del self._output
|
||||
del self._stream
|
||||
self._device_ctx.detach() # release device context
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_parser().parse_args()
|
||||
|
||||
trt = pytrt.Trt()
|
||||
|
||||
onnxModel = ""
|
||||
engineFile = args.model_path
|
||||
customOutput = []
|
||||
maxBatchSize = 1
|
||||
calibratorData = []
|
||||
mode = 2
|
||||
trt.CreateEngine(onnxModel, engineFile, customOutput, maxBatchSize, mode, calibratorData)
|
||||
trt = TrtEngine(args.model_path, batch_size=args.batch_size)
|
||||
|
||||
if not os.path.exists(args.output): os.makedirs(args.output)
|
||||
|
||||
|
@ -91,9 +191,10 @@ if __name__ == "__main__":
|
|||
if os.path.isdir(args.input[0]):
|
||||
args.input = glob.glob(os.path.expanduser(args.input[0]))
|
||||
assert args.input, "The input path(s) was not found"
|
||||
for path in tqdm.tqdm(args.input):
|
||||
input_numpy_array = preprocess(path, args.height, args.width)
|
||||
trt.DoInference(input_numpy_array)
|
||||
feat = trt.GetOutput(args.output_name)
|
||||
feat = normalize(feat, axis=1)
|
||||
np.save(os.path.join(args.output, path.replace('.jpg', '.npy').split('/')[-1]), feat)
|
||||
inputs = []
|
||||
for img_path in tqdm.tqdm(args.input):
|
||||
img = cv2.imread(img_path)
|
||||
# the model expects RGB inputs
|
||||
cvt_img = img[:, :, ::-1]
|
||||
feat = trt.inference_on_images([cvt_img])
|
||||
np.save(os.path.join(args.output, os.path.basename(img_path) + '.npy'), feat[0])
|
||||
|
|
Loading…
Reference in New Issue