From ccb19163476e83cdacf224180c7a32642b5f1224 Mon Sep 17 00:00:00 2001 From: Eric Heiden Date: Mon, 28 Oct 2024 20:22:01 +0100 Subject: [PATCH 1/3] Fix error raising for mixing constant + dynamic for-loops [GH-331] --- CHANGELOG.md | 1 + warp/codegen.py | 19 +++--------------- warp/tests/test_codegen.py | 40 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ad15536b..24987e0d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ - Fix printing vector and matrix adjoints in backward kernels. - Fix kernel compile error when printing structs. - Fix an incorrect user function being sometimes resolved when multiple overloads are available with array parameters with different `dtype` values. +- Fix error being raised when static and dynamic for-loops are written in sequence with the same iteration variable names ([GH-331](https://github.com/NVIDIA/warp/issues/331)). ## [1.4.1] - 2024-10-15 diff --git a/warp/codegen.py b/warp/codegen.py index 7db6ecc5..c4f05920 100644 --- a/warp/codegen.py +++ b/warp/codegen.py @@ -939,7 +939,7 @@ def build(adj, builder, default_builder_options=None): adj.return_var = None # return type for function or kernel adj.loop_symbols = [] # symbols at the start of each loop - adj.loop_const_iter_symbols = [] # iteration variables (constant) for static loops + adj.loop_const_iter_symbols = set() # constant iteration variables for static loops (mutating them does not raise an error) # blocks adj.blocks = [Block()] @@ -2000,22 +2000,11 @@ def get_unroll_range(adj, loop): ) return range_call - def begin_record_constant_iter_symbols(adj): - if len(adj.loop_const_iter_symbols) > 0: - adj.loop_const_iter_symbols.append(adj.loop_const_iter_symbols[-1]) - else: - adj.loop_const_iter_symbols.append(set()) - - def end_record_constant_iter_symbols(adj): - if len(adj.loop_const_iter_symbols) > 0: - adj.loop_const_iter_symbols.pop() - def record_constant_iter_symbol(adj, sym): - if len(adj.loop_const_iter_symbols) > 0: - adj.loop_const_iter_symbols[-1].add(sym) + adj.loop_const_iter_symbols.add(sym) def is_constant_iter_symbol(adj, sym): - return len(adj.loop_const_iter_symbols) > 0 and sym in adj.loop_const_iter_symbols[-1] + return sym in adj.loop_const_iter_symbols def emit_For(adj, node): # try and unroll simple range() statements that use constant args @@ -2045,7 +2034,6 @@ def emit_For(adj, node): iter = adj.eval(node.iter) adj.symbols[node.target.id] = adj.begin_for(iter) - adj.begin_record_constant_iter_symbols() # for loops should be side-effect free, here we store a copy adj.loop_symbols.append(adj.symbols.copy()) @@ -2056,7 +2044,6 @@ def emit_For(adj, node): adj.materialize_redefinitions(adj.loop_symbols[-1]) adj.loop_symbols.pop() - adj.end_record_constant_iter_symbols() adj.end_for(iter) diff --git a/warp/tests/test_codegen.py b/warp/tests/test_codegen.py index 6beb4d54..fd775178 100644 --- a/warp/tests/test_codegen.py +++ b/warp/tests/test_codegen.py @@ -534,6 +534,46 @@ def mixed_dyn_static_loop_kernel(dyn_a: int, dyn_b: int, dyn_c: int, output: wp. ) assert_np_equal(output.numpy(), np.ones([num_threads, const_a + const_b + dyn_a + dyn_b + dyn_c + 1])) + @wp.kernel + def static_then_dynamic_loop_kernel(mats: wp.array(dtype=wp.mat33d)): + tid = wp.tid() + mat = wp.mat33d() + for i in range(3): + for j in range(3): + mat[i, j] = wp.float64(0.0) + + dim = 2 + for i in range(dim + 1): + for j in range(dim + 1): + mat[i, j] = wp.float64(1.0) + + mats[tid] = mat + + mats = wp.empty(1, dtype=wp.mat33d, device=device) + wp.launch(static_then_dynamic_loop_kernel, dim=1, inputs=[mats], device=device) + assert_np_equal(mats.numpy(), np.ones((1, 3, 3))) + + @wp.kernel + def dynamic_then_static_loop_kernel(mats: wp.array(dtype=wp.mat33d)): + tid = wp.tid() + mat = wp.mat33d() + + dim = 2 + for i in range(dim + 1): + for j in range(dim + 1): + mat[i, j] = wp.float64(1.0) + + for i in range(3): + for j in range(3): + mat[i, j] = wp.float64(0.0) + + mats[tid] = mat + + mats = wp.empty(1, dtype=wp.mat33d, device=device) + wp.launch(dynamic_then_static_loop_kernel, dim=1, inputs=[mats], device=device) + assert_np_equal(mats.numpy(), np.zeros((1, 3, 3))) + + @wp.kernel def test_call_syntax(): From 48e695a5a7a3738ee8ff654accfa5cf8393ad8c1 Mon Sep 17 00:00:00 2001 From: Eric Shi Date: Mon, 28 Oct 2024 12:56:03 -0700 Subject: [PATCH 2/3] Fix Ruff issue --- warp/codegen.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/warp/codegen.py b/warp/codegen.py index c4f05920..79c8cd08 100644 --- a/warp/codegen.py +++ b/warp/codegen.py @@ -939,7 +939,9 @@ def build(adj, builder, default_builder_options=None): adj.return_var = None # return type for function or kernel adj.loop_symbols = [] # symbols at the start of each loop - adj.loop_const_iter_symbols = set() # constant iteration variables for static loops (mutating them does not raise an error) + adj.loop_const_iter_symbols = ( + set() + ) # constant iteration variables for static loops (mutating them does not raise an error) # blocks adj.blocks = [Block()] From e7b7c0b892b27b01bef2f0a49cf1153e2981b884 Mon Sep 17 00:00:00 2001 From: Eric Shi Date: Mon, 28 Oct 2024 12:57:03 -0700 Subject: [PATCH 3/3] Fix Ruff issue --- warp/tests/test_codegen.py | 1 - 1 file changed, 1 deletion(-) diff --git a/warp/tests/test_codegen.py b/warp/tests/test_codegen.py index fd775178..db0bdee7 100644 --- a/warp/tests/test_codegen.py +++ b/warp/tests/test_codegen.py @@ -574,7 +574,6 @@ def dynamic_then_static_loop_kernel(mats: wp.array(dtype=wp.mat33d)): assert_np_equal(mats.numpy(), np.zeros((1, 3, 3))) - @wp.kernel def test_call_syntax(): expected_pow = 16.0