Skip to content

Commit

Permalink
Test exponential rewards
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Jan 18, 2024
1 parent 73c9d81 commit 2712141
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 34 deletions.
48 changes: 26 additions & 22 deletions experiments/cf_asexual_evo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
SavedProfile,
)
from emevo.reward_fn import (
ExponentialReward,
LinearReward,
RewardFn,
SigmoidReward,
SigmoidReward_01,
mutate_reward_fn,
serialize_weight,
)
from emevo.rl.ppo_normal import (
NormalPPONet,
Expand All @@ -54,6 +56,7 @@

class RewardKind(str, enum.Enum):
LINEAR = "linear"
EXPONENTIAL = "exponential"
SIGMOID = "sigmoid"
SIGMOID_01 = "sigmoid-01"

Expand Down Expand Up @@ -94,30 +97,22 @@ def extract_sigmoid(
return jnp.concatenate((collision, act_input), axis=1), energy


def slice_last(w: jax.Array, i: int) -> jax.Array:
return jnp.squeeze(jax.lax.slice_in_dim(w, i, i + 1, axis=-1))


def linear_reward_serializer(w: jax.Array) -> dict[str, jax.Array]:
return {
"agent": slice_last(w, 0),
"food": slice_last(w, 1),
"wall": slice_last(w, 2),
"action": slice_last(w, 3),
}
def exp_reward_serializer(w: jax.Array, scale: jax.Array) -> dict[str, jax.Array]:
w_dict = serialize_weight(w, ["w_agent", "w_food", "w_wall", "w_action"])
scale_dict = serialize_weight(
scale,
["scale_agent", "scale_food", "scale_wall", "scale_action"],
)
return w_dict | scale_dict


def sigmoid_reward_serializer(w: jax.Array, alpha: jax.Array) -> dict[str, jax.Array]:
return {
"w_agent": slice_last(w, 0),
"w_food": slice_last(w, 1),
"w_wall": slice_last(w, 2),
"w_action": slice_last(w, 3),
"alpha_agent": slice_last(alpha, 0),
"alpha_food": slice_last(alpha, 1),
"alpha_wall": slice_last(alpha, 2),
"alpha_action": slice_last(alpha, 3),
}
w_dict = serialize_weight(w, ["w_agent", "w_food", "w_wall", "w_action"])
alpha_dict = serialize_weight(
alpha,
["alpha_agent", "alpha_food", "alpha_wall", "alpha_action"],
)
return w_dict | alpha_dict


def exec_rollout(
Expand Down Expand Up @@ -459,7 +454,16 @@ def evolve(
reward_fn_instance = LinearReward(
**common_rewardfn_args,
extractor=reward_extracor.extract_linear,
serializer=linear_reward_serializer,
serializer=lambda w: serialize_weight(
w,
["agent", "food", "wall", "action"],
),
)
elif reward_fn == RewardKind.EXPONENTIAL:
reward_fn_instance = ExponentialReward(
**common_rewardfn_args,
extractor=reward_extracor.extract_linear,
serializer=exp_reward_serializer,
)
elif reward_fn == RewardKind.SIGMOID:
reward_fn_instance = SigmoidReward(
Expand Down
17 changes: 9 additions & 8 deletions src/emevo/exp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,16 +325,17 @@ def save_agents(
modelpath = self.logdir.joinpath(f"trained-{uid}.eqx")
eqx.tree_serialise_leaves(modelpath, sliced_net)

def save_profile_and_rewards(self) -> None:
profile_and_rewards = [
v.serialise() | dataclasses.asdict(self.profile_dict[k])
for k, v in self.reward_fn_dict.items()
]
table = pa.Table.from_pylist(profile_and_rewards)
pq.write_table(table, self.logdir.joinpath("profile_and_rewards.parquet"))

def finalize(self) -> None:
if self.mode != LogMode.NONE:
profile_and_rewards = [
v.serialise() | dataclasses.asdict(self.profile_dict[k])
for k, v in self.reward_fn_dict.items()
]
pq.write_table(
pa.Table.from_pylist(profile_and_rewards),
self.logdir.joinpath("profile_and_rewards.parquet"),
)
self.save_profile_and_rewards()

if self.mode in [LogMode.FULL, LogMode.REWARD_AND_LOG]:
self._save_log()
Expand Down
17 changes: 13 additions & 4 deletions src/emevo/reward_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def _item_or_np(array: jax.Array) -> float | NDArray:
return np.array(array)


def slice_last(w: jax.Array, i: int) -> jax.Array:
return jnp.squeeze(jax.lax.slice_in_dim(w, i, i + 1, axis=-1))


def serialize_weight(w: jax.Array, keys: list[str]) -> dict[str, jax.Array]:
return {key: slice_last(w, i) for i, key in enumerate(keys)}


class LinearReward(RewardFn):
weight: jax.Array
extractor: Callable[..., jax.Array]
Expand Down Expand Up @@ -63,10 +71,11 @@ def __call__(self, *args) -> jax.Array:
def serialise(self) -> dict[str, float | NDArray]:
return jax.tree_map(_item_or_np, self.serializer(self.weight))


class ExponentialReward(RewardFn):
weight: jax.Array
alpha: jax.Array
extractor: Callable[..., tuple[jax.Array, jax.Array]]
scale: jax.Array
extractor: Callable[..., jax.Array]
serializer: Callable[[jax.Array, jax.Array], dict[str, jax.Array]]

def __init__(
Expand All @@ -75,7 +84,7 @@ def __init__(
key: chex.PRNGKey,
n_agents: int,
n_weights: int,
extractor: Callable[..., tuple[jax.Array, jax.Array]],
extractor: Callable[..., jax.Array],
serializer: Callable[[jax.Array, jax.Array], dict[str, jax.Array]],
std: float = 1.0,
mean: float = 0.0,
Expand All @@ -88,7 +97,7 @@ def __init__(

def __call__(self, *args) -> jax.Array:
extracted = self.extractor(*args)
weight = (10 ** self.scale) * self.weight
weight = (10**self.scale) * self.weight
return jax.vmap(jnp.dot)(extracted, weight)

def serialise(self) -> dict[str, float | NDArray]:
Expand Down

0 comments on commit 2712141

Please sign in to comment.