Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use simple memory management with wasm #81

Merged
merged 4 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/cubecl-core/src/codegen/integrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,11 @@ impl KernelIntegrator {
let output = match self.expansion.outputs.get_mut(mapping.pos_output) {
Some(output) => output,
None => {
if let Some(binding) = self.input_bindings.get_mut(mapping.pos_input) {
// Update input visibility.
binding.visibility = Visibility::ReadWrite;
}

// The mapping is handled differently, normally by cube itself.
return;
}
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-core/tests/error/for_loop_range.stderr
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
error: Invalid for loop: use [range](cubecl::prelude::range] instead.
error: Invalid for loop: use [range](cubecl::prelude::range] or [range_stepped](cubecl::prelude::range_stepped) instead.
--> tests/error/for_loop_range.rs:6:14
|
6 | for _ in 0..10 {}
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-macros/src/codegen_function/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ use super::{

/// Codegen of for loops
/// Supports range:
/// ```norun
/// ```ignore
/// for i in range(start, end, unroll) {...}
/// ```
/// and range_stepped:
/// ```norun
/// ```ignore
/// for i in range_stepped(start, end, step, unroll) {...}
/// ```
pub(crate) fn codegen_for_loop(
Expand Down
4 changes: 4 additions & 0 deletions crates/cubecl-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ default = [
"cubecl-core/default",
]
std = ["cubecl-runtime/std", "cubecl-common/std", "cubecl-core/std"]
simple-memory-management = []

[dependencies]
cubecl-runtime = { path = "../cubecl-runtime", version = "0.1.1", default-features = false, features = [
Expand All @@ -41,3 +42,6 @@ cubecl-core = { path = "../cubecl-core", version = "0.1.1", features = [
cubecl-linalg = { path = "../cubecl-linalg", version = "0.1.1", features = [
"export_tests",
] }

[build-dependencies]
cfg_aliases = "0.2.1"
8 changes: 8 additions & 0 deletions crates/cubecl-wgpu/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use cfg_aliases::cfg_aliases;

fn main() {
// Setup cfg aliases
cfg_aliases! {
simple_memory_management: { any(feature = "simple-memory-management", target_family = "wasm") },
}
}
5 changes: 5 additions & 0 deletions crates/cubecl-wgpu/src/compiler/wgsl/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,12 @@ impl Display for Location {
impl Display for Visibility {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
// With the dynamic memory strategy we have to put everything read_write.
#[cfg(not(simple_memory_management))]
Visibility::Read => f.write_str("read_write"),
// With the simple memory strategy we can use the correct visibility.
#[cfg(simple_memory_management)]
Visibility::Read => f.write_str("read"),
Visibility::ReadWrite => f.write_str("read_write"),
}
}
Expand Down
60 changes: 38 additions & 22 deletions crates/cubecl-wgpu/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,9 @@ use crate::{
};
use alloc::sync::Arc;
use cubecl_core::{Feature, FeatureSet, Runtime};
use cubecl_runtime::{
channel::MutexComputeChannel,
client::ComputeClient,
memory_management::dynamic::{DynamicMemoryManagement, DynamicMemoryManagementOptions},
ComputeRuntime,
};
use wgpu::DeviceDescriptor;
use cubecl_runtime::memory_management;
use cubecl_runtime::{channel::MutexComputeChannel, client::ComputeClient, ComputeRuntime};
use wgpu::{DeviceDescriptor, Limits};

/// Runtime that uses the [wgpu] crate with the wgsl compiler. This is used in the Wgpu backend.
/// For advanced configuration, use [`init_sync`] to pass in runtime options or to select a
Expand All @@ -23,13 +19,42 @@ pub struct WgpuRuntime;
static RUNTIME: ComputeRuntime<WgpuDevice, Server, MutexComputeChannel<Server>> =
ComputeRuntime::new();

type Server = WgpuServer<DynamicMemoryManagement<WgpuStorage>>;
type Server = WgpuServer<MemoryManagement>;

#[cfg(not(simple_memory_management))]
type MemoryManagement = memory_management::dynamic::DynamicMemoryManagement<WgpuStorage>;
#[cfg(simple_memory_management)]
type MemoryManagement = memory_management::simple::SimpleMemoryManagement<WgpuStorage>;

#[cfg(not(simple_memory_management))]
fn init_memory_management(device: Arc<wgpu::Device>, limits: &Limits) -> MemoryManagement {
let storage = WgpuStorage::new(device.clone());

memory_management::dynamic::DynamicMemoryManagement::new(
storage,
memory_management::dynamic::DynamicMemoryManagementOptions::preset(
limits.max_storage_buffer_binding_size as usize,
limits.min_storage_buffer_offset_alignment as usize,
),
)
}

#[cfg(simple_memory_management)]
fn init_memory_management(device: Arc<wgpu::Device>, _limits: &Limits) -> MemoryManagement {
let storage = WgpuStorage::new(device.clone());

memory_management::simple::SimpleMemoryManagement::new(
storage,
memory_management::simple::DeallocStrategy::new_period_tick(32),
memory_management::simple::SliceStrategy::Ratio(0.8),
)
}

impl Runtime for WgpuRuntime {
type Compiler = wgsl::WgslCompiler;
type Server = WgpuServer<DynamicMemoryManagement<WgpuStorage>>;
type Server = WgpuServer<MemoryManagement>;

type Channel = MutexComputeChannel<WgpuServer<DynamicMemoryManagement<WgpuStorage>>>;
type Channel = MutexComputeChannel<WgpuServer<MemoryManagement>>;
type Device = WgpuDevice;

fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
Expand Down Expand Up @@ -112,19 +137,10 @@ fn create_client(
device_wgpu: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
options: RuntimeOptions,
) -> ComputeClient<
WgpuServer<DynamicMemoryManagement<WgpuStorage>>,
MutexComputeChannel<WgpuServer<DynamicMemoryManagement<WgpuStorage>>>,
> {
) -> ComputeClient<WgpuServer<MemoryManagement>, MutexComputeChannel<WgpuServer<MemoryManagement>>>
{
let limits = device_wgpu.limits();
let storage = WgpuStorage::new(device_wgpu.clone());
let memory_management = DynamicMemoryManagement::new(
storage,
DynamicMemoryManagementOptions::preset(
limits.max_storage_buffer_binding_size as usize,
limits.min_storage_buffer_offset_alignment as usize,
),
);
let memory_management = init_memory_management(device_wgpu.clone(), &limits);
let server = WgpuServer::new(memory_management, device_wgpu, queue, options.tasks_max);
let channel = MutexComputeChannel::new(server);

Expand Down
1 change: 1 addition & 0 deletions crates/cubecl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ default = ["std", "linalg", "cubecl-core/default", "cubecl-wgpu?/default", "cube
std = ["cubecl-core/std", "cubecl-wgpu?/std", "cubecl-cuda?/std"]
template = ["cubecl-core/template"]
linalg = ["dep:cubecl-linalg"]
simple-memory-management = ["cubecl-wgpu?/simple-memory-management"]

# Runtimes
wgpu = ["cubecl-wgpu"]
Expand Down
Loading