diff --git a/crates/cubecl-wgpu/Cargo.toml b/crates/cubecl-wgpu/Cargo.toml index 944e900f3..da00cde61 100644 --- a/crates/cubecl-wgpu/Cargo.toml +++ b/crates/cubecl-wgpu/Cargo.toml @@ -15,9 +15,10 @@ default = [ "cubecl-runtime/default", "cubecl-common/default", "cubecl-core/default", + "spirv", ] exclusive-memory-only = ["cubecl-runtime/exclusive-memory-only"] -spirv = ["cubecl-spirv", "ash"] +spirv = ["cubecl-spirv", "ash", "wgpu/spirv", "wgpu-core/vulkan"] std = ["cubecl-runtime/std", "cubecl-common/std", "cubecl-core/std"] spirv-dump = ["sanitize-filename"] @@ -35,7 +36,9 @@ ash = { version = "0.38", optional = true } cubecl-spirv = { path = "../cubecl-spirv", version = "0.4.0", optional = true } bytemuck = { workspace = true } -wgpu = { version = "22.0.0", features = ["fragile-send-sync-non-atomic-wasm"] } +wgpu = { version = "23.0.0", features = ["fragile-send-sync-non-atomic-wasm"] } +wgpu-core = { version = "23.0.0" } +wgpu-hal = { version = "23.0.0" } async-channel = { workspace = true } derive-new = { workspace = true } diff --git a/crates/cubecl-wgpu/src/compiler/spirv.rs b/crates/cubecl-wgpu/src/compiler/spirv.rs index cec5bc7a8..40200c574 100644 --- a/crates/cubecl-wgpu/src/compiler/spirv.rs +++ b/crates/cubecl-wgpu/src/compiler/spirv.rs @@ -126,7 +126,7 @@ impl WgpuCompiler for SpirvCompiler { label: None, layout: layout.as_ref(), module: &module, - entry_point: &kernel.entrypoint_name, + entry_point: Some(&kernel.entrypoint_name), compilation_options: wgpu::PipelineCompilationOptions { zero_initialize_workgroup_memory: false, ..Default::default() @@ -280,7 +280,7 @@ fn request_device( adapter .device_from_raw( vk_device, - true, + None, &device_extensions, features, &wgpu::MemoryHints::MemoryUsage, diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 1be50e5c5..7af4ca699 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -158,7 +158,7 @@ impl WgpuCompiler for WgslCompiler { label: None, layout: layout.as_ref(), module: &module, - entry_point: &kernel.entrypoint_name, + entry_point: Some(&kernel.entrypoint_name), compilation_options: wgpu::PipelineCompilationOptions { zero_initialize_workgroup_memory: false, ..Default::default() diff --git a/crates/cubecl-wgpu/src/device.rs b/crates/cubecl-wgpu/src/device.rs index a878ad8cb..b268f9a27 100644 --- a/crates/cubecl-wgpu/src/device.rs +++ b/crates/cubecl-wgpu/src/device.rs @@ -41,8 +41,10 @@ pub enum WgpuDevice { /// Use an externally created, existing, wgpu setup. This is helpful when using CubeCL in conjunction /// with some existing wgpu setup (eg. egui or bevy), as resources can be transferred in & out of CubeCL. /// - /// The device is indexed by the global wgpu [adapter ID](wgpu::Device::global_id). - Existing(wgpu::Id), + /// # Notes + /// + /// This can be initialized with [`init_device`](crate::runtime::init_device). + Existing(u32), } impl Default for WgpuDevice { diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index 26208007e..550306ff2 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -109,8 +109,24 @@ pub struct WgpuSetup { /// Create a [`WgpuDevice`] on an existing [`WgpuSetup`]. /// Useful when you want to share a device between CubeCL and other wgpu-dependent libraries. +/// +/// # Note +/// +/// Please **do not** to call on the same [`setup`](WgpuSetup) more than once. +/// +/// This function generates a new, globally unique ID for the device every time it is called, +/// even if called on the same device multiple times. pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice { - let device_id = WgpuDevice::Existing(setup.device.as_ref().global_id()); + use core::sync::atomic::{AtomicU32, Ordering}; + + static COUNTER: AtomicU32 = AtomicU32::new(0); + + let device_id = COUNTER.fetch_add(1, Ordering::Relaxed); + if device_id == u32::MAX { + core::panic!("Memory ID overflowed"); + } + + let device_id = WgpuDevice::Existing(device_id); let client = create_client_on_setup(setup, options); RUNTIME.register(&device_id, client); device_id diff --git a/examples/device_sharing/Cargo.toml b/examples/device_sharing/Cargo.toml new file mode 100644 index 000000000..fb8c135cf --- /dev/null +++ b/examples/device_sharing/Cargo.toml @@ -0,0 +1,19 @@ +[package] +authors = [] +edition.workspace = true +license.workspace = true +name = "device_sharing" +publish = false +version.workspace = true + +[features] +default = [] +wgpu = ["cubecl/wgpu"] +cuda = ["cubecl/cuda"] + +[dependencies] +cubecl = { path = "../../crates/cubecl", version = "0.4.0" } +half = { workspace = true } + +sum_things = { path = "../sum_things" } +wgpu = { version = "23.0.0", features = ["fragile-send-sync-non-atomic-wasm"] } diff --git a/examples/device_sharing/examples/device_sharing.rs b/examples/device_sharing/examples/device_sharing.rs new file mode 100644 index 000000000..3b76446ce --- /dev/null +++ b/examples/device_sharing/examples/device_sharing.rs @@ -0,0 +1,9 @@ +fn main() { + #[cfg(feature = "wgpu")] + { + let setup_shared = device_sharing::create_wgpu_setup_from_raw(); + let device_cubecl = cubecl::wgpu::init_device(setup_shared.clone(), Default::default()); + device_sharing::assert_wgpu_device_existing(&device_cubecl); + sum_things::launch::(&device_cubecl); + } +} diff --git a/examples/device_sharing/src/lib.rs b/examples/device_sharing/src/lib.rs new file mode 100644 index 000000000..f789df7b5 --- /dev/null +++ b/examples/device_sharing/src/lib.rs @@ -0,0 +1,45 @@ +#[cfg(feature = "wgpu")] +mod device_sharing_wgpu { + use cubecl::wgpu::{WgpuDevice, WgpuSetup}; + + pub fn create_wgpu_setup_from_raw() -> WgpuSetup { + cubecl::future::block_on(create_wgpu_setup_from_raw_async()) + } + + pub async fn create_wgpu_setup_from_raw_async() -> WgpuSetup { + let instance = wgpu::Instance::default(); + let adapter = instance + .request_adapter(&Default::default()) + .await + .expect("Failed to create wgpu adapter from instance"); + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: Some("Raw"), + required_features: adapter.features(), + required_limits: adapter.limits(), + memory_hints: wgpu::MemoryHints::MemoryUsage, + }, + None, + ) + .await + .expect("Failed to create wgpu device from adapter"); + + WgpuSetup { + instance: instance.into(), + adapter: adapter.into(), + device: device.into(), + queue: queue.into(), + } + } + + pub fn assert_wgpu_device_existing(device: &WgpuDevice) { + assert!( + matches!(device, cubecl::wgpu::WgpuDevice::Existing(_)), + "device should be WgpuDevice::Existing" + ); + } +} + +#[cfg(feature = "wgpu")] +pub use device_sharing_wgpu::*;