Skip to content

Commit

Permalink
Delayed sigmoid sinh
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Feb 29, 2024
1 parent 0930851 commit e994450
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 4 deletions.
29 changes: 29 additions & 0 deletions experiments/cf_asexual_evo.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class RewardKind(str, enum.Enum):
LINEAR = "linear"
EXPONENTIAL = "exponential"
OFFSET_DELAYED_SE = "offset-delayed-se"
OFFSET_DELAYED_SINH = "offset-delayed-sinh"
SIGMOID = "sigmoid"
SIGMOID_01 = "sigmoid-01"
SIGMOID_EXP = "sigmoid-exp"
Expand Down Expand Up @@ -122,6 +123,15 @@ def sigmoid_rs(w: jax.Array, alpha: jax.Array) -> dict[str, jax.Array]:
return w_dict | alpha_dict


def delayed_sigmoid_rs(w: jax.Array, delay: jax.Array) -> dict[str, jax.Array]:
w_dict = rfn.serialize_weight(w, ["w_agent", "w_food", "w_wall", "w_action"])
delay_dict = rfn.serialize_weight(
delay,
["delay_agent", "delay_food", "delay_wall", "delay_action"],
)
return w_dict | delay_dict


def sigmoid_exp_rs(
w: jax.Array,
scale: jax.Array,
Expand Down Expand Up @@ -183,6 +193,17 @@ def sigmoid_rs_withp(w: jax.Array, alpha: jax.Array) -> dict[str, jax.Array]:
return w_dict | alpha_dict


def delayed_sigmoid_rs_withp(w: jax.Array, delay: jax.Array) -> dict[str, jax.Array]:
w_dict = rfn.serialize_weight(
w, ["w_agent", "w_food", "w_poison", "w_wall", "w_action"]
)
delay_dict = rfn.serialize_weight(
delay,
["delay_agent", "delay_food", "w_poison", "delay_wall", "delay_action"],
)
return w_dict | delay_dict


def sigmoid_exp_rs_withp(
w: jax.Array, scale: jax.Array, alpha: jax.Array
) -> dict[str, jax.Array]:
Expand Down Expand Up @@ -610,6 +631,14 @@ def evolve(
extractor=reward_extracor.extract_linear,
serializer=linear_rs_withp if poison_reward else linear_rs,
)
elif reward_fn == RewardKind.OFFSET_DELAYED_SINH:
reward_fn_instance = rfn.OffsetDelayedSinhReward(
**common_rewardfn_args,
extractor=reward_extracor.extract_sigmoid,
serializer=delayed_sigmoid_rs_withp
if poison_reward
else delayed_sigmoid_rs,
)
else:
raise ValueError(f"Invalid reward_fn {reward_fn}")

Expand Down
52 changes: 48 additions & 4 deletions src/emevo/reward_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ def __init__(
std: float = 1.0,
mean: float = 0.0,
) -> None:
k1, k2 = jax.random.split(key)
k1, k2, k3 = jax.random.split(key, 3)
self.weight = jax.random.normal(k1, (n_agents, n_weights)) * std + mean
self.scale = jax.random.normal(k2, (n_agents, n_weights)) * std + mean
self.alpha = jax.random.normal(k2, (n_agents, n_weights)) * std + mean
self.alpha = jax.random.normal(k3, (n_agents, n_weights)) * std + mean
self.extractor = extractor
self.serializer = serializer

Expand Down Expand Up @@ -242,10 +242,10 @@ def __init__(
mean: float = 0.0,
delay_scale: float = 20.0,
) -> None:
k1, k2 = jax.random.split(key)
k1, k2, k3 = jax.random.split(key, 3)
self.weight = jax.random.normal(k1, (n_agents, n_weights)) * std + mean
self.scale = jax.random.normal(k2, (n_agents, n_weights)) * std + mean
self.delay = jax.random.normal(k2, (n_agents, n_weights)) * std + mean
self.delay = jax.random.normal(k3, (n_agents, n_weights)) * std + mean
self.extractor = extractor
self.serializer = serializer
self.delay_scale = delay_scale
Expand Down Expand Up @@ -279,6 +279,50 @@ def __call__(self, *args) -> jax.Array:
return jax.vmap(jnp.dot)(filtered, weight)


class OffsetDelayedSinhReward(RewardFn):
weight: jax.Array
delay: jax.Array
extractor: Callable[..., tuple[jax.Array, jax.Array]]
serializer: Callable[[jax.Array, jax.Array], dict[str, jax.Array]]
delay_scale: float
scale: float

def __init__(
self,
*, # order of arguments are a bit confusing here...
key: chex.PRNGKey,
n_agents: int,
n_weights: int,
extractor: Callable[..., tuple[jax.Array, jax.Array]],
serializer: Callable[[jax.Array, jax.Array], dict[str, jax.Array]],
std: float = 1.0,
mean: float = 0.0,
scale: float = 2.5,
delay_scale: float = 20.0,
) -> None:
k1, k2 = jax.random.split(key)
self.weight = jax.random.normal(k1, (n_agents, n_weights)) * std + mean
self.delay = jax.random.normal(k2, (n_agents, n_weights)) * std + mean
self.extractor = extractor
self.serializer = serializer
self.scale = scale
self.delay_scale = delay_scale

def __call__(self, *args) -> jax.Array:
extracted = self.extractor(*args)
extracted, energy = self.extractor(*args)
weight = jnp.sinh(self.weight * self.scale)
e = energy.reshape(-1, 1) # (N, n_weights)
exp_pos = jnp.exp(-e + self.delay_scale * self.delay)
exp_neg = jnp.exp(e - self.delay_scale * (1.0 + self.delay) - self.delay_scale)
exp = jnp.where(self.delay > 0, exp_pos, exp_neg)
filtered = extracted / (1.0 + exp)
return jax.vmap(jnp.dot)(filtered, weight)

def serialise(self) -> dict[str, float | NDArray]:
return jax.tree_map(_item_or_np, self.serializer(self.weight, self.delay))


def mutate_reward_fn(
key: chex.PRNGKey,
reward_fn_dict: dict[int, RF],
Expand Down

0 comments on commit e994450

Please sign in to comment.