diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ff68a16..2c07915a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ - Fix printing vector and matrix adjoints in backward kernels. - Fix kernel compile error when printing structs. +- Fix an incorrect user function being sometimes resolved when multiple overloads are available with array parameters with different `dtype` values. ## [1.4.1] - 2024-10-15 diff --git a/warp/tests/test_func.py b/warp/tests/test_func.py index 495e0a9c..631fe769 100644 --- a/warp/tests/test_func.py +++ b/warp/tests/test_func.py @@ -7,7 +7,7 @@ import math import unittest -from typing import Tuple +from typing import Any, Tuple import numpy as np @@ -191,6 +191,37 @@ def test_user_func_return_multiple_values(): wp.expect_eq(b, 54756.0) +@wp.func +def user_func_overload( + b: wp.array(dtype=Any), + i: int, +): + return b[i] * 2.0 + + +@wp.kernel +def user_func_overload_resolution_kernel( + a: wp.array(dtype=Any), + b: wp.array(dtype=Any), +): + i = wp.tid() + a[i] = user_func_overload(b, i) + + +def test_user_func_overload_resolution(test, device): + a0 = wp.array((1, 2, 3), dtype=wp.vec3) + b0 = wp.array((2, 3, 4), dtype=wp.vec3) + + a1 = wp.array((5,), dtype=float) + b1 = wp.array((6,), dtype=float) + + wp.launch(user_func_overload_resolution_kernel, a0.shape, (a0, b0)) + wp.launch(user_func_overload_resolution_kernel, a1.shape, (a1, b1)) + + assert_np_equal(a0.numpy()[0], (4, 6, 8)) + assert a1.numpy()[0] == 12 + + devices = get_test_devices() @@ -375,6 +406,9 @@ def test_native_function_error_resolution(self): dim=1, devices=devices, ) +add_function_test( + TestFunc, func=test_user_func_overload_resolution, name="test_user_func_overload_resolution", devices=devices +) if __name__ == "__main__": diff --git a/warp/types.py b/warp/types.py index cdd659f8..1b7e3ded 100644 --- a/warp/types.py +++ b/warp/types.py @@ -1492,7 +1492,7 @@ def types_equal(a, b, match_generic=False): return True - if is_array(a) and type(a) is type(b): + if is_array(a) and type(a) is type(b) and types_equal(a.dtype, b.dtype, match_generic=match_generic): return True # match NewStructInstance and Struct dtype