Skip to content

Commit

Permalink
Add support for atomics and atomic operations (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Aug 14, 2024
1 parent 034f667 commit b09821d
Show file tree
Hide file tree
Showing 20 changed files with 963 additions and 24 deletions.
4 changes: 4 additions & 0 deletions crates/cubecl-core/src/codegen/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ fn execute_dynamic<R, K, E1, E2, E3>(
let settings = execute_settings(
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
);

let mut handles = settings.handles_tensors;

handles.push(settings.handle_info.binding());
Expand All @@ -224,6 +225,7 @@ struct ExecuteSettings<R: Runtime> {
cube_count: CubeCount<R::Server>,
}

#[allow(clippy::too_many_arguments)]
fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
inputs: &'a [TensorHandleRef<R>],
outputs: &'a [TensorHandleRef<R>],
Expand Down Expand Up @@ -316,7 +318,9 @@ fn create_scalar_handles<R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeE
let element_priority = |elem: Elem| match elem {
Elem::Float(_) => 0,
Elem::Int(_) => 1,
Elem::AtomicInt(_) => 1,
Elem::UInt => 2,
Elem::AtomicUInt => 2,
Elem::Bool => panic!("Bool scalars are not supported"),
};
let scalar_priorities: [usize; 3] = [
Expand Down
5 changes: 5 additions & 0 deletions crates/cubecl-core/src/compute/launcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,12 @@ impl<R: Runtime> KernelLauncher<R> {
IntKind::I32 => self.scalar_i32.register::<R>(client, &mut bindings),
IntKind::I64 => self.scalar_i64.register::<R>(client, &mut bindings),
},
Elem::AtomicInt(kind) => match kind {
IntKind::I32 => self.scalar_i32.register::<R>(client, &mut bindings),
IntKind::I64 => self.scalar_i64.register::<R>(client, &mut bindings),
},
Elem::UInt => self.scalar_u32.register::<R>(client, &mut bindings),
Elem::AtomicUInt => self.scalar_u32.register::<R>(client, &mut bindings),
Elem::Bool => panic!("Bool can't be passed as bindings."),
}
}
Expand Down
5 changes: 5 additions & 0 deletions crates/cubecl-core/src/frontend/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ impl CubeContext {
/// When a new variable is required, we check if we can reuse an old one
/// Otherwise we create a new one.
pub fn create_local(&mut self, item: Item) -> ExpandElement {
if item.elem.is_atomic() {
let new = self.scope.borrow_mut().create_local_undeclared(item);
return ExpandElement::Plain(new);
}

// Reuse an old variable if possible
if let Some(var) = self.pool.reuse(item) {
return var;
Expand Down
Loading

0 comments on commit b09821d

Please sign in to comment.