Skip to content

Commit

Permalink
add comptime marker to line_size in naive
Browse files Browse the repository at this point in the history
  • Loading branch information
maxtremblay committed Nov 29, 2024
1 parent 6e37492 commit 8b98e53
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions crates/cubecl-reduce/src/naive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub trait ReduceNaiveInstruction<EI: Numeric>: Send + Sync + 'static {
/// This could be called many time during reduction. It is required
/// that reducing the initial accumulator any number of times do not change the outcome
/// of the reduction. For example, adding 0s in a sum do not change the outcome.
fn init_accumulator(line_size: u32) -> Self::Accumulator;
fn init_accumulator(#[comptime] ine_size: u32) -> Self::Accumulator;

/// Reduce `current_value` into `accumulator`.
fn accumulate(accumulator: &mut Self::Accumulator, item: Line<EI>, coordinate: u32);
Expand Down Expand Up @@ -75,7 +75,7 @@ pub fn reduce_naive<RD: ReduceNaiveInstruction<EI>, EI: Numeric, EO: Numeric>(
impl<EI: Numeric> ReduceNaiveInstruction<EI> for Sum {
type Accumulator = Line<EI>;

fn init_accumulator(line_size: u32) -> Line<EI> {
fn init_accumulator(#[comptime] line_size: u32) -> Line<EI> {
Line::empty(line_size).fill(EI::from_int(0))
}

Expand All @@ -97,7 +97,7 @@ impl<EI: Numeric> ReduceNaiveInstruction<EI> for Sum {
impl<EI: Numeric> ReduceNaiveInstruction<EI> for Prod {
type Accumulator = Line<EI>;

fn init_accumulator(line_size: u32) -> Line<EI> {
fn init_accumulator(#[comptime] line_size: u32) -> Line<EI> {
Line::empty(line_size).fill(EI::from_int(1))
}

Expand All @@ -119,7 +119,7 @@ impl<EI: Numeric> ReduceNaiveInstruction<EI> for Prod {
impl<EI: Numeric> ReduceNaiveInstruction<EI> for Mean {
type Accumulator = Line<EI>;

fn init_accumulator(line_size: u32) -> Self::Accumulator {
fn init_accumulator(#[comptime] line_size: u32) -> Self::Accumulator {
Line::empty(line_size).fill(EI::from_int(0))
}

Expand All @@ -143,7 +143,7 @@ impl<EI: Numeric> ReduceNaiveInstruction<EI> for Mean {
impl<EI: Numeric> ReduceNaiveInstruction<EI> for ArgMax {
type Accumulator = (Line<EI>, Line<u32>);

fn init_accumulator(line_size: u32) -> Self::Accumulator {
fn init_accumulator(#[comptime] line_size: u32) -> Self::Accumulator {
(
// TODO: switch to using f32::NEG_INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68
Line::empty(line_size).fill(EI::MIN),
Expand Down Expand Up @@ -178,7 +178,7 @@ impl<EI: Numeric> ReduceNaiveInstruction<EI> for ArgMax {
impl<EI: Numeric> ReduceNaiveInstruction<EI> for ArgMin {
type Accumulator = (Line<EI>, Line<u32>);

fn init_accumulator(line_size: u32) -> Self::Accumulator {
fn init_accumulator(#[comptime] line_size: u32) -> Self::Accumulator {
(
// TODO: switch to using f32::INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68
Line::empty(line_size).fill(EI::MAX),
Expand Down

0 comments on commit 8b98e53

Please sign in to comment.