mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* minor changes * support windows * fix GCC build * fix lint * reformat * fix Windows build * fix GCC build * search backend ops for onnxruntime * fix lint * fix lint * code clean-up * code clean-up * fix clang build * fix trt support * fix cmake for ncnn * fix cmake for openvino * fix SDK Python API * handle ops for other backends (ncnn, trt) * handle SDK Python API library location * robustify linkage * fix cuda * minor fix for openvino & ncnn * use CMAKE_CUDA_ARCHITECTURES if set * fix cuda preprocessor * fix misc * fix pplnn & pplcv, drop support for pplcv<0.6.0 * robustify cmake * update build.md (#2) * build dynamic modules as module library & fix demo (partially) * fix candidate path for mmdeploy_python * move "enable CUDA" to cmake config for demo * refine demo cmake * add comment * fix ubuntu build * revert docs/en/build.md * fix C API * fix lint * Windows build doc (#3) * check in docs related to mmdeploy build on windows * update build guide on windows platform * update build guide on windows platform * make path of thirdparty libraries consistent * make path consistency * correct build command for custom ops * correct build command for sdk * update sdk build instructions * update doc * correct build command * fix lint * correct build command and fix lint Co-authored-by: lvhan <lvhan@pjlab.org> * trailing whitespace (#4) * minor fix * fix sr sdk model * fix type deduction * fix cudaFree after driver shutting down * update ppl.cv installation warning (#5) * fix device allocator threshold & fix lint * update doc (#6) * update ppl.cv installation warning * missing 'git clone' Co-authored-by: chenxin <chenxin2@sensetime.com> Co-authored-by: zhangli <zhangli@sensetime.com> Co-authored-by: lvhan028 <lvhan_028@163.com> Co-authored-by: lvhan <lvhan@pjlab.org>
70 lines
2.1 KiB
C++
70 lines
2.1 KiB
C++
// Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
#ifndef MMDEPLOY_SRC_DEVICE_CUDA_DEFAULT_ALLOCATOR_H_
|
|
#define MMDEPLOY_SRC_DEVICE_CUDA_DEFAULT_ALLOCATOR_H_
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
#include <atomic>
|
|
#include <chrono>
|
|
|
|
#include "core/logger.h"
|
|
|
|
namespace mmdeploy::cuda {
|
|
|
|
class DefaultAllocator {
|
|
public:
|
|
DefaultAllocator() = default;
|
|
~DefaultAllocator() {
|
|
MMDEPLOY_ERROR("=== CUDA Default Allocator ===");
|
|
MMDEPLOY_ERROR(" Allocation: count={}, size={}MB, time={}ms", alloc_count_,
|
|
alloc_size_ / (1024 * 1024.f), alloc_time_ / 1000000.f);
|
|
MMDEPLOY_ERROR("Deallocation: count={}, size={}MB, time={}ms", dealloc_count_,
|
|
dealloc_size_ / (1024 * 1024.f), dealloc_time_ / 1000000.f);
|
|
}
|
|
[[nodiscard]] void* Allocate(std::size_t n) {
|
|
void* p{};
|
|
auto t0 = std::chrono::high_resolution_clock::now();
|
|
auto ret = cudaMalloc(&p, n);
|
|
auto t1 = std::chrono::high_resolution_clock::now();
|
|
alloc_time_ += (int64_t)std::chrono::duration<double, std::nano>(t1 - t0).count();
|
|
if (ret != cudaSuccess) {
|
|
MMDEPLOY_ERROR("error allocating cuda memory: {}", cudaGetErrorString(ret));
|
|
return nullptr;
|
|
}
|
|
alloc_count_ += 1;
|
|
alloc_size_ += n;
|
|
return p;
|
|
}
|
|
void Deallocate(void* p, std::size_t n) {
|
|
(void)n;
|
|
auto t0 = std::chrono::high_resolution_clock::now();
|
|
auto ret = cudaFree(p);
|
|
auto t1 = std::chrono::high_resolution_clock::now();
|
|
dealloc_time_ += (int64_t)std::chrono::duration<double, std::nano>(t1 - t0).count();
|
|
if (ret != cudaSuccess) {
|
|
MMDEPLOY_ERROR("error deallocating cuda memory: {}", cudaGetErrorString(ret));
|
|
return;
|
|
}
|
|
dealloc_count_ += 1;
|
|
dealloc_size_ += n;
|
|
}
|
|
|
|
private:
|
|
std::atomic<std::size_t> alloc_count_;
|
|
std::atomic<std::size_t> alloc_size_;
|
|
std::atomic<std::size_t> alloc_time_;
|
|
std::atomic<std::size_t> dealloc_count_;
|
|
std::atomic<std::size_t> dealloc_size_;
|
|
std::atomic<std::size_t> dealloc_time_;
|
|
};
|
|
|
|
inline DefaultAllocator& gDefaultAllocator() {
|
|
static DefaultAllocator v;
|
|
return v;
|
|
}
|
|
|
|
} // namespace mmdeploy::cuda
|
|
|
|
#endif // MMDEPLOY_SRC_DEVICE_CUDA_DEFAULT_ALLOCATOR_H_
|