From 8c71394dce3f4bbdf004209bf1bdf66e1cd40a5a Mon Sep 17 00:00:00 2001 From: hnj21 Date: Wed, 7 Feb 2024 11:35:42 +0000 Subject: [PATCH] fix: fix definition of pareto dominance in compute_pareto_dominance and compute_masked_pareto_dominance --- qdax/utils/pareto_front.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/qdax/utils/pareto_front.py b/qdax/utils/pareto_front.py index 3ece1f64..569c1b1a 100644 --- a/qdax/utils/pareto_front.py +++ b/qdax/utils/pareto_front.py @@ -24,10 +24,6 @@ def compute_pareto_dominance( Return booleans when the vector is dominated by the batch. """ diff = jnp.subtract(batch_of_criteria, criteria_point) - neutral_values = -jnp.ones_like(diff) - diff = jax.vmap(lambda x1, x2: jnp.where(mask, x1, x2), in_axes=(1, 1), out_axes=1)( - neutral_values, diff - ) diff_greater_than_zero = jnp.any(diff > 0, axis=-1) diff_geq_than_zero = jnp.all(diff >= 0, axis=-1) @@ -75,8 +71,10 @@ def compute_masked_pareto_dominance( diff = jax.vmap(lambda x1, x2: jnp.where(mask, x1, x2), in_axes=(1, 1), out_axes=1)( neutral_values, diff ) - return jnp.any(jnp.all(diff > 0, axis=-1)) + diff_greater_than_zero = jnp.any(diff > 0, axis=-1) + diff_geq_than_zero = jnp.all(diff >= 0, axis=-1) + return jnp.any(jnp.logical_and(diff_greater_than_zero, diff_geq_than_zero)) def compute_masked_pareto_front( batch_of_criteria: jnp.ndarray, mask: Mask