Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor for wgpu v23 compatibility with an example of wgpu device sharing #211

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
a3c4e33
Refactor for wgpu v23 compatibility
AsherJingkongChen Oct 29, 2024
7f689c0
Add example of using existing wgpu device
AsherJingkongChen Oct 29, 2024
13dd59d
Format
AsherJingkongChen Oct 29, 2024
bfa5032
Patch changes for SPIR-V compiler
AsherJingkongChen Oct 29, 2024
dd04e6c
Make WgpuDevice::Existing unique
AsherJingkongChen Oct 29, 2024
e676ccf
Merge commit '1226222ad8be9ac1701d4b40b4c454ffa6d3ee20' into refactor…
AsherJingkongChen Oct 29, 2024
3833348
Adjust namings and comments
AsherJingkongChen Oct 29, 2024
567e9c3
Merge branch 'refactor/cubecl-wgpu/naming' into refactor/wgpu-v23
AsherJingkongChen Oct 29, 2024
c1203d7
Adjust example for wgpu runtime updates
AsherJingkongChen Oct 29, 2024
ddffbb6
Merge branch 'main' into refactor/wgpu-v23
AsherJingkongChen Oct 30, 2024
c8ff1e1
Update wgpu 23.0.0
AsherJingkongChen Oct 30, 2024
038bed6
Fix doc links
AsherJingkongChen Oct 31, 2024
57f21ab
Merge remote-tracking branch 'origin/main' into refactor/wgpu-v23
AsherJingkongChen Oct 31, 2024
c6ec950
Merge commit 'd85d503895cadf773bc25a65ce515aee15d17a33' into refactor…
AsherJingkongChen Nov 12, 2024
ba78b17
Add wgpu spirv required features
AsherJingkongChen Nov 12, 2024
e04d19e
Merge remote-tracking branch 'origin/main' into refactor/wgpu-v23
AsherJingkongChen Nov 25, 2024
1f13590
Merge remote-tracking branch 'origin/main' into refactor/wgpu-v23
AsherJingkongChen Nov 28, 2024
8636b60
Merge branch 'main' of https://github.com/tracel-ai/cubecl into refac…
AsherJingkongChen Nov 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions crates/cubecl-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ default = [
"cubecl-runtime/default",
"cubecl-common/default",
"cubecl-core/default",
"spirv",
]
exclusive-memory-only = ["cubecl-runtime/exclusive-memory-only"]
spirv = ["cubecl-spirv", "ash"]
spirv = ["cubecl-spirv", "ash", "wgpu/spirv", "wgpu-core/vulkan"]
std = ["cubecl-runtime/std", "cubecl-common/std", "cubecl-core/std"]

spirv-dump = ["sanitize-filename"]
Expand All @@ -35,7 +36,9 @@ ash = { version = "0.38", optional = true }
cubecl-spirv = { path = "../cubecl-spirv", version = "0.4.0", optional = true }

bytemuck = { workspace = true }
wgpu = { version = "22.0.0", features = ["fragile-send-sync-non-atomic-wasm"] }
wgpu = { version = "23.0.0", features = ["fragile-send-sync-non-atomic-wasm"] }
wgpu-core = { version = "23.0.0" }
wgpu-hal = { version = "23.0.0" }

async-channel = { workspace = true }
derive-new = { workspace = true }
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-wgpu/src/compiler/spirv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ impl WgpuCompiler for SpirvCompiler<GLCompute> {
label: None,
layout: layout.as_ref(),
module: &module,
entry_point: &kernel.entrypoint_name,
entry_point: Some(&kernel.entrypoint_name),
compilation_options: wgpu::PipelineCompilationOptions {
zero_initialize_workgroup_memory: false,
..Default::default()
Expand Down Expand Up @@ -280,7 +280,7 @@ fn request_device(
adapter
.device_from_raw(
vk_device,
true,
None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I am doing the equivalent change here.

&device_extensions,
features,
&wgpu::MemoryHints::MemoryUsage,
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ impl WgpuCompiler for WgslCompiler {
label: None,
layout: layout.as_ref(),
module: &module,
entry_point: &kernel.entrypoint_name,
entry_point: Some(&kernel.entrypoint_name),
compilation_options: wgpu::PipelineCompilationOptions {
zero_initialize_workgroup_memory: false,
..Default::default()
Expand Down
6 changes: 4 additions & 2 deletions crates/cubecl-wgpu/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ pub enum WgpuDevice {
/// Use an externally created, existing, wgpu setup. This is helpful when using CubeCL in conjunction
/// with some existing wgpu setup (eg. egui or bevy), as resources can be transferred in & out of CubeCL.
///
/// The device is indexed by the global wgpu [adapter ID](wgpu::Device::global_id).
Existing(wgpu::Id<wgpu::Device>),
/// # Notes
///
/// This can be initialized with [`init_device`](crate::runtime::init_device).
Existing(u32),
}

impl Default for WgpuDevice {
Expand Down
18 changes: 17 additions & 1 deletion crates/cubecl-wgpu/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,24 @@ pub struct WgpuSetup {

/// Create a [`WgpuDevice`] on an existing [`WgpuSetup`].
/// Useful when you want to share a device between CubeCL and other wgpu-dependent libraries.
///
/// # Note
///
/// Please **do not** to call on the same [`setup`](WgpuSetup) more than once.
///
/// This function generates a new, globally unique ID for the device every time it is called,
/// even if called on the same device multiple times.
pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice {
let device_id = WgpuDevice::Existing(setup.device.as_ref().global_id());
use core::sync::atomic::{AtomicU32, Ordering};

static COUNTER: AtomicU32 = AtomicU32::new(0);

let device_id = COUNTER.fetch_add(1, Ordering::Relaxed);
if device_id == u32::MAX {
core::panic!("Memory ID overflowed");
}

let device_id = WgpuDevice::Existing(device_id);
let client = create_client_on_setup(setup, options);
RUNTIME.register(&device_id, client);
device_id
Expand Down
19 changes: 19 additions & 0 deletions examples/device_sharing/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
authors = []
edition.workspace = true
license.workspace = true
name = "device_sharing"
publish = false
version.workspace = true

[features]
default = []
wgpu = ["cubecl/wgpu"]
cuda = ["cubecl/cuda"]

[dependencies]
cubecl = { path = "../../crates/cubecl", version = "0.4.0" }
half = { workspace = true }

sum_things = { path = "../sum_things" }
wgpu = { version = "23.0.0", features = ["fragile-send-sync-non-atomic-wasm"] }
9 changes: 9 additions & 0 deletions examples/device_sharing/examples/device_sharing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
fn main() {
#[cfg(feature = "wgpu")]
{
let setup_shared = device_sharing::create_wgpu_setup_from_raw();
let device_cubecl = cubecl::wgpu::init_device(setup_shared.clone(), Default::default());
device_sharing::assert_wgpu_device_existing(&device_cubecl);
sum_things::launch::<cubecl::wgpu::WgpuRuntime>(&device_cubecl);
}
}
45 changes: 45 additions & 0 deletions examples/device_sharing/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#[cfg(feature = "wgpu")]
mod device_sharing_wgpu {
use cubecl::wgpu::{WgpuDevice, WgpuSetup};

pub fn create_wgpu_setup_from_raw() -> WgpuSetup {
cubecl::future::block_on(create_wgpu_setup_from_raw_async())
}

pub async fn create_wgpu_setup_from_raw_async() -> WgpuSetup {
let instance = wgpu::Instance::default();
let adapter = instance
.request_adapter(&Default::default())
.await
.expect("Failed to create wgpu adapter from instance");
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
label: Some("Raw"),
required_features: adapter.features(),
required_limits: adapter.limits(),
memory_hints: wgpu::MemoryHints::MemoryUsage,
},
None,
)
.await
.expect("Failed to create wgpu device from adapter");

WgpuSetup {
instance: instance.into(),
adapter: adapter.into(),
device: device.into(),
queue: queue.into(),
}
}

pub fn assert_wgpu_device_existing(device: &WgpuDevice) {
assert!(
matches!(device, cubecl::wgpu::WgpuDevice::Existing(_)),
"device should be WgpuDevice::Existing"
);
}
}

#[cfg(feature = "wgpu")]
pub use device_sharing_wgpu::*;
Loading