Skip to content

Commit

Permalink
Merge branch 'fix-mutating-error' into 'main'
Browse files Browse the repository at this point in the history
Fix error raising for mixing constant + dynamic for-loops [GH-331]

See merge request omniverse/warp!821
  • Loading branch information
shi-eric committed Oct 28, 2024
2 parents 72e3e81 + e7b7c0b commit 0c60fa8
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Fix printing vector and matrix adjoints in backward kernels.
- Fix kernel compile error when printing structs.
- Fix an incorrect user function being sometimes resolved when multiple overloads are available with array parameters with different `dtype` values.
- Fix error being raised when static and dynamic for-loops are written in sequence with the same iteration variable names ([GH-331](https://github.com/NVIDIA/warp/issues/331)).

## [1.4.1] - 2024-10-15

Expand Down
21 changes: 5 additions & 16 deletions warp/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,9 @@ def build(adj, builder, default_builder_options=None):

adj.return_var = None # return type for function or kernel
adj.loop_symbols = [] # symbols at the start of each loop
adj.loop_const_iter_symbols = [] # iteration variables (constant) for static loops
adj.loop_const_iter_symbols = (
set()
) # constant iteration variables for static loops (mutating them does not raise an error)

# blocks
adj.blocks = [Block()]
Expand Down Expand Up @@ -2000,22 +2002,11 @@ def get_unroll_range(adj, loop):
)
return range_call

def begin_record_constant_iter_symbols(adj):
if len(adj.loop_const_iter_symbols) > 0:
adj.loop_const_iter_symbols.append(adj.loop_const_iter_symbols[-1])
else:
adj.loop_const_iter_symbols.append(set())

def end_record_constant_iter_symbols(adj):
if len(adj.loop_const_iter_symbols) > 0:
adj.loop_const_iter_symbols.pop()

def record_constant_iter_symbol(adj, sym):
if len(adj.loop_const_iter_symbols) > 0:
adj.loop_const_iter_symbols[-1].add(sym)
adj.loop_const_iter_symbols.add(sym)

def is_constant_iter_symbol(adj, sym):
return len(adj.loop_const_iter_symbols) > 0 and sym in adj.loop_const_iter_symbols[-1]
return sym in adj.loop_const_iter_symbols

def emit_For(adj, node):
# try and unroll simple range() statements that use constant args
Expand Down Expand Up @@ -2045,7 +2036,6 @@ def emit_For(adj, node):
iter = adj.eval(node.iter)

adj.symbols[node.target.id] = adj.begin_for(iter)
adj.begin_record_constant_iter_symbols()

# for loops should be side-effect free, here we store a copy
adj.loop_symbols.append(adj.symbols.copy())
Expand All @@ -2056,7 +2046,6 @@ def emit_For(adj, node):

adj.materialize_redefinitions(adj.loop_symbols[-1])
adj.loop_symbols.pop()
adj.end_record_constant_iter_symbols()

adj.end_for(iter)

Expand Down
39 changes: 39 additions & 0 deletions warp/tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,45 @@ def mixed_dyn_static_loop_kernel(dyn_a: int, dyn_b: int, dyn_c: int, output: wp.
)
assert_np_equal(output.numpy(), np.ones([num_threads, const_a + const_b + dyn_a + dyn_b + dyn_c + 1]))

@wp.kernel
def static_then_dynamic_loop_kernel(mats: wp.array(dtype=wp.mat33d)):
tid = wp.tid()
mat = wp.mat33d()
for i in range(3):
for j in range(3):
mat[i, j] = wp.float64(0.0)

dim = 2
for i in range(dim + 1):
for j in range(dim + 1):
mat[i, j] = wp.float64(1.0)

mats[tid] = mat

mats = wp.empty(1, dtype=wp.mat33d, device=device)
wp.launch(static_then_dynamic_loop_kernel, dim=1, inputs=[mats], device=device)
assert_np_equal(mats.numpy(), np.ones((1, 3, 3)))

@wp.kernel
def dynamic_then_static_loop_kernel(mats: wp.array(dtype=wp.mat33d)):
tid = wp.tid()
mat = wp.mat33d()

dim = 2
for i in range(dim + 1):
for j in range(dim + 1):
mat[i, j] = wp.float64(1.0)

for i in range(3):
for j in range(3):
mat[i, j] = wp.float64(0.0)

mats[tid] = mat

mats = wp.empty(1, dtype=wp.mat33d, device=device)
wp.launch(dynamic_then_static_loop_kernel, dim=1, inputs=[mats], device=device)
assert_np_equal(mats.numpy(), np.zeros((1, 3, 3)))


@wp.kernel
def test_call_syntax():
Expand Down

0 comments on commit 0c60fa8

Please sign in to comment.