From a61d605021176a5714551a9dc610000ca02a07a6 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 26 Aug 2024 22:39:28 +0200 Subject: [PATCH] Fix wgpu memory corruption --- crates/cubecl-wgpu/src/compute/server.rs | 43 ++++++++++++++++++------ 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index 28146ed6e..fa8a87a6d 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -6,7 +6,7 @@ use cubecl_common::{reader::Reader, sync_type::SyncType}; use cubecl_core::{compute::DebugInformation, prelude::*, server::Handle, FeatureSet, KernelId}; use cubecl_runtime::{ debug::DebugLogger, - memory_management::MemoryManagement, + memory_management::{MemoryHandle, MemoryManagement}, server::{self, ComputeServer}, storage::{ComputeStorage, StorageId}, ExecutionMode, @@ -24,6 +24,7 @@ pub struct WgpuServer> { current_pass: Option>, tasks_count: usize, compute_storage_used: Vec, + copy_handles_used: Vec<(StorageId, u32)>, pipelines: HashMap>, tasks_max: usize, logger: DebugLogger, @@ -54,6 +55,7 @@ where current_pass: None, tasks_count: 0, compute_storage_used: Vec::new(), + copy_handles_used: Vec::new(), pipelines: HashMap::new(), tasks_max, logger: DebugLogger::new(), @@ -188,8 +190,7 @@ where &mut self, binding: server::Binding, ) -> ::Resource { - let handle = self.memory_management.get(binding.memory); - self.memory_management.storage().get(&handle) + self.memory_management.get_resource(binding.memory) } /// When we create a new handle from existing data, we use custom allocations so that we don't @@ -198,15 +199,24 @@ where /// This is important, otherwise the compute passes are going to be too small and we won't be able to /// fully utilize the GPU. fn create(&mut self, data: &[u8]) -> server::Handle { - // Reserve memory on some storage we haven't yet used this command queue. - let memory = self - .memory_management - .reserve(data.len(), &self.compute_storage_used); - - let handle = Handle::new(memory); + // Reserve memory on some storage we haven't yet used this command queue for compute + // or copying. + let total_handles = self + .compute_storage_used + .iter() + .copied() + .chain(self.copy_handles_used.iter().map(|x| x.0)) + .collect::>(); + let memory = self.memory_management.reserve(data.len(), &total_handles); if let Some(len) = NonZero::new(data.len() as u64) { - let resource = self.get_resource(handle.clone().binding()); + let resource_handle = self.memory_management.get(memory.clone().binding()); + + // Dont re-use this handle for writing until the queue is flushed. All writes + // would happen at the start of the submission. + self.copy_handles_used.push((resource_handle.id, 0)); + + let resource = self.memory_management.storage().get(&resource_handle); // Write to the staging buffer. Next queue submission this will copy the data to the GPU. self.queue @@ -215,7 +225,7 @@ where .copy_from_slice(data); } - handle + Handle::new(memory) } fn empty(&mut self, size: usize) -> server::Handle { @@ -309,6 +319,17 @@ where self.tasks_count = 0; self.compute_storage_used.clear(); + self.copy_handles_used.retain_mut(|x| { + // For some unknown reason, we have to make sure + // a buffer isn't used more than once not just in the current + // submission, but also in the next one. + // + // This really needs a better explanation of why this is, or + // some investigation, maybe it's a wgpu bug. + x.1 += 1; + x.1 < 2 + }); + if sync_type == SyncType::Wait { self.device.poll(wgpu::Maintain::Wait); }