From ea5ab54a0d158ff184343a197d22241593dc7530 Mon Sep 17 00:00:00 2001 From: Eric Heiden Date: Mon, 7 Oct 2024 17:46:47 -0700 Subject: [PATCH] Fix codegen error when nesting dynamic and static for-loops --- CHANGELOG.md | 1 + warp/codegen.py | 25 +++++++++++++++++-------- warp/tests/test_codegen.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ea4b281f..3302df1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ - Fix potential out-of-bounds memory access when a `wp.sparse.BsrMatrix` object is reused for storing matrices of different shapes - Fix robustness to very low desired tolerance in `wp.fem.utils.symmetric_eigenvalues_qr` +- Fix invalid code generation error messages when nesting dynamic and static for-loops ## [1.4.0] - 2024-10-01 diff --git a/warp/codegen.py b/warp/codegen.py index 9bb817ad..7db6ecc5 100644 --- a/warp/codegen.py +++ b/warp/codegen.py @@ -939,7 +939,7 @@ def build(adj, builder, default_builder_options=None): adj.return_var = None # return type for function or kernel adj.loop_symbols = [] # symbols at the start of each loop - adj.loop_const_iter_symbols = set() # iteration variables (constant) for static loops + adj.loop_const_iter_symbols = [] # iteration variables (constant) for static loops # blocks adj.blocks = [Block()] @@ -1849,7 +1849,7 @@ def materialize_redefinitions(adj, symbols): # detect symbols with conflicting definitions (assigned inside the for loop) for items in symbols.items(): sym = items[0] - if adj.loop_const_iter_symbols is not None and sym in adj.loop_const_iter_symbols: + if adj.is_constant_iter_symbol(sym): # ignore constant overwriting in for-loops if it is a loop iterator # (it is no problem to unroll static loops multiple times in sequence) continue @@ -2001,11 +2001,21 @@ def get_unroll_range(adj, loop): return range_call def begin_record_constant_iter_symbols(adj): - if adj.loop_const_iter_symbols is None: - adj.loop_const_iter_symbols = set() + if len(adj.loop_const_iter_symbols) > 0: + adj.loop_const_iter_symbols.append(adj.loop_const_iter_symbols[-1]) + else: + adj.loop_const_iter_symbols.append(set()) def end_record_constant_iter_symbols(adj): - adj.loop_const_iter_symbols = None + if len(adj.loop_const_iter_symbols) > 0: + adj.loop_const_iter_symbols.pop() + + def record_constant_iter_symbol(adj, sym): + if len(adj.loop_const_iter_symbols) > 0: + adj.loop_const_iter_symbols[-1].add(sym) + + def is_constant_iter_symbol(adj, sym): + return len(adj.loop_const_iter_symbols) > 0 and sym in adj.loop_const_iter_symbols[-1] def emit_For(adj, node): # try and unroll simple range() statements that use constant args @@ -2013,9 +2023,8 @@ def emit_For(adj, node): if isinstance(unroll_range, range): const_iter_sym = node.target.id - if adj.loop_const_iter_symbols is not None: - # prevent constant conflicts in `materialize_redefinitions()` - adj.loop_const_iter_symbols.add(const_iter_sym) + # prevent constant conflicts in `materialize_redefinitions()` + adj.record_constant_iter_symbol(const_iter_sym) # unroll static for-loop for i in unroll_range: diff --git a/warp/tests/test_codegen.py b/warp/tests/test_codegen.py index e3552ad2..6beb4d54 100644 --- a/warp/tests/test_codegen.py +++ b/warp/tests/test_codegen.py @@ -503,6 +503,37 @@ def dynamic_loop_kernel(n: int, input: wp.array(dtype=float)): ): wp.launch(dynamic_loop_kernel, dim=1, inputs=[3, inputs], device=device) + # the following nested loop must not raise an error + const_a = 7 + const_b = 5 + + @wp.kernel + def mixed_dyn_static_loop_kernel(dyn_a: int, dyn_b: int, dyn_c: int, output: wp.array(dtype=float, ndim=2)): + tid = wp.tid() + for i in range(const_a + 1): + for j in range(dyn_a + 1): + for k in range(dyn_b + 1): + for l in range(const_b + 1): + for m in range(dyn_c + 1): + coeff = i + j + k + l + m + output[tid, coeff] = 1.0 + + dyn_a, dyn_b, dyn_c = 3, 4, 5 + num_threads = 10 + output = wp.empty([num_threads, const_a + const_b + dyn_a + dyn_b + dyn_c + 1], dtype=float, device=device) + wp.launch( + mixed_dyn_static_loop_kernel, + num_threads, + inputs=[ + dyn_a, + dyn_b, + dyn_c, + ], + outputs=[output], + device=device, + ) + assert_np_equal(output.numpy(), np.ones([num_threads, const_a + const_b + dyn_a + dyn_b + dyn_c + 1])) + @wp.kernel def test_call_syntax():