From a9117c989a36404ec9cd59ef469d96cbba587ed1 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Thu, 10 Oct 2024 18:08:18 -0400 Subject: [PATCH] Added test for function assignment to a variable inside kernel --- warp/tests/test_codegen.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/warp/tests/test_codegen.py b/warp/tests/test_codegen.py index 6beb4d54..e052ef65 100644 --- a/warp/tests/test_codegen.py +++ b/warp/tests/test_codegen.py @@ -576,6 +576,23 @@ def test_while_condition_eval(): while it.valid: it.valid = False +def test_function_assignment(test, device): + + @wp.func + def multiply_by_two(a: float): + return a * 2.0 + + @wp.kernel + def kernel_function_assignment(input: wp.array(dtype=wp.float32), output: wp.array(dtype=wp.float32)): + tid = wp.tid() + func = multiply_by_two # Assign function to variable + output[tid] = func(input[tid]) # Call function through variable + + input_data = wp.array([1.0, 2.0, 3.0], dtype=wp.float32, device=device) + output_data = wp.empty_like(input_data) + wp.launch(kernel_function_assignment, dim=input_data.size, inputs=[input_data], outputs=[output_data], device=device) + expected_output = np.array([2.0, 4.0, 6.0], dtype=np.float32) + assert_np_equal(output_data.numpy(), expected_output) class TestCodeGen(unittest.TestCase): pass @@ -719,6 +736,12 @@ class TestCodeGen(unittest.TestCase): name="test_error_mutating_constant_in_dynamic_loop", devices=devices, ) +add_function_test( + TestCodeGen, + func=test_function_assignment, + name="test_function_assignment", + devices=devices +) add_kernel_test(TestCodeGen, name="test_call_syntax", kernel=test_call_syntax, dim=1, devices=devices) add_kernel_test(TestCodeGen, name="test_shadow_builtin", kernel=test_shadow_builtin, dim=1, devices=devices)