mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
update deployment toolchain (#428)
Summary: Remove tiny-tensorrt dependency and rewrite a new tensorrt inference api. In the new version of trt infer, it can pad the input to fixed batch automatically, so you don't need to worry about dynamic batch size.
This commit is contained in:
parent
d7c1294d9e
commit
cb7a1cb3e1
@ -18,32 +18,14 @@ This is a tiny example for converting fastreid-baseline in `meta_arch` to Caffe
|
|||||||
1. Run `caffe_export.py` to get the converted Caffe model,
|
1. Run `caffe_export.py` to get the converted Caffe model,
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python caffe_export.py --config-file root-path/market1501/bagtricks_R50/config.yml --name "baseline_R50" --output outputs/caffe_model --opts MODEL.WEIGHTS root-path/logs/market1501/bagtricks_R50/model_final.pth
|
python caffe_export.py --config-file root-path/market1501/bagtricks_R50/config.yml --name baseline_R50 --output outputs/caffe_model --opts MODEL.WEIGHTS root-path/logs/market1501/bagtricks_R50/model_final.pth
|
||||||
```
|
```
|
||||||
|
|
||||||
then you can check the Caffe model and prototxt in `outputs/caffe_model`.
|
then you can check the Caffe model and prototxt in `outputs/caffe_model`.
|
||||||
|
|
||||||
2. Change `prototxt` following next three steps:
|
2. Change `prototxt` following next two steps:
|
||||||
|
|
||||||
1) Edit `max_pooling` in `baseline_R50.prototxt` like this
|
1) Modify `avg_pooling` in `baseline_R50.prototxt`
|
||||||
|
|
||||||
```prototxt
|
|
||||||
layer {
|
|
||||||
name: "max_pool1"
|
|
||||||
type: "Pooling"
|
|
||||||
bottom: "relu_blob1"
|
|
||||||
top: "max_pool_blob1"
|
|
||||||
pooling_param {
|
|
||||||
pool: MAX
|
|
||||||
kernel_size: 3
|
|
||||||
stride: 2
|
|
||||||
pad: 0 # 1
|
|
||||||
# ceil_mode: false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
2) Add `avg_pooling` right place in `baseline_R50.prototxt`
|
|
||||||
|
|
||||||
```prototxt
|
```prototxt
|
||||||
layer {
|
layer {
|
||||||
@ -58,7 +40,7 @@ This is a tiny example for converting fastreid-baseline in `meta_arch` to Caffe
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
3) Change the last layer `top` name to `output`
|
2) Change the last layer `top` name to `output`
|
||||||
|
|
||||||
```prototxt
|
```prototxt
|
||||||
layer {
|
layer {
|
||||||
@ -100,7 +82,7 @@ This is a tiny example for converting fastreid-baseline in `meta_arch` to ONNX m
|
|||||||
1. Run `onnx_export.py` to get the converted ONNX model,
|
1. Run `onnx_export.py` to get the converted ONNX model,
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python onnx_export.py --config-file root-path/bagtricks_R50/config.yml --name "baseline_R50" --output outputs/onnx_model --opts MODEL.WEIGHTS root-path/logs/market1501/bagtricks_R50/model_final.pth
|
python onnx_export.py --config-file root-path/bagtricks_R50/config.yml --name baseline_R50 --output outputs/onnx_model --opts MODEL.WEIGHTS root-path/logs/market1501/bagtricks_R50/model_final.pth
|
||||||
```
|
```
|
||||||
|
|
||||||
then you can check the ONNX model in `outputs/onnx_model`.
|
then you can check the ONNX model in `outputs/onnx_model`.
|
||||||
@ -127,15 +109,16 @@ This is a tiny example for converting fastreid-baseline in `meta_arch` to ONNX m
|
|||||||
<details>
|
<details>
|
||||||
<summary>step-to-step pipeline for trt convert</summary>
|
<summary>step-to-step pipeline for trt convert</summary>
|
||||||
|
|
||||||
This is a tiny example for converting fastreid-baseline in `meta_arch` to TRT model. We use [tiny-tensorrt](https://github.com/zerollzeng/tiny-tensorrt) which is a simple and easy-to-use nvidia TensorRt warpper, to get the model converted to tensorRT.
|
This is a tiny example for converting fastreid-baseline in `meta_arch` to TRT model.
|
||||||
|
|
||||||
First you need to convert the pytorch model to ONNX format following [ONNX Convert](https://github.com/JDAI-CV/fast-reid#fastreid), and you need to remember your `output` name. Then you can convert ONNX model to TensorRT following instructions below.
|
First you need to convert the pytorch model to ONNX format following [ONNX Convert](https://github.com/JDAI-CV/fast-reid#fastreid), and you need to remember your `output` name. Then you can convert ONNX model to TensorRT following instructions below.
|
||||||
|
|
||||||
1. Run command line below to get the converted TRT model from ONNX model,
|
1. Run command line below to get the converted TRT model from ONNX model,
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
python trt_export.py --name baseline_R50 --output outputs/trt_model \
|
||||||
python trt_export.py --name "baseline_R50" --output outputs/trt_model --onnx-model outputs/onnx_model/baseline.onnx --heighi 256 --width 128
|
--mode fp32 --batch-size 8 --height 256 --width 128 \
|
||||||
|
--onnx-model outputs/onnx_model/baseline.onnx
|
||||||
```
|
```
|
||||||
|
|
||||||
then you can check the TRT model in `outputs/trt_model`.
|
then you can check the TRT model in `outputs/trt_model`.
|
||||||
@ -143,16 +126,18 @@ First you need to convert the pytorch model to ONNX format following [ONNX Conve
|
|||||||
2. Run `trt_inference.py` to save TRT model features with input images
|
2. Run `trt_inference.py` to save TRT model features with input images
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python onnx_inference.py --model-path outputs/trt_model/baseline.engine \
|
python3 trt_inference.py --model-path outputs/trt_model/baseline.engine \
|
||||||
--input test_data/*.jpg --output trt_output --output-name trt_model_outputname
|
--input test_data/*.jpg --batch-size 8 --height 256 --width 128 --output trt_output
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Run `demo/demo.py` to get fastreid model features with the same input images, then verify that TensorRT and PyTorch are computing the same value for the network.
|
3. Run `demo/demo.py` to get fastreid model features with the same input images, then verify that TensorRT and PyTorch are computing the same value for the network.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
np.testing.assert_allclose(torch_out, ort_out, rtol=1e-3, atol=1e-6)
|
np.testing.assert_allclose(torch_out, trt_out, rtol=1e-3, atol=1e-6)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Notice: The int8 mode in tensorRT runtime is not supported now and there are some bugs in calibrator. Need help!
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## Acknowledgements
|
## Acknowledgements
|
||||||
|
@ -19,12 +19,17 @@ from fastreid.utils.file_io import PathManager
|
|||||||
from fastreid.utils.checkpoint import Checkpointer
|
from fastreid.utils.checkpoint import Checkpointer
|
||||||
from fastreid.utils.logger import setup_logger
|
from fastreid.utils.logger import setup_logger
|
||||||
|
|
||||||
|
# import some modules added in project like this below
|
||||||
|
# sys.path.append('../projects/FastCls')
|
||||||
|
# from fastcls import *
|
||||||
|
|
||||||
setup_logger(name='fastreid')
|
setup_logger(name='fastreid')
|
||||||
logger = logging.getLogger("fastreid.caffe_export")
|
logger = logging.getLogger("fastreid.caffe_export")
|
||||||
|
|
||||||
|
|
||||||
def setup_cfg(args):
|
def setup_cfg(args):
|
||||||
cfg = get_cfg()
|
cfg = get_cfg()
|
||||||
|
# add_cls_config(cfg)
|
||||||
cfg.merge_from_file(args.config_file)
|
cfg.merge_from_file(args.config_file)
|
||||||
cfg.merge_from_list(args.opts)
|
cfg.merge_from_list(args.opts)
|
||||||
cfg.freeze()
|
cfg.freeze()
|
||||||
@ -64,7 +69,8 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
cfg.defrost()
|
cfg.defrost()
|
||||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||||
cfg.MODEL.HEADS.POOL_LAYER = "identity"
|
if cfg.MODEL.HEADS.POOL_LAYER == 'fastavgpool':
|
||||||
|
cfg.MODEL.HEADS.POOL_LAYER = 'avgpool'
|
||||||
cfg.MODEL.BACKBONE.WITH_NL = False
|
cfg.MODEL.BACKBONE.WITH_NL = False
|
||||||
|
|
||||||
model = build_model(cfg)
|
model = build_model(cfg)
|
||||||
|
@ -21,6 +21,10 @@ from fastreid.utils.file_io import PathManager
|
|||||||
from fastreid.utils.checkpoint import Checkpointer
|
from fastreid.utils.checkpoint import Checkpointer
|
||||||
from fastreid.utils.logger import setup_logger
|
from fastreid.utils.logger import setup_logger
|
||||||
|
|
||||||
|
# import some modules added in project like this below
|
||||||
|
# sys.path.append('../../projects/FastDistill')
|
||||||
|
# from fastdistill import *
|
||||||
|
|
||||||
logger = setup_logger(name='onnx_export')
|
logger = setup_logger(name='onnx_export')
|
||||||
|
|
||||||
|
|
||||||
@ -50,6 +54,12 @@ def get_parser():
|
|||||||
default='onnx_model',
|
default='onnx_model',
|
||||||
help='path to save converted onnx model'
|
help='path to save converted onnx model'
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--batch-size',
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="the maximum batch size of onnx runtime"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--opts",
|
"--opts",
|
||||||
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
||||||
@ -130,7 +140,7 @@ if __name__ == '__main__':
|
|||||||
model.eval()
|
model.eval()
|
||||||
logger.info(model)
|
logger.info(model)
|
||||||
|
|
||||||
inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1])
|
inputs = torch.randn(args.batch_size, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1])
|
||||||
onnx_model = export_onnx_model(model, inputs)
|
onnx_model = export_onnx_model(model, inputs)
|
||||||
|
|
||||||
model_simp, check = simplify(onnx_model)
|
model_simp, check = simplify(onnx_model)
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
python3 caffe_export.py --config-file /export/home/lxy/cvpalgo-fast-reid/projects/bjzProject/configs/r34.yml \
|
|
||||||
--name "r34-0603" \
|
|
||||||
--output logs/caffe_r34-0603 \
|
|
||||||
--opts MODEL.WEIGHTS /export/home/lxy/cvpalgo-fast-reid/logs/bjz/sbs_R34_bjz_0603_8x32/model_final.pth
|
|
102
tools/deploy/trt_calibrator.py
Normal file
102
tools/deploy/trt_calibrator.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
# encoding: utf-8
|
||||||
|
"""
|
||||||
|
@author: xingyu liao
|
||||||
|
@contact: sherlockliao01@gmail.com
|
||||||
|
|
||||||
|
Create custom calibrator, use to calibrate int8 TensorRT model.
|
||||||
|
Need to override some methods of trt.IInt8EntropyCalibrator2, such as get_batch_size, get_batch,
|
||||||
|
read_calibration_cache, write_calibration_cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# based on:
|
||||||
|
# https://github.com/qq995431104/Pytorch2TensorRT/blob/master/myCalibrator.py
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import tensorrt as trt
|
||||||
|
import pycuda.driver as cuda
|
||||||
|
import pycuda.autoinit
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torchvision.transforms as T
|
||||||
|
|
||||||
|
sys.path.append('../..')
|
||||||
|
|
||||||
|
from fastreid.data.build import _root
|
||||||
|
from fastreid.data.data_utils import read_image
|
||||||
|
from fastreid.data.datasets import DATASET_REGISTRY
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastreid.data.transforms import ToTensor
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger('trt_export.calibrator')
|
||||||
|
|
||||||
|
|
||||||
|
class FeatEntropyCalibrator(trt.IInt8EntropyCalibrator2):
|
||||||
|
|
||||||
|
def __init__(self, args):
|
||||||
|
trt.IInt8EntropyCalibrator2.__init__(self)
|
||||||
|
|
||||||
|
self.cache_file = 'reid_feat.cache'
|
||||||
|
|
||||||
|
self.batch_size = args.batch_size
|
||||||
|
self.channel = args.channel
|
||||||
|
self.height = args.height
|
||||||
|
self.width = args.width
|
||||||
|
self.transform = T.Compose([
|
||||||
|
T.Resize((self.height, self.width), interpolation=3), # [h,w]
|
||||||
|
ToTensor(),
|
||||||
|
])
|
||||||
|
|
||||||
|
dataset = DATASET_REGISTRY.get(args.calib_data)(root=_root)
|
||||||
|
self._data_items = dataset.train + dataset.query + dataset.gallery
|
||||||
|
np.random.shuffle(self._data_items)
|
||||||
|
self.imgs = [item[0] for item in self._data_items]
|
||||||
|
|
||||||
|
self.batch_idx = 0
|
||||||
|
self.max_batch_idx = len(self.imgs) // self.batch_size
|
||||||
|
|
||||||
|
self.data_size = self.batch_size * self.channel * self.height * self.width * trt.float32.itemsize
|
||||||
|
self.device_input = cuda.mem_alloc(self.data_size)
|
||||||
|
|
||||||
|
def next_batch(self):
|
||||||
|
if self.batch_idx < self.max_batch_idx:
|
||||||
|
batch_files = self.imgs[self.batch_idx * self.batch_size:(self.batch_idx + 1) * self.batch_size]
|
||||||
|
batch_imgs = np.zeros((self.batch_size, self.channel, self.height, self.width),
|
||||||
|
dtype=np.float32)
|
||||||
|
for i, f in enumerate(batch_files):
|
||||||
|
img = read_image(f)
|
||||||
|
img = self.transform(img).numpy()
|
||||||
|
assert (img.nbytes == self.data_size // self.batch_size), 'not valid img!' + f
|
||||||
|
batch_imgs[i] = img
|
||||||
|
self.batch_idx += 1
|
||||||
|
logger.info("batch:[{}/{}]".format(self.batch_idx, self.max_batch_idx))
|
||||||
|
return np.ascontiguousarray(batch_imgs)
|
||||||
|
else:
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
def get_batch_size(self):
|
||||||
|
return self.batch_size
|
||||||
|
|
||||||
|
def get_batch(self, names, p_str=None):
|
||||||
|
try:
|
||||||
|
batch_imgs = self.next_batch()
|
||||||
|
batch_imgs = batch_imgs.ravel()
|
||||||
|
if batch_imgs.size == 0 or batch_imgs.size != self.batch_size * self.channel * self.height * self.width:
|
||||||
|
return None
|
||||||
|
cuda.memcpy_htod(self.device_input, batch_imgs.astype(np.float32))
|
||||||
|
return [int(self.device_input)]
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def read_calibration_cache(self):
|
||||||
|
# If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
|
||||||
|
if os.path.exists(self.cache_file):
|
||||||
|
with open(self.cache_file, "rb") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
def write_calibration_cache(self, cache):
|
||||||
|
with open(self.cache_file, "wb") as f:
|
||||||
|
f.write(cache)
|
@ -6,18 +6,17 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import tensorrt as trt
|
import tensorrt as trt
|
||||||
|
|
||||||
sys.path.append('../../')
|
from trt_calibrator import FeatEntropyCalibrator
|
||||||
sys.path.append("/export/home/lxy/runtimelib-tensorrt-tiny/build")
|
|
||||||
|
sys.path.append('../../')
|
||||||
|
|
||||||
import pytrt
|
|
||||||
from fastreid.utils.logger import setup_logger
|
from fastreid.utils.logger import setup_logger
|
||||||
from fastreid.utils.file_io import PathManager
|
from fastreid.utils.file_io import PathManager
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger(name='trt_export')
|
logger = setup_logger(name='trt_export')
|
||||||
|
|
||||||
|
|
||||||
@ -25,79 +24,103 @@ def get_parser():
|
|||||||
parser = argparse.ArgumentParser(description="Convert ONNX to TRT model")
|
parser = argparse.ArgumentParser(description="Convert ONNX to TRT model")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--name",
|
'--name',
|
||||||
default="baseline",
|
default='baseline',
|
||||||
help="name for converted model"
|
help="name for converted model"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output",
|
'--output',
|
||||||
default='outputs/trt_model',
|
default='outputs/trt_model',
|
||||||
help='path to save converted trt model'
|
help="path to save converted trt model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--mode',
|
||||||
|
default='fp32',
|
||||||
|
help="which mode is used in tensorRT engine, mode can be ['fp32', 'fp16' 'int8']"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--batch-size',
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="the maximum batch size of trt module"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--height',
|
||||||
|
default=256,
|
||||||
|
type=int,
|
||||||
|
help="input image height"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--width',
|
||||||
|
default=128,
|
||||||
|
type=int,
|
||||||
|
help="input image width"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--channel',
|
||||||
|
default=3,
|
||||||
|
type=int,
|
||||||
|
help="input image channel"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--calib-data',
|
||||||
|
default='Market1501',
|
||||||
|
help="int8 calibrator dataset name"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--onnx-model",
|
"--onnx-model",
|
||||||
default='outputs/onnx_model/baseline.onnx',
|
default='outputs/onnx_model/baseline.onnx',
|
||||||
help='path to onnx model'
|
help='path to onnx model'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--height",
|
|
||||||
type=int,
|
|
||||||
default=256,
|
|
||||||
help="height of image"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--width",
|
|
||||||
type=int,
|
|
||||||
default=128,
|
|
||||||
help="width of image"
|
|
||||||
)
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def onnx2trt(
|
def onnx2trt(
|
||||||
model,
|
onnx_file_path,
|
||||||
save_path,
|
save_path,
|
||||||
|
mode,
|
||||||
log_level='ERROR',
|
log_level='ERROR',
|
||||||
max_batch_size=1,
|
|
||||||
max_workspace_size=1,
|
max_workspace_size=1,
|
||||||
fp16_mode=False,
|
|
||||||
strict_type_constraints=False,
|
strict_type_constraints=False,
|
||||||
int8_mode=False,
|
|
||||||
int8_calibrator=None,
|
int8_calibrator=None,
|
||||||
):
|
):
|
||||||
"""build TensorRT model from onnx model.
|
"""build TensorRT model from onnx model.
|
||||||
Args:
|
Args:
|
||||||
model (string or io object): onnx model name
|
onnx_file_path (string or io object): onnx model name
|
||||||
|
save_path (string): tensortRT serialization save path
|
||||||
|
mode (string): Whether or not FP16 or Int8 kernels are permitted during engine build.
|
||||||
log_level (string, default is ERROR): tensorrt logger level, now
|
log_level (string, default is ERROR): tensorrt logger level, now
|
||||||
INTERNAL_ERROR, ERROR, WARNING, INFO, VERBOSE are support.
|
INTERNAL_ERROR, ERROR, WARNING, INFO, VERBOSE are support.
|
||||||
max_batch_size (int, default=1): The maximum batch size which can be used at execution time, and also the
|
|
||||||
batch size for which the ICudaEngine will be optimized.
|
|
||||||
max_workspace_size (int, default is 1): The maximum GPU temporary memory which the ICudaEngine can use at
|
max_workspace_size (int, default is 1): The maximum GPU temporary memory which the ICudaEngine can use at
|
||||||
execution time. default is 1GB.
|
execution time. default is 1GB.
|
||||||
fp16_mode (bool, default is False): Whether or not 16-bit kernels are permitted. During engine build
|
|
||||||
fp16 kernels will also be tried when this mode is enabled.
|
|
||||||
strict_type_constraints (bool, default is False): When strict type constraints is set, TensorRT will choose
|
strict_type_constraints (bool, default is False): When strict type constraints is set, TensorRT will choose
|
||||||
the type constraints that conforms to type constraints. If the flag is not enabled higher precision
|
the type constraints that conforms to type constraints. If the flag is not enabled higher precision
|
||||||
implementation may be chosen if it results in higher performance.
|
implementation may be chosen if it results in higher performance.
|
||||||
int8_mode (bool, default is False): Whether Int8 mode is used.
|
|
||||||
int8_calibrator (volksdep.calibrators.base.BaseCalibrator, default is None): calibrator for int8 mode,
|
int8_calibrator (volksdep.calibrators.base.BaseCalibrator, default is None): calibrator for int8 mode,
|
||||||
if None, default calibrator will be used as calibration data.
|
if None, default calibrator will be used as calibration data.
|
||||||
"""
|
"""
|
||||||
|
mode = mode.lower()
|
||||||
|
assert mode in ['fp32', 'fp16', 'int8'], "mode should be in ['fp32', 'fp16', 'int8'], " \
|
||||||
|
"but got {}".format(mode)
|
||||||
|
|
||||||
logger = trt.Logger(getattr(trt.Logger, log_level))
|
trt_logger = trt.Logger(getattr(trt.Logger, log_level))
|
||||||
builder = trt.Builder(logger)
|
builder = trt.Builder(trt_logger)
|
||||||
|
|
||||||
network = builder.create_network(1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
logger.info("Loading ONNX file from path {}...".format(onnx_file_path))
|
||||||
parser = trt.OnnxParser(network, logger)
|
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||||
if isinstance(model, str):
|
network = builder.create_network(EXPLICIT_BATCH)
|
||||||
with open(model, 'rb') as f:
|
parser = trt.OnnxParser(network, trt_logger)
|
||||||
|
if isinstance(onnx_file_path, str):
|
||||||
|
with open(onnx_file_path, 'rb') as f:
|
||||||
|
logger.info("Beginning ONNX file parsing")
|
||||||
flag = parser.parse(f.read())
|
flag = parser.parse(f.read())
|
||||||
else:
|
else:
|
||||||
flag = parser.parse(model.read())
|
flag = parser.parse(onnx_file_path.read())
|
||||||
if not flag:
|
if not flag:
|
||||||
for error in range(parser.num_errors):
|
for error in range(parser.num_errors):
|
||||||
print(parser.get_error(error))
|
logger.info(parser.get_error(error))
|
||||||
|
|
||||||
|
logger.info("Completed parsing of ONNX file.")
|
||||||
# re-order output tensor
|
# re-order output tensor
|
||||||
output_tensors = [network.get_output(i) for i in range(network.num_outputs)]
|
output_tensors = [network.get_output(i) for i in range(network.num_outputs)]
|
||||||
[network.unmark_output(tensor) for tensor in output_tensors]
|
[network.unmark_output(tensor) for tensor in output_tensors]
|
||||||
@ -106,68 +129,39 @@ def onnx2trt(
|
|||||||
identity_out_tensor.name = 'identity_{}'.format(tensor.name)
|
identity_out_tensor.name = 'identity_{}'.format(tensor.name)
|
||||||
network.mark_output(tensor=identity_out_tensor)
|
network.mark_output(tensor=identity_out_tensor)
|
||||||
|
|
||||||
builder.max_batch_size = max_batch_size
|
|
||||||
|
|
||||||
config = builder.create_builder_config()
|
config = builder.create_builder_config()
|
||||||
config.max_workspace_size = max_workspace_size * (1 << 25)
|
config.max_workspace_size = max_workspace_size * (1 << 25)
|
||||||
if fp16_mode:
|
if mode == 'fp16':
|
||||||
config.set_flag(trt.BuilderFlag.FP16)
|
assert builder.platform_has_fast_fp16, "not support fp16"
|
||||||
|
builder.fp16_mode = True
|
||||||
|
if mode == 'int8':
|
||||||
|
assert builder.platform_has_fast_int8, "not support int8"
|
||||||
|
builder.int8_mode = True
|
||||||
|
builder.int8_calibrator = int8_calibrator
|
||||||
|
|
||||||
if strict_type_constraints:
|
if strict_type_constraints:
|
||||||
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
|
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
|
||||||
# if int8_mode:
|
|
||||||
# config.set_flag(trt.BuilderFlag.INT8)
|
|
||||||
# if int8_calibrator is None:
|
|
||||||
# shapes = [(1,) + network.get_input(i).shape[1:] for i in range(network.num_inputs)]
|
|
||||||
# dummy_data = utils.gen_ones_data(shapes)
|
|
||||||
# int8_calibrator = EntropyCalibrator2(CustomDataset(dummy_data))
|
|
||||||
# config.int8_calibrator = int8_calibrator
|
|
||||||
|
|
||||||
# set dynamic batch size profile
|
logger.info("Building an engine from file {}; this may take a while...".format(onnx_file_path))
|
||||||
profile = builder.create_optimization_profile()
|
engine = builder.build_cuda_engine(network)
|
||||||
for i in range(network.num_inputs):
|
logger.info("Create engine successfully!")
|
||||||
tensor = network.get_input(i)
|
|
||||||
name = tensor.name
|
|
||||||
shape = tensor.shape[1:]
|
|
||||||
min_shape = (1,) + shape
|
|
||||||
opt_shape = ((1 + max_batch_size) // 2,) + shape
|
|
||||||
max_shape = (max_batch_size,) + shape
|
|
||||||
profile.set_shape(name, min_shape, opt_shape, max_shape)
|
|
||||||
config.add_optimization_profile(profile)
|
|
||||||
|
|
||||||
engine = builder.build_engine(network, config)
|
|
||||||
|
|
||||||
|
logger.info("Saving TRT engine file to path {}".format(save_path))
|
||||||
with open(save_path, 'wb') as f:
|
with open(save_path, 'wb') as f:
|
||||||
f.write(engine.serialize())
|
f.write(engine.serialize())
|
||||||
# trt_model = TRTModel(engine)
|
logger.info("Engine file has already saved to {}!".format(save_path))
|
||||||
|
|
||||||
# return trt_model
|
|
||||||
|
|
||||||
|
|
||||||
def export_trt_model(onnxModel, engineFile, input_numpy_array):
|
|
||||||
r"""
|
|
||||||
Export a model to trt format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
trt = pytrt.Trt()
|
|
||||||
|
|
||||||
customOutput = []
|
|
||||||
maxBatchSize = 8
|
|
||||||
calibratorData = []
|
|
||||||
mode = 0
|
|
||||||
trt.CreateEngine(onnxModel, engineFile, customOutput, maxBatchSize, mode, calibratorData)
|
|
||||||
trt.DoInference(input_numpy_array) # slightly different from c++
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
|
|
||||||
inputs = np.zeros(shape=(1, args.height, args.width, 3))
|
onnx_file_path = args.onnx_model
|
||||||
onnxModel = args.onnx_model
|
engineFile = os.path.join(args.output, args.name + '.engine')
|
||||||
engineFile = os.path.join(args.output, args.name+'.engine')
|
|
||||||
|
if args.mode.lower() == 'int8':
|
||||||
|
int8_calib = FeatEntropyCalibrator(args)
|
||||||
|
else:
|
||||||
|
int8_calib = None
|
||||||
|
|
||||||
PathManager.mkdirs(args.output)
|
PathManager.mkdirs(args.output)
|
||||||
onnx2trt(onnxModel, engineFile)
|
onnx2trt(onnx_file_path, engineFile, args.mode, int8_calibrator=int8_calib)
|
||||||
# export_trt_model(onnxModel, engineFile, inputs)
|
|
||||||
|
|
||||||
logger.info(f"Export trt model in {args.output} successfully!")
|
|
||||||
|
@ -6,15 +6,14 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pycuda.driver as cuda
|
||||||
|
import tensorrt as trt
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
sys.path.append("/export/home/lxy/runtimelib-tensorrt-tiny/build")
|
TRT_LOGGER = trt.Logger()
|
||||||
|
|
||||||
import pytrt
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -37,8 +36,10 @@ def get_parser():
|
|||||||
help="path to save trt model inference results"
|
help="path to save trt model inference results"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-name",
|
'--batch-size',
|
||||||
help="tensorRT model output name"
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help='the maximum batch size of trt module'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--height",
|
"--height",
|
||||||
@ -55,35 +56,134 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def preprocess(image_path, image_height, image_width):
|
class HostDeviceMem(object):
|
||||||
original_image = cv2.imread(image_path)
|
""" Host and Device Memory Package """
|
||||||
# the model expects RGB inputs
|
|
||||||
original_image = original_image[:, :, ::-1]
|
|
||||||
|
|
||||||
# Apply pre-processing to image.
|
def __init__(self, host_mem, device_mem):
|
||||||
img = cv2.resize(original_image, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
|
self.host = host_mem
|
||||||
img = img.astype("float32").transpose(2, 0, 1)[np.newaxis] # (1, 3, h, w)
|
self.device = device_mem
|
||||||
return img
|
|
||||||
|
def __str__(self):
|
||||||
|
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
def normalize(nparray, order=2, axis=-1):
|
class TrtEngine:
|
||||||
"""Normalize a N-D numpy array along the specified axis."""
|
|
||||||
norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
|
def __init__(self, trt_file=None, gpu_idx=0, batch_size=1):
|
||||||
return nparray / (norm + np.finfo(np.float32).eps)
|
cuda.init()
|
||||||
|
self._batch_size = batch_size
|
||||||
|
self._device_ctx = cuda.Device(gpu_idx).make_context()
|
||||||
|
self._engine = self._load_engine(trt_file)
|
||||||
|
self._context = self._engine.create_execution_context()
|
||||||
|
self._input, self._output, self._bindings, self._stream = self._allocate_buffers(self._context)
|
||||||
|
|
||||||
|
def _load_engine(self, trt_file):
|
||||||
|
"""
|
||||||
|
Load tensorrt engine.
|
||||||
|
:param trt_file: tensorrt file.
|
||||||
|
:return:
|
||||||
|
ICudaEngine
|
||||||
|
"""
|
||||||
|
with open(trt_file, "rb") as f, \
|
||||||
|
trt.Runtime(TRT_LOGGER) as runtime:
|
||||||
|
engine = runtime.deserialize_cuda_engine(f.read())
|
||||||
|
return engine
|
||||||
|
|
||||||
|
def _allocate_buffers(self, context):
|
||||||
|
"""
|
||||||
|
Allocate device memory space for data.
|
||||||
|
:param context:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
inputs = []
|
||||||
|
outputs = []
|
||||||
|
bindings = []
|
||||||
|
stream = cuda.Stream()
|
||||||
|
for binding in self._engine:
|
||||||
|
size = trt.volume(self._engine.get_binding_shape(binding)) * self._engine.max_batch_size
|
||||||
|
dtype = trt.nptype(self._engine.get_binding_dtype(binding))
|
||||||
|
# Allocate host and device buffers
|
||||||
|
host_mem = cuda.pagelocked_empty(size, dtype)
|
||||||
|
device_mem = cuda.mem_alloc(host_mem.nbytes)
|
||||||
|
# Append the device buffer to device bindings.
|
||||||
|
bindings.append(int(device_mem))
|
||||||
|
# Append to the appropriate list.
|
||||||
|
if self._engine.binding_is_input(binding):
|
||||||
|
inputs.append(HostDeviceMem(host_mem, device_mem))
|
||||||
|
else:
|
||||||
|
outputs.append(HostDeviceMem(host_mem, device_mem))
|
||||||
|
return inputs, outputs, bindings, stream
|
||||||
|
|
||||||
|
def infer(self, data):
|
||||||
|
"""
|
||||||
|
Real inference process.
|
||||||
|
:param model: Model objects
|
||||||
|
:param data: Preprocessed data
|
||||||
|
:return:
|
||||||
|
output
|
||||||
|
"""
|
||||||
|
# Copy data to input memory buffer
|
||||||
|
[np.copyto(_inp.host, data.ravel()) for _inp in self._input]
|
||||||
|
# Push to device
|
||||||
|
self._device_ctx.push()
|
||||||
|
# Transfer input data to the GPU.
|
||||||
|
# cuda.memcpy_htod_async(self._input.device, self._input.host, self._stream)
|
||||||
|
[cuda.memcpy_htod_async(inp.device, inp.host, self._stream) for inp in self._input]
|
||||||
|
# Run inference.
|
||||||
|
self._context.execute_async_v2(bindings=self._bindings, stream_handle=self._stream.handle)
|
||||||
|
# Transfer predictions back from the GPU.
|
||||||
|
# cuda.memcpy_dtoh_async(self._output.host, self._output.device, self._stream)
|
||||||
|
[cuda.memcpy_dtoh_async(out.host, out.device, self._stream) for out in self._output]
|
||||||
|
# Synchronize the stream
|
||||||
|
self._stream.synchronize()
|
||||||
|
# Pop the device
|
||||||
|
self._device_ctx.pop()
|
||||||
|
|
||||||
|
return [out.host.reshape(self._batch_size, -1) for out in self._output[::-1]]
|
||||||
|
|
||||||
|
def inference_on_images(self, imgs, new_size=(256, 128)):
|
||||||
|
trt_inputs = []
|
||||||
|
for img in imgs:
|
||||||
|
input_ndarray = self.preprocess(img, *new_size)
|
||||||
|
trt_inputs.append(input_ndarray)
|
||||||
|
trt_inputs = np.vstack(trt_inputs)
|
||||||
|
|
||||||
|
valid_bsz = trt_inputs.shape[0]
|
||||||
|
if valid_bsz < self._batch_size:
|
||||||
|
trt_inputs = np.vstack([trt_inputs, np.zeros((self._batch_size - valid_bsz, 3, *new_size))])
|
||||||
|
|
||||||
|
result, = self.infer(trt_inputs)
|
||||||
|
result = result[:valid_bsz]
|
||||||
|
feat = self.postprocess(result, axis=1)
|
||||||
|
return feat
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def preprocess(cls, img, img_height, img_width):
|
||||||
|
# Apply pre-processing to image.
|
||||||
|
resize_img = cv2.resize(img, (img_width, img_height), interpolation=cv2.INTER_CUBIC)
|
||||||
|
type_img = resize_img.astype("float32").transpose(2, 0, 1)[np.newaxis] # (1, 3, h, w)
|
||||||
|
return type_img
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def postprocess(cls, nparray, order=2, axis=-1):
|
||||||
|
"""Normalize a N-D numpy array along the specified axis."""
|
||||||
|
norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
|
||||||
|
return nparray / (norm + np.finfo(np.float32).eps)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
del self._input
|
||||||
|
del self._output
|
||||||
|
del self._stream
|
||||||
|
self._device_ctx.detach() # release device context
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
|
|
||||||
trt = pytrt.Trt()
|
trt = TrtEngine(args.model_path, batch_size=args.batch_size)
|
||||||
|
|
||||||
onnxModel = ""
|
|
||||||
engineFile = args.model_path
|
|
||||||
customOutput = []
|
|
||||||
maxBatchSize = 1
|
|
||||||
calibratorData = []
|
|
||||||
mode = 2
|
|
||||||
trt.CreateEngine(onnxModel, engineFile, customOutput, maxBatchSize, mode, calibratorData)
|
|
||||||
|
|
||||||
if not os.path.exists(args.output): os.makedirs(args.output)
|
if not os.path.exists(args.output): os.makedirs(args.output)
|
||||||
|
|
||||||
@ -91,9 +191,10 @@ if __name__ == "__main__":
|
|||||||
if os.path.isdir(args.input[0]):
|
if os.path.isdir(args.input[0]):
|
||||||
args.input = glob.glob(os.path.expanduser(args.input[0]))
|
args.input = glob.glob(os.path.expanduser(args.input[0]))
|
||||||
assert args.input, "The input path(s) was not found"
|
assert args.input, "The input path(s) was not found"
|
||||||
for path in tqdm.tqdm(args.input):
|
inputs = []
|
||||||
input_numpy_array = preprocess(path, args.height, args.width)
|
for img_path in tqdm.tqdm(args.input):
|
||||||
trt.DoInference(input_numpy_array)
|
img = cv2.imread(img_path)
|
||||||
feat = trt.GetOutput(args.output_name)
|
# the model expects RGB inputs
|
||||||
feat = normalize(feat, axis=1)
|
cvt_img = img[:, :, ::-1]
|
||||||
np.save(os.path.join(args.output, path.replace('.jpg', '.npy').split('/')[-1]), feat)
|
feat = trt.inference_on_images([cvt_img])
|
||||||
|
np.save(os.path.join(args.output, os.path.basename(img_path) + '.npy'), feat[0])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user