diff --git a/python/ncollpyde/__init__.py b/python/ncollpyde/__init__.py index 96614cc..04df5b1 100644 --- a/python/ncollpyde/__init__.py +++ b/python/ncollpyde/__init__.py @@ -11,7 +11,7 @@ from .main import INDEX # noqa: F401 from .main import N_CPUS # noqa: F401 from .main import PRECISION # noqa: F401 -from .main import Volume # noqa: F401 +from .main import Volume, Validation from .main import configure_threadpool # noqa: F401 from ._ncollpyde import n_threads # noqa: F401 from ._ncollpyde import _version @@ -19,4 +19,4 @@ __version__ = _version() __version_info__ = tuple(int(n) for n in __version__.split("-")[0].split(".")) -__all__ = ["Volume"] +__all__ = ["Volume", "Validation"] diff --git a/python/ncollpyde/_ncollpyde.pyi b/python/ncollpyde/_ncollpyde.pyi index 3fe119f..a58cd90 100644 --- a/python/ncollpyde/_ncollpyde.pyi +++ b/python/ncollpyde/_ncollpyde.pyi @@ -14,7 +14,12 @@ Indices = npt.NDArray[np.uint32] class TriMeshWrapper: def __init__( - self, points: Points, indices: Indices, n_rays: int, ray_seed: int + self, + points: Points, + indices: Indices, + n_rays: int, + ray_seed: int, + validate: int, ): ... def contains( self, points: Points, n_rays: int, consensus: int, parallel: bool diff --git a/python/ncollpyde/main.py b/python/ncollpyde/main.py index 54efe1e..1cc9ca8 100644 --- a/python/ncollpyde/main.py +++ b/python/ncollpyde/main.py @@ -3,6 +3,7 @@ import warnings from multiprocessing import cpu_count from typing import TYPE_CHECKING, Optional, Tuple, Union, List +from enum import IntFlag, auto import numpy as np from numpy.typing import ArrayLike, NDArray @@ -29,6 +30,41 @@ INDEX = np.dtype(_index()) +class Validation(IntFlag): + """Enum representing the different validations which can be applied. + + Combine with `|`. + Must contain ORIENTED (see `Validation.minimum()`). + """ + HALF_EDGE_TOPOLOGY = 1 + CONNECTED_COMPONENTS = 2 + DELETE_BAD_TOPOLOGY_TRIANGLES = 4 + ORIENTED = 8 + MERGE_DUPLICATE_VERTICES = 16 + MERGE_DEGENERATE_TRIANGLES = 32 + MERGE_DUPLICATE_TRIANGLES = 64 + + @classmethod + def minimum(cls): + return cls.ORIENTED + + @classmethod + def default(cls): + return cls.all() + + @classmethod + def all(cls): + return ( + cls.HALF_EDGE_TOPOLOGY + | cls.CONNECTED_COMPONENTS + | cls.DELETE_BAD_TOPOLOGY_TRIANGLES + | cls.ORIENTED + | cls.MERGE_DUPLICATE_VERTICES + | cls.MERGE_DEGENERATE_TRIANGLES + | cls.MERGE_DUPLICATE_TRIANGLES + ) + + def configure_threadpool(n_threads: Optional[int], name_prefix: Optional[str]): """Configure the thread pool used for parallelisation. @@ -85,7 +121,7 @@ def __init__( self, vertices: ArrayLike, triangles: ArrayLike, - validate=False, + validate: Validation = Validation.default(), threads: Optional[bool] = None, n_rays=DEFAULT_RAYS, ray_seed=DEFAULT_SEED, @@ -114,10 +150,24 @@ def __init__( """ vert = np.asarray(vertices, self.dtype) if len(vert) > np.iinfo(INDEX).max: - raise ValueError(f"Cannot represent {len(vert)} vertices with {INDEX}") + raise ValueError( + f"Cannot represent {len(vert)} vertices with {INDEX}" + ) tri = np.asarray(triangles, INDEX) - if validate: - vert, tri = self._validate(vert, tri) + if isinstance(validate, bool): + warnings.warn( + "`validate: bool` should be replaced by a Validation enum " + "controlling validation by this library. " + "The previous behaviour of validating using the " + "external trimesh library may be removed in future.", + DeprecationWarning, + ) + if validate: + vert, tri = self._validate(vert, tri) + validate = Validation.all() + else: + validate = Validation.minimum() + self.threads = self._interpret_threads(threads) if ray_seed is None: logger.warning( @@ -129,7 +179,7 @@ def __init__( self.n_rays = int(n_rays) inner_rays = 0 if self.n_rays < 0 else self.n_rays - self._impl = TriMeshWrapper(vert, tri, inner_rays, ray_seed) + self._impl = TriMeshWrapper(vert, tri, inner_rays, ray_seed, validate) def _validate( self, vertices: np.ndarray, triangles: np.ndarray @@ -358,7 +408,7 @@ def intersections( def from_meshio( cls, mesh: "meshio.Mesh", - validate=False, + validate: Validation = Validation.default(), threads=None, n_rays=DEFAULT_RAYS, ray_seed=DEFAULT_SEED, @@ -367,7 +417,7 @@ def from_meshio( Convenience function for instantiating a Volume from a meshio Mesh. :param mesh: meshio Mesh whose only cells are triangles. - :param validate: as passed to ``__init__``, defaults to False + :param validate: as passed to ``__init__``, defaults to Validation.default() :param threads: as passed to ``__init__``, defaults to None :param n_rays: as passed to ``__init__``, defaults to 3 :param ray_seed: as passed to ``__init__``, defaults to None (random) diff --git a/src/interface.rs b/src/interface.rs index 81d39f8..5fee118 100644 --- a/src/interface.rs +++ b/src/interface.rs @@ -4,7 +4,7 @@ use ndarray::{Array2, ArrayView1}; use numpy::ndarray::{Array, Zip}; use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray2}; use parry3d_f64::math::{Point, Vector}; -use parry3d_f64::shape::TriMesh; +use parry3d_f64::shape::{TriMesh, TriMeshFlags}; use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; use rand::SeedableRng; @@ -13,7 +13,7 @@ use rayon::{prelude::*, ThreadPoolBuilder}; use crate::utils::{ aabb_diag, dist_from_mesh, mesh_contains_point, mesh_contains_point_oriented, - points_cross_mesh, random_dir, sdf_inner, Precision, FLAGS, + points_cross_mesh, random_dir, sdf_inner, Precision, }; // fn vec_to_point(v: Vec) -> Point { @@ -38,6 +38,7 @@ impl TriMeshWrapper { indices: PyReadonlyArray2, n_rays: usize, ray_seed: u64, + validate: u8, ) -> PyResult { let points2 = points .as_array() @@ -53,7 +54,14 @@ impl TriMeshWrapper { .collect(); let mut mesh = TriMesh::new(points2, indices2); - mesh.set_flags(FLAGS) + let flags = TriMeshFlags::from_bits(validate) + .ok_or(PyValueError::new_err("Invalid `validate` enum"))?; + if !flags.contains(TriMeshFlags::ORIENTED) { + return Err(PyValueError::new_err( + "`validate` enum must contain ORIENTED", + )); + } + mesh.set_flags(flags) .map_err(|e| PyValueError::new_err(format!("Invalid mesh topology: {e}")))?; if n_rays > 0 { diff --git a/tests/conftest.py b/tests/conftest.py index 9cd3fb2..365ca99 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import meshio import pytest -from ncollpyde import Volume +from ncollpyde import Volume, Validation test_dir = Path(__file__).resolve().parent project_dir = test_dir.parent @@ -17,7 +17,7 @@ def mesh(): @pytest.fixture def volume(mesh): - return Volume.from_meshio(mesh, validate=True) + return Volume.from_meshio(mesh, validate=Validation.all()) @pytest.fixture @@ -27,11 +27,11 @@ def simple_mesh(): @pytest.fixture def simple_volume(simple_mesh): - return Volume.from_meshio(simple_mesh, validate=True) + return Volume.from_meshio(simple_mesh, validate=Validation.all()) @pytest.fixture def sez_right(): return Volume.from_meshio( - meshio.read(str(mesh_dir / "SEZ_right.stl")), validate=True + meshio.read(str(mesh_dir / "SEZ_right.stl")), validate=Validation.all() ) diff --git a/tests/test_ncollpyde.py b/tests/test_ncollpyde.py index d4b9f3a..c8c0455 100644 --- a/tests/test_ncollpyde.py +++ b/tests/test_ncollpyde.py @@ -57,44 +57,44 @@ def test_contains_results(volume: Volume): assert np.allclose(ray, psnorms) -def test_no_validation(mesh): - triangles = mesh.cells_dict["triangle"] - Volume(mesh.points, triangles, True) +# def test_no_validation(mesh): +# triangles = mesh.cells_dict["triangle"] +# Volume(mesh.points, triangles, True) -@pytest.mark.skipif(not trimesh, reason="Requires trimesh") -def test_can_repair_hole(mesh): - triangles = mesh.cells_dict["triangle"] - triangles = triangles[:-1] - Volume(mesh.points, triangles, True) +# @pytest.mark.skipif(not trimesh, reason="Requires trimesh") +# def test_can_repair_hole(mesh): +# triangles = mesh.cells_dict["triangle"] +# triangles = triangles[:-1] +# Volume(mesh.points, triangles, True) -@pytest.mark.skipif(not trimesh, reason="Requires trimesh") -def test_can_repair_inversion(mesh): - triangles = mesh.cells_dict["triangle"] - triangles[-1] = triangles[-1, ::-1] - Volume(mesh.points, triangles, True) +# @pytest.mark.skipif(not trimesh, reason="Requires trimesh") +# def test_can_repair_inversion(mesh): +# triangles = mesh.cells_dict["triangle"] +# triangles[-1] = triangles[-1, ::-1] +# Volume(mesh.points, triangles, True) -@pytest.mark.skipif(not trimesh, reason="Requires trimesh") -def test_can_repair_inversions(mesh): - triangles = mesh.cells_dict["triangle"] - triangles = triangles[:, ::-1] - Volume(mesh.points, triangles, True) +# @pytest.mark.skipif(not trimesh, reason="Requires trimesh") +# def test_can_repair_inversions(mesh): +# triangles = mesh.cells_dict["triangle"] +# triangles = triangles[:, ::-1] +# Volume(mesh.points, triangles, True) -@pytest.mark.skipif(not trimesh, reason="Requires trimesh") -def test_inversions_repaired(simple_mesh): - center = [0.5, 0.5, 0.5] +# @pytest.mark.skipif(not trimesh, reason="Requires trimesh") +# def test_inversions_repaired(simple_mesh): +# center = [0.5, 0.5, 0.5] - orig_points = simple_mesh.points - orig_triangles = simple_mesh.cells_dict["triangle"] - assert center in Volume(orig_points, orig_triangles) +# orig_points = simple_mesh.points +# orig_triangles = simple_mesh.cells_dict["triangle"] +# assert center in Volume(orig_points, orig_triangles) - inv_triangles = orig_triangles[:, ::-1] - assert center not in Volume(orig_points, inv_triangles) +# inv_triangles = orig_triangles[:, ::-1] +# assert center not in Volume(orig_points, inv_triangles) - assert center in Volume(orig_points, inv_triangles, validate=True) +# assert center in Volume(orig_points, inv_triangles, validate=True) def test_points(mesh):