Skip to content

Commit

Permalink
Use simple memory management with wasm (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Aug 25, 2024
1 parent 4572305 commit d41dd0c
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 25 deletions.
5 changes: 5 additions & 0 deletions crates/cubecl-core/src/codegen/integrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,11 @@ impl KernelIntegrator {
let output = match self.expansion.outputs.get_mut(mapping.pos_output) {
Some(output) => output,
None => {
if let Some(binding) = self.input_bindings.get_mut(mapping.pos_input) {
// Update input visibility.
binding.visibility = Visibility::ReadWrite;
}

// The mapping is handled differently, normally by cube itself.
return;
}
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-core/tests/error/for_loop_range.stderr
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
error: Invalid for loop: use [range](cubecl::prelude::range] instead.
error: Invalid for loop: use [range](cubecl::prelude::range] or [range_stepped](cubecl::prelude::range_stepped) instead.
--> tests/error/for_loop_range.rs:6:14
|
6 | for _ in 0..10 {}
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-macros/src/codegen_function/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ use super::{

/// Codegen of for loops
/// Supports range:
/// ```norun
/// ```ignore
/// for i in range(start, end, unroll) {...}
/// ```
/// and range_stepped:
/// ```norun
/// ```ignore
/// for i in range_stepped(start, end, step, unroll) {...}
/// ```
pub(crate) fn codegen_for_loop(
Expand Down
4 changes: 4 additions & 0 deletions crates/cubecl-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ default = [
"cubecl-core/default",
]
std = ["cubecl-runtime/std", "cubecl-common/std", "cubecl-core/std"]
simple-memory-management = []

[dependencies]
cubecl-runtime = { path = "../cubecl-runtime", version = "0.1.1", default-features = false, features = [
Expand All @@ -41,3 +42,6 @@ cubecl-core = { path = "../cubecl-core", version = "0.1.1", features = [
cubecl-linalg = { path = "../cubecl-linalg", version = "0.1.1", features = [
"export_tests",
] }

[build-dependencies]
cfg_aliases = "0.2.1"
8 changes: 8 additions & 0 deletions crates/cubecl-wgpu/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use cfg_aliases::cfg_aliases;

fn main() {
// Setup cfg aliases
cfg_aliases! {
simple_memory_management: { any(feature = "simple-memory-management", target_family = "wasm") },
}
}
5 changes: 5 additions & 0 deletions crates/cubecl-wgpu/src/compiler/wgsl/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,12 @@ impl Display for Location {
impl Display for Visibility {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
// With the dynamic memory strategy we have to put everything read_write.
#[cfg(not(simple_memory_management))]
Visibility::Read => f.write_str("read_write"),
// With the simple memory strategy we can use the correct visibility.
#[cfg(simple_memory_management)]
Visibility::Read => f.write_str("read"),
Visibility::ReadWrite => f.write_str("read_write"),
}
}
Expand Down
60 changes: 38 additions & 22 deletions crates/cubecl-wgpu/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,9 @@ use crate::{
};
use alloc::sync::Arc;
use cubecl_core::{Feature, FeatureSet, Runtime};
use cubecl_runtime::{
channel::MutexComputeChannel,
client::ComputeClient,
memory_management::dynamic::{DynamicMemoryManagement, DynamicMemoryManagementOptions},
ComputeRuntime,
};
use wgpu::DeviceDescriptor;
use cubecl_runtime::memory_management;
use cubecl_runtime::{channel::MutexComputeChannel, client::ComputeClient, ComputeRuntime};
use wgpu::{DeviceDescriptor, Limits};

/// Runtime that uses the [wgpu] crate with the wgsl compiler. This is used in the Wgpu backend.
/// For advanced configuration, use [`init_sync`] to pass in runtime options or to select a
Expand All @@ -23,13 +19,42 @@ pub struct WgpuRuntime;
static RUNTIME: ComputeRuntime<WgpuDevice, Server, MutexComputeChannel<Server>> =
ComputeRuntime::new();

type Server = WgpuServer<DynamicMemoryManagement<WgpuStorage>>;
type Server = WgpuServer<MemoryManagement>;

#[cfg(not(simple_memory_management))]
type MemoryManagement = memory_management::dynamic::DynamicMemoryManagement<WgpuStorage>;
#[cfg(simple_memory_management)]
type MemoryManagement = memory_management::simple::SimpleMemoryManagement<WgpuStorage>;

#[cfg(not(simple_memory_management))]
fn init_memory_management(device: Arc<wgpu::Device>, limits: &Limits) -> MemoryManagement {
let storage = WgpuStorage::new(device.clone());

memory_management::dynamic::DynamicMemoryManagement::new(
storage,
memory_management::dynamic::DynamicMemoryManagementOptions::preset(
limits.max_storage_buffer_binding_size as usize,
limits.min_storage_buffer_offset_alignment as usize,
),
)
}

#[cfg(simple_memory_management)]
fn init_memory_management(device: Arc<wgpu::Device>, _limits: &Limits) -> MemoryManagement {
let storage = WgpuStorage::new(device.clone());

memory_management::simple::SimpleMemoryManagement::new(
storage,
memory_management::simple::DeallocStrategy::new_period_tick(32),
memory_management::simple::SliceStrategy::Ratio(0.8),
)
}

impl Runtime for WgpuRuntime {
type Compiler = wgsl::WgslCompiler;
type Server = WgpuServer<DynamicMemoryManagement<WgpuStorage>>;
type Server = WgpuServer<MemoryManagement>;

type Channel = MutexComputeChannel<WgpuServer<DynamicMemoryManagement<WgpuStorage>>>;
type Channel = MutexComputeChannel<WgpuServer<MemoryManagement>>;
type Device = WgpuDevice;

fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
Expand Down Expand Up @@ -112,19 +137,10 @@ fn create_client(
device_wgpu: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
options: RuntimeOptions,
) -> ComputeClient<
WgpuServer<DynamicMemoryManagement<WgpuStorage>>,
MutexComputeChannel<WgpuServer<DynamicMemoryManagement<WgpuStorage>>>,
> {
) -> ComputeClient<WgpuServer<MemoryManagement>, MutexComputeChannel<WgpuServer<MemoryManagement>>>
{
let limits = device_wgpu.limits();
let storage = WgpuStorage::new(device_wgpu.clone());
let memory_management = DynamicMemoryManagement::new(
storage,
DynamicMemoryManagementOptions::preset(
limits.max_storage_buffer_binding_size as usize,
limits.min_storage_buffer_offset_alignment as usize,
),
);
let memory_management = init_memory_management(device_wgpu.clone(), &limits);
let server = WgpuServer::new(memory_management, device_wgpu, queue, options.tasks_max);
let channel = MutexComputeChannel::new(server);

Expand Down
1 change: 1 addition & 0 deletions crates/cubecl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ default = ["std", "linalg", "cubecl-core/default", "cubecl-wgpu?/default", "cube
std = ["cubecl-core/std", "cubecl-wgpu?/std", "cubecl-cuda?/std"]
template = ["cubecl-core/template"]
linalg = ["dep:cubecl-linalg"]
simple-memory-management = ["cubecl-wgpu?/simple-memory-management"]

# Runtimes
wgpu = ["cubecl-wgpu"]
Expand Down

0 comments on commit d41dd0c

Please sign in to comment.