Skip to content

Commit

Permalink
Update validation routines
Browse files Browse the repository at this point in the history
  • Loading branch information
clbarnes committed Jan 16, 2024
1 parent aa03122 commit 9a09434
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 44 deletions.
4 changes: 2 additions & 2 deletions python/ncollpyde/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from .main import INDEX # noqa: F401
from .main import N_CPUS # noqa: F401
from .main import PRECISION # noqa: F401
from .main import Volume # noqa: F401
from .main import Volume, Validation
from .main import configure_threadpool # noqa: F401
from ._ncollpyde import n_threads # noqa: F401
from ._ncollpyde import _version

__version__ = _version()
__version_info__ = tuple(int(n) for n in __version__.split("-")[0].split("."))

__all__ = ["Volume"]
__all__ = ["Volume", "Validation"]
7 changes: 6 additions & 1 deletion python/ncollpyde/_ncollpyde.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ Indices = npt.NDArray[np.uint32]

class TriMeshWrapper:
def __init__(
self, points: Points, indices: Indices, n_rays: int, ray_seed: int
self,
points: Points,
indices: Indices,
n_rays: int,
ray_seed: int,
validate: int,
): ...
def contains(
self, points: Points, n_rays: int, consensus: int, parallel: bool
Expand Down
64 changes: 57 additions & 7 deletions python/ncollpyde/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from multiprocessing import cpu_count
from typing import TYPE_CHECKING, Optional, Tuple, Union, List
from enum import IntFlag, auto

import numpy as np
from numpy.typing import ArrayLike, NDArray
Expand All @@ -29,6 +30,41 @@
INDEX = np.dtype(_index())


class Validation(IntFlag):
"""Enum representing the different validations which can be applied.
Combine with `|`.
Must contain ORIENTED (see `Validation.minimum()`).
"""
HALF_EDGE_TOPOLOGY = 1
CONNECTED_COMPONENTS = 2
DELETE_BAD_TOPOLOGY_TRIANGLES = 4
ORIENTED = 8
MERGE_DUPLICATE_VERTICES = 16
MERGE_DEGENERATE_TRIANGLES = 32
MERGE_DUPLICATE_TRIANGLES = 64

@classmethod
def minimum(cls):
return cls.ORIENTED

@classmethod
def default(cls):
return cls.all()

@classmethod
def all(cls):
return (
cls.HALF_EDGE_TOPOLOGY
| cls.CONNECTED_COMPONENTS
| cls.DELETE_BAD_TOPOLOGY_TRIANGLES
| cls.ORIENTED
| cls.MERGE_DUPLICATE_VERTICES
| cls.MERGE_DEGENERATE_TRIANGLES
| cls.MERGE_DUPLICATE_TRIANGLES
)


def configure_threadpool(n_threads: Optional[int], name_prefix: Optional[str]):
"""Configure the thread pool used for parallelisation.
Expand Down Expand Up @@ -85,7 +121,7 @@ def __init__(
self,
vertices: ArrayLike,
triangles: ArrayLike,
validate=False,
validate: Validation = Validation.default(),
threads: Optional[bool] = None,
n_rays=DEFAULT_RAYS,
ray_seed=DEFAULT_SEED,
Expand Down Expand Up @@ -114,10 +150,24 @@ def __init__(
"""
vert = np.asarray(vertices, self.dtype)
if len(vert) > np.iinfo(INDEX).max:
raise ValueError(f"Cannot represent {len(vert)} vertices with {INDEX}")
raise ValueError(
f"Cannot represent {len(vert)} vertices with {INDEX}"
)
tri = np.asarray(triangles, INDEX)
if validate:
vert, tri = self._validate(vert, tri)
if isinstance(validate, bool):
warnings.warn(
"`validate: bool` should be replaced by a Validation enum "
"controlling validation by this library. "
"The previous behaviour of validating using the "
"external trimesh library may be removed in future.",
DeprecationWarning,
)
if validate:
vert, tri = self._validate(vert, tri)
validate = Validation.all()
else:
validate = Validation.minimum()

self.threads = self._interpret_threads(threads)
if ray_seed is None:
logger.warning(
Expand All @@ -129,7 +179,7 @@ def __init__(
self.n_rays = int(n_rays)
inner_rays = 0 if self.n_rays < 0 else self.n_rays

self._impl = TriMeshWrapper(vert, tri, inner_rays, ray_seed)
self._impl = TriMeshWrapper(vert, tri, inner_rays, ray_seed, validate)

def _validate(
self, vertices: np.ndarray, triangles: np.ndarray
Expand Down Expand Up @@ -358,7 +408,7 @@ def intersections(
def from_meshio(
cls,
mesh: "meshio.Mesh",
validate=False,
validate: Validation = Validation.default(),
threads=None,
n_rays=DEFAULT_RAYS,
ray_seed=DEFAULT_SEED,
Expand All @@ -367,7 +417,7 @@ def from_meshio(
Convenience function for instantiating a Volume from a meshio Mesh.
:param mesh: meshio Mesh whose only cells are triangles.
:param validate: as passed to ``__init__``, defaults to False
:param validate: as passed to ``__init__``, defaults to Validation.default()
:param threads: as passed to ``__init__``, defaults to None
:param n_rays: as passed to ``__init__``, defaults to 3
:param ray_seed: as passed to ``__init__``, defaults to None (random)
Expand Down
14 changes: 11 additions & 3 deletions src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use ndarray::{Array2, ArrayView1};
use numpy::ndarray::{Array, Zip};
use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray2};
use parry3d_f64::math::{Point, Vector};
use parry3d_f64::shape::TriMesh;
use parry3d_f64::shape::{TriMesh, TriMeshFlags};
use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use rand::SeedableRng;
Expand All @@ -13,7 +13,7 @@ use rayon::{prelude::*, ThreadPoolBuilder};

use crate::utils::{
aabb_diag, dist_from_mesh, mesh_contains_point, mesh_contains_point_oriented,
points_cross_mesh, random_dir, sdf_inner, Precision, FLAGS,
points_cross_mesh, random_dir, sdf_inner, Precision,
};

// fn vec_to_point<T: 'static + Debug + PartialEq + Copy>(v: Vec<T>) -> Point<T> {
Expand All @@ -38,6 +38,7 @@ impl TriMeshWrapper {
indices: PyReadonlyArray2<u32>,
n_rays: usize,
ray_seed: u64,
validate: u8,
) -> PyResult<Self> {
let points2 = points
.as_array()
Expand All @@ -53,7 +54,14 @@ impl TriMeshWrapper {
.collect();
let mut mesh = TriMesh::new(points2, indices2);

mesh.set_flags(FLAGS)
let flags = TriMeshFlags::from_bits(validate)
.ok_or(PyValueError::new_err("Invalid `validate` enum"))?;
if !flags.contains(TriMeshFlags::ORIENTED) {
return Err(PyValueError::new_err(
"`validate` enum must contain ORIENTED",
));
}
mesh.set_flags(flags)
.map_err(|e| PyValueError::new_err(format!("Invalid mesh topology: {e}")))?;

if n_rays > 0 {
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import meshio
import pytest

from ncollpyde import Volume
from ncollpyde import Volume, Validation

test_dir = Path(__file__).resolve().parent
project_dir = test_dir.parent
Expand All @@ -17,7 +17,7 @@ def mesh():

@pytest.fixture
def volume(mesh):
return Volume.from_meshio(mesh, validate=True)
return Volume.from_meshio(mesh, validate=Validation.all())


@pytest.fixture
Expand All @@ -27,11 +27,11 @@ def simple_mesh():

@pytest.fixture
def simple_volume(simple_mesh):
return Volume.from_meshio(simple_mesh, validate=True)
return Volume.from_meshio(simple_mesh, validate=Validation.all())


@pytest.fixture
def sez_right():
return Volume.from_meshio(
meshio.read(str(mesh_dir / "SEZ_right.stl")), validate=True
meshio.read(str(mesh_dir / "SEZ_right.stl")), validate=Validation.all()
)
54 changes: 27 additions & 27 deletions tests/test_ncollpyde.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,44 +57,44 @@ def test_contains_results(volume: Volume):
assert np.allclose(ray, psnorms)


def test_no_validation(mesh):
triangles = mesh.cells_dict["triangle"]
Volume(mesh.points, triangles, True)
# def test_no_validation(mesh):
# triangles = mesh.cells_dict["triangle"]
# Volume(mesh.points, triangles, True)


@pytest.mark.skipif(not trimesh, reason="Requires trimesh")
def test_can_repair_hole(mesh):
triangles = mesh.cells_dict["triangle"]
triangles = triangles[:-1]
Volume(mesh.points, triangles, True)
# @pytest.mark.skipif(not trimesh, reason="Requires trimesh")
# def test_can_repair_hole(mesh):
# triangles = mesh.cells_dict["triangle"]
# triangles = triangles[:-1]
# Volume(mesh.points, triangles, True)


@pytest.mark.skipif(not trimesh, reason="Requires trimesh")
def test_can_repair_inversion(mesh):
triangles = mesh.cells_dict["triangle"]
triangles[-1] = triangles[-1, ::-1]
Volume(mesh.points, triangles, True)
# @pytest.mark.skipif(not trimesh, reason="Requires trimesh")
# def test_can_repair_inversion(mesh):
# triangles = mesh.cells_dict["triangle"]
# triangles[-1] = triangles[-1, ::-1]
# Volume(mesh.points, triangles, True)


@pytest.mark.skipif(not trimesh, reason="Requires trimesh")
def test_can_repair_inversions(mesh):
triangles = mesh.cells_dict["triangle"]
triangles = triangles[:, ::-1]
Volume(mesh.points, triangles, True)
# @pytest.mark.skipif(not trimesh, reason="Requires trimesh")
# def test_can_repair_inversions(mesh):
# triangles = mesh.cells_dict["triangle"]
# triangles = triangles[:, ::-1]
# Volume(mesh.points, triangles, True)


@pytest.mark.skipif(not trimesh, reason="Requires trimesh")
def test_inversions_repaired(simple_mesh):
center = [0.5, 0.5, 0.5]
# @pytest.mark.skipif(not trimesh, reason="Requires trimesh")
# def test_inversions_repaired(simple_mesh):
# center = [0.5, 0.5, 0.5]

orig_points = simple_mesh.points
orig_triangles = simple_mesh.cells_dict["triangle"]
assert center in Volume(orig_points, orig_triangles)
# orig_points = simple_mesh.points
# orig_triangles = simple_mesh.cells_dict["triangle"]
# assert center in Volume(orig_points, orig_triangles)

inv_triangles = orig_triangles[:, ::-1]
assert center not in Volume(orig_points, inv_triangles)
# inv_triangles = orig_triangles[:, ::-1]
# assert center not in Volume(orig_points, inv_triangles)

assert center in Volume(orig_points, inv_triangles, validate=True)
# assert center in Volume(orig_points, inv_triangles, validate=True)


def test_points(mesh):
Expand Down

0 comments on commit 9a09434

Please sign in to comment.