From 8c92ca86e16169dab38cb2649af745492a42cf01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Wed, 17 Jul 2024 16:37:10 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=9F=A5=E8=AF=A2=E5=89=A9?= =?UTF-8?q?=E4=BD=99=E6=98=BE=E5=AD=98=E7=9A=84=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/devices/cuda/fastllm-cuda.cuh | 2 ++ src/devices/cuda/fastllm-cuda.cu | 31 +++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/include/devices/cuda/fastllm-cuda.cuh b/include/devices/cuda/fastllm-cuda.cuh index 0a30742b..080f547b 100644 --- a/include/devices/cuda/fastllm-cuda.cuh +++ b/include/devices/cuda/fastllm-cuda.cuh @@ -5,6 +5,8 @@ extern "C" { #endif void FastllmInitCublas(void); +std::vector FastllmCudaGetFreeSizes(); + void FastllmCudaMallocBigBuffer(size_t size); void FastllmCudaClearBigBuffer(); void *FastllmCudaMalloc(size_t size); diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index 0d6e9070..bfb3e09c 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -110,6 +110,37 @@ cublasHandle_t getFastllmCublasHandle() { return handler; } +std::vector FastllmCudaGetFreeSizes() { + int deviceCount; + auto error = cudaGetDeviceCount(&deviceCount); + if (error != cudaSuccess) { + printf("cudaGetDeviceCount returned %d\n-> %s\n", (int)error, cudaGetErrorString(error)); + return {}; + } + std::vector ret; + + // 遍历所有设备 + for (int i = 0; i < deviceCount; ++i) { + cudaDeviceProp prop; + error = cudaGetDeviceProperties(&prop, i); + if (error == cudaSuccess) { + // printf("Device %d: \"%s\"\n", i, prop.name); + // printf(" Compute capability: %d.%d\n", prop.major, prop.minor); + // printf(" Total global memory: %zu bytes\n", prop.totalGlobalMem); + + // 获取当前设备的显存使用情况 + size_t free = 0, total = 0; + cudaMemGetInfo(&free, &total); + ret.push_back(free); + // printf(" Free memory: %zu bytes\n", free); + // printf(" Remaining memory: %zu bytes\n", total - free); + } else { + printf("cudaGetDeviceProperties returned %d\n-> %s\n", (int)error, cudaGetErrorString(error)); + } + } + return ret; +} + __global__ void GetCudaInfoKernel(int *infos) { #if defined(__CUDA_ARCH__) infos[0] = __CUDA_ARCH__;