diff --git a/.cargo/config.toml b/.cargo/config.toml
index 275927460..84a3562ce 100644
--- a/.cargo/config.toml
+++ b/.cargo/config.toml
@@ -1,2 +1,2 @@
[alias]
-xtask = "run --target-dir target/xtask/debug --package xtask --bin xtask --"
+xtask = "run --target-dir target/xtask --color always --package xtask --bin xtask --"
\ No newline at end of file
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 0c1ac743d..390326dd3 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -1,9 +1,5 @@
name: CI
-env:
- CARGO_TERM_COLOR: always
-
-# For now we execute CI only on PR to save on CI time
on:
push:
branches:
@@ -27,44 +23,163 @@ on:
- '!LICENSE-APACHE'
- '!LICENSE-MIT'
+env:
+ CARGO_TERM_COLOR: always
+ RUST_PREVIOUS_VERSION: 1.79.0
+
+ # Sourced from https://vulkan.lunarg.com/sdk/home#linux
+ VULKAN_SDK_VERSION: "1.3.268"
+
+ # Sourced from https://archive.mesa3d.org/. Bumping this requires
+ # updating the mesa build in https://github.com/gfx-rs/ci-build and creating a new release.
+ MESA_VERSION: "23.3.1"
+ # Corresponds to https://github.com/gfx-rs/ci-build/releases
+ MESA_CI_BINARY_BUILD: "build18"
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
jobs:
- crates:
- runs-on: ubuntu-latest
+ prepare-checks:
+ runs-on: ubuntu-22.04
+ outputs:
+ rust-prev-version: ${{ env.RUST_PREVIOUS_VERSION }}
+ steps:
+ - name: Do Nothing
+ if: false
+ run: echo
+
+ code-quality:
+ runs-on: ubuntu-22.04
+ needs: prepare-checks
+ strategy:
+ matrix:
+ rust: [stable]
+ include:
+ - rust: stable
+ toolchain: stable
steps:
- - uses: actions/checkout@v4
- - uses: dtolnay/rust-toolchain@master
+ - name: Setup Rust
+ uses: tracel-ai/github-actions/setup-rust@v1
with:
- components: clippy, rustfmt
- toolchain: stable
- - uses: Swatinem/rust-cache@v2
+ rust-toolchain: ${{ matrix.toolchain }}
+ cache-key: ${{ matrix.rust }}-linux
+ # --------------------------------------------------------------------------------
+ - name: Audit
+ run: cargo xtask check audit
+ # --------------------------------------------------------------------------------
- name: Format
- run: cargo xtask ci --target crates format
+ shell: bash
+ env:
+ # work around for colors
+ # see: https://github.com/rust-lang/rustfmt/issues/3385
+ TERM: xterm-256color
+ run: cargo xtask check format
+ # --------------------------------------------------------------------------------
- name: Lint
- run: cargo xtask ci --target crates lint
- - name: Audit
- run: cargo xtask ci --target crates audit
- - name: Unit Tests
- run: cargo xtask ci --target crates unit-tests
- - name: Integration Tests
- run: cargo xtask ci --target crates integration-tests
- - name: Documentation Tests
- run: cargo xtask ci --target crates doc-tests
- examples:
- runs-on: ubuntu-latest
+ run: cargo xtask check lint
+ # --------------------------------------------------------------------------------
+ - name: Typos
+ uses: tracel-ai/github-actions/check-typos@v1
+
+ documentation:
+ runs-on: ubuntu-22.04
+ needs: prepare-checks
+ strategy:
+ matrix:
+ rust: [stable]
+ include:
+ - rust: stable
+ toolchain: stable
steps:
- - uses: actions/checkout@v4
- - uses: dtolnay/rust-toolchain@master
+ - name: Setup Rust
+ uses: tracel-ai/github-actions/setup-rust@v1
with:
- components: clippy, rustfmt
- toolchain: stable
- - uses: Swatinem/rust-cache@v2
- - name: Format
- run: cargo xtask ci --target examples format
- - name: Lint
- run: cargo xtask ci --target examples lint
- - name: Unit Tests
- run: cargo xtask ci --target examples unit-tests
- - name: Integration Tests
- run: cargo xtask ci --target examples integration-tests
+ rust-toolchain: ${{ matrix.toolchain }}
+ cache-key: ${{ matrix.rust }}-linux
+ # --------------------------------------------------------------------------------
+ - name: Documentation Build
+ run: cargo xtask doc build
+ # --------------------------------------------------------------------------------
- name: Documentation Tests
- run: cargo xtask ci --target examples doc-tests
+ run: cargo xtask doc tests
+
+ linux-std-tests:
+ runs-on: ubuntu-22.04
+ needs: prepare-checks
+ strategy:
+ matrix:
+ rust: [stable, prev]
+ include:
+ - rust: stable
+ toolchain: stable
+ - rust: prev
+ toolchain: ${{ needs.prepare-checks.outputs.rust-prev-version }}
+ steps:
+ - name: Setup Rust
+ uses: tracel-ai/github-actions/setup-rust@v1
+ with:
+ rust-toolchain: ${{ matrix.toolchain }}
+ cache-key: ${{ matrix.rust }}-linux
+ # --------------------------------------------------------------------------------
+ - name: Setup Linux runner
+ uses: tracel-ai/github-actions/setup-linux@v1
+ with:
+ vulkan-sdk-version: ${{ env.VULKAN_SDK_VERSION }}
+ mesa-version: ${{ env.MESA_VERSION }}
+ mesa-ci-build-version: ${{ env.MESA_CI_BINARY_BUILD }}
+ # --------------------------------------------------------------------------------
+ - name: Tests
+ run: cargo xtask test --ci
+
+ windows-std-tests:
+ runs-on: windows-2022
+ needs: prepare-checks
+ env:
+ DISABLE_WGPU: '1'
+ # Keep the stragegy to be able to easily add new rust versions if required
+ strategy:
+ matrix:
+ rust: [stable]
+ include:
+ - rust: stable
+ toolchain: stable
+ steps:
+ - name: Setup Rust
+ uses: tracel-ai/github-actions/setup-rust@v1
+ with:
+ rust-toolchain: ${{ matrix.toolchain }}
+ cache-key: ${{ matrix.rust }}-windows
+ # --------------------------------------------------------------------------------
+ - name: Setup Windows runner
+ if: env.DISABLE_WGPU != '1'
+ uses: tracel-ai/github-actions/setup-windows@v1
+ with:
+ dxc-release: ${{ env.DXC_RELEASE }}
+ dxc-filename: ${{ env.DXC_FILENAME }}
+ mesa-version: ${{ env.MESA_VERSION }}
+ warp-version: ${{ env.WARP_VERSION }}
+ # --------------------------------------------------------------------------------
+ - name: Tests
+ run: cargo xtask test --ci
+
+ macos-std-tests:
+ runs-on: macos-14
+ needs: prepare-checks
+ # Keep the stragegy to be able to easily add new rust versions if required
+ strategy:
+ matrix:
+ rust: [stable]
+ include:
+ - rust: stable
+ toolchain: stable
+ steps:
+ - name: Setup Rust
+ uses: tracel-ai/github-actions/setup-rust@v1
+ with:
+ rust-toolchain: ${{ matrix.toolchain }}
+ cache-key: ${{ matrix.rust }}-macos
+ # --------------------------------------------------------------------------------
+ - name: Tests
+ run: cargo xtask test --ci
diff --git a/Cargo.toml b/Cargo.toml
index 92834909b..0b78e5633 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -61,12 +61,9 @@ proc-macro2 = "1.0.86"
syn = { version = "2.0.69", features = ["full", "extra-traits"] }
quote = "1.0.36"
-# xtask
-anyhow = "1.0.86"
-clap = { version = "4.5.9", features = ["derive"] }
-derive_more = { version = "0.99.18", features = ["display"], default-features = false }
-env_logger = "0.11.3"
+### For xtask crate ###
strum = {version = "0.26.3", features = ["derive"]}
+tracel-xtask = { version = "~1.0" }
portable-atomic-util = { version = "0.2.2", features = ["alloc"] } # alloc is for no_std
diff --git a/README.md b/README.md
index f7de0379c..420ddb4b5 100644
--- a/README.md
+++ b/README.md
@@ -6,8 +6,8 @@
[![Discord](https://img.shields.io/discord/1038839012602941528.svg?color=7289da&&logo=discord)](https://discord.gg/KSBSPhAUCc)
[![Current Crates.io Version](https://img.shields.io/crates/v/cubecl.svg)](https://crates.io/crates/cubecl)
+[![Minimum Supported Rust Version](https://img.shields.io/crates/msrv/cubecl)](https://crates.io/crates/burn)
[![Test Status](https://github.com/tracel-ai/cubecl/actions/workflows/ci.yml/badge.svg)](https://github.com/tracel-ai/cubecl/actions/workflows/test.yml)
-[![Rust Version](https://img.shields.io/badge/Rust-1.79.0+-blue)](https://releases.rs/docs/1.79.0)
![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)
[![NVIDIA](https://img.shields.io/badge/nvidia-cuda-84b629)](https://github.com/tracel-ai/cubecl/tree/main/crates/cubecl-cuda)
diff --git a/crates/cubecl-core/src/compute/kernel.rs b/crates/cubecl-core/src/compute/kernel.rs
index 3e3175631..32586bf5b 100644
--- a/crates/cubecl-core/src/compute/kernel.rs
+++ b/crates/cubecl-core/src/compute/kernel.rs
@@ -84,7 +84,7 @@ fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) ->
let kernel_id = kernel_id.to_string();
let mut result = String::new();
let mut depth = 0;
- let indendation = 4;
+ let indentation = 4;
let mut prev = ' ';
@@ -105,7 +105,7 @@ fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) ->
}
result.push(start);
result.push('\n');
- result.push_str(&" ".repeat(indendation * depth));
+ result.push_str(&" ".repeat(indentation * depth));
found_marker = true;
} else if c == end {
depth -= 1;
@@ -114,10 +114,10 @@ fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) ->
result.pop();
}
result.push_str(",\n");
- result.push_str(&" ".repeat(indendation * depth));
+ result.push_str(&" ".repeat(indentation * depth));
result.push(end);
} else {
- for _ in 0..(&" ".repeat(indendation * depth).len()) + 1 + indendation {
+ for _ in 0..(&" ".repeat(indentation * depth).len()) + 1 + indentation {
result.pop();
}
result.push(end);
@@ -137,7 +137,7 @@ fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) ->
}
result.push_str(",\n");
- result.push_str(&" ".repeat(indendation * depth));
+ result.push_str(&" ".repeat(indentation * depth));
continue;
}
diff --git a/crates/cubecl-core/src/frontend/element/array.rs b/crates/cubecl-core/src/frontend/element/array.rs
index d3cad4bde..8e2ba9c5f 100644
--- a/crates/cubecl-core/src/frontend/element/array.rs
+++ b/crates/cubecl-core/src/frontend/element/array.rs
@@ -207,7 +207,7 @@ impl<'a, R: Runtime> ArrayArg<'a, R> {
///
/// # Safety
///
- /// Specifying the wrong lenght may lead to out-of-bounds reads and writes.
+ /// Specifying the wrong length may lead to out-of-bounds reads and writes.
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle,
length: usize,
@@ -225,7 +225,7 @@ impl<'a, R: Runtime> ArrayHandleRef<'a, R> {
///
/// # Safety
///
- /// Specifying the wrong lenght may lead to out-of-bounds reads and writes.
+ /// Specifying the wrong length may lead to out-of-bounds reads and writes.
pub unsafe fn from_raw_parts(
handle: &'a cubecl_runtime::server::Handle,
length: usize,
diff --git a/crates/cubecl-core/src/frontend/element/base.rs b/crates/cubecl-core/src/frontend/element/base.rs
index e98911cfe..663d2c102 100644
--- a/crates/cubecl-core/src/frontend/element/base.rs
+++ b/crates/cubecl-core/src/frontend/element/base.rs
@@ -236,7 +236,7 @@ impl From> for ExpandElement {
}
impl ExpandElementTyped {
- /// Create an [ExpandElementTyped] from a value that is normaly a literal.
+ /// Create an [ExpandElementTyped] from a value that is normally a literal.
pub fn from_lit>(lit: L) -> Self {
let variable: Variable = lit.into();
let variable = T::as_elem().from_constant(variable);
diff --git a/crates/cubecl-core/src/frontend/element/float.rs b/crates/cubecl-core/src/frontend/element/float.rs
index 0163ca2bd..ebce72cd9 100644
--- a/crates/cubecl-core/src/frontend/element/float.rs
+++ b/crates/cubecl-core/src/frontend/element/float.rs
@@ -1,6 +1,8 @@
use half::{bf16, f16};
-use crate::frontend::{Ceil, Cos, Erf, Exp, Floor, Log, Log1p, Powf, Recip, Sin, Sqrt, Tanh};
+use crate::frontend::{
+ Ceil, Cos, Erf, Exp, Floor, Log, Log1p, Powf, Recip, Round, Sin, Sqrt, Tanh,
+};
use crate::frontend::{
ComptimeType, CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementBaseInit,
ExpandElementTyped, Numeric,
@@ -25,6 +27,7 @@ pub trait Float:
+ Tanh
+ Powf
+ Sqrt
+ + Round
+ Floor
+ Ceil
+ Erf
diff --git a/crates/cubecl-core/src/frontend/operation/binary.rs b/crates/cubecl-core/src/frontend/operation/binary.rs
index 7632a5e88..dabe8001c 100644
--- a/crates/cubecl-core/src/frontend/operation/binary.rs
+++ b/crates/cubecl-core/src/frontend/operation/binary.rs
@@ -192,6 +192,26 @@ pub mod bitand {
}
}
+pub mod bitor {
+ use super::*;
+
+ pub fn expand(
+ context: &mut CubeContext,
+ lhs: ExpandElementTyped,
+ rhs: ExpandElementTyped,
+ ) -> ExpandElementTyped {
+ binary_expand(context, lhs.into(), rhs.into(), Operator::BitwiseOr).into()
+ }
+
+ impl core::ops::BitOr for UInt {
+ type Output = UInt;
+
+ fn bitor(self, _rhs: Self) -> Self::Output {
+ unexpanded!()
+ }
+ }
+}
+
pub mod or {
use super::*;
diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs
index 40569e447..3f2954ed4 100644
--- a/crates/cubecl-core/src/frontend/operation/unary.rs
+++ b/crates/cubecl-core/src/frontend/operation/unary.rs
@@ -82,6 +82,16 @@ impl_unary_func!(
F32,
F64
);
+impl_unary_func!(
+ Round,
+ round,
+ __expand_round,
+ Operator::Round,
+ F16,
+ BF16,
+ F32,
+ F64
+);
impl_unary_func!(
Floor,
floor,
diff --git a/crates/cubecl-core/src/ir/operation.rs b/crates/cubecl-core/src/ir/operation.rs
index 0a22814a8..3302b0053 100644
--- a/crates/cubecl-core/src/ir/operation.rs
+++ b/crates/cubecl-core/src/ir/operation.rs
@@ -39,6 +39,7 @@ pub enum Operator {
Tanh(UnaryOperator),
Powf(BinaryOperator),
Sqrt(UnaryOperator),
+ Round(UnaryOperator),
Floor(UnaryOperator),
Ceil(UnaryOperator),
Erf(UnaryOperator),
@@ -63,6 +64,7 @@ pub enum Operator {
Max(BinaryOperator),
Min(BinaryOperator),
BitwiseAnd(BinaryOperator),
+ BitwiseOr(BinaryOperator),
BitwiseXor(BinaryOperator),
ShiftLeft(BinaryOperator),
ShiftRight(BinaryOperator),
diff --git a/crates/cubecl-core/src/ir/processing.rs b/crates/cubecl-core/src/ir/processing.rs
index 3d2ba51c0..d06d12ba4 100644
--- a/crates/cubecl-core/src/ir/processing.rs
+++ b/crates/cubecl-core/src/ir/processing.rs
@@ -76,6 +76,9 @@ impl ScopeProcessing {
Operator::Sqrt(op) => {
sanitize_constant_scalar_ref_var(&mut op.input, &op.out);
}
+ Operator::Round(op) => {
+ sanitize_constant_scalar_ref_var(&mut op.input, &op.out);
+ }
Operator::Floor(op) => {
sanitize_constant_scalar_ref_var(&mut op.input, &op.out);
}
@@ -168,6 +171,10 @@ impl ScopeProcessing {
sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out);
sanitize_constant_scalar_ref_var(&mut op.rhs, &op.out);
}
+ Operator::BitwiseOr(op) => {
+ sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out);
+ sanitize_constant_scalar_ref_var(&mut op.rhs, &op.out);
+ }
Operator::BitwiseXor(op) => {
sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out);
sanitize_constant_scalar_ref_var(&mut op.rhs, &op.out);
diff --git a/crates/cubecl-core/src/ir/vectorization.rs b/crates/cubecl-core/src/ir/vectorization.rs
index eb7f4396c..bed3c3405 100644
--- a/crates/cubecl-core/src/ir/vectorization.rs
+++ b/crates/cubecl-core/src/ir/vectorization.rs
@@ -39,6 +39,7 @@ impl Operator {
Operator::Sub(op) => Operator::Sub(op.vectorize(vectorization)),
Operator::Mul(op) => Operator::Mul(op.vectorize(vectorization)),
Operator::Div(op) => Operator::Div(op.vectorize(vectorization)),
+ Operator::Round(op) => Operator::Round(op.vectorize(vectorization)),
Operator::Floor(op) => Operator::Floor(op.vectorize(vectorization)),
Operator::Ceil(op) => Operator::Ceil(op.vectorize(vectorization)),
Operator::Abs(op) => Operator::Abs(op.vectorize(vectorization)),
@@ -76,6 +77,7 @@ impl Operator {
Operator::And(op) => Operator::And(op.vectorize(vectorization)),
Operator::Or(op) => Operator::Or(op.vectorize(vectorization)),
Operator::Not(op) => Operator::Not(op.vectorize(vectorization)),
+ Operator::BitwiseOr(op) => Operator::BitwiseOr(op.vectorize(vectorization)),
Operator::BitwiseAnd(op) => Operator::BitwiseAnd(op.vectorize(vectorization)),
Operator::BitwiseXor(op) => Operator::BitwiseXor(op.vectorize(vectorization)),
Operator::ShiftLeft(op) => Operator::ShiftLeft(op.vectorize(vectorization)),
diff --git a/crates/cubecl-core/src/lib.rs b/crates/cubecl-core/src/lib.rs
index 59cb1a312..d8918dbf1 100644
--- a/crates/cubecl-core/src/lib.rs
+++ b/crates/cubecl-core/src/lib.rs
@@ -70,22 +70,38 @@ pub fn tensor_vectorization_factor(
strides: &[usize],
dim: usize,
) -> u8 {
- if let Some(val) = strides.get(dim) {
- if *val != 1 {
- return 1;
+ match strides.get(dim) {
+ Some(val) => {
+ if *val != 1 {
+ return 1;
+ }
}
- } else {
- return 1;
+ None => return 1,
}
- let dim_size = match shape.get(dim) {
+ let shape_check = match shape.get(dim) {
Some(val) => val,
None => return 1,
};
+ let stride_check = if let Some(dim) = dim.checked_sub(1) {
+ strides.get(dim)
+ } else {
+ None
+ };
+
for factor in factors {
- if dim_size % *factor as usize == 0 {
- return *factor;
+ let factor = *factor as usize;
+
+ if shape_check % factor == 0 {
+ match stride_check {
+ Some(check) => {
+ if check % factor == 0 {
+ return factor as u8;
+ }
+ }
+ None => return factor as u8,
+ }
}
}
diff --git a/crates/cubecl-core/src/runtime.rs b/crates/cubecl-core/src/runtime.rs
index b9a42cbf5..0841163cf 100644
--- a/crates/cubecl-core/src/runtime.rs
+++ b/crates/cubecl-core/src/runtime.rs
@@ -20,6 +20,7 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
Kernel = Box,
DispatchOptions = CubeCount,
FeatureSet = FeatureSet,
+ Properties = Properties,
>;
/// The channel used to communicate with the compute server.
type Channel: ComputeChannel;
@@ -44,6 +45,13 @@ pub struct FeatureSet {
set: alloc::collections::BTreeSet,
}
+/// The [runtime](Runtime) properties.
+#[derive(Default, Debug)]
+pub struct Properties {
+ /// The memory offset alignment in bytes.
+ pub memory_offset_alignment: u32,
+}
+
impl FeatureSet {
pub fn new(features: &[Feature]) -> Self {
let mut this = Self::default();
diff --git a/crates/cubecl-core/tests/frontend/ops.rs b/crates/cubecl-core/tests/frontend/ops.rs
index d5c9a63d1..064e8b9f6 100644
--- a/crates/cubecl-core/tests/frontend/ops.rs
+++ b/crates/cubecl-core/tests/frontend/ops.rs
@@ -66,6 +66,11 @@ pub fn sqrt_op(a: F) -> F {
F::sqrt(a)
}
+#[cube]
+pub fn round_op(a: F) -> F {
+ F::round(a)
+}
+
#[cube]
pub fn floor_op(a: F) -> F {
F::floor(a)
@@ -156,6 +161,11 @@ pub fn bitand_op(a: UInt, b: UInt) -> UInt {
a & b
}
+#[cube]
+pub fn bitor_op(a: UInt, b: UInt) -> UInt {
+ a | b
+}
+
#[cube]
pub fn bitxor_op(a: UInt, b: UInt) -> UInt {
a ^ b
@@ -286,6 +296,7 @@ mod tests {
unary_test!(cube_can_sqrt, sqrt_op::__expand::, "Sqrt");
unary_test!(cube_can_erf, erf_op::__expand::, "Erf");
unary_test!(cube_can_recip, recip_op::__expand::, "Recip");
+ unary_test!(cube_can_round, round_op::__expand::, "Round");
unary_test!(cube_can_floor, floor_op::__expand::, "Floor");
unary_test!(cube_can_ceil, ceil_op::__expand::, "Ceil");
binary_test!(cube_can_eq, equal_op::__expand::, "Equal", ref_ops_cmp);
@@ -343,6 +354,7 @@ mod tests {
binary_boolean_test!(cube_can_and, and_op::__expand, "And");
binary_boolean_test!(cube_can_or, or_op::__expand, "Or");
binary_uint_test!(cube_can_bitand, bitand_op::__expand, "BitwiseAnd");
+ binary_uint_test!(cube_can_bitor, bitor_op::__expand, "BitwiseOr");
binary_uint_test!(cube_can_bitxor, bitxor_op::__expand, "BitwiseXor");
binary_uint_test!(cube_can_shl, shl_op::__expand, "ShiftLeft");
binary_uint_test!(cube_can_shr, shr_op::__expand, "ShiftRight");
diff --git a/crates/cubecl-cuda/src/compiler/base.rs b/crates/cubecl-cuda/src/compiler/base.rs
index 42c68b1dd..5787fe5d3 100644
--- a/crates/cubecl-cuda/src/compiler/base.rs
+++ b/crates/cubecl-cuda/src/compiler/base.rs
@@ -451,6 +451,9 @@ impl CudaCompiler {
gpu::Operator::NotEqual(op) => {
instructions.push(Instruction::NotEqual(self.compile_binary(op)))
}
+ gpu::Operator::BitwiseOr(op) => {
+ instructions.push(Instruction::BitwiseOr(self.compile_binary(op)))
+ }
gpu::Operator::BitwiseAnd(op) => {
instructions.push(Instruction::BitwiseAnd(self.compile_binary(op)))
}
@@ -487,6 +490,9 @@ impl CudaCompiler {
out: self.compile_variable(op.out),
}))
}
+ gpu::Operator::Round(op) => {
+ instructions.push(Instruction::Round(self.compile_unary(op)))
+ }
gpu::Operator::Floor(op) => {
instructions.push(Instruction::Floor(self.compile_unary(op)))
}
diff --git a/crates/cubecl-cuda/src/compiler/binary.rs b/crates/cubecl-cuda/src/compiler/binary.rs
index a6016d740..75e7e8076 100644
--- a/crates/cubecl-cuda/src/compiler/binary.rs
+++ b/crates/cubecl-cuda/src/compiler/binary.rs
@@ -100,6 +100,7 @@ operator!(Greater, ">");
operator!(GreaterEqual, ">=");
operator!(ShiftLeft, "<<");
operator!(ShiftRight, ">>");
+operator!(BitwiseOr, "|");
operator!(BitwiseAnd, "&");
operator!(BitwiseXor, "^");
operator!(Or, "||");
@@ -128,29 +129,11 @@ impl Binary for IndexAssign {
let item_rhs = rhs.item();
let format_vec = |f: &mut Formatter<'_>, cast: bool| {
- let is_vec_native = item_out.is_vec_native();
f.write_str("{\n")?;
let var = "broadcasted";
f.write_fmt(format_args!("{item_out} {var};\n"))?;
for i in 0..item_out.vectorization {
- if is_vec_native {
- let char = match i {
- 0 => 'x',
- 1 => 'y',
- 2 => 'z',
- 3 => 'w',
- _ => panic!("Invalid"),
- };
- if cast {
- f.write_fmt(format_args!(
- "{var}.{char} = {}({});\n",
- item_out.elem,
- rhs.index(i)
- ))?;
- } else {
- f.write_fmt(format_args!("{var}.{char} = {};\n", rhs.index(i)))?;
- }
- } else if cast {
+ if cast {
f.write_fmt(format_args!(
"{var}.i_{i} = {}({});\n",
item_out.elem,
@@ -254,29 +237,14 @@ impl Binary for Index {
let item_lhs = lhs.item();
let format_vec = |f: &mut Formatter<'_>| {
- let is_vec_native = item_out.is_vec_native();
f.write_str("{\n")?;
let var = "broadcasted";
f.write_fmt(format_args!("{item_out} {var};\n"))?;
for i in 0..item_out.vectorization {
- if is_vec_native {
- let char = match i {
- 0 => 'x',
- 1 => 'y',
- 2 => 'z',
- 3 => 'w',
- _ => panic!("Invalid"),
- };
- f.write_fmt(format_args!(
- "{var}.{char} = {}({lhs}[{rhs}].i_{i});\n",
- item_out.elem
- ))?;
- } else {
- f.write_fmt(format_args!(
- "{var}.i_{i} = {}({lhs}[{rhs}].i_{i});\n",
- item_out.elem
- ))?;
- }
+ f.write_fmt(format_args!(
+ "{var}.i_{i} = {}({lhs}[{rhs}].i_{i});\n",
+ item_out.elem
+ ))?;
}
f.write_fmt(format_args!("{out} = {var};\n"))?;
f.write_str("}")?;
diff --git a/crates/cubecl-cuda/src/compiler/element.rs b/crates/cubecl-cuda/src/compiler/element.rs
index 689e150db..c24e3f2ad 100644
--- a/crates/cubecl-cuda/src/compiler/element.rs
+++ b/crates/cubecl-cuda/src/compiler/element.rs
@@ -43,11 +43,6 @@ impl Display for Item {
return f.write_fmt(format_args!("{}", self.elem));
}
- if self.is_vec_native() {
- let elem = self.optimized().elem;
- return f.write_fmt(format_args!("{elem}"));
- }
-
return f.write_fmt(format_args!("{}_{}", self.elem, self.vectorization));
}
}
@@ -397,12 +392,15 @@ impl Display for IndexedVariable {
if self.optimized {
let item = self.var.item();
f.write_fmt(format_args!(
- "(reinterpret_cast<{item}*>(&{var}))->i_{}",
+ "(reinterpret_cast<{item}&>({var})).i_{}",
self.index
))
} else {
f.write_fmt(format_args!("{var}.i_{}", self.index))
}
+ } else if self.optimized {
+ let item = self.var.item();
+ f.write_fmt(format_args!("reinterpret_cast<{item}&>({var})"))
} else {
f.write_fmt(format_args!("{var}"))
}
@@ -438,16 +436,6 @@ impl Item {
matches!(self.elem, Elem::F162 | Elem::BF162)
}
- pub fn is_vec_native(&self) -> bool {
- match &self.elem {
- Elem::F16 => self.vectorization == 2,
- Elem::BF16 => self.vectorization == 2,
- Elem::F162 => self.vectorization == 1,
- Elem::BF162 => self.vectorization == 1,
- _ => false,
- }
- }
-
pub fn optimized(&self) -> Item {
if self.vectorization == 1 {
return *self;
diff --git a/crates/cubecl-cuda/src/compiler/instruction.rs b/crates/cubecl-cuda/src/compiler/instruction.rs
index 97b35bba5..7c4bbda72 100644
--- a/crates/cubecl-cuda/src/compiler/instruction.rs
+++ b/crates/cubecl-cuda/src/compiler/instruction.rs
@@ -89,6 +89,7 @@ pub enum Instruction {
LowerEqual(BinaryInstruction),
GreaterEqual(BinaryInstruction),
Erf(UnaryInstruction),
+ BitwiseOr(BinaryInstruction),
BitwiseAnd(BinaryInstruction),
BitwiseXor(BinaryInstruction),
ShiftLeft(BinaryInstruction),
@@ -114,6 +115,8 @@ pub enum Instruction {
out: Variable,
},
SyncThreads,
+ ThreadFence,
+ Round(UnaryInstruction),
Ceil(UnaryInstruction),
Floor(UnaryInstruction),
Wrap(WarpInstruction),
@@ -168,6 +171,7 @@ impl Display for Instruction {
Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
+ Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
@@ -264,6 +268,8 @@ for (uint {i} = {start}; {i} < {end}; {increment}) {{
out,
} => Clamp::format(f, input, min_value, max_value, out),
Instruction::SyncThreads => f.write_str("__syncthreads();\n"),
+ Instruction::ThreadFence => f.write_str("__threadfence();\n"),
+ Instruction::Round(it) => Round::format(f, &it.input, &it.out),
Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
Instruction::SliceLength { input, out } => {
diff --git a/crates/cubecl-cuda/src/compiler/kernel.rs b/crates/cubecl-cuda/src/compiler/kernel.rs
index a32f8d19d..6fcfeb2d5 100644
--- a/crates/cubecl-cuda/src/compiler/kernel.rs
+++ b/crates/cubecl-cuda/src/compiler/kernel.rs
@@ -87,10 +87,6 @@ impl Display for ComputeKernel {
f.write_str("typedef unsigned int uint;\n")?;
for item in self.items.iter() {
- if item.is_vec_native() {
- continue;
- }
-
let elem = item.elem;
let size = item.vectorization;
let alignment = elem.size() * size;
diff --git a/crates/cubecl-cuda/src/compiler/unary.rs b/crates/cubecl-cuda/src/compiler/unary.rs
index 6529701a4..9f45c6c5e 100644
--- a/crates/cubecl-cuda/src/compiler/unary.rs
+++ b/crates/cubecl-cuda/src/compiler/unary.rs
@@ -82,6 +82,7 @@ function!(Exp, "exp");
function!(Erf, "erf");
function!(Ceil, "ceil");
function!(Floor, "floor");
+function!(Round, "rint");
pub struct Not;
diff --git a/crates/cubecl-cuda/src/compute/server.rs b/crates/cubecl-cuda/src/compute/server.rs
index 6f824e200..afb4cfab4 100644
--- a/crates/cubecl-cuda/src/compute/server.rs
+++ b/crates/cubecl-cuda/src/compute/server.rs
@@ -6,8 +6,8 @@ use cubecl_common::reader::{reader_from_concrete, Reader};
use cubecl_common::sync_type::SyncType;
use cubecl_core::compute::DebugInformation;
use cubecl_core::ir::CubeDim;
-use cubecl_core::FeatureSet;
use cubecl_core::{prelude::*, KernelId};
+use cubecl_core::{FeatureSet, Properties};
use cubecl_runtime::debug::DebugLogger;
use cubecl_runtime::ExecutionMode;
use cubecl_runtime::{
@@ -25,8 +25,6 @@ use std::path::PathBuf;
pub struct CudaServer> {
state: CudaServerState,
logger: DebugLogger,
- pub(crate) archs: Vec,
- pub(crate) minimum_arch_version: i32,
}
pub(crate) enum CudaServerState> {
@@ -51,6 +49,7 @@ pub(crate) struct CudaContext> {
stream: cudarc::driver::sys::CUstream,
memory_management: MM,
module_names: HashMap,
+ pub(crate) arch: u32,
}
#[derive(Debug)]
@@ -63,9 +62,18 @@ struct CompiledKernel {
unsafe impl> Send for CudaServer {}
impl> CudaServer {
+ pub(crate) fn arch_version(&mut self) -> u32 {
+ let ctx = self.get_context();
+ ctx.arch
+ }
+
fn read_sync(&mut self, binding: server::Binding) -> Vec {
let ctx = self.get_context();
- let resource = ctx.memory_management.get_resource(binding.memory);
+ let resource = ctx.memory_management.get_resource(
+ binding.memory,
+ binding.offset_start,
+ binding.offset_end,
+ );
// TODO: Check if it is possible to make this faster
let mut data = vec![0; resource.size() as usize];
@@ -83,6 +91,7 @@ impl> ComputeServer for CudaServer {
type Storage = CudaStorage;
type MemoryManagement = MM;
type FeatureSet = FeatureSet;
+ type Properties = Properties;
fn read(&mut self, binding: server::Binding) -> Reader {
reader_from_concrete(self.read_sync(binding))
@@ -92,9 +101,12 @@ impl> ComputeServer for CudaServer {
let handle = self.empty(data.len());
let ctx = self.get_context();
- let resource = ctx
- .memory_management
- .get_resource(handle.clone().binding().memory);
+ let binding = handle.clone().binding();
+ let resource = ctx.memory_management.get_resource(
+ binding.memory,
+ binding.offset_start,
+ binding.offset_end,
+ );
unsafe {
cudarc::driver::result::memcpy_htod_async(resource.ptr, data, ctx.stream).unwrap();
@@ -106,7 +118,7 @@ impl> ComputeServer for CudaServer {
fn empty(&mut self, size: usize) -> server::Handle {
let ctx = self.get_context();
let handle = ctx.memory_management.reserve(size, &[]);
- server::Handle::new(handle)
+ server::Handle::new(handle, None, None)
}
unsafe fn execute(
@@ -116,8 +128,6 @@ impl> ComputeServer for CudaServer {
bindings: Vec>,
mode: ExecutionMode,
) {
- let arch = self.minimum_arch_version;
-
let mut kernel_id = kernel.id();
kernel_id.mode(mode);
@@ -140,12 +150,18 @@ impl> ComputeServer for CudaServer {
let (ctx, logger) = self.get_context_with_logger();
if !ctx.module_names.contains_key(&kernel_id) {
- ctx.compile_kernel(&kernel_id, kernel, arch, logger, mode);
+ ctx.compile_kernel(&kernel_id, kernel, logger, mode);
}
let resources = bindings
.into_iter()
- .map(|binding| ctx.memory_management.get_resource(binding.memory))
+ .map(|binding| {
+ ctx.memory_management.get_resource(
+ binding.memory,
+ binding.offset_start,
+ binding.offset_end,
+ )
+ })
.collect::>();
ctx.execute_task(kernel_id, count, resources);
@@ -168,7 +184,8 @@ impl> ComputeServer for CudaServer {
binding: server::Binding,
) -> ::Resource {
let ctx = self.get_context();
- ctx.memory_management.get_resource(binding.memory)
+ ctx.memory_management
+ .get_resource(binding.memory, binding.offset_start, binding.offset_end)
}
}
@@ -177,12 +194,14 @@ impl> CudaContext {
memory_management: MM,
stream: cudarc::driver::sys::CUstream,
context: *mut CUctx_st,
+ arch: u32,
) -> Self {
Self {
context,
memory_management,
module_names: HashMap::new(),
stream,
+ arch,
}
}
@@ -197,7 +216,6 @@ impl> CudaContext {
&mut self,
kernel_id: &KernelId,
kernel: Box,
- arch: i32,
logger: &mut DebugLogger,
mode: ExecutionMode,
) {
@@ -213,7 +231,7 @@ impl> CudaContext {
let shared_mem_bytes = kernel_compiled.shared_mem_bytes;
let cube_dim = kernel_compiled.cube_dim;
- let arch = format!("--gpu-architecture=sm_{}", arch);
+ let arch = format!("--gpu-architecture=sm_{}", self.arch);
let include_path = include_path();
let include_option = format!("--include-path={}", include_path.to_str().unwrap());
@@ -286,26 +304,12 @@ impl> CudaContext {
impl> CudaServer {
/// Create a new cuda server.
pub(crate) fn new(index: usize, init: Box CudaContext>) -> Self {
- let archs = unsafe {
- let mut num_supported_arg: core::ffi::c_int = 0;
- cudarc::nvrtc::sys::lib()
- .nvrtcGetNumSupportedArchs(core::ptr::from_mut(&mut num_supported_arg));
-
- let mut archs: Vec = vec![0; num_supported_arg as usize];
- cudarc::nvrtc::sys::lib().nvrtcGetSupportedArchs(core::ptr::from_mut(&mut archs[0]));
- archs
- };
-
- let minimum_arch_version = archs[0];
-
Self {
state: CudaServerState::Uninitialized {
device_index: index,
init,
},
logger: DebugLogger::new(),
- archs,
- minimum_arch_version,
}
}
@@ -316,6 +320,7 @@ impl> CudaServer {
fn get_context_with_logger(&mut self) -> (&mut CudaContext, &mut DebugLogger) {
if let CudaServerState::Uninitialized { device_index, init } = &self.state {
let ctx = init(*device_index);
+
self.state = CudaServerState::Initialized { ctx };
}
if let CudaServerState::Initialized { ctx } = &mut self.state {
diff --git a/crates/cubecl-cuda/src/runtime.rs b/crates/cubecl-cuda/src/runtime.rs
index 4c5caa788..5895dbdd1 100644
--- a/crates/cubecl-cuda/src/runtime.rs
+++ b/crates/cubecl-cuda/src/runtime.rs
@@ -1,6 +1,6 @@
use cubecl_core::{
ir::{Elem, FloatKind},
- Feature, FeatureSet, Runtime,
+ Feature, FeatureSet, Properties, Runtime,
};
use cubecl_runtime::{
channel::MutexComputeChannel,
@@ -8,7 +8,6 @@ use cubecl_runtime::{
memory_management::dynamic::{DynamicMemoryManagement, DynamicMemoryManagementOptions},
ComputeRuntime,
};
-use std::sync::Arc;
use crate::{
compiler::CudaCompiler,
@@ -35,6 +34,11 @@ impl Runtime for CudaRuntime {
fn init(index: usize) -> CudaContext> {
cudarc::driver::result::init().unwrap();
let device_ptr = cudarc::driver::result::device::get(index as i32).unwrap();
+ let arch = unsafe {
+ let major = cudarc::driver::result::device::get_attribute(device_ptr, cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR).unwrap();
+ let minor = cudarc::driver::result::device::get_attribute(device_ptr, cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR).unwrap();
+ major * 10 + minor
+ } as u32;
let ctx = unsafe {
let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap();
@@ -49,20 +53,21 @@ impl Runtime for CudaRuntime {
let storage = CudaStorage::new(stream);
let options = DynamicMemoryManagementOptions::preset(2048 + 512 * 1024 * 1024, 32);
let memory_management = DynamicMemoryManagement::new(storage, options);
- CudaContext::new(memory_management, stream, ctx)
+ CudaContext::new(memory_management, stream, ctx, arch)
}
RUNTIME.client(device, move || {
let mut server = CudaServer::new(device.index, Box::new(init));
let mut features = FeatureSet::new(&[Feature::Subcube]);
- if let Some(wmma_minimum_version) = register_wmma_features(&mut features, &server.archs)
- {
- server.minimum_arch_version =
- i32::max(server.minimum_arch_version, wmma_minimum_version);
- }
-
- ComputeClient::new(MutexComputeChannel::new(server), Arc::new(features))
+ register_wmma_features(&mut features, server.arch_version());
+ ComputeClient::new(
+ MutexComputeChannel::new(server),
+ features,
+ Properties {
+ memory_offset_alignment: 4,
+ },
+ )
})
}
@@ -75,15 +80,12 @@ impl Runtime for CudaRuntime {
}
}
-fn register_wmma_features(features: &mut FeatureSet, archs: &[i32]) -> Option {
+fn register_wmma_features(features: &mut FeatureSet, arch: u32) {
let wmma_minimum_version = 70;
let mut wmma = false;
- for arch in archs {
- if *arch >= wmma_minimum_version {
- wmma = true;
- break;
- }
+ if arch >= wmma_minimum_version {
+ wmma = true;
}
if wmma {
@@ -130,8 +132,5 @@ fn register_wmma_features(features: &mut FeatureSet, archs: &[i32]) -> Option(
#[derive(Debug)]
pub enum UnavailabilityReason {
NotMultipleOf4, // TODO: Support that case.
- HiglyPermutatedInput,
+ HighlyPermutatedInput,
ShapeMemoryLimitBusted,
InvalidConfig(String),
CmmaInstructionsUnsupported,
diff --git a/crates/cubecl-linalg/src/tensor/contiguous.rs b/crates/cubecl-linalg/src/tensor/contiguous.rs
index e26b3afa0..081f9d5d7 100644
--- a/crates/cubecl-linalg/src/tensor/contiguous.rs
+++ b/crates/cubecl-linalg/src/tensor/contiguous.rs
@@ -28,7 +28,7 @@ pub fn index_offset_with_layout(
offset / vectorization_factor_runtime
}
-#[cube(launch)]
+#[cube(launch_unchecked)]
fn into_contiguous_kernel(
input: &Tensor,
output: &mut Tensor,
@@ -69,14 +69,16 @@ pub fn into_contiguous(
let handle = client.empty(num_elems * E::as_elem().size());
let output = TensorHandle::new_contiguous(input.shape.to_vec(), handle);
- into_contiguous_kernel::launch::(
- client,
- cube_count,
- cube_dim,
- input.as_tensor_arg(vectorization_factor),
- output.as_ref().as_tensor_arg(vectorization_factor),
- Some(UInt::new(rank as u32)),
- );
+ unsafe {
+ into_contiguous_kernel::launch_unchecked::(
+ client,
+ cube_count,
+ cube_dim,
+ input.as_tensor_arg(vectorization_factor),
+ output.as_ref().as_tensor_arg(vectorization_factor),
+ Some(UInt::new(rank as u32)),
+ );
+ }
output
}
diff --git a/crates/cubecl-macros/src/codegen_function/operation.rs b/crates/cubecl-macros/src/codegen_function/operation.rs
index ede4aeeb0..18b6ee804 100644
--- a/crates/cubecl-macros/src/codegen_function/operation.rs
+++ b/crates/cubecl-macros/src/codegen_function/operation.rs
@@ -206,6 +206,14 @@ pub(crate) fn codegen_binary(
cubecl::frontend::or::expand(context, _lhs, _rhs)
}
},
+ syn::BinOp::BitOr(_) => quote::quote! {
+ {
+
+ let _lhs = #lhs;
+ let _rhs = #rhs;
+ cubecl::frontend::bitor::expand(context, _lhs, _rhs)
+ }
+ },
syn::BinOp::BitAnd(_) => quote::quote! {
{
diff --git a/crates/cubecl-runtime/src/client.rs b/crates/cubecl-runtime/src/client.rs
index 25cf9b7b7..cce4ced15 100644
--- a/crates/cubecl-runtime/src/client.rs
+++ b/crates/cubecl-runtime/src/client.rs
@@ -14,7 +14,7 @@ pub use cubecl_common::sync_type::SyncType;
#[derive(Debug)]
pub struct ComputeClient {
channel: Channel,
- features: Arc,
+ settings: Arc<(Server::FeatureSet, Server::Properties)>,
}
impl Clone for ComputeClient
@@ -25,7 +25,7 @@ where
fn clone(&self) -> Self {
Self {
channel: self.channel.clone(),
- features: self.features.clone(),
+ settings: self.settings.clone(),
}
}
}
@@ -36,8 +36,15 @@ where
Channel: ComputeChannel,
{
/// Create a new client.
- pub fn new(channel: Channel, features: Arc) -> Self {
- Self { channel, features }
+ pub fn new(
+ channel: Channel,
+ features: Server::FeatureSet,
+ properties: Server::Properties,
+ ) -> Self {
+ Self {
+ channel,
+ settings: Arc::new((features, properties)),
+ }
}
/// Given a binding, returns owned resource as bytes.
@@ -106,6 +113,11 @@ where
/// Get the features supported by the compute server.
pub fn features(&self) -> &Server::FeatureSet {
- self.features.as_ref()
+ &self.settings.as_ref().0
+ }
+
+ /// Get the properties supported by the compute server.
+ pub fn properties(&self) -> &Server::Properties {
+ &self.settings.as_ref().1
}
}
diff --git a/crates/cubecl-runtime/src/memory_management/base.rs b/crates/cubecl-runtime/src/memory_management/base.rs
index deeb4b239..3f763290a 100644
--- a/crates/cubecl-runtime/src/memory_management/base.rs
+++ b/crates/cubecl-runtime/src/memory_management/base.rs
@@ -27,8 +27,21 @@ pub trait MemoryManagement: Send + core::fmt::Debug {
fn get(&mut self, binding: Self::Binding) -> StorageHandle;
/// Returns the resource from the storage at the specified handle
- fn get_resource(&mut self, binding: Self::Binding) -> Storage::Resource {
+ fn get_resource(
+ &mut self,
+ binding: Self::Binding,
+ offset_start: Option,
+ offset_end: Option,
+ ) -> Storage::Resource {
let handle = self.get(binding);
+ let handle = match offset_start {
+ Some(offset) => handle.offset_start(offset),
+ None => handle,
+ };
+ let handle = match offset_end {
+ Some(offset) => handle.offset_end(offset),
+ None => handle,
+ };
self.storage().get(&handle)
}
diff --git a/crates/cubecl-runtime/src/server.rs b/crates/cubecl-runtime/src/server.rs
index dbc104f5e..1edee0e78 100644
--- a/crates/cubecl-runtime/src/server.rs
+++ b/crates/cubecl-runtime/src/server.rs
@@ -25,6 +25,8 @@ where
type MemoryManagement: MemoryManagement;
/// Features supported by the compute server.
type FeatureSet: Send + Sync;
+ /// Properties of the compute server.
+ type Properties: Send + Sync;
/// Given a handle, returns the owned resource as bytes.
fn read(&mut self, binding: Binding) -> Reader;
@@ -66,6 +68,33 @@ where
pub struct Handle {
/// Memory handle.
pub memory: >::Handle,
+ /// Memory offset in bytes.
+ pub offset_start: Option,
+ /// Memory offset in bytes.
+ pub offset_end: Option,
+}
+
+impl Handle {
+ /// Add to the current offset in bytes.
+ pub fn offset_start(mut self, offset: usize) -> Self {
+ if let Some(val) = &mut self.offset_start {
+ *val += offset;
+ } else {
+ self.offset_start = Some(offset);
+ }
+
+ self
+ }
+ /// Add to the current offset in bytes.
+ pub fn offset_end(mut self, offset: usize) -> Self {
+ if let Some(val) = &mut self.offset_end {
+ *val += offset;
+ } else {
+ self.offset_end = Some(offset);
+ }
+
+ self
+ }
}
/// Binding of a [tensor handle](Handle) to execute a kernel.
@@ -73,6 +102,10 @@ pub struct Handle {
pub struct Binding {
/// Memory binding.
pub memory: >::Binding,
+ /// Memory offset in bytes.
+ pub offset_start: Option,
+ /// Memory offset in bytes.
+ pub offset_end: Option,
}
impl Handle {
@@ -87,6 +120,8 @@ impl Handle {
pub fn binding(self) -> Binding {
Binding {
memory: MemoryHandle::binding(self.memory),
+ offset_start: self.offset_start,
+ offset_end: self.offset_end,
}
}
}
@@ -95,6 +130,8 @@ impl Clone for Handle {
fn clone(&self) -> Self {
Self {
memory: self.memory.clone(),
+ offset_start: self.offset_start,
+ offset_end: self.offset_end,
}
}
}
@@ -103,6 +140,8 @@ impl Clone for Binding {
fn clone(&self) -> Self {
Self {
memory: self.memory.clone(),
+ offset_start: self.offset_start,
+ offset_end: self.offset_end,
}
}
}
diff --git a/crates/cubecl-runtime/src/storage/base.rs b/crates/cubecl-runtime/src/storage/base.rs
index 968aadbcd..bae7fcede 100644
--- a/crates/cubecl-runtime/src/storage/base.rs
+++ b/crates/cubecl-runtime/src/storage/base.rs
@@ -42,6 +42,44 @@ impl StorageHandle {
StorageUtilization::Slice { offset, .. } => offset,
}
}
+
+ /// Increase the current offset with the given value in bytes.
+ pub fn offset_start(&self, offset_bytes: usize) -> Self {
+ let utilization = match self.utilization {
+ StorageUtilization::Full(size) => StorageUtilization::Slice {
+ offset: offset_bytes,
+ size: size - offset_bytes,
+ },
+ StorageUtilization::Slice { offset, size } => StorageUtilization::Slice {
+ offset: offset + offset_bytes,
+ size: size - offset_bytes,
+ },
+ };
+
+ Self {
+ id: self.id,
+ utilization,
+ }
+ }
+
+ /// Reduce the size of the memory handle..
+ pub fn offset_end(&self, offset_bytes: usize) -> Self {
+ let utilization = match self.utilization {
+ StorageUtilization::Full(size) => StorageUtilization::Slice {
+ offset: 0,
+ size: size - offset_bytes,
+ },
+ StorageUtilization::Slice { offset, size } => StorageUtilization::Slice {
+ offset,
+ size: size - offset_bytes,
+ },
+ };
+
+ Self {
+ id: self.id,
+ utilization,
+ }
+ }
}
/// Storage types are responsible for allocating and deallocating memory.
diff --git a/crates/cubecl-runtime/tests/dummy/compute.rs b/crates/cubecl-runtime/tests/dummy/compute.rs
index be071d10e..7654a43b6 100644
--- a/crates/cubecl-runtime/tests/dummy/compute.rs
+++ b/crates/cubecl-runtime/tests/dummy/compute.rs
@@ -1,5 +1,3 @@
-use std::sync::Arc;
-
use super::DummyServer;
use cubecl_runtime::channel::MutexComputeChannel;
use cubecl_runtime::client::ComputeClient;
@@ -35,7 +33,7 @@ pub fn init_client() -> ComputeClient DummyClient {
diff --git a/crates/cubecl-runtime/tests/dummy/server.rs b/crates/cubecl-runtime/tests/dummy/server.rs
index 8c6ee178a..33bdf8bc3 100644
--- a/crates/cubecl-runtime/tests/dummy/server.rs
+++ b/crates/cubecl-runtime/tests/dummy/server.rs
@@ -28,6 +28,7 @@ where
type Storage = BytesStorage;
type MemoryManagement = MM;
type FeatureSet = ();
+ type Properties = ();
fn read(&mut self, binding: Binding) -> cubecl_common::reader::Reader {
let bytes_handle = self.memory_management.get(binding.memory);
@@ -54,7 +55,7 @@ where
}
fn empty(&mut self, size: usize) -> Handle {
- Handle::new(self.memory_management.reserve(size, &[]))
+ Handle::new(self.memory_management.reserve(size, &[]), None, None)
}
unsafe fn execute(
diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
index b2c695b27..1a03ada62 100644
--- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
+++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
@@ -569,6 +569,10 @@ impl WgslCompiler {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
},
+ cube::Operator::Round(op) => wgsl::Instruction::Round {
+ input: self.compile_variable(op.input),
+ out: self.compile_variable(op.out),
+ },
cube::Operator::Floor(op) => wgsl::Instruction::Floor {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
@@ -649,6 +653,11 @@ impl WgslCompiler {
input: self.compile_variable(op.input),
out: self.compile_variable(op.out),
},
+ cube::Operator::BitwiseOr(op) => wgsl::Instruction::BitwiseOr {
+ lhs: self.compile_variable(op.lhs),
+ rhs: self.compile_variable(op.rhs),
+ out: self.compile_variable(op.out),
+ },
cube::Operator::BitwiseAnd(op) => wgsl::Instruction::BitwiseAnd {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),
diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs
index 39a47ccec..84112445b 100644
--- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs
+++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs
@@ -201,6 +201,11 @@ pub enum Instruction {
Loop {
instructions: Vec,
},
+ BitwiseOr {
+ lhs: Variable,
+ rhs: Variable,
+ out: Variable,
+ },
BitwiseAnd {
lhs: Variable,
rhs: Variable,
@@ -221,6 +226,10 @@ pub enum Instruction {
rhs: Variable,
out: Variable,
},
+ Round {
+ input: Variable,
+ out: Variable,
+ },
Floor {
input: Variable,
out: Variable,
@@ -591,6 +600,9 @@ for (var {i}: u32 = {start}; {i} < {end}; {increment}) {{
}
f.write_str("}\n")
}
+ Instruction::BitwiseOr { lhs, rhs, out } => {
+ f.write_fmt(format_args!("{out} = {lhs} | {rhs};\n"))
+ }
Instruction::BitwiseAnd { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} & {rhs};\n"))
}
@@ -603,6 +615,9 @@ for (var {i}: u32 = {start}; {i} < {end}; {increment}) {{
Instruction::ShiftRight { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} >> {rhs};\n"))
}
+ Instruction::Round { input, out } => {
+ f.write_fmt(format_args!("{out} = round({input});\n"))
+ }
Instruction::Floor { input, out } => {
f.write_fmt(format_args!("{out} = floor({input});\n"))
}
diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs
index a5373b4fc..b8ea2e022 100644
--- a/crates/cubecl-wgpu/src/compute/server.rs
+++ b/crates/cubecl-wgpu/src/compute/server.rs
@@ -3,7 +3,9 @@ use std::num::NonZero;
use super::WgpuStorage;
use alloc::{borrow::Cow, sync::Arc};
use cubecl_common::{reader::Reader, sync_type::SyncType};
-use cubecl_core::{compute::DebugInformation, prelude::*, server::Handle, FeatureSet, KernelId};
+use cubecl_core::{
+ compute::DebugInformation, prelude::*, server::Handle, FeatureSet, KernelId, Properties,
+};
use cubecl_runtime::{
debug::DebugLogger,
memory_management::{MemoryHandle, MemoryManagement},
@@ -129,6 +131,7 @@ where
type Storage = WgpuStorage;
type MemoryManagement = MM;
type FeatureSet = FeatureSet;
+ type Properties = Properties;
fn read(&mut self, binding: server::Binding) -> Reader {
let resource = self.get_resource(binding);
@@ -190,7 +193,11 @@ where
&mut self,
binding: server::Binding,
) -> ::Resource {
- self.memory_management.get_resource(binding.memory)
+ self.memory_management.get_resource(
+ binding.memory,
+ binding.offset_start,
+ binding.offset_end,
+ )
}
/// When we create a new handle from existing data, we use custom allocations so that we don't
@@ -229,11 +236,11 @@ where
.copy_from_slice(data);
}
- Handle::new(memory)
+ Handle::new(memory, None, None)
}
fn empty(&mut self, size: usize) -> server::Handle {
- server::Handle::new(self.memory_management.reserve(size, &[]))
+ server::Handle::new(self.memory_management.reserve(size, &[]), None, None)
}
unsafe fn execute(
@@ -255,7 +262,15 @@ where
// Keep track of the storage we've used so far.
self.compute_storage_used.push(resource_handle.id);
- self.memory_management.storage().get(&resource_handle)
+ let handle = match binding.offset_start {
+ Some(offset) => resource_handle.offset_start(offset),
+ None => resource_handle.clone(),
+ };
+ let handle = match binding.offset_end {
+ Some(offset) => handle.offset_end(offset),
+ None => handle,
+ };
+ self.memory_management.storage().get(&handle)
})
.collect();
diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs
index 450cde12e..185c466ab 100644
--- a/crates/cubecl-wgpu/src/runtime.rs
+++ b/crates/cubecl-wgpu/src/runtime.rs
@@ -4,7 +4,7 @@ use crate::{
AutoGraphicsApi, GraphicsApi, WgpuDevice,
};
use alloc::sync::Arc;
-use cubecl_core::{Feature, FeatureSet, Runtime};
+use cubecl_core::{Feature, FeatureSet, Properties, Runtime};
use cubecl_runtime::memory_management;
use cubecl_runtime::{channel::MutexComputeChannel, client::ComputeClient, ComputeRuntime};
use wgpu::{DeviceDescriptor, Limits};
@@ -150,8 +150,11 @@ fn create_client(
if features.contains(wgpu::Features::SUBGROUP) {
features_cube.register(Feature::Subcube);
}
+ let properties = Properties {
+ memory_offset_alignment: limits.min_storage_buffer_offset_alignment,
+ };
- ComputeClient::new(channel, Arc::new(features_cube))
+ ComputeClient::new(channel, features_cube, properties)
}
/// Select the wgpu device and queue based on the provided [device](WgpuDevice).
diff --git a/profiling/matmul-example/README.md b/profiling/matmul-example/README.md
index e2d5c94cd..7ff757bf6 100644
--- a/profiling/matmul-example/README.md
+++ b/profiling/matmul-example/README.md
@@ -13,7 +13,7 @@ NVIDIA Nsight Compute is a powerful tool for GPU profiling. Make sure it is inst
For effective profiling, isolate the kernel you want to profile into a main function. This allows you to focus on the performance of a specific kernel without interference from other parts of your code.
## 4. Use the CUDA device/runtime
-Make sure your code uses the CUDA runtime API and device for lauching the kernel.
+Make sure your code uses the CUDA runtime API and device for launching the kernel.
```rust
#[cfg(feature = "cube-cuda")]
@@ -29,19 +29,19 @@ mod cube_cuda {
let client = CudaRuntime::client(&device);
let num_of_batch = 12;
- let heigth = 1024;
+ let height = 1024;
let width = 1024;
- let tensor_values: Vec = (0..num_of_batch * heigth * width)
+ let tensor_values: Vec = (0..num_of_batch * height * width)
.map(|x| x as f32)
.collect();
let tensor_a_handle = client.create(f32::as_bytes(&tensor_values));
let tensor_b_handle = client.create(f32::as_bytes(&tensor_values));
let tensor_c_handle = client.empty(12 * 1024 * 1024 * core::mem::size_of::());
- let tensor_a_shape = vec![num_of_batch, heigth, width];
- let tensor_b_shape = vec![num_of_batch, heigth, width];
- let tensor_c_shape = vec![num_of_batch, heigth, width];
+ let tensor_a_shape = vec![num_of_batch, height, width];
+ let tensor_b_shape = vec![num_of_batch, height, width];
+ let tensor_c_shape = vec![num_of_batch, height, width];
let tensor_a: TensorHandle =
TensorHandle::new_contiguous(tensor_a_shape, tensor_a_handle);
diff --git a/profiling/matmul-example/src/main.rs b/profiling/matmul-example/src/main.rs
index 77cd68277..08a28e428 100644
--- a/profiling/matmul-example/src/main.rs
+++ b/profiling/matmul-example/src/main.rs
@@ -33,19 +33,19 @@ mod cube_cuda {
let client = CudaRuntime::client(&device);
let num_of_batch = 12;
- let heigth = 1024;
+ let height = 1024;
let width = 1024;
- let tensor_values: Vec = (0..num_of_batch * heigth * width)
+ let tensor_values: Vec = (0..num_of_batch * height * width)
.map(|x| x as f32)
.collect();
let tensor_a_handle = client.create(f32::as_bytes(&tensor_values));
let tensor_b_handle = client.create(f32::as_bytes(&tensor_values));
let tensor_c_handle = client.empty(12 * 1024 * 1024 * core::mem::size_of::());
- let tensor_a_shape = vec![num_of_batch, heigth, width];
- let tensor_b_shape = vec![num_of_batch, heigth, width];
- let tensor_c_shape = vec![num_of_batch, heigth, width];
+ let tensor_a_shape = vec![num_of_batch, height, width];
+ let tensor_b_shape = vec![num_of_batch, height, width];
+ let tensor_c_shape = vec![num_of_batch, height, width];
let tensor_a: TensorHandle =
TensorHandle::new_contiguous(tensor_a_shape, tensor_a_handle);
diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml
index 6a086a11f..7a787ca94 100644
--- a/xtask/Cargo.toml
+++ b/xtask/Cargo.toml
@@ -1,19 +1,14 @@
[package]
name = "xtask"
-version = "0.2.0"
+version = "1.0.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
-anyhow = { workspace = true }
-clap = { workspace = true }
-derive_more = { workspace = true }
-env_logger = { workspace = true }
log = { workspace = true }
-rand = { workspace = true, features = ["std"] }
-serde_json = { version = "1.0.119" }
strum = { workspace = true }
+tracel-xtask = { workspace = true }
[dev-dependencies]
rstest = { workspace = true }
diff --git a/xtask/src/commands/build.rs b/xtask/src/commands/build.rs
new file mode 100644
index 000000000..081d268b5
--- /dev/null
+++ b/xtask/src/commands/build.rs
@@ -0,0 +1,17 @@
+use tracel_xtask::prelude::*;
+
+#[macros::extend_command_args(BuildCmdArgs, Target, None)]
+pub struct CubeCLBuildCmdArgs {
+ /// Build in CI mode which excludes unsupported crates.
+ #[arg(long)]
+ pub ci: bool,
+}
+
+pub(crate) fn handle_command(mut args: CubeCLBuildCmdArgs) -> anyhow::Result<()> {
+ if args.ci {
+ // Exclude crates that are not supported on CI
+ args.exclude.extend(vec!["cubecl-cuda".to_string()]);
+ }
+ base_commands::build::handle_command(args.try_into().unwrap())?;
+ Ok(())
+}
diff --git a/xtask/src/commands/bump.rs b/xtask/src/commands/bump.rs
deleted file mode 100644
index 7d1c60a62..000000000
--- a/xtask/src/commands/bump.rs
+++ /dev/null
@@ -1,42 +0,0 @@
-use std::process::Command;
-
-use anyhow::{anyhow, Ok};
-use clap::{Args, Subcommand};
-use strum::{Display, EnumIter, EnumString};
-
-use crate::{endgroup, group, utils::cargo::ensure_cargo_crate_is_installed};
-
-#[derive(Args)]
-pub(crate) struct BumpCmdArgs {
- #[command(subcommand)]
- command: BumpCommand,
-}
-
-#[derive(EnumString, EnumIter, Display, Clone, PartialEq, Subcommand)]
-#[strum(serialize_all = "lowercase")]
-enum BumpCommand {
- /// Run unit tests.
- Major,
- /// Run integration tests.
- Minor,
- /// Run documentation tests.
- Patch,
-}
-
-pub(crate) fn handle_command(args: BumpCmdArgs) -> anyhow::Result<()> {
- bump(&args.command)
-}
-
-fn bump(command: &BumpCommand) -> anyhow::Result<()> {
- group!("Bump version: {command}");
- ensure_cargo_crate_is_installed("cargo-edit", None, false)?;
- let status = Command::new("cargo")
- .args(["set-version", "--bump", &command.to_string()])
- .status()
- .map_err(|e| anyhow!("Failed to execute cargo set-version: {}", e))?;
- if !status.success() {
- return Err(anyhow!("Cannot set new {command} version"));
- }
- endgroup!();
- Ok(())
-}
diff --git a/xtask/src/commands/check.rs b/xtask/src/commands/check.rs
deleted file mode 100644
index 12bba7f99..000000000
--- a/xtask/src/commands/check.rs
+++ /dev/null
@@ -1,211 +0,0 @@
-use std::process::Command;
-
-use anyhow::{anyhow, Ok, Result};
-use clap::{Args, Subcommand};
-use strum::{Display, EnumIter, EnumString, IntoEnumIterator};
-
-use crate::{
- endgroup, group,
- utils::{
- cargo::ensure_cargo_crate_is_installed,
- prompt::ask_once,
- workspace::{get_workspace_members, WorkspaceMemberType},
- },
-};
-
-use super::Target;
-
-#[derive(Args)]
-pub(crate) struct CheckCmdArgs {
- /// Target to check for.
- #[arg(short, long, value_enum)]
- target: Target,
- #[command(subcommand)]
- command: CheckCommand,
-}
-
-#[derive(EnumString, EnumIter, Display, Clone, PartialEq, Subcommand)]
-#[strum(serialize_all = "lowercase")]
-enum CheckCommand {
- /// Run audit command.
- Audit,
- /// Run format command.
- Format,
- /// Run ling command.
- Lint,
- /// Run all the checks.
- All,
-}
-
-pub(crate) fn handle_command(args: CheckCmdArgs, answer: Option) -> anyhow::Result<()> {
- match args.command {
- CheckCommand::Audit => run_audit(&args.target, answer),
- CheckCommand::Format => run_format(&args.target, answer),
- CheckCommand::Lint => run_lint(&args.target, answer),
- CheckCommand::All => {
- let answer = ask_once(
- "This will run all the checks with autofix on all members of the workspace.",
- );
- CheckCommand::iter()
- .filter(|c| *c != CheckCommand::All)
- .try_for_each(|c| {
- handle_command(
- CheckCmdArgs {
- command: c,
- target: args.target.clone(),
- },
- Some(answer),
- )
- })
- }
- }
-}
-
-pub(crate) fn run_audit(target: &Target, mut answer: Option) -> anyhow::Result<()> {
- match target {
- Target::Crates | Target::Examples => {
- if answer.is_none() {
- answer = Some(ask_once(
- "This will run the audit check with autofix mode enabled.",
- ));
- };
- if answer.unwrap() {
- ensure_cargo_crate_is_installed("cargo-audit", Some("fix"), false)?;
- group!("Audit: Crates and Examples");
- info!("Command line: cargo audit fix");
- let status = Command::new("cargo")
- .args(["audit", "-q", "--color", "always", "fix"])
- .status()
- .map_err(|e| anyhow!("Failed to execute cargo audit: {}", e))?;
- if !status.success() {
- return Err(anyhow!("Audit check execution failed"));
- }
- endgroup!();
- }
- }
- Target::All => {
- if answer.is_none() {
- answer = Some(ask_once("This will run audit checks on all targets."));
- };
- Target::iter()
- .filter(|p| *p != Target::All && *p != Target::Examples)
- .try_for_each(|p| run_audit(&p, answer))?;
- }
- }
- Ok(())
-}
-
-fn run_format(target: &Target, mut answer: Option) -> Result<()> {
- match target {
- Target::Crates | Target::Examples => {
- let members = match target {
- Target::Crates => get_workspace_members(WorkspaceMemberType::Crate),
- Target::Examples => get_workspace_members(WorkspaceMemberType::Example),
- _ => unreachable!(),
- };
-
- if answer.is_none() {
- answer = Some(ask_once(&format!(
- "This will run format checks on all {} of the workspace.",
- if *target == Target::Crates {
- "crates"
- } else {
- "examples"
- }
- )));
- }
-
- if answer.unwrap() {
- for member in members {
- group!("Format: {}", member.name);
- info!("Command line: cargo fmt -p {}", &member.name);
- let status = Command::new("cargo")
- .args(["fmt", "-p", &member.name])
- .status()
- .map_err(|e| anyhow!("Failed to execute cargo fmt: {}", e))?;
- if !status.success() {
- return Err(anyhow!(
- "Format check execution failed for {}",
- &member.name
- ));
- }
- endgroup!();
- }
- }
- }
- Target::All => {
- if answer.is_none() {
- answer = Some(ask_once(
- "This will run format check on all members of the workspace.",
- ));
- }
- if answer.unwrap() {
- Target::iter()
- .filter(|t| *t != Target::All)
- .try_for_each(|t| run_format(&t, answer))?;
- }
- }
- }
- Ok(())
-}
-
-fn run_lint(target: &Target, mut answer: Option) -> anyhow::Result<()> {
- match target {
- Target::Crates | Target::Examples => {
- let members = match target {
- Target::Crates => get_workspace_members(WorkspaceMemberType::Crate),
- Target::Examples => get_workspace_members(WorkspaceMemberType::Example),
- _ => unreachable!(),
- };
-
- if answer.is_none() {
- answer = Some(ask_once(&format!(
- "This will run lint fix on all {} of the workspace.",
- if *target == Target::Crates {
- "crates"
- } else {
- "examples"
- }
- )));
- }
-
- if answer.unwrap() {
- for member in members {
- group!("Lint: {}", member.name);
- info!(
- "Command line: cargo clippy --no-deps --fix --allow-dirty -p {}",
- &member.name
- );
- let status = Command::new("cargo")
- .args([
- "clippy",
- "--no-deps",
- "--fix",
- "--allow-dirty",
- "-p",
- &member.name,
- ])
- .status()
- .map_err(|e| anyhow!("Failed to execute cargo clippy: {}", e))?;
- if !status.success() {
- return Err(anyhow!("Lint fix execution failed for {}", &member.name));
- }
- endgroup!();
- }
- }
- }
- Target::All => {
- if answer.is_none() {
- answer = Some(ask_once(
- "This will run lint fix on all members of the workspace.",
- ));
- }
- if answer.unwrap() {
- Target::iter()
- .filter(|t| *t != Target::All)
- .try_for_each(|t| run_lint(&t, answer))?;
- }
- }
- }
- Ok(())
-}
diff --git a/xtask/src/commands/ci.rs b/xtask/src/commands/ci.rs
deleted file mode 100644
index fba5be9dc..000000000
--- a/xtask/src/commands/ci.rs
+++ /dev/null
@@ -1,187 +0,0 @@
-use std::process::Command;
-
-use anyhow::{anyhow, Ok, Result};
-use clap::{Args, Subcommand};
-use strum::{Display, EnumIter, EnumString, IntoEnumIterator};
-
-use crate::{
- endgroup, group,
- utils::{
- cargo::ensure_cargo_crate_is_installed,
- workspace::{get_workspace_members, WorkspaceMemberType},
- },
-};
-
-use super::{
- test::{run_documentation, run_integration, run_unit},
- Target,
-};
-
-#[derive(Args)]
-pub(crate) struct CICmdArgs {
- /// Target to check for.
- #[arg(short, long, value_enum)]
- pub target: Target,
- #[command(subcommand)]
- pub command: CICommand,
-}
-
-#[derive(EnumString, EnumIter, Display, Clone, PartialEq, Subcommand)]
-#[strum(serialize_all = "lowercase")]
-pub(crate) enum CICommand {
- /// Run audit command.
- Audit,
- /// Run format command.
- Format,
- /// Run lint command.
- Lint,
- /// Run unit tests.
- UnitTests,
- /// Run integration tests.
- IntegrationTests,
- /// Run documentation tests.
- DocTests,
- /// Run all tests.
- AllTests,
- /// Run all the checks.
- All,
-}
-
-pub(crate) fn handle_command(args: CICmdArgs) -> anyhow::Result<()> {
- match args.command {
- CICommand::Audit => run_audit(&args.target),
- CICommand::Format => run_format(&args.target),
- CICommand::Lint => run_lint(&args.target),
- CICommand::UnitTests => run_unit_tests(&args.target),
- CICommand::IntegrationTests => run_integration_tests(&args.target),
- CICommand::DocTests => run_doc_tests(&args.target),
- CICommand::AllTests => run_all_tests(&args.target),
- CICommand::All => CICommand::iter()
- .filter(|c| *c != CICommand::All && *c != CICommand::AllTests)
- .try_for_each(|c| {
- handle_command(CICmdArgs {
- command: c,
- target: args.target.clone(),
- })
- }),
- }
-}
-
-fn run_audit(target: &Target) -> anyhow::Result<()> {
- match target {
- Target::Crates | Target::Examples => {
- group!("Audit: Crates and Examples");
- ensure_cargo_crate_is_installed("cargo-audit", Some("fix"), false)?;
- info!("Command line: cargo audit");
- let status = Command::new("cargo")
- .args(["audit", "-q", "--color", "always"])
- .status()
- .map_err(|e| anyhow!("Failed to execute cargo audit: {}", e))?;
- if !status.success() {
- return Err(anyhow!("Audit check execution failed"));
- }
- endgroup!();
- }
- Target::All => {
- Target::iter()
- .filter(|t| *t != Target::All && *t != Target::Examples)
- .try_for_each(|t| run_audit(&t))?;
- }
- }
- Ok(())
-}
-
-fn run_format(target: &Target) -> Result<()> {
- match target {
- Target::Crates | Target::Examples => {
- let members = match target {
- Target::Crates => get_workspace_members(WorkspaceMemberType::Crate),
- Target::Examples => get_workspace_members(WorkspaceMemberType::Example),
- _ => unreachable!(),
- };
-
- for member in members {
- group!("Format: {}", member.name);
- info!("Command line: cargo fmt --check -p {}", &member.name);
- let status = Command::new("cargo")
- .args(["fmt", "--check", "-p", &member.name])
- .status()
- .map_err(|e| anyhow!("Failed to execute cargo fmt: {}", e))?;
- if !status.success() {
- return Err(anyhow!(
- "Format check execution failed for {}",
- &member.name
- ));
- }
- endgroup!();
- }
- }
- Target::All => {
- Target::iter()
- .filter(|t| *t != Target::All)
- .try_for_each(|t| run_format(&t))?;
- }
- }
- Ok(())
-}
-
-fn run_lint(target: &Target) -> anyhow::Result<()> {
- match target {
- Target::Crates | Target::Examples => {
- let members = match target {
- Target::Crates => get_workspace_members(WorkspaceMemberType::Crate),
- Target::Examples => get_workspace_members(WorkspaceMemberType::Example),
- _ => unreachable!(),
- };
-
- for member in members {
- group!("Lint: {}", member.name);
- info!(
- "Command line: cargo clippy --no-deps -p {} -- --deny warnings",
- &member.name
- );
- let status = Command::new("cargo")
- .args([
- "clippy",
- "--no-deps",
- "-p",
- &member.name,
- "--",
- "--deny",
- "warnings",
- ])
- .status()
- .map_err(|e| anyhow!("Failed to execute cargo clippy: {}", e))?;
- if !status.success() {
- return Err(anyhow!("Lint fix execution failed for {}", &member.name));
- }
- endgroup!();
- }
- }
- Target::All => {
- Target::iter()
- .filter(|t| *t != Target::All)
- .try_for_each(|t| run_lint(&t))?;
- }
- }
- Ok(())
-}
-
-fn run_unit_tests(target: &Target) -> anyhow::Result<()> {
- run_unit(target)
-}
-
-fn run_integration_tests(target: &Target) -> anyhow::Result<()> {
- run_integration(target)
-}
-
-fn run_doc_tests(target: &Target) -> anyhow::Result<()> {
- run_documentation(target)
-}
-
-fn run_all_tests(target: &Target) -> anyhow::Result<()> {
- run_unit_tests(target)?;
- run_integration_tests(target)?;
- run_doc_tests(target)?;
- Ok(())
-}
diff --git a/xtask/src/commands/mod.rs b/xtask/src/commands/mod.rs
index 00d1263f4..726f3ad4c 100644
--- a/xtask/src/commands/mod.rs
+++ b/xtask/src/commands/mod.rs
@@ -1,17 +1,2 @@
-pub(crate) mod bump;
-pub(crate) mod check;
-pub(crate) mod ci;
-pub(crate) mod publish;
-pub(crate) mod pull_request_checks;
+pub(crate) mod build;
pub(crate) mod test;
-
-use clap::ValueEnum;
-use strum::{Display, EnumIter, EnumString};
-
-#[derive(EnumString, EnumIter, Display, Clone, PartialEq, ValueEnum)]
-#[strum(serialize_all = "lowercase")]
-pub(crate) enum Target {
- All,
- Crates,
- Examples,
-}
diff --git a/xtask/src/commands/publish.rs b/xtask/src/commands/publish.rs
deleted file mode 100644
index e2ee4ef26..000000000
--- a/xtask/src/commands/publish.rs
+++ /dev/null
@@ -1,107 +0,0 @@
-use std::{env, process::Command, str};
-
-use anyhow::{anyhow, Ok};
-use clap::Args;
-
-use crate::{endgroup, group};
-
-// Crates.io API token
-const CRATES_IO_API_TOKEN: &str = "CRATES_IO_API_TOKEN";
-
-#[derive(Args)]
-pub(crate) struct PublishCmdArgs {
- /// The name of the crate to publish on crates.io
- name: String,
-}
-
-pub(crate) fn handle_command(args: PublishCmdArgs) -> anyhow::Result<()> {
- let crate_name = args.name;
-
- group!("Publishing crate '{}'...", &crate_name);
- // Retrieve local version for crate
- let local_version = local_version(&crate_name)?;
- info!("Local version: {local_version}");
- // Retrieve remote version for crate if it exists
- match remote_version(&crate_name)? {
- Some(remote_version) => {
- info!("Found remote version: {remote_version}");
- // Early return if we don't need to publish the crate
- if local_version == remote_version {
- info!("Remote version is up to date, skipping publishing!");
- return Ok(());
- }
- }
- None => info!("This is the first version to be published on crates.io!"),
- }
- // Publish the crate
- publish(crate_name)?;
- endgroup!();
-
- Ok(())
-}
-
-// Obtain local crate version
-fn local_version(crate_name: &str) -> anyhow::Result {
- // Obtain local crate version contained in cargo pkgid data
- let cargo_pkgid_output = Command::new("cargo")
- .args(["pkgid", "-p", crate_name])
- .output()
- .map_err(|e| anyhow!("Failed to execute cargo pkgid: {}", e))?;
- // Convert cargo pkgid output into a str
- let cargo_pkgid_str = str::from_utf8(&cargo_pkgid_output.stdout)
- .expect("Failed to convert pkgid output into a str");
- // Extract only the local crate version from str
- let (_, local_version) = cargo_pkgid_str
- .split_once('#')
- .expect("Failed to get local crate version");
- Ok(local_version.trim_end().to_string())
-}
-
-// Obtain remote crate version
-fn remote_version(crate_name: &str) -> anyhow::Result