diff --git a/experiments/cf_asexual_evo.py b/experiments/cf_asexual_evo.py index 091336de..f913d8dd 100644 --- a/experiments/cf_asexual_evo.py +++ b/experiments/cf_asexual_evo.py @@ -33,11 +33,13 @@ SavedProfile, ) from emevo.reward_fn import ( + ExponentialReward, LinearReward, RewardFn, SigmoidReward, SigmoidReward_01, mutate_reward_fn, + serialize_weight, ) from emevo.rl.ppo_normal import ( NormalPPONet, @@ -54,6 +56,7 @@ class RewardKind(str, enum.Enum): LINEAR = "linear" + EXPONENTIAL = "exponential" SIGMOID = "sigmoid" SIGMOID_01 = "sigmoid-01" @@ -94,30 +97,22 @@ def extract_sigmoid( return jnp.concatenate((collision, act_input), axis=1), energy -def slice_last(w: jax.Array, i: int) -> jax.Array: - return jnp.squeeze(jax.lax.slice_in_dim(w, i, i + 1, axis=-1)) - - -def linear_reward_serializer(w: jax.Array) -> dict[str, jax.Array]: - return { - "agent": slice_last(w, 0), - "food": slice_last(w, 1), - "wall": slice_last(w, 2), - "action": slice_last(w, 3), - } +def exp_reward_serializer(w: jax.Array, scale: jax.Array) -> dict[str, jax.Array]: + w_dict = serialize_weight(w, ["w_agent", "w_food", "w_wall", "w_action"]) + scale_dict = serialize_weight( + scale, + ["scale_agent", "scale_food", "scale_wall", "scale_action"], + ) + return w_dict | scale_dict def sigmoid_reward_serializer(w: jax.Array, alpha: jax.Array) -> dict[str, jax.Array]: - return { - "w_agent": slice_last(w, 0), - "w_food": slice_last(w, 1), - "w_wall": slice_last(w, 2), - "w_action": slice_last(w, 3), - "alpha_agent": slice_last(alpha, 0), - "alpha_food": slice_last(alpha, 1), - "alpha_wall": slice_last(alpha, 2), - "alpha_action": slice_last(alpha, 3), - } + w_dict = serialize_weight(w, ["w_agent", "w_food", "w_wall", "w_action"]) + alpha_dict = serialize_weight( + alpha, + ["alpha_agent", "alpha_food", "alpha_wall", "alpha_action"], + ) + return w_dict | alpha_dict def exec_rollout( @@ -459,7 +454,16 @@ def evolve( reward_fn_instance = LinearReward( **common_rewardfn_args, extractor=reward_extracor.extract_linear, - serializer=linear_reward_serializer, + serializer=lambda w: serialize_weight( + w, + ["agent", "food", "wall", "action"], + ), + ) + elif reward_fn == RewardKind.EXPONENTIAL: + reward_fn_instance = ExponentialReward( + **common_rewardfn_args, + extractor=reward_extracor.extract_linear, + serializer=exp_reward_serializer, ) elif reward_fn == RewardKind.SIGMOID: reward_fn_instance = SigmoidReward( diff --git a/src/emevo/exp_utils.py b/src/emevo/exp_utils.py index 43e99367..dd8d483c 100644 --- a/src/emevo/exp_utils.py +++ b/src/emevo/exp_utils.py @@ -325,16 +325,17 @@ def save_agents( modelpath = self.logdir.joinpath(f"trained-{uid}.eqx") eqx.tree_serialise_leaves(modelpath, sliced_net) + def save_profile_and_rewards(self) -> None: + profile_and_rewards = [ + v.serialise() | dataclasses.asdict(self.profile_dict[k]) + for k, v in self.reward_fn_dict.items() + ] + table = pa.Table.from_pylist(profile_and_rewards) + pq.write_table(table, self.logdir.joinpath("profile_and_rewards.parquet")) + def finalize(self) -> None: if self.mode != LogMode.NONE: - profile_and_rewards = [ - v.serialise() | dataclasses.asdict(self.profile_dict[k]) - for k, v in self.reward_fn_dict.items() - ] - pq.write_table( - pa.Table.from_pylist(profile_and_rewards), - self.logdir.joinpath("profile_and_rewards.parquet"), - ) + self.save_profile_and_rewards() if self.mode in [LogMode.FULL, LogMode.REWARD_AND_LOG]: self._save_log() diff --git a/src/emevo/reward_fn.py b/src/emevo/reward_fn.py index 4316ddb2..f70979da 100644 --- a/src/emevo/reward_fn.py +++ b/src/emevo/reward_fn.py @@ -36,6 +36,14 @@ def _item_or_np(array: jax.Array) -> float | NDArray: return np.array(array) +def slice_last(w: jax.Array, i: int) -> jax.Array: + return jnp.squeeze(jax.lax.slice_in_dim(w, i, i + 1, axis=-1)) + + +def serialize_weight(w: jax.Array, keys: list[str]) -> dict[str, jax.Array]: + return {key: slice_last(w, i) for i, key in enumerate(keys)} + + class LinearReward(RewardFn): weight: jax.Array extractor: Callable[..., jax.Array] @@ -63,10 +71,11 @@ def __call__(self, *args) -> jax.Array: def serialise(self) -> dict[str, float | NDArray]: return jax.tree_map(_item_or_np, self.serializer(self.weight)) + class ExponentialReward(RewardFn): weight: jax.Array - alpha: jax.Array - extractor: Callable[..., tuple[jax.Array, jax.Array]] + scale: jax.Array + extractor: Callable[..., jax.Array] serializer: Callable[[jax.Array, jax.Array], dict[str, jax.Array]] def __init__( @@ -75,7 +84,7 @@ def __init__( key: chex.PRNGKey, n_agents: int, n_weights: int, - extractor: Callable[..., tuple[jax.Array, jax.Array]], + extractor: Callable[..., jax.Array], serializer: Callable[[jax.Array, jax.Array], dict[str, jax.Array]], std: float = 1.0, mean: float = 0.0, @@ -88,7 +97,7 @@ def __init__( def __call__(self, *args) -> jax.Array: extracted = self.extractor(*args) - weight = (10 ** self.scale) * self.weight + weight = (10**self.scale) * self.weight return jax.vmap(jnp.dot)(extracted, weight) def serialise(self) -> dict[str, float | NDArray]: