Skip to content

Commit

Permalink
Stepped range (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Aug 18, 2024
1 parent b09821d commit f245f4b
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 9 deletions.
85 changes: 85 additions & 0 deletions crates/cubecl-core/src/frontend/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ use crate::ir::{Branch, Elem, If, IfElse, Item, Loop, RangeLoop, Variable};
use super::comptime::Comptime;
use super::ExpandElementTyped;

/// UInt range. Equivalent to:
/// ```no_run
/// for i in start..end { ... }
/// ```
pub fn range<S, E>(start: S, end: E, _unroll: Comptime<bool>) -> impl Iterator<Item = UInt>
where
S: Into<UInt>,
Expand All @@ -17,6 +21,30 @@ where
(start.val..end.val).map(UInt::new)
}

/// Stepped range. Equivalent to:
/// ```no_run
/// for i in (start..end).step_by(step) { ... }
/// ```
pub fn range_stepped<S, E, Step>(
start: S,
end: E,
step: Step,
_unroll: Comptime<bool>,
) -> impl Iterator<Item = UInt>
where
S: Into<UInt>,
E: Into<UInt>,
Step: Into<UInt>,
{
let start: UInt = start.into();
let end: UInt = end.into();
let step: UInt = step.into();

(start.val..end.val)
.step_by(step.val as usize)
.map(UInt::new)
}

pub fn range_expand<F, S, E>(context: &mut CubeContext, start: S, end: E, unroll: bool, mut func: F)
where
F: FnMut(&mut CubeContext, ExpandElementTyped<UInt>),
Expand Down Expand Up @@ -54,6 +82,63 @@ where
i: *i,
start: *start,
end: *end,
step: None,
scope: child.into_scope(),
}));
}
}

pub fn range_stepped_expand<F, S, E, Step>(
context: &mut CubeContext,
start: S,
end: E,
step: Step,
unroll: bool,
mut func: F,
) where
F: FnMut(&mut CubeContext, ExpandElementTyped<UInt>),
S: Into<ExpandElementTyped<UInt>>,
E: Into<ExpandElementTyped<UInt>>,
Step: Into<ExpandElementTyped<UInt>>,
{
let start: ExpandElementTyped<UInt> = start.into();
let end: ExpandElementTyped<UInt> = end.into();
let step: ExpandElementTyped<UInt> = step.into();
let start = start.expand;
let end = end.expand;
let step = step.expand;

if unroll {
let start = match start.deref() {
Variable::ConstantScalar(value) => value.as_usize(),
_ => panic!("Only constant start can be unrolled."),
};
let end = match end.deref() {
Variable::ConstantScalar(value) => value.as_usize(),
_ => panic!("Only constant end can be unrolled."),
};
let step: usize = match step.deref() {
Variable::ConstantScalar(value) => value.as_usize(),
_ => panic!("Only constant step can be unrolled."),
};

for i in (start..end).step_by(step) {
let var: ExpandElement = i.into();
func(context, var.into())
}
} else {
let mut child = context.child();
let index_ty = Item::new(Elem::UInt);
let i = child.scope.borrow_mut().create_local_undeclared(index_ty);
let i = ExpandElement::Plain(i);

func(&mut child, i.clone().into());

context.register(Branch::RangeLoop(RangeLoop {
i: *i,
start: *start,
end: *end,
step: Some(*step),
scope: child.into_scope(),
}));
}
Expand Down
21 changes: 18 additions & 3 deletions crates/cubecl-core/src/ir/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub struct RangeLoop {
pub i: Variable,
pub start: Variable,
pub end: Variable,
pub step: Option<Variable>,
pub scope: Scope,
}

Expand Down Expand Up @@ -91,6 +92,7 @@ impl RangeLoop {
parent_scope: &mut Scope,
start: Variable,
end: Variable,
step: Option<Variable>,
func: F,
) {
let mut scope = parent_scope.child();
Expand All @@ -103,6 +105,7 @@ impl RangeLoop {
i,
start,
end,
step,
scope,
}));
}
Expand All @@ -125,9 +128,21 @@ pub struct UnrolledRangeLoop;

impl UnrolledRangeLoop {
/// Registers an unrolled range loop to the given scope.
pub fn register<F: Fn(Variable, &mut Scope)>(scope: &mut Scope, start: u32, end: u32, func: F) {
for i in start..end {
func(i.into(), scope);
pub fn register<F: Fn(Variable, &mut Scope)>(
scope: &mut Scope,
start: u32,
end: u32,
step: Option<u32>,
func: F,
) {
if let Some(step) = step {
for i in (start..end).step_by(step as usize) {
func(i.into(), scope);
}
} else {
for i in start..end {
func(i.into(), scope);
}
}
}
}
18 changes: 15 additions & 3 deletions crates/cubecl-core/src/ir/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,16 +361,28 @@ macro_rules! cpa {
};
// range(start, end).for_each(|i, scope| { ... })
($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => {
$crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), $arg);
$crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), None, $arg);
};
// range(start, end, unroll).for_each(|i, scope| { ... })
($scope:expr, range($start:expr, $end:expr, $unroll:expr).for_each($arg:expr)) => {
if $unroll {
$crate::ir::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), $arg);
$crate::ir::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), None, $arg);
} else {
$crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), $arg);
$crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), None, $arg);
}
};
// range_stepped(start, end, step).for_each(|i, scope| { ... })
($scope:expr, range($start:expr, $end:expr, $step:expr).for_each($arg:expr)) => {
$crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), Some($step), $arg);
};
// range_stepped(start, end, step, unroll).for_each(|i, scope| { ... })
($scope:expr, range($start:expr, $end:expr, $step:expr, $unroll:expr).for_each($arg:expr)) => {
if $unroll {
$crate::ir::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), Some($step), $arg);
} else {
$crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), Some($step), $arg);
}
};
// loop(|scope| { ... })
($scope:expr, loop($arg:expr)) => {
$crate::ir::Loop::register($scope, $arg);
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-cuda/src/compiler/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ impl CudaCompiler {
i: self.compile_variable(range_loop.i),
start: self.compile_variable(range_loop.start),
end: self.compile_variable(range_loop.end),
step: range_loop.step.map(|it| self.compile_variable(it)),
instructions: self.compile_scope(&mut range_loop.scope),
}),
gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
Expand Down
8 changes: 7 additions & 1 deletion crates/cubecl-cuda/src/compiler/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub enum Instruction {
i: Variable,
start: Variable,
end: Variable,
step: Option<Variable>,
instructions: Vec<Self>,
},
Loop {
Expand Down Expand Up @@ -181,11 +182,16 @@ impl Display for Instruction {
i,
start,
end,
step,
instructions,
} => {
let increment = step
.map(|step| format!("{i} += {step}"))
.unwrap_or_else(|| format!("++{i}"));

f.write_fmt(format_args!(
"
for (uint {i} = {start}; {i} < {end}; {i}++) {{
for (uint {i} = {start}; {i} < {end}; {increment}) {{
"
))?;
for instruction in instructions {
Expand Down
43 changes: 42 additions & 1 deletion crates/cubecl-macros/src/codegen_function/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@ use super::{

/// Codegen of for loops
/// Supports range:
/// ```norun
/// for i in range(start, end, unroll) {...}
/// ```
/// and range_stepped:
/// ```norun
/// for i in range_stepped(start, end, step, unroll) {...}
/// ```
pub(crate) fn codegen_for_loop(
for_loop: &syn::ExprForLoop,
loop_level: usize,
Expand All @@ -30,7 +36,7 @@ pub(crate) fn codegen_for_loop(
let invalid_for_loop = || {
syn::Error::new_spanned(
&for_loop.expr,
"Invalid for loop: use [range](cubecl::prelude::range] instead.",
"Invalid for loop: use [range](cubecl::prelude::range] or [range_stepped](cubecl::prelude::range_stepped) instead.",
)
.into_compile_error()
};
Expand Down Expand Up @@ -76,6 +82,41 @@ pub(crate) fn codegen_for_loop(
cubecl::frontend::branch::range_expand(context, _start, _end, _unroll, |context, #i| #block);
}
}
} else if &func_name.to_string() == "range_stepped" {
let mut args = call.args.clone();

let unroll = codegen_expr(
&args.pop().unwrap().into_value(),
loop_level,
variable_tracker,
);
let step = codegen_expr(
&args.pop().unwrap().into_value(),
loop_level,
variable_tracker,
);
let end = codegen_expr(
&args.pop().unwrap().into_value(),
loop_level,
variable_tracker,
);
let start = codegen_expr(
&args.pop().unwrap().into_value(),
loop_level,
variable_tracker,
);

let block = codegen_block(&for_loop.body, loop_level + 1, variable_tracker);

quote::quote! {
{
let _start = #start;
let _end = #end;
let _step = #step;
let _unroll = #unroll;
cubecl::frontend::branch::range_stepped_expand(context, _start, _end, _step, _unroll, |context, #i| #block);
}
}
} else {
invalid_for_loop()
}
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ impl WgslCompiler {
i: self.compile_variable(range_loop.i),
start: self.compile_variable(range_loop.start),
end: self.compile_variable(range_loop.end),
step: range_loop.step.map(|it| self.compile_variable(it)),
instructions: self.compile_scope(&mut range_loop.scope),
})
}
Expand Down
9 changes: 8 additions & 1 deletion crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ pub enum Instruction {
i: Variable,
start: Variable,
end: Variable,
step: Option<Variable>,
instructions: Vec<Instruction>,
},
And {
Expand Down Expand Up @@ -515,11 +516,17 @@ impl Display for Instruction {
i,
start,
end,
step,
instructions,
} => {
let increment = step
.as_ref()
.map(|step| format!("{i} += {step}"))
.unwrap_or_else(|| format!("{i}++"));

f.write_fmt(format_args!(
"
for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
for (var {i}: u32 = {start}; {i} < {end}; {increment}) {{
"
))?;
for instruction in instructions {
Expand Down

0 comments on commit f245f4b

Please sign in to comment.