Skip to content

Commit

Permalink
Implement inclusive ranges
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Sep 9, 2024
1 parent 02bc447 commit 6a94008
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 2 deletions.
3 changes: 3 additions & 0 deletions crates/cubecl-core/src/ir/processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ impl ScopeProcessing {
Branch::RangeLoop(op) => {
sanitize_constant_scalar_ref_elem(&mut op.start, Elem::UInt);
sanitize_constant_scalar_ref_elem(&mut op.end, Elem::UInt);
if let Some(step) = &mut op.step {
sanitize_constant_scalar_ref_elem(step, Elem::UInt);
}
}
_ => {
// Nothing to do.
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-cuda/src/compiler/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ impl CudaCompiler {
start: self.compile_variable(range_loop.start),
end: self.compile_variable(range_loop.end),
step: range_loop.step.map(|it| self.compile_variable(it)),
inclusive: range_loop.inclusive,
instructions: self.compile_scope(&mut range_loop.scope),
}),
gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
Expand Down
5 changes: 4 additions & 1 deletion crates/cubecl-cuda/src/compiler/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ pub enum Instruction {
start: Variable,
end: Variable,
step: Option<Variable>,
inclusive: bool,
instructions: Vec<Self>,
},
Loop {
Expand Down Expand Up @@ -188,15 +189,17 @@ impl Display for Instruction {
start,
end,
step,
inclusive,
instructions,
} => {
let increment = step
.map(|step| format!("{i} += {step}"))
.unwrap_or_else(|| format!("++{i}"));
let cmp = if *inclusive { "<=" } else { "<" };

f.write_fmt(format_args!(
"
for (uint {i} = {start}; {i} < {end}; {increment}) {{
for (uint {i} = {start}; {i} {cmp} {end}; {increment}) {{
"
))?;
for instruction in instructions {
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ impl WgslCompiler {
start: self.compile_variable(range_loop.start),
end: self.compile_variable(range_loop.end),
step: range_loop.step.map(|it| self.compile_variable(it)),
inclusive: range_loop.inclusive,
instructions: self.compile_scope(&mut range_loop.scope),
})
}
Expand Down
5 changes: 4 additions & 1 deletion crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ pub enum Instruction {
start: Variable,
end: Variable,
step: Option<Variable>,
inclusive: bool,
instructions: Vec<Instruction>,
},
And {
Expand Down Expand Up @@ -531,16 +532,18 @@ impl Display for Instruction {
start,
end,
step,
inclusive,
instructions,
} => {
let increment = step
.as_ref()
.map(|step| format!("{i} += {step}"))
.unwrap_or_else(|| format!("{i}++"));
let cmp = if *inclusive { "<=" } else { "<" };

f.write_fmt(format_args!(
"
for (var {i}: u32 = {start}; {i} < {end}; {increment}) {{
for (var {i}: u32 = {start}; {i} {cmp} {end}; {increment}) {{
"
))?;
for instruction in instructions {
Expand Down

0 comments on commit 6a94008

Please sign in to comment.