diff --git a/.cargo/config.toml b/.cargo/config.toml index 275927460..84a3562ce 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,2 +1,2 @@ [alias] -xtask = "run --target-dir target/xtask/debug --package xtask --bin xtask --" +xtask = "run --target-dir target/xtask --color always --package xtask --bin xtask --" \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0c1ac743d..390326dd3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,9 +1,5 @@ name: CI -env: - CARGO_TERM_COLOR: always - -# For now we execute CI only on PR to save on CI time on: push: branches: @@ -27,44 +23,163 @@ on: - '!LICENSE-APACHE' - '!LICENSE-MIT' +env: + CARGO_TERM_COLOR: always + RUST_PREVIOUS_VERSION: 1.79.0 + + # Sourced from https://vulkan.lunarg.com/sdk/home#linux + VULKAN_SDK_VERSION: "1.3.268" + + # Sourced from https://archive.mesa3d.org/. Bumping this requires + # updating the mesa build in https://github.com/gfx-rs/ci-build and creating a new release. + MESA_VERSION: "23.3.1" + # Corresponds to https://github.com/gfx-rs/ci-build/releases + MESA_CI_BINARY_BUILD: "build18" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: - crates: - runs-on: ubuntu-latest + prepare-checks: + runs-on: ubuntu-22.04 + outputs: + rust-prev-version: ${{ env.RUST_PREVIOUS_VERSION }} + steps: + - name: Do Nothing + if: false + run: echo + + code-quality: + runs-on: ubuntu-22.04 + needs: prepare-checks + strategy: + matrix: + rust: [stable] + include: + - rust: stable + toolchain: stable steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@master + - name: Setup Rust + uses: tracel-ai/github-actions/setup-rust@v1 with: - components: clippy, rustfmt - toolchain: stable - - uses: Swatinem/rust-cache@v2 + rust-toolchain: ${{ matrix.toolchain }} + cache-key: ${{ matrix.rust }}-linux + # -------------------------------------------------------------------------------- + - name: Audit + run: cargo xtask check audit + # -------------------------------------------------------------------------------- - name: Format - run: cargo xtask ci --target crates format + shell: bash + env: + # work around for colors + # see: https://github.com/rust-lang/rustfmt/issues/3385 + TERM: xterm-256color + run: cargo xtask check format + # -------------------------------------------------------------------------------- - name: Lint - run: cargo xtask ci --target crates lint - - name: Audit - run: cargo xtask ci --target crates audit - - name: Unit Tests - run: cargo xtask ci --target crates unit-tests - - name: Integration Tests - run: cargo xtask ci --target crates integration-tests - - name: Documentation Tests - run: cargo xtask ci --target crates doc-tests - examples: - runs-on: ubuntu-latest + run: cargo xtask check lint + # -------------------------------------------------------------------------------- + - name: Typos + uses: tracel-ai/github-actions/check-typos@v1 + + documentation: + runs-on: ubuntu-22.04 + needs: prepare-checks + strategy: + matrix: + rust: [stable] + include: + - rust: stable + toolchain: stable steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@master + - name: Setup Rust + uses: tracel-ai/github-actions/setup-rust@v1 with: - components: clippy, rustfmt - toolchain: stable - - uses: Swatinem/rust-cache@v2 - - name: Format - run: cargo xtask ci --target examples format - - name: Lint - run: cargo xtask ci --target examples lint - - name: Unit Tests - run: cargo xtask ci --target examples unit-tests - - name: Integration Tests - run: cargo xtask ci --target examples integration-tests + rust-toolchain: ${{ matrix.toolchain }} + cache-key: ${{ matrix.rust }}-linux + # -------------------------------------------------------------------------------- + - name: Documentation Build + run: cargo xtask doc build + # -------------------------------------------------------------------------------- - name: Documentation Tests - run: cargo xtask ci --target examples doc-tests + run: cargo xtask doc tests + + linux-std-tests: + runs-on: ubuntu-22.04 + needs: prepare-checks + strategy: + matrix: + rust: [stable, prev] + include: + - rust: stable + toolchain: stable + - rust: prev + toolchain: ${{ needs.prepare-checks.outputs.rust-prev-version }} + steps: + - name: Setup Rust + uses: tracel-ai/github-actions/setup-rust@v1 + with: + rust-toolchain: ${{ matrix.toolchain }} + cache-key: ${{ matrix.rust }}-linux + # -------------------------------------------------------------------------------- + - name: Setup Linux runner + uses: tracel-ai/github-actions/setup-linux@v1 + with: + vulkan-sdk-version: ${{ env.VULKAN_SDK_VERSION }} + mesa-version: ${{ env.MESA_VERSION }} + mesa-ci-build-version: ${{ env.MESA_CI_BINARY_BUILD }} + # -------------------------------------------------------------------------------- + - name: Tests + run: cargo xtask test --ci + + windows-std-tests: + runs-on: windows-2022 + needs: prepare-checks + env: + DISABLE_WGPU: '1' + # Keep the stragegy to be able to easily add new rust versions if required + strategy: + matrix: + rust: [stable] + include: + - rust: stable + toolchain: stable + steps: + - name: Setup Rust + uses: tracel-ai/github-actions/setup-rust@v1 + with: + rust-toolchain: ${{ matrix.toolchain }} + cache-key: ${{ matrix.rust }}-windows + # -------------------------------------------------------------------------------- + - name: Setup Windows runner + if: env.DISABLE_WGPU != '1' + uses: tracel-ai/github-actions/setup-windows@v1 + with: + dxc-release: ${{ env.DXC_RELEASE }} + dxc-filename: ${{ env.DXC_FILENAME }} + mesa-version: ${{ env.MESA_VERSION }} + warp-version: ${{ env.WARP_VERSION }} + # -------------------------------------------------------------------------------- + - name: Tests + run: cargo xtask test --ci + + macos-std-tests: + runs-on: macos-14 + needs: prepare-checks + # Keep the stragegy to be able to easily add new rust versions if required + strategy: + matrix: + rust: [stable] + include: + - rust: stable + toolchain: stable + steps: + - name: Setup Rust + uses: tracel-ai/github-actions/setup-rust@v1 + with: + rust-toolchain: ${{ matrix.toolchain }} + cache-key: ${{ matrix.rust }}-macos + # -------------------------------------------------------------------------------- + - name: Tests + run: cargo xtask test --ci diff --git a/Cargo.toml b/Cargo.toml index 92834909b..0b78e5633 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,12 +61,9 @@ proc-macro2 = "1.0.86" syn = { version = "2.0.69", features = ["full", "extra-traits"] } quote = "1.0.36" -# xtask -anyhow = "1.0.86" -clap = { version = "4.5.9", features = ["derive"] } -derive_more = { version = "0.99.18", features = ["display"], default-features = false } -env_logger = "0.11.3" +### For xtask crate ### strum = {version = "0.26.3", features = ["derive"]} +tracel-xtask = { version = "~1.0" } portable-atomic-util = { version = "0.2.2", features = ["alloc"] } # alloc is for no_std diff --git a/README.md b/README.md index f7de0379c..420ddb4b5 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,8 @@ [![Discord](https://img.shields.io/discord/1038839012602941528.svg?color=7289da&&logo=discord)](https://discord.gg/KSBSPhAUCc) [![Current Crates.io Version](https://img.shields.io/crates/v/cubecl.svg)](https://crates.io/crates/cubecl) +[![Minimum Supported Rust Version](https://img.shields.io/crates/msrv/cubecl)](https://crates.io/crates/burn) [![Test Status](https://github.com/tracel-ai/cubecl/actions/workflows/ci.yml/badge.svg)](https://github.com/tracel-ai/cubecl/actions/workflows/test.yml) -[![Rust Version](https://img.shields.io/badge/Rust-1.79.0+-blue)](https://releases.rs/docs/1.79.0) ![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)
[![NVIDIA](https://img.shields.io/badge/nvidia-cuda-84b629)](https://github.com/tracel-ai/cubecl/tree/main/crates/cubecl-cuda) diff --git a/crates/cubecl-core/src/compute/kernel.rs b/crates/cubecl-core/src/compute/kernel.rs index 3e3175631..32586bf5b 100644 --- a/crates/cubecl-core/src/compute/kernel.rs +++ b/crates/cubecl-core/src/compute/kernel.rs @@ -84,7 +84,7 @@ fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) -> let kernel_id = kernel_id.to_string(); let mut result = String::new(); let mut depth = 0; - let indendation = 4; + let indentation = 4; let mut prev = ' '; @@ -105,7 +105,7 @@ fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) -> } result.push(start); result.push('\n'); - result.push_str(&" ".repeat(indendation * depth)); + result.push_str(&" ".repeat(indentation * depth)); found_marker = true; } else if c == end { depth -= 1; @@ -114,10 +114,10 @@ fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) -> result.pop(); } result.push_str(",\n"); - result.push_str(&" ".repeat(indendation * depth)); + result.push_str(&" ".repeat(indentation * depth)); result.push(end); } else { - for _ in 0..(&" ".repeat(indendation * depth).len()) + 1 + indendation { + for _ in 0..(&" ".repeat(indentation * depth).len()) + 1 + indentation { result.pop(); } result.push(end); @@ -137,7 +137,7 @@ fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) -> } result.push_str(",\n"); - result.push_str(&" ".repeat(indendation * depth)); + result.push_str(&" ".repeat(indentation * depth)); continue; } diff --git a/crates/cubecl-core/src/frontend/element/array.rs b/crates/cubecl-core/src/frontend/element/array.rs index d3cad4bde..8e2ba9c5f 100644 --- a/crates/cubecl-core/src/frontend/element/array.rs +++ b/crates/cubecl-core/src/frontend/element/array.rs @@ -207,7 +207,7 @@ impl<'a, R: Runtime> ArrayArg<'a, R> { /// /// # Safety /// - /// Specifying the wrong lenght may lead to out-of-bounds reads and writes. + /// Specifying the wrong length may lead to out-of-bounds reads and writes. pub unsafe fn from_raw_parts( handle: &'a cubecl_runtime::server::Handle, length: usize, @@ -225,7 +225,7 @@ impl<'a, R: Runtime> ArrayHandleRef<'a, R> { /// /// # Safety /// - /// Specifying the wrong lenght may lead to out-of-bounds reads and writes. + /// Specifying the wrong length may lead to out-of-bounds reads and writes. pub unsafe fn from_raw_parts( handle: &'a cubecl_runtime::server::Handle, length: usize, diff --git a/crates/cubecl-core/src/frontend/element/base.rs b/crates/cubecl-core/src/frontend/element/base.rs index e98911cfe..663d2c102 100644 --- a/crates/cubecl-core/src/frontend/element/base.rs +++ b/crates/cubecl-core/src/frontend/element/base.rs @@ -236,7 +236,7 @@ impl From> for ExpandElement { } impl ExpandElementTyped { - /// Create an [ExpandElementTyped] from a value that is normaly a literal. + /// Create an [ExpandElementTyped] from a value that is normally a literal. pub fn from_lit>(lit: L) -> Self { let variable: Variable = lit.into(); let variable = T::as_elem().from_constant(variable); diff --git a/crates/cubecl-core/src/frontend/element/float.rs b/crates/cubecl-core/src/frontend/element/float.rs index 0163ca2bd..ebce72cd9 100644 --- a/crates/cubecl-core/src/frontend/element/float.rs +++ b/crates/cubecl-core/src/frontend/element/float.rs @@ -1,6 +1,8 @@ use half::{bf16, f16}; -use crate::frontend::{Ceil, Cos, Erf, Exp, Floor, Log, Log1p, Powf, Recip, Sin, Sqrt, Tanh}; +use crate::frontend::{ + Ceil, Cos, Erf, Exp, Floor, Log, Log1p, Powf, Recip, Round, Sin, Sqrt, Tanh, +}; use crate::frontend::{ ComptimeType, CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementBaseInit, ExpandElementTyped, Numeric, @@ -25,6 +27,7 @@ pub trait Float: + Tanh + Powf + Sqrt + + Round + Floor + Ceil + Erf diff --git a/crates/cubecl-core/src/frontend/operation/binary.rs b/crates/cubecl-core/src/frontend/operation/binary.rs index 7632a5e88..dabe8001c 100644 --- a/crates/cubecl-core/src/frontend/operation/binary.rs +++ b/crates/cubecl-core/src/frontend/operation/binary.rs @@ -192,6 +192,26 @@ pub mod bitand { } } +pub mod bitor { + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElementTyped, + rhs: ExpandElementTyped, + ) -> ExpandElementTyped { + binary_expand(context, lhs.into(), rhs.into(), Operator::BitwiseOr).into() + } + + impl core::ops::BitOr for UInt { + type Output = UInt; + + fn bitor(self, _rhs: Self) -> Self::Output { + unexpanded!() + } + } +} + pub mod or { use super::*; diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs index 40569e447..3f2954ed4 100644 --- a/crates/cubecl-core/src/frontend/operation/unary.rs +++ b/crates/cubecl-core/src/frontend/operation/unary.rs @@ -82,6 +82,16 @@ impl_unary_func!( F32, F64 ); +impl_unary_func!( + Round, + round, + __expand_round, + Operator::Round, + F16, + BF16, + F32, + F64 +); impl_unary_func!( Floor, floor, diff --git a/crates/cubecl-core/src/ir/operation.rs b/crates/cubecl-core/src/ir/operation.rs index 0a22814a8..3302b0053 100644 --- a/crates/cubecl-core/src/ir/operation.rs +++ b/crates/cubecl-core/src/ir/operation.rs @@ -39,6 +39,7 @@ pub enum Operator { Tanh(UnaryOperator), Powf(BinaryOperator), Sqrt(UnaryOperator), + Round(UnaryOperator), Floor(UnaryOperator), Ceil(UnaryOperator), Erf(UnaryOperator), @@ -63,6 +64,7 @@ pub enum Operator { Max(BinaryOperator), Min(BinaryOperator), BitwiseAnd(BinaryOperator), + BitwiseOr(BinaryOperator), BitwiseXor(BinaryOperator), ShiftLeft(BinaryOperator), ShiftRight(BinaryOperator), diff --git a/crates/cubecl-core/src/ir/processing.rs b/crates/cubecl-core/src/ir/processing.rs index 3d2ba51c0..d06d12ba4 100644 --- a/crates/cubecl-core/src/ir/processing.rs +++ b/crates/cubecl-core/src/ir/processing.rs @@ -76,6 +76,9 @@ impl ScopeProcessing { Operator::Sqrt(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &op.out); } + Operator::Round(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &op.out); + } Operator::Floor(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &op.out); } @@ -168,6 +171,10 @@ impl ScopeProcessing { sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out); sanitize_constant_scalar_ref_var(&mut op.rhs, &op.out); } + Operator::BitwiseOr(op) => { + sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out); + sanitize_constant_scalar_ref_var(&mut op.rhs, &op.out); + } Operator::BitwiseXor(op) => { sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out); sanitize_constant_scalar_ref_var(&mut op.rhs, &op.out); diff --git a/crates/cubecl-core/src/ir/vectorization.rs b/crates/cubecl-core/src/ir/vectorization.rs index eb7f4396c..bed3c3405 100644 --- a/crates/cubecl-core/src/ir/vectorization.rs +++ b/crates/cubecl-core/src/ir/vectorization.rs @@ -39,6 +39,7 @@ impl Operator { Operator::Sub(op) => Operator::Sub(op.vectorize(vectorization)), Operator::Mul(op) => Operator::Mul(op.vectorize(vectorization)), Operator::Div(op) => Operator::Div(op.vectorize(vectorization)), + Operator::Round(op) => Operator::Round(op.vectorize(vectorization)), Operator::Floor(op) => Operator::Floor(op.vectorize(vectorization)), Operator::Ceil(op) => Operator::Ceil(op.vectorize(vectorization)), Operator::Abs(op) => Operator::Abs(op.vectorize(vectorization)), @@ -76,6 +77,7 @@ impl Operator { Operator::And(op) => Operator::And(op.vectorize(vectorization)), Operator::Or(op) => Operator::Or(op.vectorize(vectorization)), Operator::Not(op) => Operator::Not(op.vectorize(vectorization)), + Operator::BitwiseOr(op) => Operator::BitwiseOr(op.vectorize(vectorization)), Operator::BitwiseAnd(op) => Operator::BitwiseAnd(op.vectorize(vectorization)), Operator::BitwiseXor(op) => Operator::BitwiseXor(op.vectorize(vectorization)), Operator::ShiftLeft(op) => Operator::ShiftLeft(op.vectorize(vectorization)), diff --git a/crates/cubecl-core/src/lib.rs b/crates/cubecl-core/src/lib.rs index 59cb1a312..d8918dbf1 100644 --- a/crates/cubecl-core/src/lib.rs +++ b/crates/cubecl-core/src/lib.rs @@ -70,22 +70,38 @@ pub fn tensor_vectorization_factor( strides: &[usize], dim: usize, ) -> u8 { - if let Some(val) = strides.get(dim) { - if *val != 1 { - return 1; + match strides.get(dim) { + Some(val) => { + if *val != 1 { + return 1; + } } - } else { - return 1; + None => return 1, } - let dim_size = match shape.get(dim) { + let shape_check = match shape.get(dim) { Some(val) => val, None => return 1, }; + let stride_check = if let Some(dim) = dim.checked_sub(1) { + strides.get(dim) + } else { + None + }; + for factor in factors { - if dim_size % *factor as usize == 0 { - return *factor; + let factor = *factor as usize; + + if shape_check % factor == 0 { + match stride_check { + Some(check) => { + if check % factor == 0 { + return factor as u8; + } + } + None => return factor as u8, + } } } diff --git a/crates/cubecl-core/src/runtime.rs b/crates/cubecl-core/src/runtime.rs index b9a42cbf5..0841163cf 100644 --- a/crates/cubecl-core/src/runtime.rs +++ b/crates/cubecl-core/src/runtime.rs @@ -20,6 +20,7 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug { Kernel = Box, DispatchOptions = CubeCount, FeatureSet = FeatureSet, + Properties = Properties, >; /// The channel used to communicate with the compute server. type Channel: ComputeChannel; @@ -44,6 +45,13 @@ pub struct FeatureSet { set: alloc::collections::BTreeSet, } +/// The [runtime](Runtime) properties. +#[derive(Default, Debug)] +pub struct Properties { + /// The memory offset alignment in bytes. + pub memory_offset_alignment: u32, +} + impl FeatureSet { pub fn new(features: &[Feature]) -> Self { let mut this = Self::default(); diff --git a/crates/cubecl-core/tests/frontend/ops.rs b/crates/cubecl-core/tests/frontend/ops.rs index d5c9a63d1..064e8b9f6 100644 --- a/crates/cubecl-core/tests/frontend/ops.rs +++ b/crates/cubecl-core/tests/frontend/ops.rs @@ -66,6 +66,11 @@ pub fn sqrt_op(a: F) -> F { F::sqrt(a) } +#[cube] +pub fn round_op(a: F) -> F { + F::round(a) +} + #[cube] pub fn floor_op(a: F) -> F { F::floor(a) @@ -156,6 +161,11 @@ pub fn bitand_op(a: UInt, b: UInt) -> UInt { a & b } +#[cube] +pub fn bitor_op(a: UInt, b: UInt) -> UInt { + a | b +} + #[cube] pub fn bitxor_op(a: UInt, b: UInt) -> UInt { a ^ b @@ -286,6 +296,7 @@ mod tests { unary_test!(cube_can_sqrt, sqrt_op::__expand::, "Sqrt"); unary_test!(cube_can_erf, erf_op::__expand::, "Erf"); unary_test!(cube_can_recip, recip_op::__expand::, "Recip"); + unary_test!(cube_can_round, round_op::__expand::, "Round"); unary_test!(cube_can_floor, floor_op::__expand::, "Floor"); unary_test!(cube_can_ceil, ceil_op::__expand::, "Ceil"); binary_test!(cube_can_eq, equal_op::__expand::, "Equal", ref_ops_cmp); @@ -343,6 +354,7 @@ mod tests { binary_boolean_test!(cube_can_and, and_op::__expand, "And"); binary_boolean_test!(cube_can_or, or_op::__expand, "Or"); binary_uint_test!(cube_can_bitand, bitand_op::__expand, "BitwiseAnd"); + binary_uint_test!(cube_can_bitor, bitor_op::__expand, "BitwiseOr"); binary_uint_test!(cube_can_bitxor, bitxor_op::__expand, "BitwiseXor"); binary_uint_test!(cube_can_shl, shl_op::__expand, "ShiftLeft"); binary_uint_test!(cube_can_shr, shr_op::__expand, "ShiftRight"); diff --git a/crates/cubecl-cuda/src/compiler/base.rs b/crates/cubecl-cuda/src/compiler/base.rs index 42c68b1dd..5787fe5d3 100644 --- a/crates/cubecl-cuda/src/compiler/base.rs +++ b/crates/cubecl-cuda/src/compiler/base.rs @@ -451,6 +451,9 @@ impl CudaCompiler { gpu::Operator::NotEqual(op) => { instructions.push(Instruction::NotEqual(self.compile_binary(op))) } + gpu::Operator::BitwiseOr(op) => { + instructions.push(Instruction::BitwiseOr(self.compile_binary(op))) + } gpu::Operator::BitwiseAnd(op) => { instructions.push(Instruction::BitwiseAnd(self.compile_binary(op))) } @@ -487,6 +490,9 @@ impl CudaCompiler { out: self.compile_variable(op.out), })) } + gpu::Operator::Round(op) => { + instructions.push(Instruction::Round(self.compile_unary(op))) + } gpu::Operator::Floor(op) => { instructions.push(Instruction::Floor(self.compile_unary(op))) } diff --git a/crates/cubecl-cuda/src/compiler/binary.rs b/crates/cubecl-cuda/src/compiler/binary.rs index a6016d740..75e7e8076 100644 --- a/crates/cubecl-cuda/src/compiler/binary.rs +++ b/crates/cubecl-cuda/src/compiler/binary.rs @@ -100,6 +100,7 @@ operator!(Greater, ">"); operator!(GreaterEqual, ">="); operator!(ShiftLeft, "<<"); operator!(ShiftRight, ">>"); +operator!(BitwiseOr, "|"); operator!(BitwiseAnd, "&"); operator!(BitwiseXor, "^"); operator!(Or, "||"); @@ -128,29 +129,11 @@ impl Binary for IndexAssign { let item_rhs = rhs.item(); let format_vec = |f: &mut Formatter<'_>, cast: bool| { - let is_vec_native = item_out.is_vec_native(); f.write_str("{\n")?; let var = "broadcasted"; f.write_fmt(format_args!("{item_out} {var};\n"))?; for i in 0..item_out.vectorization { - if is_vec_native { - let char = match i { - 0 => 'x', - 1 => 'y', - 2 => 'z', - 3 => 'w', - _ => panic!("Invalid"), - }; - if cast { - f.write_fmt(format_args!( - "{var}.{char} = {}({});\n", - item_out.elem, - rhs.index(i) - ))?; - } else { - f.write_fmt(format_args!("{var}.{char} = {};\n", rhs.index(i)))?; - } - } else if cast { + if cast { f.write_fmt(format_args!( "{var}.i_{i} = {}({});\n", item_out.elem, @@ -254,29 +237,14 @@ impl Binary for Index { let item_lhs = lhs.item(); let format_vec = |f: &mut Formatter<'_>| { - let is_vec_native = item_out.is_vec_native(); f.write_str("{\n")?; let var = "broadcasted"; f.write_fmt(format_args!("{item_out} {var};\n"))?; for i in 0..item_out.vectorization { - if is_vec_native { - let char = match i { - 0 => 'x', - 1 => 'y', - 2 => 'z', - 3 => 'w', - _ => panic!("Invalid"), - }; - f.write_fmt(format_args!( - "{var}.{char} = {}({lhs}[{rhs}].i_{i});\n", - item_out.elem - ))?; - } else { - f.write_fmt(format_args!( - "{var}.i_{i} = {}({lhs}[{rhs}].i_{i});\n", - item_out.elem - ))?; - } + f.write_fmt(format_args!( + "{var}.i_{i} = {}({lhs}[{rhs}].i_{i});\n", + item_out.elem + ))?; } f.write_fmt(format_args!("{out} = {var};\n"))?; f.write_str("}")?; diff --git a/crates/cubecl-cuda/src/compiler/element.rs b/crates/cubecl-cuda/src/compiler/element.rs index 689e150db..c24e3f2ad 100644 --- a/crates/cubecl-cuda/src/compiler/element.rs +++ b/crates/cubecl-cuda/src/compiler/element.rs @@ -43,11 +43,6 @@ impl Display for Item { return f.write_fmt(format_args!("{}", self.elem)); } - if self.is_vec_native() { - let elem = self.optimized().elem; - return f.write_fmt(format_args!("{elem}")); - } - return f.write_fmt(format_args!("{}_{}", self.elem, self.vectorization)); } } @@ -397,12 +392,15 @@ impl Display for IndexedVariable { if self.optimized { let item = self.var.item(); f.write_fmt(format_args!( - "(reinterpret_cast<{item}*>(&{var}))->i_{}", + "(reinterpret_cast<{item}&>({var})).i_{}", self.index )) } else { f.write_fmt(format_args!("{var}.i_{}", self.index)) } + } else if self.optimized { + let item = self.var.item(); + f.write_fmt(format_args!("reinterpret_cast<{item}&>({var})")) } else { f.write_fmt(format_args!("{var}")) } @@ -438,16 +436,6 @@ impl Item { matches!(self.elem, Elem::F162 | Elem::BF162) } - pub fn is_vec_native(&self) -> bool { - match &self.elem { - Elem::F16 => self.vectorization == 2, - Elem::BF16 => self.vectorization == 2, - Elem::F162 => self.vectorization == 1, - Elem::BF162 => self.vectorization == 1, - _ => false, - } - } - pub fn optimized(&self) -> Item { if self.vectorization == 1 { return *self; diff --git a/crates/cubecl-cuda/src/compiler/instruction.rs b/crates/cubecl-cuda/src/compiler/instruction.rs index 97b35bba5..7c4bbda72 100644 --- a/crates/cubecl-cuda/src/compiler/instruction.rs +++ b/crates/cubecl-cuda/src/compiler/instruction.rs @@ -89,6 +89,7 @@ pub enum Instruction { LowerEqual(BinaryInstruction), GreaterEqual(BinaryInstruction), Erf(UnaryInstruction), + BitwiseOr(BinaryInstruction), BitwiseAnd(BinaryInstruction), BitwiseXor(BinaryInstruction), ShiftLeft(BinaryInstruction), @@ -114,6 +115,8 @@ pub enum Instruction { out: Variable, }, SyncThreads, + ThreadFence, + Round(UnaryInstruction), Ceil(UnaryInstruction), Floor(UnaryInstruction), Wrap(WarpInstruction), @@ -168,6 +171,7 @@ impl Display for Instruction { Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out), + Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out), Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out), Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out), Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out), @@ -264,6 +268,8 @@ for (uint {i} = {start}; {i} < {end}; {increment}) {{ out, } => Clamp::format(f, input, min_value, max_value, out), Instruction::SyncThreads => f.write_str("__syncthreads();\n"), + Instruction::ThreadFence => f.write_str("__threadfence();\n"), + Instruction::Round(it) => Round::format(f, &it.input, &it.out), Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out), Instruction::Floor(it) => Floor::format(f, &it.input, &it.out), Instruction::SliceLength { input, out } => { diff --git a/crates/cubecl-cuda/src/compiler/kernel.rs b/crates/cubecl-cuda/src/compiler/kernel.rs index a32f8d19d..6fcfeb2d5 100644 --- a/crates/cubecl-cuda/src/compiler/kernel.rs +++ b/crates/cubecl-cuda/src/compiler/kernel.rs @@ -87,10 +87,6 @@ impl Display for ComputeKernel { f.write_str("typedef unsigned int uint;\n")?; for item in self.items.iter() { - if item.is_vec_native() { - continue; - } - let elem = item.elem; let size = item.vectorization; let alignment = elem.size() * size; diff --git a/crates/cubecl-cuda/src/compiler/unary.rs b/crates/cubecl-cuda/src/compiler/unary.rs index 6529701a4..9f45c6c5e 100644 --- a/crates/cubecl-cuda/src/compiler/unary.rs +++ b/crates/cubecl-cuda/src/compiler/unary.rs @@ -82,6 +82,7 @@ function!(Exp, "exp"); function!(Erf, "erf"); function!(Ceil, "ceil"); function!(Floor, "floor"); +function!(Round, "rint"); pub struct Not; diff --git a/crates/cubecl-cuda/src/compute/server.rs b/crates/cubecl-cuda/src/compute/server.rs index 6f824e200..afb4cfab4 100644 --- a/crates/cubecl-cuda/src/compute/server.rs +++ b/crates/cubecl-cuda/src/compute/server.rs @@ -6,8 +6,8 @@ use cubecl_common::reader::{reader_from_concrete, Reader}; use cubecl_common::sync_type::SyncType; use cubecl_core::compute::DebugInformation; use cubecl_core::ir::CubeDim; -use cubecl_core::FeatureSet; use cubecl_core::{prelude::*, KernelId}; +use cubecl_core::{FeatureSet, Properties}; use cubecl_runtime::debug::DebugLogger; use cubecl_runtime::ExecutionMode; use cubecl_runtime::{ @@ -25,8 +25,6 @@ use std::path::PathBuf; pub struct CudaServer> { state: CudaServerState, logger: DebugLogger, - pub(crate) archs: Vec, - pub(crate) minimum_arch_version: i32, } pub(crate) enum CudaServerState> { @@ -51,6 +49,7 @@ pub(crate) struct CudaContext> { stream: cudarc::driver::sys::CUstream, memory_management: MM, module_names: HashMap, + pub(crate) arch: u32, } #[derive(Debug)] @@ -63,9 +62,18 @@ struct CompiledKernel { unsafe impl> Send for CudaServer {} impl> CudaServer { + pub(crate) fn arch_version(&mut self) -> u32 { + let ctx = self.get_context(); + ctx.arch + } + fn read_sync(&mut self, binding: server::Binding) -> Vec { let ctx = self.get_context(); - let resource = ctx.memory_management.get_resource(binding.memory); + let resource = ctx.memory_management.get_resource( + binding.memory, + binding.offset_start, + binding.offset_end, + ); // TODO: Check if it is possible to make this faster let mut data = vec![0; resource.size() as usize]; @@ -83,6 +91,7 @@ impl> ComputeServer for CudaServer { type Storage = CudaStorage; type MemoryManagement = MM; type FeatureSet = FeatureSet; + type Properties = Properties; fn read(&mut self, binding: server::Binding) -> Reader { reader_from_concrete(self.read_sync(binding)) @@ -92,9 +101,12 @@ impl> ComputeServer for CudaServer { let handle = self.empty(data.len()); let ctx = self.get_context(); - let resource = ctx - .memory_management - .get_resource(handle.clone().binding().memory); + let binding = handle.clone().binding(); + let resource = ctx.memory_management.get_resource( + binding.memory, + binding.offset_start, + binding.offset_end, + ); unsafe { cudarc::driver::result::memcpy_htod_async(resource.ptr, data, ctx.stream).unwrap(); @@ -106,7 +118,7 @@ impl> ComputeServer for CudaServer { fn empty(&mut self, size: usize) -> server::Handle { let ctx = self.get_context(); let handle = ctx.memory_management.reserve(size, &[]); - server::Handle::new(handle) + server::Handle::new(handle, None, None) } unsafe fn execute( @@ -116,8 +128,6 @@ impl> ComputeServer for CudaServer { bindings: Vec>, mode: ExecutionMode, ) { - let arch = self.minimum_arch_version; - let mut kernel_id = kernel.id(); kernel_id.mode(mode); @@ -140,12 +150,18 @@ impl> ComputeServer for CudaServer { let (ctx, logger) = self.get_context_with_logger(); if !ctx.module_names.contains_key(&kernel_id) { - ctx.compile_kernel(&kernel_id, kernel, arch, logger, mode); + ctx.compile_kernel(&kernel_id, kernel, logger, mode); } let resources = bindings .into_iter() - .map(|binding| ctx.memory_management.get_resource(binding.memory)) + .map(|binding| { + ctx.memory_management.get_resource( + binding.memory, + binding.offset_start, + binding.offset_end, + ) + }) .collect::>(); ctx.execute_task(kernel_id, count, resources); @@ -168,7 +184,8 @@ impl> ComputeServer for CudaServer { binding: server::Binding, ) -> ::Resource { let ctx = self.get_context(); - ctx.memory_management.get_resource(binding.memory) + ctx.memory_management + .get_resource(binding.memory, binding.offset_start, binding.offset_end) } } @@ -177,12 +194,14 @@ impl> CudaContext { memory_management: MM, stream: cudarc::driver::sys::CUstream, context: *mut CUctx_st, + arch: u32, ) -> Self { Self { context, memory_management, module_names: HashMap::new(), stream, + arch, } } @@ -197,7 +216,6 @@ impl> CudaContext { &mut self, kernel_id: &KernelId, kernel: Box, - arch: i32, logger: &mut DebugLogger, mode: ExecutionMode, ) { @@ -213,7 +231,7 @@ impl> CudaContext { let shared_mem_bytes = kernel_compiled.shared_mem_bytes; let cube_dim = kernel_compiled.cube_dim; - let arch = format!("--gpu-architecture=sm_{}", arch); + let arch = format!("--gpu-architecture=sm_{}", self.arch); let include_path = include_path(); let include_option = format!("--include-path={}", include_path.to_str().unwrap()); @@ -286,26 +304,12 @@ impl> CudaContext { impl> CudaServer { /// Create a new cuda server. pub(crate) fn new(index: usize, init: Box CudaContext>) -> Self { - let archs = unsafe { - let mut num_supported_arg: core::ffi::c_int = 0; - cudarc::nvrtc::sys::lib() - .nvrtcGetNumSupportedArchs(core::ptr::from_mut(&mut num_supported_arg)); - - let mut archs: Vec = vec![0; num_supported_arg as usize]; - cudarc::nvrtc::sys::lib().nvrtcGetSupportedArchs(core::ptr::from_mut(&mut archs[0])); - archs - }; - - let minimum_arch_version = archs[0]; - Self { state: CudaServerState::Uninitialized { device_index: index, init, }, logger: DebugLogger::new(), - archs, - minimum_arch_version, } } @@ -316,6 +320,7 @@ impl> CudaServer { fn get_context_with_logger(&mut self) -> (&mut CudaContext, &mut DebugLogger) { if let CudaServerState::Uninitialized { device_index, init } = &self.state { let ctx = init(*device_index); + self.state = CudaServerState::Initialized { ctx }; } if let CudaServerState::Initialized { ctx } = &mut self.state { diff --git a/crates/cubecl-cuda/src/runtime.rs b/crates/cubecl-cuda/src/runtime.rs index 4c5caa788..5895dbdd1 100644 --- a/crates/cubecl-cuda/src/runtime.rs +++ b/crates/cubecl-cuda/src/runtime.rs @@ -1,6 +1,6 @@ use cubecl_core::{ ir::{Elem, FloatKind}, - Feature, FeatureSet, Runtime, + Feature, FeatureSet, Properties, Runtime, }; use cubecl_runtime::{ channel::MutexComputeChannel, @@ -8,7 +8,6 @@ use cubecl_runtime::{ memory_management::dynamic::{DynamicMemoryManagement, DynamicMemoryManagementOptions}, ComputeRuntime, }; -use std::sync::Arc; use crate::{ compiler::CudaCompiler, @@ -35,6 +34,11 @@ impl Runtime for CudaRuntime { fn init(index: usize) -> CudaContext> { cudarc::driver::result::init().unwrap(); let device_ptr = cudarc::driver::result::device::get(index as i32).unwrap(); + let arch = unsafe { + let major = cudarc::driver::result::device::get_attribute(device_ptr, cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR).unwrap(); + let minor = cudarc::driver::result::device::get_attribute(device_ptr, cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR).unwrap(); + major * 10 + minor + } as u32; let ctx = unsafe { let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap(); @@ -49,20 +53,21 @@ impl Runtime for CudaRuntime { let storage = CudaStorage::new(stream); let options = DynamicMemoryManagementOptions::preset(2048 + 512 * 1024 * 1024, 32); let memory_management = DynamicMemoryManagement::new(storage, options); - CudaContext::new(memory_management, stream, ctx) + CudaContext::new(memory_management, stream, ctx, arch) } RUNTIME.client(device, move || { let mut server = CudaServer::new(device.index, Box::new(init)); let mut features = FeatureSet::new(&[Feature::Subcube]); - if let Some(wmma_minimum_version) = register_wmma_features(&mut features, &server.archs) - { - server.minimum_arch_version = - i32::max(server.minimum_arch_version, wmma_minimum_version); - } - - ComputeClient::new(MutexComputeChannel::new(server), Arc::new(features)) + register_wmma_features(&mut features, server.arch_version()); + ComputeClient::new( + MutexComputeChannel::new(server), + features, + Properties { + memory_offset_alignment: 4, + }, + ) }) } @@ -75,15 +80,12 @@ impl Runtime for CudaRuntime { } } -fn register_wmma_features(features: &mut FeatureSet, archs: &[i32]) -> Option { +fn register_wmma_features(features: &mut FeatureSet, arch: u32) { let wmma_minimum_version = 70; let mut wmma = false; - for arch in archs { - if *arch >= wmma_minimum_version { - wmma = true; - break; - } + if arch >= wmma_minimum_version { + wmma = true; } if wmma { @@ -130,8 +132,5 @@ fn register_wmma_features(features: &mut FeatureSet, archs: &[i32]) -> Option( #[derive(Debug)] pub enum UnavailabilityReason { NotMultipleOf4, // TODO: Support that case. - HiglyPermutatedInput, + HighlyPermutatedInput, ShapeMemoryLimitBusted, InvalidConfig(String), CmmaInstructionsUnsupported, diff --git a/crates/cubecl-linalg/src/tensor/contiguous.rs b/crates/cubecl-linalg/src/tensor/contiguous.rs index e26b3afa0..081f9d5d7 100644 --- a/crates/cubecl-linalg/src/tensor/contiguous.rs +++ b/crates/cubecl-linalg/src/tensor/contiguous.rs @@ -28,7 +28,7 @@ pub fn index_offset_with_layout( offset / vectorization_factor_runtime } -#[cube(launch)] +#[cube(launch_unchecked)] fn into_contiguous_kernel( input: &Tensor, output: &mut Tensor, @@ -69,14 +69,16 @@ pub fn into_contiguous( let handle = client.empty(num_elems * E::as_elem().size()); let output = TensorHandle::new_contiguous(input.shape.to_vec(), handle); - into_contiguous_kernel::launch::( - client, - cube_count, - cube_dim, - input.as_tensor_arg(vectorization_factor), - output.as_ref().as_tensor_arg(vectorization_factor), - Some(UInt::new(rank as u32)), - ); + unsafe { + into_contiguous_kernel::launch_unchecked::( + client, + cube_count, + cube_dim, + input.as_tensor_arg(vectorization_factor), + output.as_ref().as_tensor_arg(vectorization_factor), + Some(UInt::new(rank as u32)), + ); + } output } diff --git a/crates/cubecl-macros/src/codegen_function/operation.rs b/crates/cubecl-macros/src/codegen_function/operation.rs index ede4aeeb0..18b6ee804 100644 --- a/crates/cubecl-macros/src/codegen_function/operation.rs +++ b/crates/cubecl-macros/src/codegen_function/operation.rs @@ -206,6 +206,14 @@ pub(crate) fn codegen_binary( cubecl::frontend::or::expand(context, _lhs, _rhs) } }, + syn::BinOp::BitOr(_) => quote::quote! { + { + + let _lhs = #lhs; + let _rhs = #rhs; + cubecl::frontend::bitor::expand(context, _lhs, _rhs) + } + }, syn::BinOp::BitAnd(_) => quote::quote! { { diff --git a/crates/cubecl-runtime/src/client.rs b/crates/cubecl-runtime/src/client.rs index 25cf9b7b7..cce4ced15 100644 --- a/crates/cubecl-runtime/src/client.rs +++ b/crates/cubecl-runtime/src/client.rs @@ -14,7 +14,7 @@ pub use cubecl_common::sync_type::SyncType; #[derive(Debug)] pub struct ComputeClient { channel: Channel, - features: Arc, + settings: Arc<(Server::FeatureSet, Server::Properties)>, } impl Clone for ComputeClient @@ -25,7 +25,7 @@ where fn clone(&self) -> Self { Self { channel: self.channel.clone(), - features: self.features.clone(), + settings: self.settings.clone(), } } } @@ -36,8 +36,15 @@ where Channel: ComputeChannel, { /// Create a new client. - pub fn new(channel: Channel, features: Arc) -> Self { - Self { channel, features } + pub fn new( + channel: Channel, + features: Server::FeatureSet, + properties: Server::Properties, + ) -> Self { + Self { + channel, + settings: Arc::new((features, properties)), + } } /// Given a binding, returns owned resource as bytes. @@ -106,6 +113,11 @@ where /// Get the features supported by the compute server. pub fn features(&self) -> &Server::FeatureSet { - self.features.as_ref() + &self.settings.as_ref().0 + } + + /// Get the properties supported by the compute server. + pub fn properties(&self) -> &Server::Properties { + &self.settings.as_ref().1 } } diff --git a/crates/cubecl-runtime/src/memory_management/base.rs b/crates/cubecl-runtime/src/memory_management/base.rs index deeb4b239..3f763290a 100644 --- a/crates/cubecl-runtime/src/memory_management/base.rs +++ b/crates/cubecl-runtime/src/memory_management/base.rs @@ -27,8 +27,21 @@ pub trait MemoryManagement: Send + core::fmt::Debug { fn get(&mut self, binding: Self::Binding) -> StorageHandle; /// Returns the resource from the storage at the specified handle - fn get_resource(&mut self, binding: Self::Binding) -> Storage::Resource { + fn get_resource( + &mut self, + binding: Self::Binding, + offset_start: Option, + offset_end: Option, + ) -> Storage::Resource { let handle = self.get(binding); + let handle = match offset_start { + Some(offset) => handle.offset_start(offset), + None => handle, + }; + let handle = match offset_end { + Some(offset) => handle.offset_end(offset), + None => handle, + }; self.storage().get(&handle) } diff --git a/crates/cubecl-runtime/src/server.rs b/crates/cubecl-runtime/src/server.rs index dbc104f5e..1edee0e78 100644 --- a/crates/cubecl-runtime/src/server.rs +++ b/crates/cubecl-runtime/src/server.rs @@ -25,6 +25,8 @@ where type MemoryManagement: MemoryManagement; /// Features supported by the compute server. type FeatureSet: Send + Sync; + /// Properties of the compute server. + type Properties: Send + Sync; /// Given a handle, returns the owned resource as bytes. fn read(&mut self, binding: Binding) -> Reader; @@ -66,6 +68,33 @@ where pub struct Handle { /// Memory handle. pub memory: >::Handle, + /// Memory offset in bytes. + pub offset_start: Option, + /// Memory offset in bytes. + pub offset_end: Option, +} + +impl Handle { + /// Add to the current offset in bytes. + pub fn offset_start(mut self, offset: usize) -> Self { + if let Some(val) = &mut self.offset_start { + *val += offset; + } else { + self.offset_start = Some(offset); + } + + self + } + /// Add to the current offset in bytes. + pub fn offset_end(mut self, offset: usize) -> Self { + if let Some(val) = &mut self.offset_end { + *val += offset; + } else { + self.offset_end = Some(offset); + } + + self + } } /// Binding of a [tensor handle](Handle) to execute a kernel. @@ -73,6 +102,10 @@ pub struct Handle { pub struct Binding { /// Memory binding. pub memory: >::Binding, + /// Memory offset in bytes. + pub offset_start: Option, + /// Memory offset in bytes. + pub offset_end: Option, } impl Handle { @@ -87,6 +120,8 @@ impl Handle { pub fn binding(self) -> Binding { Binding { memory: MemoryHandle::binding(self.memory), + offset_start: self.offset_start, + offset_end: self.offset_end, } } } @@ -95,6 +130,8 @@ impl Clone for Handle { fn clone(&self) -> Self { Self { memory: self.memory.clone(), + offset_start: self.offset_start, + offset_end: self.offset_end, } } } @@ -103,6 +140,8 @@ impl Clone for Binding { fn clone(&self) -> Self { Self { memory: self.memory.clone(), + offset_start: self.offset_start, + offset_end: self.offset_end, } } } diff --git a/crates/cubecl-runtime/src/storage/base.rs b/crates/cubecl-runtime/src/storage/base.rs index 968aadbcd..bae7fcede 100644 --- a/crates/cubecl-runtime/src/storage/base.rs +++ b/crates/cubecl-runtime/src/storage/base.rs @@ -42,6 +42,44 @@ impl StorageHandle { StorageUtilization::Slice { offset, .. } => offset, } } + + /// Increase the current offset with the given value in bytes. + pub fn offset_start(&self, offset_bytes: usize) -> Self { + let utilization = match self.utilization { + StorageUtilization::Full(size) => StorageUtilization::Slice { + offset: offset_bytes, + size: size - offset_bytes, + }, + StorageUtilization::Slice { offset, size } => StorageUtilization::Slice { + offset: offset + offset_bytes, + size: size - offset_bytes, + }, + }; + + Self { + id: self.id, + utilization, + } + } + + /// Reduce the size of the memory handle.. + pub fn offset_end(&self, offset_bytes: usize) -> Self { + let utilization = match self.utilization { + StorageUtilization::Full(size) => StorageUtilization::Slice { + offset: 0, + size: size - offset_bytes, + }, + StorageUtilization::Slice { offset, size } => StorageUtilization::Slice { + offset, + size: size - offset_bytes, + }, + }; + + Self { + id: self.id, + utilization, + } + } } /// Storage types are responsible for allocating and deallocating memory. diff --git a/crates/cubecl-runtime/tests/dummy/compute.rs b/crates/cubecl-runtime/tests/dummy/compute.rs index be071d10e..7654a43b6 100644 --- a/crates/cubecl-runtime/tests/dummy/compute.rs +++ b/crates/cubecl-runtime/tests/dummy/compute.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use super::DummyServer; use cubecl_runtime::channel::MutexComputeChannel; use cubecl_runtime::client::ComputeClient; @@ -35,7 +33,7 @@ pub fn init_client() -> ComputeClient DummyClient { diff --git a/crates/cubecl-runtime/tests/dummy/server.rs b/crates/cubecl-runtime/tests/dummy/server.rs index 8c6ee178a..33bdf8bc3 100644 --- a/crates/cubecl-runtime/tests/dummy/server.rs +++ b/crates/cubecl-runtime/tests/dummy/server.rs @@ -28,6 +28,7 @@ where type Storage = BytesStorage; type MemoryManagement = MM; type FeatureSet = (); + type Properties = (); fn read(&mut self, binding: Binding) -> cubecl_common::reader::Reader { let bytes_handle = self.memory_management.get(binding.memory); @@ -54,7 +55,7 @@ where } fn empty(&mut self, size: usize) -> Handle { - Handle::new(self.memory_management.reserve(size, &[])) + Handle::new(self.memory_management.reserve(size, &[]), None, None) } unsafe fn execute( diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index b2c695b27..1a03ada62 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -569,6 +569,10 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(op.out), }, + cube::Operator::Round(op) => wgsl::Instruction::Round { + input: self.compile_variable(op.input), + out: self.compile_variable(op.out), + }, cube::Operator::Floor(op) => wgsl::Instruction::Floor { input: self.compile_variable(op.input), out: self.compile_variable(op.out), @@ -649,6 +653,11 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(op.out), }, + cube::Operator::BitwiseOr(op) => wgsl::Instruction::BitwiseOr { + lhs: self.compile_variable(op.lhs), + rhs: self.compile_variable(op.rhs), + out: self.compile_variable(op.out), + }, cube::Operator::BitwiseAnd(op) => wgsl::Instruction::BitwiseAnd { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index 39a47ccec..84112445b 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -201,6 +201,11 @@ pub enum Instruction { Loop { instructions: Vec, }, + BitwiseOr { + lhs: Variable, + rhs: Variable, + out: Variable, + }, BitwiseAnd { lhs: Variable, rhs: Variable, @@ -221,6 +226,10 @@ pub enum Instruction { rhs: Variable, out: Variable, }, + Round { + input: Variable, + out: Variable, + }, Floor { input: Variable, out: Variable, @@ -591,6 +600,9 @@ for (var {i}: u32 = {start}; {i} < {end}; {increment}) {{ } f.write_str("}\n") } + Instruction::BitwiseOr { lhs, rhs, out } => { + f.write_fmt(format_args!("{out} = {lhs} | {rhs};\n")) + } Instruction::BitwiseAnd { lhs, rhs, out } => { f.write_fmt(format_args!("{out} = {lhs} & {rhs};\n")) } @@ -603,6 +615,9 @@ for (var {i}: u32 = {start}; {i} < {end}; {increment}) {{ Instruction::ShiftRight { lhs, rhs, out } => { f.write_fmt(format_args!("{out} = {lhs} >> {rhs};\n")) } + Instruction::Round { input, out } => { + f.write_fmt(format_args!("{out} = round({input});\n")) + } Instruction::Floor { input, out } => { f.write_fmt(format_args!("{out} = floor({input});\n")) } diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index a5373b4fc..b8ea2e022 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -3,7 +3,9 @@ use std::num::NonZero; use super::WgpuStorage; use alloc::{borrow::Cow, sync::Arc}; use cubecl_common::{reader::Reader, sync_type::SyncType}; -use cubecl_core::{compute::DebugInformation, prelude::*, server::Handle, FeatureSet, KernelId}; +use cubecl_core::{ + compute::DebugInformation, prelude::*, server::Handle, FeatureSet, KernelId, Properties, +}; use cubecl_runtime::{ debug::DebugLogger, memory_management::{MemoryHandle, MemoryManagement}, @@ -129,6 +131,7 @@ where type Storage = WgpuStorage; type MemoryManagement = MM; type FeatureSet = FeatureSet; + type Properties = Properties; fn read(&mut self, binding: server::Binding) -> Reader { let resource = self.get_resource(binding); @@ -190,7 +193,11 @@ where &mut self, binding: server::Binding, ) -> ::Resource { - self.memory_management.get_resource(binding.memory) + self.memory_management.get_resource( + binding.memory, + binding.offset_start, + binding.offset_end, + ) } /// When we create a new handle from existing data, we use custom allocations so that we don't @@ -229,11 +236,11 @@ where .copy_from_slice(data); } - Handle::new(memory) + Handle::new(memory, None, None) } fn empty(&mut self, size: usize) -> server::Handle { - server::Handle::new(self.memory_management.reserve(size, &[])) + server::Handle::new(self.memory_management.reserve(size, &[]), None, None) } unsafe fn execute( @@ -255,7 +262,15 @@ where // Keep track of the storage we've used so far. self.compute_storage_used.push(resource_handle.id); - self.memory_management.storage().get(&resource_handle) + let handle = match binding.offset_start { + Some(offset) => resource_handle.offset_start(offset), + None => resource_handle.clone(), + }; + let handle = match binding.offset_end { + Some(offset) => handle.offset_end(offset), + None => handle, + }; + self.memory_management.storage().get(&handle) }) .collect(); diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index 450cde12e..185c466ab 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -4,7 +4,7 @@ use crate::{ AutoGraphicsApi, GraphicsApi, WgpuDevice, }; use alloc::sync::Arc; -use cubecl_core::{Feature, FeatureSet, Runtime}; +use cubecl_core::{Feature, FeatureSet, Properties, Runtime}; use cubecl_runtime::memory_management; use cubecl_runtime::{channel::MutexComputeChannel, client::ComputeClient, ComputeRuntime}; use wgpu::{DeviceDescriptor, Limits}; @@ -150,8 +150,11 @@ fn create_client( if features.contains(wgpu::Features::SUBGROUP) { features_cube.register(Feature::Subcube); } + let properties = Properties { + memory_offset_alignment: limits.min_storage_buffer_offset_alignment, + }; - ComputeClient::new(channel, Arc::new(features_cube)) + ComputeClient::new(channel, features_cube, properties) } /// Select the wgpu device and queue based on the provided [device](WgpuDevice). diff --git a/profiling/matmul-example/README.md b/profiling/matmul-example/README.md index e2d5c94cd..7ff757bf6 100644 --- a/profiling/matmul-example/README.md +++ b/profiling/matmul-example/README.md @@ -13,7 +13,7 @@ NVIDIA Nsight Compute is a powerful tool for GPU profiling. Make sure it is inst For effective profiling, isolate the kernel you want to profile into a main function. This allows you to focus on the performance of a specific kernel without interference from other parts of your code. ## 4. Use the CUDA device/runtime -Make sure your code uses the CUDA runtime API and device for lauching the kernel. +Make sure your code uses the CUDA runtime API and device for launching the kernel. ```rust #[cfg(feature = "cube-cuda")] @@ -29,19 +29,19 @@ mod cube_cuda { let client = CudaRuntime::client(&device); let num_of_batch = 12; - let heigth = 1024; + let height = 1024; let width = 1024; - let tensor_values: Vec = (0..num_of_batch * heigth * width) + let tensor_values: Vec = (0..num_of_batch * height * width) .map(|x| x as f32) .collect(); let tensor_a_handle = client.create(f32::as_bytes(&tensor_values)); let tensor_b_handle = client.create(f32::as_bytes(&tensor_values)); let tensor_c_handle = client.empty(12 * 1024 * 1024 * core::mem::size_of::()); - let tensor_a_shape = vec![num_of_batch, heigth, width]; - let tensor_b_shape = vec![num_of_batch, heigth, width]; - let tensor_c_shape = vec![num_of_batch, heigth, width]; + let tensor_a_shape = vec![num_of_batch, height, width]; + let tensor_b_shape = vec![num_of_batch, height, width]; + let tensor_c_shape = vec![num_of_batch, height, width]; let tensor_a: TensorHandle = TensorHandle::new_contiguous(tensor_a_shape, tensor_a_handle); diff --git a/profiling/matmul-example/src/main.rs b/profiling/matmul-example/src/main.rs index 77cd68277..08a28e428 100644 --- a/profiling/matmul-example/src/main.rs +++ b/profiling/matmul-example/src/main.rs @@ -33,19 +33,19 @@ mod cube_cuda { let client = CudaRuntime::client(&device); let num_of_batch = 12; - let heigth = 1024; + let height = 1024; let width = 1024; - let tensor_values: Vec = (0..num_of_batch * heigth * width) + let tensor_values: Vec = (0..num_of_batch * height * width) .map(|x| x as f32) .collect(); let tensor_a_handle = client.create(f32::as_bytes(&tensor_values)); let tensor_b_handle = client.create(f32::as_bytes(&tensor_values)); let tensor_c_handle = client.empty(12 * 1024 * 1024 * core::mem::size_of::()); - let tensor_a_shape = vec![num_of_batch, heigth, width]; - let tensor_b_shape = vec![num_of_batch, heigth, width]; - let tensor_c_shape = vec![num_of_batch, heigth, width]; + let tensor_a_shape = vec![num_of_batch, height, width]; + let tensor_b_shape = vec![num_of_batch, height, width]; + let tensor_c_shape = vec![num_of_batch, height, width]; let tensor_a: TensorHandle = TensorHandle::new_contiguous(tensor_a_shape, tensor_a_handle); diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index 6a086a11f..7a787ca94 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -1,19 +1,14 @@ [package] name = "xtask" -version = "0.2.0" +version = "1.0.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -anyhow = { workspace = true } -clap = { workspace = true } -derive_more = { workspace = true } -env_logger = { workspace = true } log = { workspace = true } -rand = { workspace = true, features = ["std"] } -serde_json = { version = "1.0.119" } strum = { workspace = true } +tracel-xtask = { workspace = true } [dev-dependencies] rstest = { workspace = true } diff --git a/xtask/src/commands/build.rs b/xtask/src/commands/build.rs new file mode 100644 index 000000000..081d268b5 --- /dev/null +++ b/xtask/src/commands/build.rs @@ -0,0 +1,17 @@ +use tracel_xtask::prelude::*; + +#[macros::extend_command_args(BuildCmdArgs, Target, None)] +pub struct CubeCLBuildCmdArgs { + /// Build in CI mode which excludes unsupported crates. + #[arg(long)] + pub ci: bool, +} + +pub(crate) fn handle_command(mut args: CubeCLBuildCmdArgs) -> anyhow::Result<()> { + if args.ci { + // Exclude crates that are not supported on CI + args.exclude.extend(vec!["cubecl-cuda".to_string()]); + } + base_commands::build::handle_command(args.try_into().unwrap())?; + Ok(()) +} diff --git a/xtask/src/commands/bump.rs b/xtask/src/commands/bump.rs deleted file mode 100644 index 7d1c60a62..000000000 --- a/xtask/src/commands/bump.rs +++ /dev/null @@ -1,42 +0,0 @@ -use std::process::Command; - -use anyhow::{anyhow, Ok}; -use clap::{Args, Subcommand}; -use strum::{Display, EnumIter, EnumString}; - -use crate::{endgroup, group, utils::cargo::ensure_cargo_crate_is_installed}; - -#[derive(Args)] -pub(crate) struct BumpCmdArgs { - #[command(subcommand)] - command: BumpCommand, -} - -#[derive(EnumString, EnumIter, Display, Clone, PartialEq, Subcommand)] -#[strum(serialize_all = "lowercase")] -enum BumpCommand { - /// Run unit tests. - Major, - /// Run integration tests. - Minor, - /// Run documentation tests. - Patch, -} - -pub(crate) fn handle_command(args: BumpCmdArgs) -> anyhow::Result<()> { - bump(&args.command) -} - -fn bump(command: &BumpCommand) -> anyhow::Result<()> { - group!("Bump version: {command}"); - ensure_cargo_crate_is_installed("cargo-edit", None, false)?; - let status = Command::new("cargo") - .args(["set-version", "--bump", &command.to_string()]) - .status() - .map_err(|e| anyhow!("Failed to execute cargo set-version: {}", e))?; - if !status.success() { - return Err(anyhow!("Cannot set new {command} version")); - } - endgroup!(); - Ok(()) -} diff --git a/xtask/src/commands/check.rs b/xtask/src/commands/check.rs deleted file mode 100644 index 12bba7f99..000000000 --- a/xtask/src/commands/check.rs +++ /dev/null @@ -1,211 +0,0 @@ -use std::process::Command; - -use anyhow::{anyhow, Ok, Result}; -use clap::{Args, Subcommand}; -use strum::{Display, EnumIter, EnumString, IntoEnumIterator}; - -use crate::{ - endgroup, group, - utils::{ - cargo::ensure_cargo_crate_is_installed, - prompt::ask_once, - workspace::{get_workspace_members, WorkspaceMemberType}, - }, -}; - -use super::Target; - -#[derive(Args)] -pub(crate) struct CheckCmdArgs { - /// Target to check for. - #[arg(short, long, value_enum)] - target: Target, - #[command(subcommand)] - command: CheckCommand, -} - -#[derive(EnumString, EnumIter, Display, Clone, PartialEq, Subcommand)] -#[strum(serialize_all = "lowercase")] -enum CheckCommand { - /// Run audit command. - Audit, - /// Run format command. - Format, - /// Run ling command. - Lint, - /// Run all the checks. - All, -} - -pub(crate) fn handle_command(args: CheckCmdArgs, answer: Option) -> anyhow::Result<()> { - match args.command { - CheckCommand::Audit => run_audit(&args.target, answer), - CheckCommand::Format => run_format(&args.target, answer), - CheckCommand::Lint => run_lint(&args.target, answer), - CheckCommand::All => { - let answer = ask_once( - "This will run all the checks with autofix on all members of the workspace.", - ); - CheckCommand::iter() - .filter(|c| *c != CheckCommand::All) - .try_for_each(|c| { - handle_command( - CheckCmdArgs { - command: c, - target: args.target.clone(), - }, - Some(answer), - ) - }) - } - } -} - -pub(crate) fn run_audit(target: &Target, mut answer: Option) -> anyhow::Result<()> { - match target { - Target::Crates | Target::Examples => { - if answer.is_none() { - answer = Some(ask_once( - "This will run the audit check with autofix mode enabled.", - )); - }; - if answer.unwrap() { - ensure_cargo_crate_is_installed("cargo-audit", Some("fix"), false)?; - group!("Audit: Crates and Examples"); - info!("Command line: cargo audit fix"); - let status = Command::new("cargo") - .args(["audit", "-q", "--color", "always", "fix"]) - .status() - .map_err(|e| anyhow!("Failed to execute cargo audit: {}", e))?; - if !status.success() { - return Err(anyhow!("Audit check execution failed")); - } - endgroup!(); - } - } - Target::All => { - if answer.is_none() { - answer = Some(ask_once("This will run audit checks on all targets.")); - }; - Target::iter() - .filter(|p| *p != Target::All && *p != Target::Examples) - .try_for_each(|p| run_audit(&p, answer))?; - } - } - Ok(()) -} - -fn run_format(target: &Target, mut answer: Option) -> Result<()> { - match target { - Target::Crates | Target::Examples => { - let members = match target { - Target::Crates => get_workspace_members(WorkspaceMemberType::Crate), - Target::Examples => get_workspace_members(WorkspaceMemberType::Example), - _ => unreachable!(), - }; - - if answer.is_none() { - answer = Some(ask_once(&format!( - "This will run format checks on all {} of the workspace.", - if *target == Target::Crates { - "crates" - } else { - "examples" - } - ))); - } - - if answer.unwrap() { - for member in members { - group!("Format: {}", member.name); - info!("Command line: cargo fmt -p {}", &member.name); - let status = Command::new("cargo") - .args(["fmt", "-p", &member.name]) - .status() - .map_err(|e| anyhow!("Failed to execute cargo fmt: {}", e))?; - if !status.success() { - return Err(anyhow!( - "Format check execution failed for {}", - &member.name - )); - } - endgroup!(); - } - } - } - Target::All => { - if answer.is_none() { - answer = Some(ask_once( - "This will run format check on all members of the workspace.", - )); - } - if answer.unwrap() { - Target::iter() - .filter(|t| *t != Target::All) - .try_for_each(|t| run_format(&t, answer))?; - } - } - } - Ok(()) -} - -fn run_lint(target: &Target, mut answer: Option) -> anyhow::Result<()> { - match target { - Target::Crates | Target::Examples => { - let members = match target { - Target::Crates => get_workspace_members(WorkspaceMemberType::Crate), - Target::Examples => get_workspace_members(WorkspaceMemberType::Example), - _ => unreachable!(), - }; - - if answer.is_none() { - answer = Some(ask_once(&format!( - "This will run lint fix on all {} of the workspace.", - if *target == Target::Crates { - "crates" - } else { - "examples" - } - ))); - } - - if answer.unwrap() { - for member in members { - group!("Lint: {}", member.name); - info!( - "Command line: cargo clippy --no-deps --fix --allow-dirty -p {}", - &member.name - ); - let status = Command::new("cargo") - .args([ - "clippy", - "--no-deps", - "--fix", - "--allow-dirty", - "-p", - &member.name, - ]) - .status() - .map_err(|e| anyhow!("Failed to execute cargo clippy: {}", e))?; - if !status.success() { - return Err(anyhow!("Lint fix execution failed for {}", &member.name)); - } - endgroup!(); - } - } - } - Target::All => { - if answer.is_none() { - answer = Some(ask_once( - "This will run lint fix on all members of the workspace.", - )); - } - if answer.unwrap() { - Target::iter() - .filter(|t| *t != Target::All) - .try_for_each(|t| run_lint(&t, answer))?; - } - } - } - Ok(()) -} diff --git a/xtask/src/commands/ci.rs b/xtask/src/commands/ci.rs deleted file mode 100644 index fba5be9dc..000000000 --- a/xtask/src/commands/ci.rs +++ /dev/null @@ -1,187 +0,0 @@ -use std::process::Command; - -use anyhow::{anyhow, Ok, Result}; -use clap::{Args, Subcommand}; -use strum::{Display, EnumIter, EnumString, IntoEnumIterator}; - -use crate::{ - endgroup, group, - utils::{ - cargo::ensure_cargo_crate_is_installed, - workspace::{get_workspace_members, WorkspaceMemberType}, - }, -}; - -use super::{ - test::{run_documentation, run_integration, run_unit}, - Target, -}; - -#[derive(Args)] -pub(crate) struct CICmdArgs { - /// Target to check for. - #[arg(short, long, value_enum)] - pub target: Target, - #[command(subcommand)] - pub command: CICommand, -} - -#[derive(EnumString, EnumIter, Display, Clone, PartialEq, Subcommand)] -#[strum(serialize_all = "lowercase")] -pub(crate) enum CICommand { - /// Run audit command. - Audit, - /// Run format command. - Format, - /// Run lint command. - Lint, - /// Run unit tests. - UnitTests, - /// Run integration tests. - IntegrationTests, - /// Run documentation tests. - DocTests, - /// Run all tests. - AllTests, - /// Run all the checks. - All, -} - -pub(crate) fn handle_command(args: CICmdArgs) -> anyhow::Result<()> { - match args.command { - CICommand::Audit => run_audit(&args.target), - CICommand::Format => run_format(&args.target), - CICommand::Lint => run_lint(&args.target), - CICommand::UnitTests => run_unit_tests(&args.target), - CICommand::IntegrationTests => run_integration_tests(&args.target), - CICommand::DocTests => run_doc_tests(&args.target), - CICommand::AllTests => run_all_tests(&args.target), - CICommand::All => CICommand::iter() - .filter(|c| *c != CICommand::All && *c != CICommand::AllTests) - .try_for_each(|c| { - handle_command(CICmdArgs { - command: c, - target: args.target.clone(), - }) - }), - } -} - -fn run_audit(target: &Target) -> anyhow::Result<()> { - match target { - Target::Crates | Target::Examples => { - group!("Audit: Crates and Examples"); - ensure_cargo_crate_is_installed("cargo-audit", Some("fix"), false)?; - info!("Command line: cargo audit"); - let status = Command::new("cargo") - .args(["audit", "-q", "--color", "always"]) - .status() - .map_err(|e| anyhow!("Failed to execute cargo audit: {}", e))?; - if !status.success() { - return Err(anyhow!("Audit check execution failed")); - } - endgroup!(); - } - Target::All => { - Target::iter() - .filter(|t| *t != Target::All && *t != Target::Examples) - .try_for_each(|t| run_audit(&t))?; - } - } - Ok(()) -} - -fn run_format(target: &Target) -> Result<()> { - match target { - Target::Crates | Target::Examples => { - let members = match target { - Target::Crates => get_workspace_members(WorkspaceMemberType::Crate), - Target::Examples => get_workspace_members(WorkspaceMemberType::Example), - _ => unreachable!(), - }; - - for member in members { - group!("Format: {}", member.name); - info!("Command line: cargo fmt --check -p {}", &member.name); - let status = Command::new("cargo") - .args(["fmt", "--check", "-p", &member.name]) - .status() - .map_err(|e| anyhow!("Failed to execute cargo fmt: {}", e))?; - if !status.success() { - return Err(anyhow!( - "Format check execution failed for {}", - &member.name - )); - } - endgroup!(); - } - } - Target::All => { - Target::iter() - .filter(|t| *t != Target::All) - .try_for_each(|t| run_format(&t))?; - } - } - Ok(()) -} - -fn run_lint(target: &Target) -> anyhow::Result<()> { - match target { - Target::Crates | Target::Examples => { - let members = match target { - Target::Crates => get_workspace_members(WorkspaceMemberType::Crate), - Target::Examples => get_workspace_members(WorkspaceMemberType::Example), - _ => unreachable!(), - }; - - for member in members { - group!("Lint: {}", member.name); - info!( - "Command line: cargo clippy --no-deps -p {} -- --deny warnings", - &member.name - ); - let status = Command::new("cargo") - .args([ - "clippy", - "--no-deps", - "-p", - &member.name, - "--", - "--deny", - "warnings", - ]) - .status() - .map_err(|e| anyhow!("Failed to execute cargo clippy: {}", e))?; - if !status.success() { - return Err(anyhow!("Lint fix execution failed for {}", &member.name)); - } - endgroup!(); - } - } - Target::All => { - Target::iter() - .filter(|t| *t != Target::All) - .try_for_each(|t| run_lint(&t))?; - } - } - Ok(()) -} - -fn run_unit_tests(target: &Target) -> anyhow::Result<()> { - run_unit(target) -} - -fn run_integration_tests(target: &Target) -> anyhow::Result<()> { - run_integration(target) -} - -fn run_doc_tests(target: &Target) -> anyhow::Result<()> { - run_documentation(target) -} - -fn run_all_tests(target: &Target) -> anyhow::Result<()> { - run_unit_tests(target)?; - run_integration_tests(target)?; - run_doc_tests(target)?; - Ok(()) -} diff --git a/xtask/src/commands/mod.rs b/xtask/src/commands/mod.rs index 00d1263f4..726f3ad4c 100644 --- a/xtask/src/commands/mod.rs +++ b/xtask/src/commands/mod.rs @@ -1,17 +1,2 @@ -pub(crate) mod bump; -pub(crate) mod check; -pub(crate) mod ci; -pub(crate) mod publish; -pub(crate) mod pull_request_checks; +pub(crate) mod build; pub(crate) mod test; - -use clap::ValueEnum; -use strum::{Display, EnumIter, EnumString}; - -#[derive(EnumString, EnumIter, Display, Clone, PartialEq, ValueEnum)] -#[strum(serialize_all = "lowercase")] -pub(crate) enum Target { - All, - Crates, - Examples, -} diff --git a/xtask/src/commands/publish.rs b/xtask/src/commands/publish.rs deleted file mode 100644 index e2ee4ef26..000000000 --- a/xtask/src/commands/publish.rs +++ /dev/null @@ -1,107 +0,0 @@ -use std::{env, process::Command, str}; - -use anyhow::{anyhow, Ok}; -use clap::Args; - -use crate::{endgroup, group}; - -// Crates.io API token -const CRATES_IO_API_TOKEN: &str = "CRATES_IO_API_TOKEN"; - -#[derive(Args)] -pub(crate) struct PublishCmdArgs { - /// The name of the crate to publish on crates.io - name: String, -} - -pub(crate) fn handle_command(args: PublishCmdArgs) -> anyhow::Result<()> { - let crate_name = args.name; - - group!("Publishing crate '{}'...", &crate_name); - // Retrieve local version for crate - let local_version = local_version(&crate_name)?; - info!("Local version: {local_version}"); - // Retrieve remote version for crate if it exists - match remote_version(&crate_name)? { - Some(remote_version) => { - info!("Found remote version: {remote_version}"); - // Early return if we don't need to publish the crate - if local_version == remote_version { - info!("Remote version is up to date, skipping publishing!"); - return Ok(()); - } - } - None => info!("This is the first version to be published on crates.io!"), - } - // Publish the crate - publish(crate_name)?; - endgroup!(); - - Ok(()) -} - -// Obtain local crate version -fn local_version(crate_name: &str) -> anyhow::Result { - // Obtain local crate version contained in cargo pkgid data - let cargo_pkgid_output = Command::new("cargo") - .args(["pkgid", "-p", crate_name]) - .output() - .map_err(|e| anyhow!("Failed to execute cargo pkgid: {}", e))?; - // Convert cargo pkgid output into a str - let cargo_pkgid_str = str::from_utf8(&cargo_pkgid_output.stdout) - .expect("Failed to convert pkgid output into a str"); - // Extract only the local crate version from str - let (_, local_version) = cargo_pkgid_str - .split_once('#') - .expect("Failed to get local crate version"); - Ok(local_version.trim_end().to_string()) -} - -// Obtain remote crate version -fn remote_version(crate_name: &str) -> anyhow::Result> { - // Obtain remote crate version contained in cargo search data - let cargo_search_output = Command::new("cargo") - .args(["search", crate_name, "--limit", "1"]) - .output() - .map_err(|e| anyhow!("Failed to execute cargo search: {}", e))?; - // Cargo search returns an empty string in case of a crate not present on crates.io - if cargo_search_output.stdout.is_empty() { - Ok(None) - } else { - // Convert cargo search output into a str - let remote_version_str = str::from_utf8(&cargo_search_output.stdout) - .expect("Failed to convert cargo search output into a str"); - - // Extract only the remote crate version from str - Ok(remote_version_str - .split_once('=') - .and_then(|(_, second)| second.trim_start().split_once(' ')) - .map(|(s, _)| s.trim_matches('"').to_string())) - } -} - -fn publish(crate_name: String) -> anyhow::Result<()> { - // Perform dry-run to ensure everything is good for publishing - let status = Command::new("cargo") - .args(["publish", "-p", &crate_name, "--dry-run"]) - .status() - .map_err(|e| anyhow!("Failed to execute cargo publish dry run: {}", e))?; - if !status.success() { - return Err(anyhow!( - "Publish dry run failed for crate '{}'.", - &crate_name - )); - } - let crates_io_token = - env::var(CRATES_IO_API_TOKEN).expect("Failed to retrieve the crates.io API token"); - // Actually publish the crate - let status = Command::new("cargo") - .env("CRATES_IO_API_TOKEN", crates_io_token.clone()) - .args(["publish", "-p", &crate_name, "--token", &crates_io_token]) - .status() - .map_err(|e| anyhow!("Failed to execute cargo publish: {}", e))?; - if !status.success() { - return Err(anyhow!("Publish failed for crate '{}'.", &crate_name)); - } - Ok(()) -} diff --git a/xtask/src/commands/pull_request_checks.rs b/xtask/src/commands/pull_request_checks.rs deleted file mode 100644 index cabdac262..000000000 --- a/xtask/src/commands/pull_request_checks.rs +++ /dev/null @@ -1,15 +0,0 @@ -use strum::IntoEnumIterator; - -use super::ci::{self, CICmdArgs, CICommand}; - -pub(crate) fn handle_command() -> anyhow::Result<()> { - CICommand::iter() - // Skip audit command - .filter(|c| *c != CICommand::All && *c != CICommand::AllTests && *c != CICommand::Audit) - .try_for_each(|c| { - ci::handle_command(CICmdArgs { - target: super::Target::All, - command: c.clone(), - }) - }) -} diff --git a/xtask/src/commands/test.rs b/xtask/src/commands/test.rs index 5503c8f2a..a6d186aae 100644 --- a/xtask/src/commands/test.rs +++ b/xtask/src/commands/test.rs @@ -1,178 +1,17 @@ -use std::process::Command; +use tracel_xtask::prelude::*; -use anyhow::{anyhow, Ok, Result}; -use clap::{Args, Subcommand}; -use strum::{Display, EnumIter, EnumString, IntoEnumIterator}; - -use crate::{ - endgroup, group, - utils::workspace::{get_workspace_members, WorkspaceMember, WorkspaceMemberType}, -}; - -use super::Target; - -#[derive(Args)] -pub(crate) struct TestCmdArgs { - /// Target to test for. - #[arg(short, long, value_enum)] - target: Target, - #[command(subcommand)] - command: TestCommand, -} - -#[derive(EnumString, EnumIter, Display, Clone, PartialEq, Subcommand)] -#[strum(serialize_all = "lowercase")] -enum TestCommand { - /// Run unit tests. - Unit, - /// Run integration tests. - Integration, - /// Run documentation tests. - Documentation, - /// Run all the checks. - All, -} - -pub(crate) fn handle_command(args: TestCmdArgs) -> anyhow::Result<()> { - match args.command { - TestCommand::Unit => run_unit(&args.target), - TestCommand::Integration => run_integration(&args.target), - TestCommand::Documentation => run_documentation(&args.target), - TestCommand::All => TestCommand::iter() - .filter(|c| *c != TestCommand::All) - .try_for_each(|c| { - handle_command(TestCmdArgs { - command: c, - target: args.target.clone(), - }) - }), - } -} - -pub(crate) fn run_unit(target: &Target) -> Result<()> { - match target { - Target::Crates | Target::Examples => { - let members = match target { - Target::Crates => get_workspace_members(WorkspaceMemberType::Crate), - Target::Examples => get_workspace_members(WorkspaceMemberType::Example), - _ => unreachable!(), - }; - - for member in members { - run_unit_test(&member)?; - } - } - Target::All => { - Target::iter() - .filter(|t| *t != Target::All) - .try_for_each(|t| run_unit(&t))?; - } - } - Ok(()) -} - -fn run_unit_test(member: &WorkspaceMember) -> Result<(), anyhow::Error> { - group!("Unit Tests: {}", member.name); - info!("Command line: cargo test --lib --bins -p {}", &member.name); - let status = Command::new("cargo") - .args(["test", "--lib", "--bins", "-p", &member.name]) - .status() - .map_err(|e| anyhow!("Failed to execute unit test: {}", e))?; - if !status.success() { - return Err(anyhow!("Failed to execute unit test for {}", &member.name)); - } - endgroup!(); - Ok(()) -} - -pub(crate) fn run_documentation(target: &Target) -> Result<()> { - match target { - Target::Crates | Target::Examples => { - let members = match target { - Target::Crates => get_workspace_members(WorkspaceMemberType::Crate), - Target::Examples => get_workspace_members(WorkspaceMemberType::Example), - _ => unreachable!(), - }; - - for member in members { - run_doc_test(&member)?; - } - } - Target::All => { - Target::iter() - .filter(|t| *t != Target::All) - .try_for_each(|t| run_documentation(&t))?; - } - } - Ok(()) -} - -fn run_doc_test(member: &WorkspaceMember) -> Result<(), anyhow::Error> { - group!("Doc Tests: {}", member.name); - info!("Command line: cargo test --doc -p {}", &member.name); - let status = Command::new("cargo") - .args(["test", "--doc", "-p", &member.name]) - .status() - .map_err(|e| anyhow!("Failed to execute documentation test: {}", e))?; - if !status.success() { - return Err(anyhow!( - "Failed to execute documentation test for {}", - &member.name - )); - } - endgroup!(); - Ok(()) -} - -pub(crate) fn run_integration(target: &Target) -> anyhow::Result<()> { - match target { - Target::Crates | Target::Examples => { - let members = match target { - Target::Crates => get_workspace_members(WorkspaceMemberType::Crate), - Target::Examples => get_workspace_members(WorkspaceMemberType::Example), - _ => unreachable!(), - }; - - for member in members { - run_integration_test(&member)?; - } - } - Target::All => { - Target::iter() - .filter(|t| *t != Target::All) - .try_for_each(|t| run_integration(&t))?; - } - } - Ok(()) +#[macros::extend_command_args(TestCmdArgs, Target, TestSubCommand)] +pub struct CubeCLTestCmdArgs { + /// Build in CI mode which excludes unsupported crates. + #[arg(long)] + pub ci: bool, } -fn run_integration_test(member: &WorkspaceMember) -> Result<()> { - group!("Integration Tests: {}", &member.name); - info!( - "Command line: cargo test --test \"test_*\" -p {}", - &member.name - ); - let output = Command::new("cargo") - .args(["test", "--test", "test_*", "-p", &member.name]) - .output() - .map_err(|e| anyhow!("Failed to execute integration test: {}", e))?; - - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr); - if stderr.contains("no test target matches pattern") { - warn!( - "No tests found matching the pattern `test_*` for {}", - &member.name - ); - endgroup!(); - return Ok(()); - } - return Err(anyhow!( - "Failed to execute integration test for {}: {}", - &member.name, - stderr - )); +pub(crate) fn handle_command(mut args: CubeCLTestCmdArgs) -> anyhow::Result<()> { + if args.ci { + // Exclude crates that are not supported on CI + args.exclude.extend(vec!["cubecl-cuda".to_string()]); } - endgroup!(); + base_commands::test::handle_command(args.try_into().unwrap())?; Ok(()) } diff --git a/xtask/src/logging.rs b/xtask/src/logging.rs deleted file mode 100644 index 672f229a6..000000000 --- a/xtask/src/logging.rs +++ /dev/null @@ -1,67 +0,0 @@ -use std::io::Write; - -/// Initialise and create a `env_logger::Builder` which follows the -/// GitHub Actions logging syntax when running on CI. -pub fn init_logger() -> env_logger::Builder { - let mut builder = env_logger::Builder::from_default_env(); - builder.target(env_logger::Target::Stdout); - - // Find and setup the correct log level - builder.filter(None, get_log_level()); - builder.write_style(env_logger::WriteStyle::Always); - - // Custom Formatter for Github Actions - if std::env::var("CI").is_ok() { - builder.format(|buf, record| match record.level().as_str() { - "DEBUG" => writeln!(buf, "::debug:: {}", record.args()), - "WARN" => writeln!(buf, "::warning:: {}", record.args()), - "ERROR" => { - writeln!(buf, "::error:: {}", record.args()) - } - _ => writeln!(buf, "{}", record.args()), - }); - } - - builder -} - -/// Determine the LogLevel for the logger -fn get_log_level() -> log::LevelFilter { - // DEBUG - match std::env::var("DEBUG") { - Ok(_value) => return log::LevelFilter::Debug, - Err(_err) => (), - } - // ACTIONS_RUNNER_DEBUG - match std::env::var("ACTIONS_RUNNER_DEBUG") { - Ok(_value) => return log::LevelFilter::Debug, - Err(_err) => (), - }; - - log::LevelFilter::Info -} - -/// Group Macro -#[macro_export] -macro_rules! group { - // group!() - ($($arg:tt)*) => { - let title = format!($($arg)*); - if std::env::var("CI").is_ok() { - log!(log::Level::Info, "::group::{}", title) - } else { - log!(log::Level::Info, "{}", title) - } - }; -} - -/// End Group Macro -#[macro_export] -macro_rules! endgroup { - // endgroup!() - () => { - if std::env::var("CI").is_ok() { - log!(log::Level::Info, "::endgroup::") - } - }; -} diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 4f0347cdf..945661188 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -1,59 +1,42 @@ mod commands; -// mod dependencies; -mod logging; -// mod runchecks; -mod utils; -// mod vulnerabilities; - -use crate::{logging::init_logger, utils::time::format_duration}; -use clap::{Parser, Subcommand}; -use std::time::Instant; #[macro_use] extern crate log; -#[derive(Parser)] -#[command(author, version, about, long_about = None)] -struct XtaskArgs { - #[command(subcommand)] - command: Command, -} - -#[derive(Subcommand)] -enum Command { - /// Bump the version of all crates to be published - Bump(commands::bump::BumpCmdArgs), - /// Runs checks and fix issues (used for development purposes) - Check(commands::check::CheckCmdArgs), - /// Runs checks for Continous Integration - CI(commands::ci::CICmdArgs), - /// Publish a crate to crates.io - Publish(commands::publish::PublishCmdArgs), - /// Runs tests. - Test(commands::test::TestCmdArgs), - /// Runs all tests and checks that should pass before opening a Pull Request. - PullRequestChecks, +use std::time::Instant; +use tracel_xtask::prelude::*; + +#[macros::base_commands( + Bump, + Check, + Compile, + Coverage, + Doc, + Dependencies, + Fix, + Publish, + Validate, + Vulnerabilities +)] +pub enum Command { + /// Build Burn in different modes. + Build(commands::build::CubeCLBuildCmdArgs), + /// Test Burn. + Test(commands::test::CubeCLTestCmdArgs), } fn main() -> anyhow::Result<()> { - init_logger().init(); - let args = XtaskArgs::parse(); - let start = Instant::now(); + let args = init_xtask::()?; match args.command { - Command::Bump(args) => commands::bump::handle_command(args), - Command::Check(args) => commands::check::handle_command(args, None), - Command::CI(args) => commands::ci::handle_command(args), - Command::Publish(args) => commands::publish::handle_command(args), - Command::Test(args) => commands::test::handle_command(args), - Command::PullRequestChecks => commands::pull_request_checks::handle_command(), + Command::Build(cmd_args) => commands::build::handle_command(cmd_args), + Command::Test(cmd_args) => commands::test::handle_command(cmd_args), + _ => dispatch_base_commands(args), }?; - let duration = start.elapsed(); info!( "\x1B[32;1mTime elapsed for the current execution: {}\x1B[0m", format_duration(&duration) ); - Ok(()) } diff --git a/xtask/src/test.rs b/xtask/src/test.rs deleted file mode 100644 index e69de29bb..000000000