Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Selection-Variation Emitter #89

Draft
wants to merge 16 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 61 additions & 232 deletions qdax/core/emitters/mutation_operators.py
Original file line number Diff line number Diff line change
@@ -1,239 +1,68 @@
"""File defining mutation and crossover functions."""

from functools import partial
from typing import Optional, Tuple

import jax
import jax.numpy as jnp

from qdax.core.containers.repertoire import Repertoire
from qdax.core.emitters.emitter import Emitter, EmitterState
from qdax.core.emitters.selectors.abstract_selector import Selector
from qdax.core.emitters.selectors.uniform import UniformSelector
from qdax.core.emitters.variation_operators.abstract_variation import VariationOperator
from qdax.custom_types import Genotype, RNGKey


def _polynomial_mutation(
x: jnp.ndarray,
random_key: RNGKey,
proportion_to_mutate: float,
eta: float,
minval: float,
maxval: float,
) -> jnp.ndarray:
"""Base polynomial mutation for one genotype.

Proportion to mutate between 0 and 1
Assumed to be of shape (genotype_dim,)

Args:
x: parameters.
random_key: a random key
proportion_to_mutate: the proportion of the given parameters
that need to be mutated.
eta: the inverse of the power of the mutation applied.
minval: range of the perturbation applied by the mutation.
maxval: range of the perturbation applied by the mutation.

Returns:
New parameters.
"""

# Select positions to mutate
num_positions = x.shape[0]
positions = jnp.arange(start=0, stop=num_positions)
num_positions_to_mutate = int(proportion_to_mutate * num_positions)
random_key, subkey = jax.random.split(random_key)
selected_positions = jax.random.choice(
key=subkey, a=positions, shape=(num_positions_to_mutate,), replace=False
)

# New values
mutable_x = x[selected_positions]
delta_1 = (mutable_x - minval) / (maxval - minval)
delta_2 = (maxval - mutable_x) / (maxval - minval)
mutpow = 1.0 / (1.0 + eta)

# Randomly select where to put delta_1 and delta_2
random_key, subkey = jax.random.split(random_key)
rand = jax.random.uniform(
key=subkey,
shape=delta_1.shape,
minval=0,
maxval=1,
dtype=jnp.float32,
)

value1 = 2.0 * rand + (jnp.power(delta_1, 1.0 + eta) * (1.0 - 2.0 * rand))
value2 = 2.0 * (1 - rand) + 2.0 * (jnp.power(delta_2, 1.0 + eta) * (rand - 0.5))
value1 = jnp.power(value1, mutpow) - 1.0
value2 = 1.0 - jnp.power(value2, mutpow)

delta_q = jnp.zeros_like(mutable_x)
delta_q = jnp.where(rand < 0.5, value1, delta_q)
delta_q = jnp.where(rand >= 0.5, value2, delta_q)

# Mutate values
x = x.at[selected_positions].set(mutable_x + (delta_q * (maxval - minval)))

# Back in bounds if necessary (floating point issues)
x = jnp.clip(x, minval, maxval)

return x


def polynomial_mutation(
x: Genotype,
random_key: RNGKey,
proportion_to_mutate: float,
eta: float,
minval: float,
maxval: float,
) -> Tuple[Genotype, RNGKey]:
"""
Polynomial mutation over several genotypes

Parameters:
x: array of genotypes to transform (real values only)
random_key: RNG key for reproducibility.
Assumed to be of shape (batch_size, genotype_dim)
proportion_to_mutate (float): proportion of variables to mutate in
each genotype (must be in [0, 1]).
eta: scaling parameter, the larger the more spread the new
values will be.
minval: minimum value to clip the genotypes.
maxval: maximum value to clip the genotypes.

Returns:
New genotypes - same shape as input and a new RNG key
"""
random_key, subkey = jax.random.split(random_key)
batch_size = jax.tree_util.tree_leaves(x)[0].shape[0]
mutation_key = jax.random.split(subkey, num=batch_size)
mutation_fn = partial(
_polynomial_mutation,
proportion_to_mutate=proportion_to_mutate,
eta=eta,
minval=minval,
maxval=maxval,
)
mutation_fn = jax.vmap(mutation_fn)
x = jax.tree_util.tree_map(lambda x_: mutation_fn(x_, mutation_key), x)
return x, random_key


def _polynomial_crossover(
x1: jnp.ndarray,
x2: jnp.ndarray,
random_key: RNGKey,
proportion_var_to_change: float,
) -> jnp.ndarray:
"""
Base crossover for one pair of genotypes.

x1 and x2 should have the same shape
In this function we assume x1 shape and x2 shape to be (genotype_dim,)
"""
num_var_to_change = int(proportion_var_to_change * x1.shape[0])
indices = jnp.arange(start=0, stop=x1.shape[0])
selected_indices = jax.random.choice(
random_key, indices, shape=(num_var_to_change,)
)
x = x1.at[selected_indices].set(x2[selected_indices])
return x


def polynomial_crossover(
x1: Genotype,
x2: Genotype,
random_key: RNGKey,
proportion_var_to_change: float,
) -> Tuple[Genotype, RNGKey]:
"""
Crossover over a set of pairs of genotypes.

Batched version of _simple_crossover_function
x1 and x2 should have the same shape
In this function we assume x1 shape and x2 shape to be
(batch_size, genotype_dim)

Parameters:
x1: first batch of genotypes
x2: second batch of genotypes
random_key: RNG key for reproducibility
proportion_var_to_change: proportion of variables to exchange
between genotypes (must be [0, 1])

Returns:
New genotypes and a new RNG key
"""

random_key, subkey = jax.random.split(random_key)
batch_size = jax.tree_util.tree_leaves(x2)[0].shape[0]
crossover_keys = jax.random.split(subkey, num=batch_size)
crossover_fn = partial(
_polynomial_crossover,
proportion_var_to_change=proportion_var_to_change,
)
crossover_fn = jax.vmap(crossover_fn)
# TODO: check that key usage is correct
x = jax.tree_util.tree_map(
lambda x1_, x2_: crossover_fn(x1_, x2_, crossover_keys), x1, x2
)
return x, random_key


def isoline_variation(
x1: Genotype,
x2: Genotype,
random_key: RNGKey,
iso_sigma: float,
line_sigma: float,
minval: Optional[float] = None,
maxval: Optional[float] = None,
) -> Tuple[Genotype, RNGKey]:
"""
Iso+Line-DD Variation Operator [1] over a set of pairs of genotypes

Parameters:
x1 (Genotypes): first batch of genotypes
x2 (Genotypes): second batch of genotypes
random_key (RNGKey): RNG key for reproducibility
iso_sigma (float): spread parameter (noise)
line_sigma (float): line parameter (direction of the new genotype)
minval (float, Optional): minimum value to clip the genotypes
maxval (float, Optional): maximum value to clip the genotypes

Returns:
x (Genotypes): new genotypes
random_key (RNGKey): new RNG key

[1] Vassiliades, Vassilis, and Jean-Baptiste Mouret. "Discovering the elite
hypervolume by leveraging interspecies correlation." Proceedings of the Genetic and
Evolutionary Computation Conference. 2018.
"""

# Computing line_noise
random_key, key_line_noise = jax.random.split(random_key)
batch_size = jax.tree_util.tree_leaves(x1)[0].shape[0]
line_noise = jax.random.normal(key_line_noise, shape=(batch_size,)) * line_sigma

def _variation_fn(
x1: jnp.ndarray, x2: jnp.ndarray, random_key: RNGKey
) -> jnp.ndarray:
iso_noise = jax.random.normal(random_key, shape=x1.shape) * iso_sigma
x = (x1 + iso_noise) + jax.vmap(jnp.multiply)((x2 - x1), line_noise)

# Back in bounds if necessary (floating point issues)
if (minval is not None) or (maxval is not None):
x = jnp.clip(x, minval, maxval)
return x

# create a tree with random keys
nb_leaves = len(jax.tree_util.tree_leaves(x1))
random_key, subkey = jax.random.split(random_key)
subkeys = jax.random.split(subkey, num=nb_leaves)
keys_tree = jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(x1), subkeys)

# apply isolinedd to each branch of the tree
x = jax.tree_util.tree_map(
lambda y1, y2, key: _variation_fn(y1, y2, key), x1, x2, keys_tree
)

return x, random_key
class SelectionVariationEmitter(Emitter):
def __init__(
self,
batch_size: int,
variation_operator: VariationOperator,
selector: Optional[Selector] = None,
):
"""
Emitter that selects a batch of genotypes from the repertoire and applies
a variation operator to them.

Args:
batch_size: number of genotypes to select from the repertoire
variation_operator: operator to apply to the selected genotypes
selector: selector to use to select the genotypes. Defaults to
UniformSelector.
"""
self._batch_size = batch_size
self._variation_operator = variation_operator

if selector is not None:
self._selector = selector
else:
self._selector = UniformSelector()

def emit(
self,
repertoire: Optional[Repertoire],
emitter_state: Optional[EmitterState],
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""
Select a batch of genotypes from the repertoire and apply a variation
operator to them.

Args:
repertoire: repertoire to select genotypes from
emitter_state: state of the emitter
random_key: random key to handle stochasticity

Returns:
The new genotypes and the updated random key
"""

number_parents_to_select = (
self._variation_operator.calculate_number_parents_to_select(
self._batch_size
)
)
genotypes, emitter_state, random_key = self._selector.select(
number_parents_to_select, repertoire, emitter_state, random_key
)
new_genotypes, random_key = self._variation_operator.apply_with_clip(
genotypes, emitter_state, random_key
)
return new_genotypes, random_key
Empty file.
17 changes: 17 additions & 0 deletions qdax/core/emitters/selectors/abstract_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import abc
from typing import Tuple

from qdax.core.containers.repertoire import Repertoire
from qdax.core.emitters.emitter import EmitterState
from qdax.custom_types import Genotype, RNGKey


class Selector(metaclass=abc.ABCMeta):
@abc.abstractmethod
def select(
self,
number_parents_to_select: int,
repertoire: Repertoire,
emitter_state: EmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, EmitterState, RNGKey]: ...
79 changes: 79 additions & 0 deletions qdax/core/emitters/selectors/novelty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Tuple

import jax
import jax.numpy as jnp

from qdax.core.containers.repertoire import Repertoire
from qdax.core.emitters.emitter import EmitterState
from qdax.core.emitters.selectors.abstract_selector import Selector
from qdax.custom_types import Genotype, RNGKey


class NoveltySelector(Selector):
def __init__(self, num_nn: int):
self._num_nn = num_nn

def select(
self,
number_parents_to_select: int,
repertoire: Repertoire,
emitter_state: EmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, EmitterState, RNGKey]:
"""
Novelty-based selection of parents
"""

repertoire_genotypes = repertoire.genotypes
fitnesses = repertoire.fitnesses
descriptors = repertoire.descriptors

num_genotypes = descriptors.shape[0]
repertoire_empty = fitnesses == -jnp.inf

# calculate novelty score using the k-nearest-neighbors
v_dist = jax.vmap(lambda x, y: jnp.linalg.norm(x - y), in_axes=(0, None))
vv_dist = jax.vmap(v_dist, in_axes=(None, 0))

# Matrix of distances between all genotypes
distances = vv_dist(descriptors, descriptors)

inf_mask = jnp.logical_or(
jnp.tile(repertoire_empty.reshape(1, -1), (num_genotypes, 1)),
jnp.tile(repertoire_empty.reshape(1, -1), (num_genotypes, 1)).T,
)
distances = jnp.where(inf_mask, +jnp.inf, distances)
distances = jnp.where(
jnp.eye(num_genotypes) == 1, 0, distances
) # set diagonal to 0

# Calculate novelty scores
closest_distances, _ = jax.vmap(jax.lax.top_k, in_axes=(0, None))(
distances, self._num_nn + 1
)
closest_distances = jnp.where(
jnp.isinf(closest_distances), 0, closest_distances
)
novelty_scores = jax.vmap(lambda x: jnp.sum(x) / self._num_nn)(
closest_distances
)

nonempty_novelty_scores = novelty_scores[~repertoire_empty]
novelty_scores = jnp.where(
repertoire_empty, jnp.min(nonempty_novelty_scores), novelty_scores
)

# probability of selecting each genotype
p = novelty_scores - jnp.min(novelty_scores)
p = p / jnp.sum()

# select parents
random_key, subkey = jax.random.split(random_key)
selected_parents = jax.tree_util.tree_map(
lambda x: jax.random.choice(
subkey, x, shape=(number_parents_to_select,), p=p
),
repertoire_genotypes,
)

return selected_parents, emitter_state, random_key
Loading
Loading