From 3ada2873e1b7563f9f6eb50200b1a0bf00805741 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 12 Aug 2024 13:12:27 -0400 Subject: [PATCH] Fix: CUDA segfault when slice ptr is dropped before executed --- crates/cubecl-cuda/src/compute/server.rs | 3 +-- crates/cubecl-cuda/src/compute/storage.rs | 22 +++------------------- crates/cubecl-cuda/src/runtime.rs | 2 +- 3 files changed, 5 insertions(+), 22 deletions(-) diff --git a/crates/cubecl-cuda/src/compute/server.rs b/crates/cubecl-cuda/src/compute/server.rs index 6ea4d209..6f824e20 100644 --- a/crates/cubecl-cuda/src/compute/server.rs +++ b/crates/cubecl-cuda/src/compute/server.rs @@ -190,6 +190,7 @@ impl> CudaContext { unsafe { cudarc::driver::result::stream::synchronize(self.stream).unwrap(); }; + self.memory_management.storage().flush(); } fn compile_kernel( @@ -279,8 +280,6 @@ impl> CudaContext { ) .unwrap(); }; - - self.memory_management.storage().flush(resources) } } diff --git a/crates/cubecl-cuda/src/compute/storage.rs b/crates/cubecl-cuda/src/compute/storage.rs index de7d1521..c59a3a0f 100644 --- a/crates/cubecl-cuda/src/compute/storage.rs +++ b/crates/cubecl-cuda/src/compute/storage.rs @@ -8,7 +8,6 @@ pub struct CudaStorage { deallocations: Vec, stream: cudarc::driver::sys::CUstream, activate_slices: HashMap, - activate_slices_count: HashMap, } #[derive(new, Debug, Hash, PartialEq, Eq, Clone)] @@ -34,7 +33,6 @@ impl CudaStorage { deallocations: Vec::new(), stream, activate_slices: HashMap::new(), - activate_slices_count: HashMap::new(), } } @@ -49,17 +47,8 @@ impl CudaStorage { } } - pub fn flush(&mut self, resources: Vec) { - for resource in resources { - let key = ActiveResource::new(resource.ptr, resource.kind); - if let Some(count) = self.activate_slices_count.remove(&key) { - if count == 1 { - self.activate_slices.remove(&key); - } else { - self.activate_slices_count.insert(key, count - 1); - } - } - } + pub fn flush(&mut self) { + self.activate_slices.clear(); } } @@ -129,12 +118,7 @@ impl ComputeStorage for CudaStorage { let kind = CudaResourceKind::Slice { size, offset }; let key = ActiveResource::new(ptr, kind.clone()); - if let Some(count) = self.activate_slices_count.get_mut(&key) { - *count += 1; - } else { - self.activate_slices.insert(key.clone(), ptr); - self.activate_slices_count.insert(key.clone(), 1); - } + self.activate_slices.insert(key.clone(), ptr); // The ptr needs to stay alive until we send the task to the server. let ptr = self.activate_slices.get(&key).unwrap(); diff --git a/crates/cubecl-cuda/src/runtime.rs b/crates/cubecl-cuda/src/runtime.rs index 1a5cf7dc..4c5caa78 100644 --- a/crates/cubecl-cuda/src/runtime.rs +++ b/crates/cubecl-cuda/src/runtime.rs @@ -47,7 +47,7 @@ impl Runtime for CudaRuntime { ) .unwrap(); let storage = CudaStorage::new(stream); - let options = DynamicMemoryManagementOptions::preset(2048 * 1024 * 1024, 32); + let options = DynamicMemoryManagementOptions::preset(2048 + 512 * 1024 * 1024, 32); let memory_management = DynamicMemoryManagement::new(storage, options); CudaContext::new(memory_management, stream, ctx) }