Skip to content

Commit

Permalink
Fix codegen error when nesting dynamic and static for-loops
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-heiden committed Oct 8, 2024
1 parent 676edf8 commit ea5ab54
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

- Fix potential out-of-bounds memory access when a `wp.sparse.BsrMatrix` object is reused for storing matrices of different shapes
- Fix robustness to very low desired tolerance in `wp.fem.utils.symmetric_eigenvalues_qr`
- Fix invalid code generation error messages when nesting dynamic and static for-loops

## [1.4.0] - 2024-10-01

Expand Down
25 changes: 17 additions & 8 deletions warp/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def build(adj, builder, default_builder_options=None):

adj.return_var = None # return type for function or kernel
adj.loop_symbols = [] # symbols at the start of each loop
adj.loop_const_iter_symbols = set() # iteration variables (constant) for static loops
adj.loop_const_iter_symbols = [] # iteration variables (constant) for static loops

# blocks
adj.blocks = [Block()]
Expand Down Expand Up @@ -1849,7 +1849,7 @@ def materialize_redefinitions(adj, symbols):
# detect symbols with conflicting definitions (assigned inside the for loop)
for items in symbols.items():
sym = items[0]
if adj.loop_const_iter_symbols is not None and sym in adj.loop_const_iter_symbols:
if adj.is_constant_iter_symbol(sym):
# ignore constant overwriting in for-loops if it is a loop iterator
# (it is no problem to unroll static loops multiple times in sequence)
continue
Expand Down Expand Up @@ -2001,21 +2001,30 @@ def get_unroll_range(adj, loop):
return range_call

def begin_record_constant_iter_symbols(adj):
if adj.loop_const_iter_symbols is None:
adj.loop_const_iter_symbols = set()
if len(adj.loop_const_iter_symbols) > 0:
adj.loop_const_iter_symbols.append(adj.loop_const_iter_symbols[-1])
else:
adj.loop_const_iter_symbols.append(set())

def end_record_constant_iter_symbols(adj):
adj.loop_const_iter_symbols = None
if len(adj.loop_const_iter_symbols) > 0:
adj.loop_const_iter_symbols.pop()

def record_constant_iter_symbol(adj, sym):
if len(adj.loop_const_iter_symbols) > 0:
adj.loop_const_iter_symbols[-1].add(sym)

def is_constant_iter_symbol(adj, sym):
return len(adj.loop_const_iter_symbols) > 0 and sym in adj.loop_const_iter_symbols[-1]

def emit_For(adj, node):
# try and unroll simple range() statements that use constant args
unroll_range = adj.get_unroll_range(node)

if isinstance(unroll_range, range):
const_iter_sym = node.target.id
if adj.loop_const_iter_symbols is not None:
# prevent constant conflicts in `materialize_redefinitions()`
adj.loop_const_iter_symbols.add(const_iter_sym)
# prevent constant conflicts in `materialize_redefinitions()`
adj.record_constant_iter_symbol(const_iter_sym)

# unroll static for-loop
for i in unroll_range:
Expand Down
31 changes: 31 additions & 0 deletions warp/tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,37 @@ def dynamic_loop_kernel(n: int, input: wp.array(dtype=float)):
):
wp.launch(dynamic_loop_kernel, dim=1, inputs=[3, inputs], device=device)

# the following nested loop must not raise an error
const_a = 7
const_b = 5

@wp.kernel
def mixed_dyn_static_loop_kernel(dyn_a: int, dyn_b: int, dyn_c: int, output: wp.array(dtype=float, ndim=2)):
tid = wp.tid()
for i in range(const_a + 1):
for j in range(dyn_a + 1):
for k in range(dyn_b + 1):
for l in range(const_b + 1):
for m in range(dyn_c + 1):
coeff = i + j + k + l + m
output[tid, coeff] = 1.0

dyn_a, dyn_b, dyn_c = 3, 4, 5
num_threads = 10
output = wp.empty([num_threads, const_a + const_b + dyn_a + dyn_b + dyn_c + 1], dtype=float, device=device)
wp.launch(
mixed_dyn_static_loop_kernel,
num_threads,
inputs=[
dyn_a,
dyn_b,
dyn_c,
],
outputs=[output],
device=device,
)
assert_np_equal(output.numpy(), np.ones([num_threads, const_a + const_b + dyn_a + dyn_b + dyn_c + 1]))


@wp.kernel
def test_call_syntax():
Expand Down

0 comments on commit ea5ab54

Please sign in to comment.