diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 2d75805d0..af29eb038 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -603,7 +603,7 @@ class _PositiveDefinite(_SingletonConstraint): def __call__(self, x): jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric - symmetric = jnp.all(jnp.all(x == jnp.swapaxes(x, -2, -1), axis=-1), axis=-1) + symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) # check for the smallest eigenvalue is positive positive = jnp.linalg.eigh(x)[0][..., 0] > 0 return symmetric & positive