Skip to content

Commit

Permalink
Merge branch 'Mesh_add_change_vertex_buffer' into 'main'
Browse files Browse the repository at this point in the history
Implemented a Setter Method for wp.Mesh Which Allows Changing the Vertex Buffer

See merge request omniverse/warp!723
  • Loading branch information
mmacklin committed Sep 24, 2024
2 parents 909bf94 + 1873eb5 commit 8db4b75
Show file tree
Hide file tree
Showing 8 changed files with 357 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
- Unexposed `wp.rand*()`, `wp.sample*()`, and `wp.poisson()` from the Python scope.
- Skip unused functions in module code generation, improving performance.
- Avoid reloading modules if their content does not change, improving performance.
- `wp.Mesh.points` is now a property instead of a raw data member, its reference can be changed after the mesh is initialized.

### Fixed

Expand Down
6 changes: 6 additions & 0 deletions warp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2996,6 +2996,12 @@ def __init__(self):
self.core.mesh_refit_host.argtypes = [ctypes.c_uint64]
self.core.mesh_refit_device.argtypes = [ctypes.c_uint64]

self.core.mesh_set_points_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
self.core.mesh_set_points_device.argtypes = [ctypes.c_uint64, warp.types.array_t]

self.core.mesh_set_velocities_host.argtypes = [ctypes.c_uint64, warp.types.array_t]
self.core.mesh_set_velocities_device.argtypes = [ctypes.c_uint64, warp.types.array_t]

self.core.hash_grid_create_host.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
self.core.hash_grid_create_host.restype = ctypes.c_uint64
self.core.hash_grid_destroy_host.argtypes = [ctypes.c_uint64]
Expand Down
36 changes: 36 additions & 0 deletions warp/native/mesh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ bool mesh_get_descriptor(uint64_t id, Mesh& mesh)
return true;
}

bool mesh_set_descriptor(uint64_t id, const Mesh& mesh)
{
const auto& iter = g_mesh_descriptors.find(id);
if (iter == g_mesh_descriptors.end())
return false;
else
iter->second = mesh;
return true;
}

void mesh_add_descriptor(uint64_t id, const Mesh& mesh)
{
g_mesh_descriptors[id] = mesh;
Expand Down Expand Up @@ -191,6 +201,30 @@ void mesh_refit_host(uint64_t id)
}
}

void mesh_set_points_host(uint64_t id, wp::array_t<wp::vec3> points)
{
Mesh* m = (Mesh*)(id);
if (points.ndim != 1 || points.shape[0] != m->points.shape[0])
{
fprintf(stderr, "The new points input for mesh_set_points_host does not match the shape of the original points!\n");
return;
}

m->points = points;

mesh_refit_host(id);
}

void mesh_set_velocities_host(uint64_t id, wp::array_t<wp::vec3> velocities)
{
Mesh* m = (Mesh*)(id);
if (velocities.ndim != 1 || velocities.shape[0] != m->velocities.shape[0])
{
fprintf(stderr, "The new velocities input for mesh_set_velocities_host does not match the shape of the original velocities!\n");
return;
}
m->velocities = velocities;
}

// stubs for non-CUDA platforms
#if !WP_ENABLE_CUDA
Expand All @@ -199,6 +233,8 @@ void mesh_refit_host(uint64_t id)
WP_API uint64_t mesh_create_device(void* context, wp::array_t<wp::vec3> points, wp::array_t<wp::vec3> velocities, wp::array_t<int> tris, int num_points, int num_tris, int support_winding_number) { return 0; }
WP_API void mesh_destroy_device(uint64_t id) {}
WP_API void mesh_refit_device(uint64_t id) {}
WP_API void mesh_set_points_device(uint64_t id, wp::array_t<wp::vec3> points) {};
WP_API void mesh_set_velocities_device(uint64_t id, wp::array_t<wp::vec3> points) {};


#endif // !WP_ENABLE_CUDA
51 changes: 51 additions & 0 deletions warp/native/mesh.cu
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,54 @@ void mesh_refit_device(uint64_t id)
}
}

void mesh_set_points_device(uint64_t id, wp::array_t<wp::vec3> points)
{
wp::Mesh m;
if (mesh_get_descriptor(id, m))
{
if (points.ndim != 1 || points.shape[0] != m.points.shape[0])
{
fprintf(stderr, "The new points input for mesh_set_points_device does not match the shape of the original points!\n");
return;
}

m.points = points;

wp::Mesh* mesh_device = (wp::Mesh*)id;
memcpy_h2d(WP_CURRENT_CONTEXT, mesh_device, &m, sizeof(wp::Mesh));

// update the cpu copy as well
mesh_set_descriptor(id, m);

mesh_refit_device(id);
}
else
{
fprintf(stderr, "The mesh id provided to mesh_set_points_device is not valid!\n");
return;
}
}

void mesh_set_velocities_device(uint64_t id, wp::array_t<wp::vec3> velocities)
{
wp::Mesh m;
if (mesh_get_descriptor(id, m))
{
if (velocities.ndim != 1 || velocities.shape[0] != m.velocities.shape[0])
{
fprintf(stderr, "The new velocities input for mesh_set_velocities_device does not match the shape of the original velocities\n");
return;
}

m.velocities = velocities;

wp::Mesh* mesh_device = (wp::Mesh*)id;
memcpy_h2d(WP_CURRENT_CONTEXT, mesh_device, &m, sizeof(wp::Mesh));
mesh_set_descriptor(id, m);
}
else
{
fprintf(stderr, "The mesh id provided to mesh_set_velocities_device is not valid!\n");
return;
}
}
1 change: 1 addition & 0 deletions warp/native/mesh.h
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,7 @@ CUDA_CALLABLE inline void adj_mesh_get_index(uint64_t id, int index,
}

CUDA_CALLABLE bool mesh_get_descriptor(uint64_t id, Mesh& mesh);
CUDA_CALLABLE bool mesh_set_descriptor(uint64_t id, const Mesh& mesh);
CUDA_CALLABLE void mesh_add_descriptor(uint64_t id, const Mesh& mesh);
CUDA_CALLABLE void mesh_rem_descriptor(uint64_t id);

Expand Down
6 changes: 6 additions & 0 deletions warp/native/warp.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ extern "C"
WP_API void mesh_destroy_device(uint64_t id);
WP_API void mesh_refit_device(uint64_t id);

WP_API void mesh_set_points_host(uint64_t id, wp::array_t<wp::vec3> points);
WP_API void mesh_set_points_device(uint64_t id, wp::array_t<wp::vec3> points);

WP_API void mesh_set_velocities_host(uint64_t id, wp::array_t<wp::vec3> velocities);
WP_API void mesh_set_velocities_device(uint64_t id, wp::array_t<wp::vec3> velocities);

WP_API uint64_t hash_grid_create_host(int dim_x, int dim_y, int dim_z);
WP_API void hash_grid_reserve_host(uint64_t id, int num_points);
WP_API void hash_grid_destroy_host(uint64_t id);
Expand Down
188 changes: 188 additions & 0 deletions warp/tests/test_mesh_query_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np

import warp as wp
import warp.examples
from warp.tests.unittest_utils import *


Expand Down Expand Up @@ -654,6 +655,192 @@ def test_mesh_query_furthest_point(test, device):
assert_np_equal(dist_query.numpy(), dist_brute.numpy(), tol=1.0e-3)


@wp.func
def triangle_closest_point_for_test(a: wp.vec3, b: wp.vec3, c: wp.vec3, p: wp.vec3):
ab = b - a
ac = c - a
ap = p - a

d1 = wp.dot(ab, ap)
d2 = wp.dot(ac, ap)
if d1 <= 0.0 and d2 <= 0.0:
bary = wp.vec3(1.0, 0.0, 0.0)
return a, bary

bp = p - b
d3 = wp.dot(ab, bp)
d4 = wp.dot(ac, bp)
if d3 >= 0.0 and d4 <= d3:
bary = wp.vec3(0.0, 1.0, 0.0)
return b, bary

cp = p - c
d5 = wp.dot(ab, cp)
d6 = wp.dot(ac, cp)
if d6 >= 0.0 and d5 <= d6:
bary = wp.vec3(0.0, 0.0, 1.0)
return c, bary

vc = d1 * d4 - d3 * d2
if vc <= 0.0 and d1 >= 0.0 and d3 <= 0.0:
v = d1 / (d1 - d3)
bary = wp.vec3(1.0 - v, v, 0.0)
return a + v * ab, bary

vb = d5 * d2 - d1 * d6
if vb <= 0.0 and d2 >= 0.0 and d6 <= 0.0:
v = d2 / (d2 - d6)
bary = wp.vec3(1.0 - v, 0.0, v)
return a + v * ac, bary

va = d3 * d6 - d5 * d4
if va <= 0.0 and (d4 - d3) >= 0.0 and (d5 - d6) >= 0.0:
v = (d4 - d3) / ((d4 - d3) + (d5 - d6))
bary = wp.vec3(0.0, 1.0 - v, v)
return b + v * (c - b), bary

denom = 1.0 / (va + vb + vc)
v = vb * denom
w = vc * denom
bary = wp.vec3(1.0 - v - w, v, w)
return a + v * ab + w * ac, bary


def load_mesh():
from pxr import Usd, UsdGeom

usd_stage = Usd.Stage.Open(os.path.join(wp.examples.get_asset_directory(), "bunny.usd"))
usd_geom = UsdGeom.Mesh(usd_stage.GetPrimAtPath("/root/bunny"))

vertices = np.array(usd_geom.GetPointsAttr().Get())
faces = np.array(usd_geom.GetFaceVertexIndicesAttr().Get())

return vertices, faces


@wp.kernel
def point_query_aabb_and_closest(
query_radius: float,
mesh_id: wp.uint64,
pts: wp.array(dtype=wp.vec3),
pos: wp.array(dtype=wp.vec3),
tri_indices: wp.array(dtype=wp.int32, ndim=2),
query_results_num_collisions: wp.array(dtype=wp.int32),
query_results_min_dist: wp.array(dtype=float),
query_results_closest_point_velocity: wp.array(dtype=wp.vec3),
):
p_index = wp.tid()
p = pts[p_index]

lower = wp.vec3(p[0] - query_radius, p[1] - query_radius, p[2] - query_radius)
upper = wp.vec3(p[0] + query_radius, p[1] + query_radius, p[2] + query_radius)

closest_query = wp.mesh_query_point_no_sign(mesh_id, p, query_radius)
if closest_query.result:
closest_p = wp.mesh_eval_position(mesh_id, closest_query.face, closest_query.u, closest_query.v)
closest_p_vel = wp.mesh_eval_velocity(mesh_id, closest_query.face, closest_query.u, closest_query.v)

query_results_min_dist[p_index] = wp.length(closest_p - p)
query_results_closest_point_velocity[p_index] = closest_p_vel

query = wp.mesh_query_aabb(mesh_id, lower, upper)

tri_index = wp.int32(0)
num_collisions = wp.int32(0)
min_dis_to_tris = query_radius
while wp.mesh_query_aabb_next(query, tri_index):
t1 = tri_indices[tri_index, 0]
t2 = tri_indices[tri_index, 1]
t3 = tri_indices[tri_index, 2]

u1 = pos[t1]
u2 = pos[t2]
u3 = pos[t3]

closest_p1, barycentric1 = triangle_closest_point_for_test(u1, u2, u3, p)

dis = wp.length(closest_p1 - p)

if dis < query_radius:
num_collisions = num_collisions + 1

query_results_num_collisions[p_index] = num_collisions


@unittest.skipUnless(USD_AVAILABLE, "Requires usd-core")
def test_set_mesh_points(test, device):
vs, fs = load_mesh()

vertices1 = wp.array(vs, dtype=wp.vec3, device=device)
velocities1_np = np.random.randn(vertices1.shape[0], 3)
velocities1 = wp.array(velocities1_np, dtype=wp.vec3, device=device)

faces = wp.array(fs, dtype=wp.int32, device=device)
mesh = wp.Mesh(vertices1, faces, velocities=velocities1)
fs_2D = faces.reshape((-1, 3))
np.random.seed(12345)
n = 1000
query_radius = 0.2

pts1 = wp.array(np.random.randn(n, 3), dtype=wp.vec3, device=device)

query_results_num_cols1 = wp.zeros(n, dtype=wp.int32, device=device)
query_results_min_dist1 = wp.zeros(n, dtype=float, device=device)
query_results_closest_point_velocity1 = wp.zeros(n, dtype=wp.vec3, device=device)

wp.launch(
kernel=point_query_aabb_and_closest,
inputs=[
query_radius,
mesh.id,
pts1,
vertices1,
fs_2D,
query_results_num_cols1,
query_results_min_dist1,
query_results_closest_point_velocity1,
],
dim=n,
device=device,
)

shift = np.random.randn(3)

vs_higher = vs + shift
vertices2 = wp.array(vs_higher, dtype=wp.vec3, device=device)

velocities2_np = velocities1_np + shift[None, ...]
velocities2 = wp.array(velocities2_np, dtype=wp.vec3, device=device)

pts2 = wp.array(pts1.numpy() + shift, dtype=wp.vec3, device=device)

mesh.points = vertices2
mesh.velocities = velocities2

query_results_num_cols2 = wp.zeros(n, dtype=wp.int32, device=device)
query_results_min_dist2 = wp.zeros(n, dtype=float, device=device)
query_results_closest_point_velocity2 = wp.array([shift for i in range(n)], dtype=wp.vec3, device=device)

wp.launch(
kernel=point_query_aabb_and_closest,
inputs=[
query_radius,
mesh.id,
pts2,
vertices2,
fs_2D,
query_results_num_cols2,
query_results_min_dist2,
query_results_closest_point_velocity2,
],
dim=n,
device=device,
)

test.assertTrue((query_results_num_cols1.numpy() == query_results_num_cols2.numpy()).all())
test.assertTrue(((query_results_min_dist1.numpy() - query_results_min_dist2.numpy()) < 1e-5).all())


devices = get_test_devices()


Expand Down Expand Up @@ -684,6 +871,7 @@ def kernel_fn(
add_function_test(TestMeshQueryPoint, "test_mesh_query_point", test_mesh_query_point, devices=devices)
add_function_test(TestMeshQueryPoint, "test_mesh_query_furthest_point", test_mesh_query_furthest_point, devices=devices)
add_function_test(TestMeshQueryPoint, "test_adj_mesh_query_point", test_adj_mesh_query_point, devices=devices)
add_function_test(TestMeshQueryPoint, "test_set_mesh_points", test_set_mesh_points, devices=devices)

if __name__ == "__main__":
wp.clear_kernel_cache()
Expand Down
Loading

0 comments on commit 8db4b75

Please sign in to comment.