Skip to content

Commit

Permalink
Optimize if-else
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Sep 11, 2024
1 parent e9d29d6 commit 90b175c
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 38 deletions.
56 changes: 40 additions & 16 deletions crates/cubecl-core/src/frontend/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,33 +246,57 @@ pub fn if_expand(
}
}

pub enum IfElseExpand {
ComptimeThen,
ComptimeElse,
Runtime {
runtime_cond: ExpandElement,
then_child: CubeContext,
},
}

impl IfElseExpand {
pub fn or_else(self, context: &mut CubeContext, else_block: impl FnOnce(&mut CubeContext)) {
match self {
Self::Runtime {
runtime_cond,
then_child,
} => {
let mut else_child = context.child();
else_block(&mut else_child);

context.register(Branch::IfElse(IfElse {
cond: *runtime_cond,
scope_if: then_child.into_scope(),
scope_else: else_child.into_scope(),
}));
}
Self::ComptimeElse => else_block(context),
Self::ComptimeThen => (),
}
}
}

pub fn if_else_expand(
context: &mut CubeContext,
runtime_cond: ExpandElement,
then_block: impl FnOnce(&mut CubeContext),
else_block: impl FnOnce(&mut CubeContext),
) {
) -> IfElseExpand {
let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
match comptime_cond {
Some(cond) => {
if cond {
then_block(context);
} else {
else_block(context);
}
Some(true) => {
then_block(context);
IfElseExpand::ComptimeThen
}
Some(false) => IfElseExpand::ComptimeElse,
None => {
let mut then_child = context.child();
then_block(&mut then_child);

let mut else_child = context.child();
else_block(&mut else_child);

context.register(Branch::IfElse(IfElse {
cond: *runtime_cond,
scope_if: then_child.into_scope(),
scope_else: else_child.into_scope(),
}));
IfElseExpand::Runtime {
runtime_cond,
then_child,
}
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/tests/frontend/comptime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ mod tests {
cpa!(&mut scope, if(runtime_cond).then(|scope| {
cpa!(scope, y = x + 5.0f32);
}).else(|scope| {
cpa!(scope, y = x - 6.0f32);
cpa!(scope, x = x - 6.0f32);
}));
};

Expand All @@ -368,7 +368,7 @@ mod tests {
if comptime_cond {
cpa!(scope, y = x + 5.0f32);
} else {
cpa!(scope, y = x - 6.0f32);
cpa!(scope, x = x - 6.0f32);
}
}));

Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/tests/frontend/if.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ mod tests {
cpa!(&mut scope, if(cond).then(|scope| {
cpa!(scope, y = lhs + 4.0f32);
}).else(|scope|{
cpa!(scope, y = lhs - 5.0f32);
cpa!(scope, lhs = lhs - 5.0f32);
}));

format!("{:?}", scope.operations)
Expand All @@ -143,7 +143,7 @@ mod tests {
cpa!(&mut scope, if(cond2).then(|scope| {
cpa!(scope, y = lhs + 1.0f32);
}).else(|scope|{
cpa!(scope, y = lhs + 0.0f32);
cpa!(scope, lhs = lhs + 0.0f32);
}));
}));

Expand Down
15 changes: 10 additions & 5 deletions crates/cubecl-cuda/src/compiler/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,16 @@ impl CudaCompiler {
cond: self.compile_variable(op.cond),
instructions: self.compile_scope(&mut op.scope),
}),
gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
cond: self.compile_variable(op.cond),
instructions_if: self.compile_scope(&mut op.scope_if),
instructions_else: self.compile_scope(&mut op.scope_else),
}),
gpu::Branch::IfElse(mut op) => {
// Else is the latter branch and consumes variables, so compile that first to free
// variables for the then block. Rust doesn't guarantee struct init execution order.
let instructions_else = self.compile_scope(&mut op.scope_else);
instructions.push(Instruction::IfElse {
cond: self.compile_variable(op.cond),
instructions_if: self.compile_scope(&mut op.scope_if),
instructions_else,
});
}
gpu::Branch::Return => instructions.push(Instruction::Return),
gpu::Branch::Break => instructions.push(Instruction::Break),
gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
Expand Down
1 change: 0 additions & 1 deletion crates/cubecl-macros/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ pub enum Expression {
condition: Box<Expression>,
then_block: Block,
else_branch: Option<Box<Expression>>,
scope: Scope,
},
Return {
expr: Option<Box<Expression>>,
Expand Down
7 changes: 3 additions & 4 deletions crates/cubecl-macros/src/generate/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,16 +282,15 @@ impl Expression {
condition,
then_block,
else_branch: Some(else_branch),
scope,
} => {
let path = frontend_path();
let condition = condition.to_tokens(context);
let then_block = context.in_fn_mut(scope, |ctx| then_block.to_tokens(ctx));
let else_branch = context.in_fn_mut(scope, |ctx| else_branch.to_tokens(ctx));
let then_block = then_block.to_tokens(context);
let else_branch = else_branch.to_tokens(context);
quote! {
{
let _cond = #condition;
#path::branch::if_else_expand(context, _cond.into(), |context| #then_block, |context| #else_branch);
#path::branch::if_else_expand(context, _cond.into(), |context| #then_block).or_else(context, |context| #else_branch);
}
}
}
Expand Down
4 changes: 1 addition & 3 deletions crates/cubecl-macros/src/parse/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ pub fn expand_if(if_expr: ExprIf, context: &mut Context) -> syn::Result<Expressi
let condition = Expression::from_expr(*if_expr.cond, context)
.map_err(|_| syn::Error::new(span, "Unsupported while condition"))?;

let (then_block, scope) =
context.in_scope(|ctx| Block::from_block(if_expr.then_branch, ctx))?;
let (then_block, _) = context.in_scope(|ctx| Block::from_block(if_expr.then_branch, ctx))?;
let else_branch = if let Some((_, else_branch)) = if_expr.else_branch {
let (expr, _) = context.in_scope(|ctx| Expression::from_expr(*else_branch, ctx))?;
Some(Box::new(expr))
Expand All @@ -88,7 +87,6 @@ pub fn expand_if(if_expr: ExprIf, context: &mut Context) -> syn::Result<Expressi
condition: Box::new(condition),
then_block,
else_branch,
scope,
})
}

Expand Down
15 changes: 10 additions & 5 deletions crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,16 @@ impl WgslCompiler {
cond: self.compile_variable(op.cond),
instructions: self.compile_scope(&mut op.scope),
}),
cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse {
cond: self.compile_variable(op.cond),
instructions_if: self.compile_scope(&mut op.scope_if),
instructions_else: self.compile_scope(&mut op.scope_else),
}),
cube::Branch::IfElse(mut op) => {
// Else is the latter branch and consumes variables, so compile that first to free
// variables for the then block. Rust doesn't guarantee struct init execution order.
let instructions_else = self.compile_scope(&mut op.scope_else);
instructions.push(wgsl::Instruction::IfElse {
cond: self.compile_variable(op.cond),
instructions_if: self.compile_scope(&mut op.scope_if),
instructions_else,
});
}
cube::Branch::Return => instructions.push(wgsl::Instruction::Return),
cube::Branch::Break => instructions.push(wgsl::Instruction::Break),
cube::Branch::RangeLoop(mut range_loop) => {
Expand Down

0 comments on commit 90b175c

Please sign in to comment.