Skip to content

Commit

Permalink
Initital WIP draft
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Aug 22, 2024
1 parent e158a93 commit 0a784af
Show file tree
Hide file tree
Showing 45 changed files with 2,801 additions and 29 deletions.
34 changes: 18 additions & 16 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,13 @@
# https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2
resolver = "2"

members = [
"crates/*",
"examples/*", "profiling/matmul-example",
"xtask",
]
members = ["crates/*", "examples/*", "profiling/matmul-example", "xtask"]

[workspace.package]
edition = "2021"
version = "0.1.1"
license = "MIT OR Apache-2.0"
readme = "README.md"
version = "0.1.1"


[workspace.dependencies]
Expand All @@ -29,23 +25,23 @@ serde = { version = "1.0.204", default-features = false, features = [
serde_json = { version = "1.0.119", default-features = false }

dashmap = "5.5.3"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
hashbrown = "0.14.5"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }

getrandom = { version = "0.2.15", default-features = false }
rand = { version = "0.8.5", default-features = false, features = [
"std_rng",
] } # std_rng is for no_std
getrandom = { version = "0.2.15", default-features = false }

pollster = "0.3"
async-channel = "2.3"
dirs = "5.0.1"
web-time = "1.1.0"
md5 = "0.7.0"
async-channel = "2.3"
pollster = "0.3"
web-time = "1.1.0"

# Testing
serial_test = "3.1.1"
rstest = "0.19.0"
serial_test = "3.1.1"

bytemuck = "1.16.1"
half = { version = "2.4.1", features = [
Expand All @@ -58,17 +54,23 @@ num-traits = { version = "0.2.19", default-features = false, features = [
] } # libm is for no_std

proc-macro2 = "1.0.86"
syn = { version = "2.0.69", features = ["full", "extra-traits"] }
quote = "1.0.36"
syn = { version = "2.0.69", features = ["full", "extra-traits"] }

# xtask
anyhow = "1.0.86"
clap = { version = "4.5.9", features = ["derive"] }
derive_more = { version = "0.99.18", features = ["display"], default-features = false }
derive_more = { version = "1", features = [
"display",
"add",
"mul",
], default-features = false }
env_logger = "0.11.3"
strum = {version = "0.26.3", features = ["derive"]}
strum = { version = "0.26.3", features = ["derive"] }

portable-atomic-util = { version = "0.2.2", features = ["alloc"] } # alloc is for no_std
portable-atomic-util = { version = "0.2.2", features = [
"alloc",
] } # alloc is for no_std

[profile.dev]
opt-level = 2
18 changes: 13 additions & 5 deletions crates/cubecl-common/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
[package]
authors = ["Dilshod Tadjibaev (@antimora)", "Nathaniel Simard (@nathanielsimard)"]
authors = [
"Dilshod Tadjibaev (@antimora)",
"Nathaniel Simard (@nathanielsimard)",
]
categories = ["science", "mathematics", "algorithms"]
description = "Common crate for CubeCL"
edition.workspace = true
Expand All @@ -20,18 +23,23 @@ web-time = { version = "1.1.0" }

[dependencies]
# ** Please make sure all dependencies support no_std when std is disabled **
spin = { workspace = true } # using in place of use std::sync::Mutex;
derive-new = { workspace = true }
serde = { workspace = true }
rand = { workspace = true }
derive_more = { workspace = true }
pollster = { workspace = true, optional = true }
rand = { workspace = true }
serde = { workspace = true }
spin = { workspace = true } # using in place of use std::sync::Mutex;

[target.'cfg(target_has_atomic = "ptr")'.dependencies]
spin = { workspace = true, features = ["mutex", "spin_mutex"] }

[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic-util = { workspace = true }
spin = { workspace = true, features = ["mutex", "spin_mutex", "portable_atomic"] }
spin = { workspace = true, features = [
"mutex",
"spin_mutex",
"portable_atomic",
] }

[dev-dependencies]
dashmap = { workspace = true }
2 changes: 2 additions & 0 deletions crates/cubecl-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub mod benchmark;
/// notation.
pub mod reader;

/// Operators used by macro and IR
pub mod operator;
/// Synchronization type module, used both by ComputeServer and Backends.
pub mod sync_type;

Expand Down
80 changes: 80 additions & 0 deletions crates/cubecl-common/src/operator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use derive_more::derive::Display;

/// An operator used in the intermediate representaion
#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)]
pub enum Operator {
// Arithmetic
/// Add (+) operator
Add,
/// Sub (-) operator
Sub,
/// Mul (*) operator
Mul,
/// Div (/) operator
Div,
/// Rem (%) operator
Rem,

// Arithmetic Assign
/// Add assign (+=) operator
AddAssign,
/// Sub assign (-=) operator
SubAssign,
/// Mul assing (*=) operator
MulAssign,
/// Div assign (/=) operator
DivAssign,
/// Rem assign (%=) operator
RemAssign,

// Comparison
/// Equals (==) operator
Eq,
/// Not equal (!=) operator
Ne,
/// Less than (<) operator
Lt,
/// Less than equals (<=) operator
Le,
/// Greater than equal (>=) operator
Ge,
/// Greater than (>) operator
Gt,

// Boolean
/// And (&&) operator
And,
/// Or (||) operator
Or,
/// Bitwise XOR (^) operator
BitXor,
/// Bitwise And (&) operator
BitAnd,
/// Bitwise Or (|) operator
BitOr,

// Boolean assign
/// Bitwise xor assign (^=) operator
BitXorAssign,
/// Bitwise and assign (&=) operator
BitAndAssign,
/// Bitwise or assign (|=) operator
BitOrAssign,

/// Shift left (<<) operator
Shl,
/// Shift right (>>) operator
Shr,
/// Shift left assign (<<=) operator
ShlAssign,
/// Shift right assign (>>= operator)
ShrAssign,

// Unary
/// Dereference operator (*)
Deref,
/// Not operator (!)
Not,
/// Negation unary operator (-)
Neg,
}
8 changes: 5 additions & 3 deletions crates/cubecl-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,21 @@ version.workspace = true

[features]
default = ["cubecl-runtime/default"]
export_tests = []
std = ["cubecl-runtime/std"]
template = []
export_tests = []

[dependencies]
cubecl-common = { path = "../cubecl-common", version = "0.1.1", default-features = false }
cubecl-runtime = { path = "../cubecl-runtime", version = "0.1.1", default-features = false }

bytemuck = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
serde = { workspace = true }
cubecl-macros = { path = "../cubecl-macros", version = "0.1.1" }
derive-new = { workspace = true }
derive_more = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
num-traits = { workspace = true }
serde = { workspace = true }

log = { workspace = true }

Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/src/codegen/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ fn create_scalar_handles<R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeE
Elem::UInt => 2,
Elem::AtomicUInt => 2,
Elem::Bool => panic!("Bool scalars are not supported"),
Elem::Pointer => panic!("Pointer scalars are not supported"),
};
let scalar_priorities: [usize; 3] = [
element_priority(E1::cube_elem()),
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/src/compute/launcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ impl<R: Runtime> KernelLauncher<R> {
Elem::UInt => self.scalar_u32.register::<R>(client, &mut bindings),
Elem::AtomicUInt => self.scalar_u32.register::<R>(client, &mut bindings),
Elem::Bool => panic!("Bool can't be passed as bindings."),
Elem::Pointer => panic!("Pointer can't be passed as bindings."),
}
}

Expand Down
7 changes: 7 additions & 0 deletions crates/cubecl-core/src/ir/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub enum Elem {
UInt,
AtomicUInt,
Bool,
Pointer,
}

impl Elem {
Expand All @@ -66,6 +67,7 @@ impl Elem {
Elem::Bool => ConstantScalarValue::Bool(val > 0.0),
Elem::AtomicInt(kind) => ConstantScalarValue::Int(val as i64, *kind),
Elem::AtomicUInt => ConstantScalarValue::UInt(val as u64),
Elem::Pointer => panic!("Can't create pointer from constant"),
})
}
/// Create a constant scalar from a signed integer.
Expand All @@ -79,6 +81,7 @@ impl Elem {
Elem::Bool => ConstantScalarValue::Bool(val > 0),
Elem::AtomicInt(kind) => ConstantScalarValue::Int(val, *kind),
Elem::AtomicUInt => ConstantScalarValue::UInt(val as u64),
Elem::Pointer => panic!("Can't create pointer from constant"),
})
}
/// Create a constant scalar from a unsigned integer.
Expand All @@ -92,6 +95,7 @@ impl Elem {
Elem::Bool => ConstantScalarValue::Bool(val > 0),
Elem::AtomicInt(kind) => ConstantScalarValue::Int(val as i64, *kind),
Elem::AtomicUInt => ConstantScalarValue::UInt(val),
Elem::Pointer => panic!("Can't create pointer from constant"),
})
}
/// Create a constant scalar from a boolean.
Expand All @@ -105,6 +109,7 @@ impl Elem {
Elem::UInt => ConstantScalarValue::UInt(val as u64),
Elem::AtomicUInt => ConstantScalarValue::UInt(val as u64),
Elem::Bool => ConstantScalarValue::Bool(val),
Elem::Pointer => panic!("Can't create pointer from constant"),
})
}

Expand Down Expand Up @@ -142,6 +147,7 @@ impl Elem {
Elem::UInt => core::mem::size_of::<u32>(),
Elem::AtomicUInt => core::mem::size_of::<u32>(),
Elem::Bool => core::mem::size_of::<bool>(),
Elem::Pointer => core::mem::size_of::<usize>(),
}
}

Expand Down Expand Up @@ -176,6 +182,7 @@ impl Display for Elem {
Self::UInt => f.write_str("uint"),
Self::AtomicUInt => f.write_str("atomic<uint>"),
Self::Bool => f.write_str("bool"),
Self::Pointer => f.write_str("ptr"),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/src/ir/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ impl Scope {
Elem::UInt => ConstantScalarValue::UInt(value.to_u64().unwrap()),
Elem::AtomicUInt => ConstantScalarValue::UInt(value.to_u64().unwrap()),
Elem::Bool => ConstantScalarValue::Bool(value.to_u32().unwrap() == 1),
Elem::Pointer => panic!("Can't initialize pointer with a value"),
};
let local = self.create_local(item);
let value = Variable::ConstantScalar(value);
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub mod prelude;
mod pod;
mod runtime;

pub mod new_ir;

pub use codegen::*;
pub use pod::*;
pub use runtime::*;
Expand Down
49 changes: 49 additions & 0 deletions crates/cubecl-core/src/new_ir/branch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use super::{Block, Expr, Expression, SquareType, Variable};

pub struct Break;

impl Expr for Break {
type Output = ();

fn expression_untyped(&self) -> super::Expression {
Expression::Break
}
}

pub struct Continue;

impl Expr for Continue {
type Output = ();

fn expression_untyped(&self) -> Expression {
Expression::Continue
}
}

pub struct ForLoop<TNum: SquareType> {
pub from: Box<dyn Expr<Output = TNum>>,
pub to: Box<dyn Expr<Output = TNum>>,
pub step: Option<Box<dyn Expr<Output = TNum>>>,
pub unroll: bool,
pub variable: Variable<TNum>,

pub block: Block<()>,
}

impl<TNum: SquareType> Expr for ForLoop<TNum> {
type Output = ();

fn expression_untyped(&self) -> Expression {
Expression::ForLoop {
from: Box::new(self.from.expression_untyped()),
to: Box::new(self.to.expression_untyped()),
step: self
.step
.as_ref()
.map(|step| Box::new(step.expression_untyped())),
unroll: self.unroll,
variable: Box::new(self.variable.expression_untyped()),
block: self.block.statements.iter().cloned().collect(),
}
}
}
Loading

0 comments on commit 0a784af

Please sign in to comment.