diff --git a/qdax/utils/pareto_front.py b/qdax/utils/pareto_front.py index 692a2bde..3ece1f64 100644 --- a/qdax/utils/pareto_front.py +++ b/qdax/utils/pareto_front.py @@ -24,7 +24,15 @@ def compute_pareto_dominance( Return booleans when the vector is dominated by the batch. """ diff = jnp.subtract(batch_of_criteria, criteria_point) - return jnp.any(jnp.all(diff > 0, axis=-1)) + neutral_values = -jnp.ones_like(diff) + diff = jax.vmap(lambda x1, x2: jnp.where(mask, x1, x2), in_axes=(1, 1), out_axes=1)( + neutral_values, diff + ) + diff_greater_than_zero = jnp.any(diff > 0, axis=-1) + diff_geq_than_zero = jnp.all(diff >= 0, axis=-1) + + return jnp.any(jnp.logical_and(diff_greater_than_zero, diff_geq_than_zero)) + def compute_pareto_front(batch_of_criteria: jnp.ndarray) -> jnp.ndarray: