Skip to content

Commit

Permalink
[spv-in] Bubble up loop breaks, out of switch cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyb committed May 15, 2023
1 parent 65a791c commit 13c52a0
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 23 deletions.
178 changes: 156 additions & 22 deletions src/front/spv/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{

use super::{Error, Instruction, LookupExpression, LookupHelper as _};
use crate::front::Emitter;
use std::cell::Cell;

pub type BlockId = u32;

Expand Down Expand Up @@ -128,7 +129,7 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
expressions: &mut fun.expressions,
local_arena: &mut fun.local_variables,
const_arena: &mut module.constants,
type_arena: &module.types,
type_arena: &mut module.types,
global_arena: &module.global_variables,
arguments: &fun.arguments,
parameter_sampling: &mut parameters_sampling,
Expand Down Expand Up @@ -576,33 +577,54 @@ impl<'function> BlockContext<'function> {
}

/// Consumes the `BlockContext` producing a Ir [`Block`](crate::Block)
fn lower(mut self) -> crate::Block {
fn lower(self) -> crate::Block {
/// Smaller context type, with only the subset of `BlockContext`'s fields
/// that are needed for `lower_impl` below.
struct BlockLowerContext<'a, 'function> {
expressions: &'function mut Arena<crate::Expression>,
local_arena: &'function mut Arena<crate::LocalVariable>,
const_arena: &'function mut Arena<crate::Constant>,
type_arena: &'function mut crate::UniqueArena<crate::Type>,

blocks: crate::FastHashMap<spirv::Word, crate::Block>,
bodies: &'a [super::Body],
}

/// Helper type used for tracking the control-flow constructs which
/// support the [`Statement::Break`](crate::Statement::Break) syntax.
#[derive(Copy, Clone, Debug)]
enum Breakable {
enum Breakable<'a> {
Loop,
Switch,
Switch {
/// This `Cell<Option<...>>` is set to `Some(loop_break_cond_var_ptr)`
/// when a loop `break` is found nested in the switch `case`s,
/// and `loop_break_cond_var_ptr` is a pointer to a `bool` local,
/// which (dynamically) tracks whether the nested loop `break`
/// was actually reached (and so the `switch` gets to handle
/// `break`-ing out of the `loop`, after the `switch` itself).
bubbled_up_loop_break_cond_var_ptr: &'a Cell<Option<Handle<crate::Expression>>>,
},
}

fn lower_impl(
blocks: &mut crate::FastHashMap<spirv::Word, crate::Block>,
bodies: &[super::Body],
lower_ctx: &mut BlockLowerContext<'_, '_>,
body_idx: BodyIndex,
innermost_breakable: Option<Breakable>,
) -> crate::Block {
let mut block = crate::Block::new();

for item in bodies[body_idx].data.iter() {
for item in lower_ctx.bodies[body_idx].data.iter() {
match *item {
super::BodyFragment::BlockId(id) => block.append(blocks.get_mut(&id).unwrap()),
super::BodyFragment::BlockId(id) => {
block.append(lower_ctx.blocks.get_mut(&id).unwrap())
}
super::BodyFragment::If {
condition,
accept,
reject,
} => {
let accept = lower_impl(blocks, bodies, accept, innermost_breakable);
let reject = lower_impl(blocks, bodies, reject, innermost_breakable);
let accept = lower_impl(lower_ctx, accept, innermost_breakable);
let reject = lower_impl(lower_ctx, reject, innermost_breakable);

block.push(
crate::Statement::If {
Expand All @@ -618,12 +640,11 @@ impl<'function> BlockContext<'function> {
continuing,
break_if,
} => {
let body = lower_impl(blocks, bodies, body, Some(Breakable::Loop));
let body = lower_impl(lower_ctx, body, Some(Breakable::Loop));
// NOTE(eddyb) the `continuing {...}` block cannot `break`,
// but this is checked in the validator, and so it's allowed
// here (where we could only panic, which is worse UX).
let continuing =
lower_impl(blocks, bodies, continuing, Some(Breakable::Loop));
let continuing = lower_impl(lower_ctx, continuing, Some(Breakable::Loop));

block.push(
crate::Statement::Loop {
Expand All @@ -639,11 +660,17 @@ impl<'function> BlockContext<'function> {
ref cases,
default,
} => {
let bubbled_up_loop_break_cond_var_ptr = &Cell::new(None);
let mut ir_cases: Vec<_> = cases
.iter()
.map(|&(value, body_idx)| {
let body =
lower_impl(blocks, bodies, body_idx, Some(Breakable::Switch));
let body = lower_impl(
lower_ctx,
body_idx,
Some(Breakable::Switch {
bubbled_up_loop_break_cond_var_ptr,
}),
);

// Handle simple cases that would make a fallthrough statement unreachable code
let fall_through = body.last().map_or(true, |s| !s.is_terminator());
Expand All @@ -657,7 +684,13 @@ impl<'function> BlockContext<'function> {
.collect();
ir_cases.push(crate::SwitchCase {
value: crate::SwitchValue::Default,
body: lower_impl(blocks, bodies, default, Some(Breakable::Switch)),
body: lower_impl(
lower_ctx,
default,
Some(Breakable::Switch {
bubbled_up_loop_break_cond_var_ptr,
}),
),
fall_through: false,
});

Expand All @@ -667,27 +700,108 @@ impl<'function> BlockContext<'function> {
cases: ir_cases,
},
crate::Span::default(),
)
);

if let Some(loop_break_cond_var_ptr) =
bubbled_up_loop_break_cond_var_ptr.get()
{
match innermost_breakable.expect("stray loop `break`") {
Breakable::Loop => {}
Breakable::Switch { .. } => {
unreachable!(
"loop `break` from multiple levels of nested switch"
)
}
}

let mut emitter = Emitter::default();
emitter.start(lower_ctx.expressions);
let loop_break_cond = lower_ctx.expressions.append(
crate::Expression::Load {
pointer: loop_break_cond_var_ptr,
},
crate::Span::default(),
);
block.extend(emitter.finish(lower_ctx.expressions));

block.push(
crate::Statement::If {
condition: loop_break_cond,
accept: crate::Block::from_vec(vec![crate::Statement::Break]),
reject: crate::Block::new(),
},
crate::Span::default(),
);
}
}
super::BodyFragment::Continue => {
block.push(crate::Statement::Continue, crate::Span::default())
}
super::BodyFragment::LoopBreak => {
match innermost_breakable.expect("stray loop `break`") {
Breakable::Loop => {}
Breakable::Switch => {
unimplemented!("loop `break` from nested switch")
Breakable::Switch {
bubbled_up_loop_break_cond_var_ptr,
} => {
let bool_ty = lower_ctx.type_arena.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
},
},
crate::Span::default(),
);
let mut bool_const = |value| {
lower_ctx.const_arena.fetch_or_append(
crate::Constant {
name: None,
specialization: None,
inner: crate::ConstantInner::boolean(value),
},
crate::Span::default(),
)
};
if bubbled_up_loop_break_cond_var_ptr.get().is_none() {
let local = lower_ctx.local_arena.append(
crate::LocalVariable {
name: None,
ty: bool_ty,
init: Some(bool_const(false)),
},
crate::Span::default(),
);
let local_ptr = lower_ctx.expressions.append(
crate::Expression::LocalVariable(local),
crate::Span::default(),
);
bubbled_up_loop_break_cond_var_ptr.set(Some(local_ptr));
}

// Store a `true` in the local variable that the
// parent `switch` will read and use to decide
// whether to `break` out of its parent `loop`.
block.push(
crate::Statement::Store {
pointer: bubbled_up_loop_break_cond_var_ptr.get().unwrap(),
value: lower_ctx.expressions.append(
crate::Expression::Constant(bool_const(true)),
crate::Span::default(),
),
},
crate::Span::default(),
);
}
}
assert!(matches!(innermost_breakable.unwrap(), Breakable::Loop), "");
block.push(crate::Statement::Break, crate::Span::default())
}
super::BodyFragment::SwitchBreak => {
match innermost_breakable.expect("stray switch `break`") {
Breakable::Loop => {
unreachable!("switch `break` from nested loop")
}
Breakable::Switch => {}
Breakable::Switch { .. } => {}
}
block.push(crate::Statement::Break, crate::Span::default())
}
Expand All @@ -697,6 +811,26 @@ impl<'function> BlockContext<'function> {
block
}

lower_impl(&mut self.blocks, &self.bodies, 0, None)
let Self {
expressions,
local_arena,
const_arena,
type_arena,
blocks,
ref bodies,
..
} = self;
lower_impl(
&mut BlockLowerContext {
expressions,
local_arena,
const_arena,
type_arena,
blocks,
bodies,
},
0,
None,
)
}
}
2 changes: 1 addition & 1 deletion src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ struct BlockContext<'function> {
/// Constants arena of the module being processed
const_arena: &'function mut Arena<crate::Constant>,
/// Type arena of the module being processed
type_arena: &'function UniqueArena<crate::Type>,
type_arena: &'function mut UniqueArena<crate::Type>,
/// Global arena of the module being processed
global_arena: &'function Arena<crate::GlobalVariable>,
/// Arguments of the function currently being processed
Expand Down
6 changes: 6 additions & 0 deletions tests/out/glsl/loop-break-from-switch.main.Fragment.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,23 @@ layout(location = 0) flat in int _vs2fs_location0;
layout(location = 0) out int _fs2p_location0;

void function() {
bool local = false;
int _e8 = global;
while(true) {
switch(_e8) {
case 0: {
global_1 = 0;
local = true;
break;
}
default: {
break;
}
}
bool _e11 = local;
if (_e11) {
break;
}
global_1 = -9;
break;
}
Expand Down
7 changes: 7 additions & 0 deletions tests/out/hlsl/loop-break-from-switch.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,24 @@ struct FragmentInput_main {

void function()
{
bool local = false;

int _expr8 = global;
while(true) {
switch(_expr8) {
case 0: {
global_1 = 0;
local = true;
break;
}
default: {
break;
}
}
bool _expr11 = local;
if (_expr11) {
break;
}
global_1 = -9;
break;
}
Expand Down
6 changes: 6 additions & 0 deletions tests/out/msl/loop-break-from-switch.msl
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,23 @@ void function(
thread int& global,
thread int& global_1
) {
bool local = false;
int _e8 = global;
while(true) {
switch(_e8) {
case 0: {
global_1 = 0;
local = true;
break;
}
default: {
break;
}
}
bool _e11 = local;
if (_e11) {
break;
}
global_1 = -9;
break;
}
Expand Down
7 changes: 7 additions & 0 deletions tests/out/wgsl/loop-break-from-switch.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,24 @@ var<private> global: i32;
var<private> global_1: i32;

fn function() {
var local: bool = false;

let _e8 = global;
loop {
switch _e8 {
case 0: {
global_1 = 0;
local = true;
break;
}
default: {
break;
}
}
let _e11 = local;
if _e11 {
break;
}
global_1 = -9;
break;
}
Expand Down

0 comments on commit 13c52a0

Please sign in to comment.