From 5994b1de71d77ba3585daa9754bcf737934e4637 Mon Sep 17 00:00:00 2001 From: hnj21 Date: Mon, 29 Jan 2024 13:13:36 +0000 Subject: [PATCH 1/4] fix/ fix definition of pareto dominance to allow for cases where fitness values are same along one axis --- qdax/utils/pareto_front.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/qdax/utils/pareto_front.py b/qdax/utils/pareto_front.py index 692a2bde..3ece1f64 100644 --- a/qdax/utils/pareto_front.py +++ b/qdax/utils/pareto_front.py @@ -24,7 +24,15 @@ def compute_pareto_dominance( Return booleans when the vector is dominated by the batch. """ diff = jnp.subtract(batch_of_criteria, criteria_point) - return jnp.any(jnp.all(diff > 0, axis=-1)) + neutral_values = -jnp.ones_like(diff) + diff = jax.vmap(lambda x1, x2: jnp.where(mask, x1, x2), in_axes=(1, 1), out_axes=1)( + neutral_values, diff + ) + diff_greater_than_zero = jnp.any(diff > 0, axis=-1) + diff_geq_than_zero = jnp.all(diff >= 0, axis=-1) + + return jnp.any(jnp.logical_and(diff_greater_than_zero, diff_geq_than_zero)) + def compute_pareto_front(batch_of_criteria: jnp.ndarray) -> jnp.ndarray: From 8c71394dce3f4bbdf004209bf1bdf66e1cd40a5a Mon Sep 17 00:00:00 2001 From: hnj21 Date: Wed, 7 Feb 2024 11:35:42 +0000 Subject: [PATCH 2/4] fix: fix definition of pareto dominance in compute_pareto_dominance and compute_masked_pareto_dominance --- qdax/utils/pareto_front.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/qdax/utils/pareto_front.py b/qdax/utils/pareto_front.py index 3ece1f64..569c1b1a 100644 --- a/qdax/utils/pareto_front.py +++ b/qdax/utils/pareto_front.py @@ -24,10 +24,6 @@ def compute_pareto_dominance( Return booleans when the vector is dominated by the batch. """ diff = jnp.subtract(batch_of_criteria, criteria_point) - neutral_values = -jnp.ones_like(diff) - diff = jax.vmap(lambda x1, x2: jnp.where(mask, x1, x2), in_axes=(1, 1), out_axes=1)( - neutral_values, diff - ) diff_greater_than_zero = jnp.any(diff > 0, axis=-1) diff_geq_than_zero = jnp.all(diff >= 0, axis=-1) @@ -75,8 +71,10 @@ def compute_masked_pareto_dominance( diff = jax.vmap(lambda x1, x2: jnp.where(mask, x1, x2), in_axes=(1, 1), out_axes=1)( neutral_values, diff ) - return jnp.any(jnp.all(diff > 0, axis=-1)) + diff_greater_than_zero = jnp.any(diff > 0, axis=-1) + diff_geq_than_zero = jnp.all(diff >= 0, axis=-1) + return jnp.any(jnp.logical_and(diff_greater_than_zero, diff_geq_than_zero)) def compute_masked_pareto_front( batch_of_criteria: jnp.ndarray, mask: Mask From 572dc18152329363214275f65430c7203cb0432d Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 8 Feb 2024 11:16:06 +0000 Subject: [PATCH 3/4] Fix style (missing line) --- qdax/utils/pareto_front.py | 1 + 1 file changed, 1 insertion(+) diff --git a/qdax/utils/pareto_front.py b/qdax/utils/pareto_front.py index 569c1b1a..42211bd5 100644 --- a/qdax/utils/pareto_front.py +++ b/qdax/utils/pareto_front.py @@ -76,6 +76,7 @@ def compute_masked_pareto_dominance( return jnp.any(jnp.logical_and(diff_greater_than_zero, diff_geq_than_zero)) + def compute_masked_pareto_front( batch_of_criteria: jnp.ndarray, mask: Mask ) -> jnp.ndarray: From f1a2390b622a16b603bc8e21e9ad42657807f263 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 8 Feb 2024 11:17:01 +0000 Subject: [PATCH 4/4] Fix style (3 lines instead of 2) --- qdax/utils/pareto_front.py | 1 - 1 file changed, 1 deletion(-) diff --git a/qdax/utils/pareto_front.py b/qdax/utils/pareto_front.py index 42211bd5..f9bd77ae 100644 --- a/qdax/utils/pareto_front.py +++ b/qdax/utils/pareto_front.py @@ -30,7 +30,6 @@ def compute_pareto_dominance( return jnp.any(jnp.logical_and(diff_greater_than_zero, diff_geq_than_zero)) - def compute_pareto_front(batch_of_criteria: jnp.ndarray) -> jnp.ndarray: """Returns an array of boolean that states for each element if it is in the pareto front or not.