Skip to content

Commit

Permalink
Start implementing RewardFn
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Dec 20, 2023
1 parent 56f3b02 commit 1a27274
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
65 changes: 65 additions & 0 deletions src/emevo/reward_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Example of using circle foraging environment"""
from __future__ import annotations

import abc
from typing import Callable, Protocol

import chex
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from numpy.typing import NDArray

from emevo import genetic_ops as gops


class RewardFn(abc.ABC, eqx.Module):
@abc.abstractmethod
def as_logdict(self) -> dict[str, float | NDArray]:
pass


class LinearReward(RewardFn):
weight: jax.Array
extractor: Callable[..., jax.Array]
serializer: Callable[[jax.Array], jax.Array]

def __init__(
self,
key: chex.PRNGKey,
n_agents: int,
extractor: Callable[..., jax.Array],
) -> None:
self.weight = jax.random.normal(key, (n_agents, 4))
self.extractor = extractor

def __call__(self, *args) -> jax.Array:
extracted = self.extractor(*args)
return jax.vmap(jnp.dot)(extracted, self.weight)

def as_logdict(self) -> dict[str, float | NDArray]:
return {""}


def mutate_reward_fn(
key: chex.PRNGKey,
reward_fn_dict: dict[int, eqx.Module],
old: eqx.Module,
mutation: gops.Mutation,
parents: jax.Array,
unique_id: jax.Array,
) -> eqx.Module:
# new[i] := old[i] if i not in parents
# new[i] := mutation(old[parents[i]]) if i in parents
is_parent = parents != -1
if not jnp.any(is_parent):
return old
dynamic_net, static_net = eqx.partition(old, eqx.is_array)
keys = jax.random.split(key, jnp.sum(is_parent).item())
for i, key in zip(jnp.nonzero(is_parent)[0], keys):
parent_reward_fn = reward_fn_dict[parents[i]]
mutated_dnet = mutation(key, parent_reward_fn)
reward_fn_dict[unique_id[i]] = eqx.combine(mutated_dnet, static_net)
dynamic_net = jax.tree_map(lambda arr: arr[i].set(mutated_dnet), dynamic_net)
return eqx.combine(dynamic_net, static_net)
18 changes: 18 additions & 0 deletions tests/test_reward_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import jax.numpy as jnp



from emevo.reward_fn import init_status


def test_status_clipping(n: int, capacity: float) -> None:
status = init_status(max_n=n, init_energy=0.0)
for _ in range(200):
status.update(energy_delta=jnp.ones(n), capacity=capacity)
assert jnp.all(status.energy >= 0.0)
assert jnp.all(status.energy <= capacity)

for _ in range(300):
status.update(energy_delta=jnp.ones(n) * -1.0, capacity=capacity)
assert jnp.all(status.energy >= 0.0)
assert jnp.all(status.energy <= capacity)

0 comments on commit 1a27274

Please sign in to comment.