mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
fix(backend): disable cublaslt for cu102 (#947)
* fix(backend): disable cublaslt for cu102 * fix * fix(backend): update * fix(tensorrt/util.py): add find cuda version * fix * fix(CI): first use cmd to get cuda version * docs(tensorrt/utils.py): update docstring
This commit is contained in:
parent
792c27b054
commit
e21cad84e0
@ -323,6 +323,7 @@ pip install -e .
|
||||
- Some dependencies are optional. Simply running `pip install -e .` will only install the minimum runtime requirements.
|
||||
To use optional dependencies, install them manually with `pip install -r requirements/optional.txt` or specify desired extras when calling `pip` (e.g. `pip install -e .[optional]`).
|
||||
Valid keys for the extras field are: `all`, `tests`, `build`, `optional`.
|
||||
- It is recommended to [install patch for cuda10](https://developer.nvidia.com/cuda-10.2-download-archive?target_os=Linux&target_arch=x86_64&target_distro=Ubuntu&target_version=1804&target_type=runfilelocal), otherwise GEMM related errors may occur when model runs
|
||||
|
||||
### Build SDK and Demo
|
||||
|
||||
|
@ -319,6 +319,7 @@ pip install -e .
|
||||
|
||||
- 有些依赖项是可选的。运行 `pip install -e .` 将进行最小化依赖安装。 如果需安装其他可选依赖项,请执行`pip install -r requirements/optional.txt`,
|
||||
或者 `pip install -e .[optional]`。其中,`[optional]`可以替换为:`all`、`tests`、`build` 或 `optional`。
|
||||
- cuda10 建议安装[补丁包](https://developer.nvidia.com/cuda-10.2-download-archive?target_os=Linux&target_arch=x86_64&target_distro=Ubuntu&target_version=1804&target_type=runfilelocal),否则模型运行可能出现 GEMM 相关错误
|
||||
|
||||
#### 编译 SDK 和 Demos
|
||||
|
||||
|
@ -1,5 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
import onnx
|
||||
@ -38,6 +41,55 @@ def load(path: str) -> trt.ICudaEngine:
|
||||
return engine
|
||||
|
||||
|
||||
def search_cuda_version() -> str:
|
||||
"""try cmd to get cuda version, then try `torch.cuda`
|
||||
|
||||
Returns:
|
||||
str: cuda version, for example 10.2
|
||||
"""
|
||||
|
||||
version = None
|
||||
|
||||
pattern = re.compile(r'[0-9]+\.[0-9]+')
|
||||
platform = sys.platform.lower()
|
||||
|
||||
def cmd_result(txt: str):
|
||||
cmd = os.popen(txt)
|
||||
return cmd.read().rstrip().lstrip()
|
||||
|
||||
if platform == 'linux' or platform == 'darwin' or platform == 'freebsd': # noqa E501
|
||||
version = cmd_result(
|
||||
" nvcc --version | grep release | awk '{print $5}' | awk -F , '{print $1}' " # noqa E501
|
||||
)
|
||||
if version is None or pattern.match(version) is None:
|
||||
version = cmd_result(
|
||||
" nvidia-smi | grep CUDA | awk '{print $9}' ")
|
||||
|
||||
elif platform == 'win32' or platform == 'cygwin':
|
||||
# nvcc_release = "Cuda compilation tools, release 10.2, V10.2.89"
|
||||
nvcc_release = cmd_result(' nvcc --version | find "release" ')
|
||||
if nvcc_release is not None:
|
||||
result = pattern.findall(nvcc_release)
|
||||
if len(result) > 0:
|
||||
version = result[0]
|
||||
|
||||
if version is None or pattern.match(version) is None:
|
||||
# nvidia_smi = "| NVIDIA-SMI 440.33.01 Driver Version: 440.33.01 CUDA Version: 10.2 |" # noqa E501
|
||||
nvidia_smi = cmd_result(' nvidia-smi | find "CUDA Version" ')
|
||||
result = pattern.findall(nvidia_smi)
|
||||
if len(result) > 2:
|
||||
version = result[2]
|
||||
|
||||
if version is None or pattern.match(version) is None:
|
||||
try:
|
||||
import torch
|
||||
version = torch.version.cuda
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return version
|
||||
|
||||
|
||||
def from_onnx(onnx_model: Union[str, onnx.ModelProto],
|
||||
output_file_prefix: str,
|
||||
input_shapes: Dict[str, Sequence[int]],
|
||||
@ -118,6 +170,16 @@ def from_onnx(onnx_model: Union[str, onnx.ModelProto],
|
||||
|
||||
config = builder.create_builder_config()
|
||||
config.max_workspace_size = max_workspace_size
|
||||
|
||||
cuda_version = search_cuda_version()
|
||||
if cuda_version is not None:
|
||||
version_major = int(cuda_version.split('.')[0])
|
||||
if version_major < 11:
|
||||
# cu11 support cublasLt, so cudnn heuristic tactic should disable CUBLAS_LT # noqa E501
|
||||
tactic_source = config.get_tactic_sources() - (
|
||||
1 << int(trt.TacticSource.CUBLAS_LT))
|
||||
config.set_tactic_sources(tactic_source)
|
||||
|
||||
profile = builder.create_optimization_profile()
|
||||
|
||||
for input_name, param in input_shapes.items():
|
||||
|
Loading…
x
Reference in New Issue
Block a user