Skip to content

Commit

Permalink
Promote generic scalar annotation types to public
Browse files Browse the repository at this point in the history
  • Loading branch information
christophercrouzet committed Oct 9, 2024
1 parent ce149b9 commit 8e4c2a0
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
### Changed

- Relax the integer types expected when indexing arrays.
- Promote the `wp.Int`, `wp.Float`, and `wp.Scalar` generic annotation types to the public API.

### Fixed

Expand Down
3 changes: 3 additions & 0 deletions warp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from warp.types import spatial_vector, spatial_vectorh, spatial_vectorf, spatial_vectord
from warp.types import spatial_matrix, spatial_matrixh, spatial_matrixf, spatial_matrixd

# annotation types
from warp.types import Int, Float, Scalar

# geometry types
from warp.types import Bvh, Mesh, HashGrid, Volume, MarchingCubes
from warp.types import BvhQuery, HashGridQuery, MeshQueryAABB, MeshQueryPoint, MeshQueryRay
Expand Down
2 changes: 2 additions & 0 deletions warp/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from warp.types import spatial_vector, spatial_vectorh, spatial_vectorf, spatial_vectord
from warp.types import spatial_matrix, spatial_matrixh, spatial_matrixf, spatial_matrixd

from warp.types import Int, Float, Scalar

from warp.types import Bvh, Mesh, HashGrid, Volume, MarchingCubes
from warp.types import BvhQuery, HashGridQuery, MeshQueryAABB, MeshQueryPoint, MeshQueryRay

Expand Down
52 changes: 52 additions & 0 deletions warp/tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,57 @@ def kernel():
)


@wp.func
def vec_int_annotation_func(v: wp.vec(3, wp.Int)) -> wp.Int:
return v[0] + v[1] + v[2]


@wp.func
def vec_float_annotation_func(v: wp.vec(3, wp.Float)) -> wp.Float:
return v[0] + v[1] + v[2]


@wp.func
def vec_scalar_annotation_func(v: wp.vec(3, wp.Scalar)) -> wp.Scalar:
return v[0] + v[1] + v[2]


@wp.func
def mat_int_annotation_func(m: wp.mat((2, 2), wp.Int)) -> wp.Int:
return m[0, 0] + m[0, 1] + m[1, 0] + m[1, 1]


@wp.func
def mat_float_annotation_func(m: wp.mat((2, 2), wp.Float)) -> wp.Float:
return m[0, 0] + m[0, 1] + m[1, 0] + m[1, 1]


@wp.func
def mat_scalar_annotation_func(m: wp.mat((2, 2), wp.Scalar)) -> wp.Scalar:
return m[0, 0] + m[0, 1] + m[1, 0] + m[1, 1]


mat22s = wp.mat((2, 2), wp.int16)
mat22d = wp.mat((2, 2), wp.float64)


@wp.kernel
def test_annotations_kernel():
vi16 = wp.vec3s(wp.int16(1), wp.int16(2), wp.int16(3))
vf64 = wp.vec3d(wp.float64(1), wp.float64(2), wp.float64(3))
wp.expect_eq(vec_int_annotation_func(vi16), wp.int16(6))
wp.expect_eq(vec_float_annotation_func(vf64), wp.float64(6))
wp.expect_eq(vec_scalar_annotation_func(vi16), wp.int16(6))
wp.expect_eq(vec_scalar_annotation_func(vf64), wp.float64(6))

mi16 = mat22s(wp.int16(1), wp.int16(2), wp.int16(3), wp.int16(4))
mf64 = mat22d(wp.float64(1), wp.float64(2), wp.float64(3), wp.float64(4))
wp.expect_eq(mat_int_annotation_func(mi16), wp.int16(10))
wp.expect_eq(mat_float_annotation_func(mf64), wp.float64(10))
wp.expect_eq(mat_scalar_annotation_func(mi16), wp.int16(10))
wp.expect_eq(mat_scalar_annotation_func(mf64), wp.float64(10))


class TestGenerics(unittest.TestCase):
pass

Expand Down Expand Up @@ -590,6 +641,7 @@ class TestGenerics(unittest.TestCase):
)
add_function_test(TestGenerics, "test_type_operator_misspell", test_type_operator_misspell, devices=devices)
add_function_test(TestGenerics, "test_type_attribute_error", test_type_attribute_error, devices=devices)
add_kernel_test(TestGenerics, name="test_annotations_kernel", kernel=test_annotations_kernel, dim=1, devices=devices)

if __name__ == "__main__":
wp.clear_kernel_cache()
Expand Down
8 changes: 6 additions & 2 deletions warp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ class vec_t(ctypes.Array):

if dtype is bool:
_type_ = ctypes.c_bool
elif dtype in [Scalar, Float]:
elif dtype in (Scalar, Float):
_type_ = ctypes.c_float
elif dtype is Int:
_type_ = ctypes.c_int
else:
_type_ = dtype._type_

Expand Down Expand Up @@ -289,8 +291,10 @@ class mat_t(ctypes.Array):

if dtype is bool:
_type_ = ctypes.c_bool
elif dtype in [Scalar, Float]:
elif dtype in (Scalar, Float):
_type_ = ctypes.c_float
elif dtype is Int:
_type_ = ctypes.c_int
else:
_type_ = dtype._type_

Expand Down

0 comments on commit 8e4c2a0

Please sign in to comment.