From 97895129e5f2bde94d13dc01ca41ee79e9b629f2 Mon Sep 17 00:00:00 2001 From: leonardHONG <2695316095@qq.com> Date: Tue, 21 Apr 2026 05:30:38 +0800 Subject: [PATCH] ggml-cuda: flush legacy pool on OOM and retry (#22155) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml-cuda: flush legacy pool on OOM and retry Signed-off-by: 梁厚宏 <2695316095@qq.com> * Address review comments: add explicit sync, update destructor, clean up MUSA macros Signed-off-by: 梁厚宏 <2695316095@qq.com> --------- Signed-off-by: 梁厚宏 <2695316095@qq.com> --- ggml/src/ggml-cuda/ggml-cuda.cu | 23 +++++++++++++++++++++-- ggml/src/ggml-cuda/vendors/hip.h | 1 + ggml/src/ggml-cuda/vendors/musa.h | 1 + 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ecd12b80df..185956317e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -368,15 +368,21 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { } ~ggml_cuda_pool_leg() { + clear_pool(); + GGML_ASSERT(pool_size == 0); + } + + void clear_pool() { ggml_cuda_set_device(device); for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cuda_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { CUDA_CHECK(cudaFree(b.ptr)); pool_size -= b.size; + b.ptr = nullptr; + b.size = 0; } } - GGML_ASSERT(pool_size == 0); } void * alloc(size_t size, size_t * actual_size) override { @@ -421,7 +427,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); ggml_cuda_set_device(device); - CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); + cudaError_t err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); + if (err == cudaErrorMemoryAllocation) { + (void)cudaGetLastError(); + const size_t cached_bytes = pool_size; + GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: alloc of %.2f MiB failed, flushing %.2f MiB of cached buffers and retrying\n", + device, look_ahead_size/1024.0/1024.0, cached_bytes/1024.0/1024.0); + CUDA_CHECK(cudaDeviceSynchronize()); + clear_pool(); + err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); + if (err == cudaSuccess) { + GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: retry succeeded\n", device); + } + } + CUDA_CHECK(err); *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 52c38908e0..78ca364d38 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -58,6 +58,7 @@ #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t +#define cudaErrorMemoryAllocation hipErrorOutOfMemory #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags hipEventCreateWithFlags diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 1abb8acfd4..8aa056e917 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -42,6 +42,7 @@ #define cudaDeviceProp musaDeviceProp #define cudaDeviceSynchronize musaDeviceSynchronize #define cudaError_t musaError_t +#define cudaErrorMemoryAllocation musaErrorMemoryAllocation #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags musaEventCreateWithFlags