Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/jax' into jax
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Feb 28, 2024
2 parents 0cd19f2 + 6eb05a5 commit 0930851
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 19 deletions.
2 changes: 1 addition & 1 deletion config/env/20240226-6s-lvp.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ n_max_foods = 60
food_num_fn = ["logistic", 30, 0.01, 60]
food_loc_fn = [
"scheduled",
102400,
1024000,
["gaussian", [360.0, 240.0], [80.0, 60.0]],
["switching",
1000,
Expand Down
35 changes: 35 additions & 0 deletions config/env/20240227-ls-season.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
n_initial_agents = 50
n_max_agents = 150
n_max_foods = 60
food_num_fn = ["logistic", 30, 0.01, 60]
food_loc_fn = ["switching",
1000,
["gaussian", [360.0, 270.0], [48.0, 36.0]],
["gaussian", [240.0, 270.0], [48.0, 36.0]],
["gaussian", [120.0, 270.0], [48.0, 36.0]],
["gaussian", [240.0, 270.0], [48.0, 36.0]],
]
agent_loc_fn = "uniform"
xlim = [0.0, 480.0]
ylim = [0.0, 360.0]
env_shape = "square"
neighbor_stddev = 100.0
n_agent_sensors = 24
sensor_length = 200.0
sensor_range = "wide"
agent_radius = 10.0
food_radius = 4.0
dt = 0.1
linear_damping = 0.8
angular_damping = 0.6
max_force = 80.0
min_force = -20.0
init_energy = 40.0
energy_capacity = 400.0
force_energy_consumption = 2e-5
basic_energy_consumption = 2e-4
energy_share_ratio = 0.4
n_velocity_iter = 6
n_position_iter = 2
n_physics_iter = 5
max_place_attempts = 10
7 changes: 7 additions & 0 deletions experiments/cf_asexual_evo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class RewardKind(str, enum.Enum):
DELAYED_SE = "delayed-se"
LINEAR = "linear"
EXPONENTIAL = "exponential"
OFFSET_DELAYED_SE = "offset-delayed-se"
SIGMOID = "sigmoid"
SIGMOID_01 = "sigmoid-01"
SIGMOID_EXP = "sigmoid-exp"
Expand Down Expand Up @@ -597,6 +598,12 @@ def evolve(
extractor=reward_extracor.extract_sigmoid,
serializer=delayed_se_rs_withp if poison_reward else delayed_se_rs,
)
elif reward_fn == RewardKind.OFFSET_DELAYED_SE:
reward_fn_instance = rfn.OffsetDelayedSEReward(
**common_rewardfn_args,
extractor=reward_extracor.extract_sigmoid,
serializer=delayed_se_rs_withp if poison_reward else delayed_se_rs,
)
elif reward_fn == RewardKind.SINH:
reward_fn_instance = rfn.SinhReward(
**common_rewardfn_args,
Expand Down
42 changes: 24 additions & 18 deletions notebooks/reward_fn.ipynb

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions src/emevo/reward_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,18 @@ def serialise(self) -> dict[str, float | NDArray]:
)


class OffsetDelayedSEReward(DelayedSEReward):
def __call__(self, *args) -> jax.Array:
extracted, energy = self.extractor(*args)
weight = (10**self.scale) * self.weight
e = energy.reshape(-1, 1) # (N, n_weights)
exp_pos = jnp.exp(-e + self.delay_scale * self.delay)
exp_neg = jnp.exp(e - self.delay_scale * (1.0 + self.delay) - self.delay_scale)
exp = jnp.where(self.delay > 0, exp_pos, exp_neg)
filtered = extracted / (1.0 + exp)
return jax.vmap(jnp.dot)(filtered, weight)


def mutate_reward_fn(
key: chex.PRNGKey,
reward_fn_dict: dict[int, RF],
Expand Down

0 comments on commit 0930851

Please sign in to comment.