diff --git a/CHANGELOG.md b/CHANGELOG.md index 3302df1d..c3c12235 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ ### Changed - Relax the integer types expected when indexing arrays. +- Promote the `wp.Int`, `wp.Float`, and `wp.Scalar` generic annotation types to the public API. ### Fixed diff --git a/warp/__init__.py b/warp/__init__.py index dc77075c..28c165b4 100644 --- a/warp/__init__.py +++ b/warp/__init__.py @@ -26,6 +26,9 @@ from warp.types import spatial_vector, spatial_vectorh, spatial_vectorf, spatial_vectord from warp.types import spatial_matrix, spatial_matrixh, spatial_matrixf, spatial_matrixd +# annotation types +from warp.types import Int, Float, Scalar + # geometry types from warp.types import Bvh, Mesh, HashGrid, Volume, MarchingCubes from warp.types import BvhQuery, HashGridQuery, MeshQueryAABB, MeshQueryPoint, MeshQueryRay diff --git a/warp/stubs.py b/warp/stubs.py index 050f601f..fcf8c4c8 100644 --- a/warp/stubs.py +++ b/warp/stubs.py @@ -39,6 +39,8 @@ from warp.types import spatial_vector, spatial_vectorh, spatial_vectorf, spatial_vectord from warp.types import spatial_matrix, spatial_matrixh, spatial_matrixf, spatial_matrixd +from warp.types import Int, Float, Scalar + from warp.types import Bvh, Mesh, HashGrid, Volume, MarchingCubes from warp.types import BvhQuery, HashGridQuery, MeshQueryAABB, MeshQueryPoint, MeshQueryRay diff --git a/warp/tests/test_generics.py b/warp/tests/test_generics.py index ed769338..1b5ab9ac 100644 --- a/warp/tests/test_generics.py +++ b/warp/tests/test_generics.py @@ -522,6 +522,57 @@ def kernel(): ) +@wp.func +def vec_int_annotation_func(v: wp.vec(3, wp.Int)) -> wp.Int: + return v[0] + v[1] + v[2] + + +@wp.func +def vec_float_annotation_func(v: wp.vec(3, wp.Float)) -> wp.Float: + return v[0] + v[1] + v[2] + + +@wp.func +def vec_scalar_annotation_func(v: wp.vec(3, wp.Scalar)) -> wp.Scalar: + return v[0] + v[1] + v[2] + + +@wp.func +def mat_int_annotation_func(m: wp.mat((2, 2), wp.Int)) -> wp.Int: + return m[0, 0] + m[0, 1] + m[1, 0] + m[1, 1] + + +@wp.func +def mat_float_annotation_func(m: wp.mat((2, 2), wp.Float)) -> wp.Float: + return m[0, 0] + m[0, 1] + m[1, 0] + m[1, 1] + + +@wp.func +def mat_scalar_annotation_func(m: wp.mat((2, 2), wp.Scalar)) -> wp.Scalar: + return m[0, 0] + m[0, 1] + m[1, 0] + m[1, 1] + + +mat22s = wp.mat((2, 2), wp.int16) +mat22d = wp.mat((2, 2), wp.float64) + + +@wp.kernel +def test_annotations_kernel(): + vi16 = wp.vec3s(wp.int16(1), wp.int16(2), wp.int16(3)) + vf64 = wp.vec3d(wp.float64(1), wp.float64(2), wp.float64(3)) + wp.expect_eq(vec_int_annotation_func(vi16), wp.int16(6)) + wp.expect_eq(vec_float_annotation_func(vf64), wp.float64(6)) + wp.expect_eq(vec_scalar_annotation_func(vi16), wp.int16(6)) + wp.expect_eq(vec_scalar_annotation_func(vf64), wp.float64(6)) + + mi16 = mat22s(wp.int16(1), wp.int16(2), wp.int16(3), wp.int16(4)) + mf64 = mat22d(wp.float64(1), wp.float64(2), wp.float64(3), wp.float64(4)) + wp.expect_eq(mat_int_annotation_func(mi16), wp.int16(10)) + wp.expect_eq(mat_float_annotation_func(mf64), wp.float64(10)) + wp.expect_eq(mat_scalar_annotation_func(mi16), wp.int16(10)) + wp.expect_eq(mat_scalar_annotation_func(mf64), wp.float64(10)) + + class TestGenerics(unittest.TestCase): pass @@ -590,6 +641,7 @@ class TestGenerics(unittest.TestCase): ) add_function_test(TestGenerics, "test_type_operator_misspell", test_type_operator_misspell, devices=devices) add_function_test(TestGenerics, "test_type_attribute_error", test_type_attribute_error, devices=devices) +add_kernel_test(TestGenerics, name="test_annotations_kernel", kernel=test_annotations_kernel, dim=1, devices=devices) if __name__ == "__main__": wp.clear_kernel_cache() diff --git a/warp/types.py b/warp/types.py index e50b4cfd..cdd659f8 100644 --- a/warp/types.py +++ b/warp/types.py @@ -100,8 +100,10 @@ class vec_t(ctypes.Array): if dtype is bool: _type_ = ctypes.c_bool - elif dtype in [Scalar, Float]: + elif dtype in (Scalar, Float): _type_ = ctypes.c_float + elif dtype is Int: + _type_ = ctypes.c_int else: _type_ = dtype._type_ @@ -289,8 +291,10 @@ class mat_t(ctypes.Array): if dtype is bool: _type_ = ctypes.c_bool - elif dtype in [Scalar, Float]: + elif dtype in (Scalar, Float): _type_ = ctypes.c_float + elif dtype is Int: + _type_ = ctypes.c_int else: _type_ = dtype._type_