Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement macro to make writing kernels more ergonomic #80

Merged
merged 67 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
0a784af
Initital WIP draft
wingertge Aug 22, 2024
4f53194
Don't check const arguments for `SquareType`
wingertge Aug 22, 2024
3abe68e
Add vectorization tracing
wingertge Aug 22, 2024
601d046
Refactor field expansion, add method expansion
wingertge Aug 23, 2024
4140435
Add macros to simplify expansion impl
wingertge Aug 24, 2024
6a0d34a
Merge MethodExpand and FieldExpand since they don't conflict and have…
wingertge Aug 24, 2024
a371b12
Clean up code
wingertge Aug 24, 2024
f76e6c6
Add support for associated functions
wingertge Aug 24, 2024
8f06255
Implement `Expr` for values to make `Literal` superfluous
wingertge Aug 25, 2024
f1750f5
Implement for loop
wingertge Aug 25, 2024
39453ce
Fix comptime bounds
wingertge Aug 25, 2024
3be1aba
Implement while and loop
wingertge Aug 25, 2024
3bee11f
Implement if
wingertge Aug 25, 2024
84a8506
Implement explicit return
wingertge Aug 25, 2024
543cd64
Implement index
wingertge Aug 26, 2024
36ae70f
Implement slices
wingertge Aug 26, 2024
30660df
Fix up tensors and other complex types to work with references
wingertge Aug 26, 2024
c8f2088
Improve IDE handling when proc macro fails
wingertge Aug 26, 2024
a673f6e
Code cleanup
wingertge Aug 27, 2024
763f94a
Add struct destructuring
wingertge Aug 27, 2024
c0e96fc
Start implementing flattening
wingertge Aug 28, 2024
1ef15e7
Get first test working on new macro
wingertge Aug 29, 2024
42ea6b0
Implement for loop unrolling
wingertge Aug 29, 2024
4bbbdb1
Convert more tests
wingertge Aug 29, 2024
1741f96
More testing
wingertge Aug 30, 2024
ec2d511
Merge remote-tracking branch 'origin/main' into new-ir
wingertge Aug 30, 2024
cc119d6
Finish implementing runtime tests
wingertge Aug 30, 2024
a70ee40
More testing and some fixes to `if` generation. Also make sure to fre…
wingertge Aug 30, 2024
18e3daa
Intermediate commit
wingertge Sep 1, 2024
379b1f9
Remove old macro
wingertge Sep 1, 2024
b587f26
Fix traits
wingertge Sep 1, 2024
3e5930f
Commit before expand rework
wingertge Sep 2, 2024
e018e33
Temp commit
wingertge Sep 3, 2024
871f45a
Temp commit
wingertge Sep 4, 2024
dc35ab2
Temp commit
wingertge Sep 5, 2024
c34922d
Revert to old IR and clean up `CubeType` macro
wingertge Sep 5, 2024
3f11196
Start backport to old IR
wingertge Sep 6, 2024
e230a45
More implementation stuff
wingertge Sep 6, 2024
1543230
remove leftover macro code
wingertge Sep 6, 2024
1fd19d6
Cleanup
wingertge Sep 6, 2024
82ef4c2
Finish backporting
wingertge Sep 7, 2024
4a7bfb3
Fix several bugs and try to improve codegen spans
wingertge Sep 8, 2024
6139cd1
Merge remote-tracking branch 'origin/main' into new-ir
wingertge Sep 8, 2024
8eaa2c2
Fix array and trybuild tests
wingertge Sep 8, 2024
79433a1
Fix assign tests
wingertge Sep 8, 2024
84ad6a4
Fix cast tests
wingertge Sep 8, 2024
22ab844
Fix comptime tests
wingertge Sep 8, 2024
df8c970
Fix more frontend tests
wingertge Sep 8, 2024
02bc447
Fix remaining tests
wingertge Sep 8, 2024
6a94008
Implement inclusive ranges
wingertge Sep 9, 2024
fe69647
Insert index import for trait functions with default and trait impls
wingertge Sep 9, 2024
994215e
Fix `KernelLauncher` path
wingertge Sep 9, 2024
4dd31f3
Track variable use in user-defined closures
wingertge Sep 9, 2024
3e7b379
Implement missing assign ops
wingertge Sep 9, 2024
48fa1df
Fix bugs and edge cases encountered in burn
wingertge Sep 9, 2024
7fd3cf9
Remove leftover println
wingertge Sep 9, 2024
a4fb8c7
Merge remote-tracking branch 'origin/main' into new-ir
wingertge Sep 9, 2024
0bbaeef
Remove commented out frontend tests in favor of the existing ones.
wingertge Sep 9, 2024
efac371
Remove non-error spans for now since they're useless on stable
wingertge Sep 10, 2024
a57380f
Fix concerns from review
wingertge Sep 10, 2024
3bbc019
Normalize line endings on test comparison files
wingertge Sep 10, 2024
c8ef9f6
Merge branch 'main' into new-ir
wingertge Sep 10, 2024
6a6b809
Replace panics in codegen with `compile_error!`
wingertge Sep 10, 2024
3184f13
Allow using qualified `vectorization_of`, add infra for potential fut…
wingertge Sep 10, 2024
cfaa510
Add vectorization_of intrinsic test
wingertge Sep 10, 2024
0757330
Add `comptime` macro and tests for it
wingertge Sep 10, 2024
c3ef9dc
Remove compiletest_rs dependency
wingertge Sep 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ Cargo.lock

# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
**/out
.clangd
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"files.eol": "\n"
}
32 changes: 17 additions & 15 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,13 @@
# https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2
resolver = "2"

members = [
"crates/*",
"examples/*",
"xtask",
]
members = ["crates/*", "examples/*", "xtask"]

[workspace.package]
edition = "2021"
version = "0.2.0"
license = "MIT OR Apache-2.0"
readme = "README.md"
version = "0.2.0"


[workspace.dependencies]
Expand All @@ -29,23 +25,24 @@ serde = { version = "1.0.204", default-features = false, features = [
serde_json = { version = "1.0.119", default-features = false }

dashmap = "5.5.3"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
hashbrown = "0.14.5"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }

getrandom = { version = "0.2.15", default-features = false }
rand = { version = "0.8.5", default-features = false, features = [
"std_rng",
] } # std_rng is for no_std
getrandom = { version = "0.2.15", default-features = false }

pollster = "0.3"
async-channel = "2.3"
dirs = "5.0.1"
web-time = "1.1.0"
md5 = "0.7.0"
async-channel = "2.3"
pollster = "0.3"
weak-table = "0.3"
web-time = "1.1.0"

# Testing
serial_test = "3.1.1"
rstest = "0.19.0"
serial_test = "3.1.1"

bytemuck = "1.16.1"
half = { version = "2.4.1", features = [
Expand All @@ -57,15 +54,20 @@ num-traits = { version = "0.2.19", default-features = false, features = [
"libm",
] } # libm is for no_std

darling = "0.20.10"
ident_case = "1"
proc-macro2 = "1.0.86"
syn = { version = "2.0.69", features = ["full", "extra-traits"] }
quote = "1.0.36"
syn = { version = "2", features = ["full", "extra-traits", "visit-mut"] }

### For xtask crate ###
strum = {version = "0.26.3", features = ["derive"]}
strum = { version = "0.26.3", features = ["derive"] }
tracel-xtask = { version = "~1.0" }

portable-atomic-util = { version = "0.2.2", features = ["alloc"] } # alloc is for no_std
portable-atomic-util = { version = "0.2.2", features = [
"alloc",
] } # alloc is for no_std
pretty_assertions = "1.4"

[profile.dev]
opt-level = 2
17 changes: 12 additions & 5 deletions crates/cubecl-common/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
[package]
authors = ["Dilshod Tadjibaev (@antimora)", "Nathaniel Simard (@nathanielsimard)"]
authors = [
"Dilshod Tadjibaev (@antimora)",
"Nathaniel Simard (@nathanielsimard)",
]
categories = ["science", "mathematics", "algorithms"]
description = "Common crate for CubeCL"
edition.workspace = true
Expand All @@ -20,18 +23,22 @@ web-time = { version = "1.1.0" }

[dependencies]
# ** Please make sure all dependencies support no_std when std is disabled **
spin = { workspace = true } # using in place of use std::sync::Mutex;
derive-new = { workspace = true }
serde = { workspace = true }
rand = { workspace = true }
pollster = { workspace = true, optional = true }
rand = { workspace = true }
serde = { workspace = true }
spin = { workspace = true } # using in place of use std::sync::Mutex;

[target.'cfg(target_has_atomic = "ptr")'.dependencies]
spin = { workspace = true, features = ["mutex", "spin_mutex"] }

[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic-util = { workspace = true }
spin = { workspace = true, features = ["mutex", "spin_mutex", "portable_atomic"] }
spin = { workspace = true, features = [
"mutex",
"spin_mutex",
"portable_atomic",
] }

[dev-dependencies]
dashmap = { workspace = true }
2 changes: 2 additions & 0 deletions crates/cubecl-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub mod benchmark;
/// notation.
pub mod reader;

/// Operators used by macro and IR
pub mod operator;
/// Synchronization type module, used both by ComputeServer and Backends.
pub mod sync_type;

Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-common/src/operator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

8 changes: 5 additions & 3 deletions crates/cubecl-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,23 @@ version.workspace = true

[features]
default = ["cubecl-runtime/default"]
export_tests = []
std = ["cubecl-runtime/std"]
template = []
export_tests = []

[dependencies]
cubecl-runtime = { path = "../cubecl-runtime", version = "0.2.0", default-features = false }

bytemuck = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
serde = { workspace = true }
cubecl-common = { path = "../cubecl-common", version = "0.2.0" }
cubecl-macros = { path = "../cubecl-macros", version = "0.2.0" }
derive-new = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
num-traits = { workspace = true }
serde = { workspace = true }

log = { workspace = true }

[dev-dependencies]
pretty_assertions = { workspace = true }
trybuild = "1"
28 changes: 17 additions & 11 deletions crates/cubecl-core/src/codegen/integrator.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::num::NonZero;

use super::Compiler;
use crate::{
ir::{
Expand Down Expand Up @@ -95,18 +97,22 @@ impl core::fmt::Display for KernelSettings {
}

match self.vectorization_global {
Some(vectorization) => f.write_fmt(format_args!("vg{}", vectorization))?,
Some(vectorization) => f.write_fmt(format_args!(
"vg{}",
vectorization.map(NonZero::get).unwrap_or(1)
))?,
None => f.write_str("vn")?,
};

for vectorization in self.vectorization_partial.iter() {
match vectorization {
VectorizationPartial::Input { pos, vectorization } => {
f.write_fmt(format_args!("v{vectorization}i{pos}"))?
}
VectorizationPartial::Output { pos, vectorization } => {
f.write_fmt(format_args!("v{vectorization}o{pos}"))?
}
VectorizationPartial::Input { pos, vectorization } => f.write_fmt(format_args!(
"v{}i{pos}",
vectorization.map(NonZero::get).unwrap_or(1)
))?,
VectorizationPartial::Output { pos, vectorization } => f.write_fmt(
format_args!("v{}o{pos}", vectorization.map(NonZero::get).unwrap_or(1)),
)?,
};
}

Expand All @@ -130,7 +136,7 @@ impl KernelSettings {
pub fn vectorize_input(mut self, position: usize, vectorization: Vectorization) -> Self {
// Not setting the vectorization factor when it's the default value reduces the kernel id
// size.
if vectorization == 1 {
if vectorization.is_none() {
return self;
}

Expand All @@ -147,7 +153,7 @@ impl KernelSettings {
pub fn vectorize_output(mut self, position: usize, vectorization: Vectorization) -> Self {
// Not setting the vectorization factor when it's the default value reduces the kernel id
// size.
if vectorization == 1 {
if vectorization.is_none() {
return self;
}

Expand All @@ -173,7 +179,7 @@ impl KernelSettings {
}
}

1
None
}

/// Fetch the vectorization for the provided output position.
Expand All @@ -190,7 +196,7 @@ impl KernelSettings {
}
}

1
None
}

/// Compile the shader with inplace enabled by the given [mapping](InplaceMapping).
Expand Down
Loading
Loading