From 4bf9d68e4a61333e69b157d1b95189660c84e17c Mon Sep 17 00:00:00 2001 From: kngwyu Date: Thu, 21 Mar 2024 16:08:00 +0900 Subject: [PATCH] Compute reward from n_ate_food instead of observation --- experiments/cf_simple.py | 13 ++++++------- src/emevo/environments/circle_foraging.py | 10 +++++----- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/experiments/cf_simple.py b/experiments/cf_simple.py index ea4fd584..28ce6468 100644 --- a/experiments/cf_simple.py +++ b/experiments/cf_simple.py @@ -59,27 +59,26 @@ def __post_init__(self) -> None: def normalize_action(self, action: jax.Array) -> jax.Array: scaled = self.act_space.sigmoid_scale(action) - norm = jnp.sqrt(jnp.sum(scaled**2, axis=-1)) + norm = jnp.sqrt(jnp.sum(scaled**2, axis=-1, keepdims=True)) return norm / self._max_norm def extract_linear( self, - collision: jax.Array, + ate_food: jax.Array, action: jax.Array, energy: jax.Array, ) -> jax.Array: del energy act_input = self.act_coef * self.normalize_action(action) - food_collision = collision[:, 1] - return jnp.stack((food_collision, act_input), axis=1) + return jnp.concatenate((ate_food.astype(jnp.float32), act_input), axis=1) def extract_sigmoid( self, - collision: jax.Array, + ate_food: jax.Array, action: jax.Array, energy: jax.Array, ) -> tuple[jax.Array, jax.Array]: - return self.extract_linear(collision, action, energy), energy + return self.extract_linear(ate_food, action, energy), energy def serialize_weight(w: jax.Array) -> dict[str, jax.Array]: @@ -115,7 +114,7 @@ def step_rollout( ) obs_t1 = timestep.obs energy = state_t.status.energy - rewards = reward_fn(obs_t1.collision, actions, energy).reshape(-1, 1) + rewards = reward_fn(timestep.info["n_ate_food"], actions, energy).reshape(-1, 1) rollout = Rollout( observations=obs_t_array, actions=actions, diff --git a/src/emevo/environments/circle_foraging.py b/src/emevo/environments/circle_foraging.py index 40e5f77a..4dc7f0a4 100644 --- a/src/emevo/environments/circle_foraging.py +++ b/src/emevo/environments/circle_foraging.py @@ -78,7 +78,7 @@ def as_array(self) -> jax.Array: return jnp.concatenate( ( self.sensor.reshape(self.sensor.shape[0], -1), - self.collision.reshape(self.collision.shape[0], -1), + self.collision.reshape(self.collision.shape[0], -1).astype(jnp.float32), self.velocity, jnp.expand_dims(self.angle, axis=1), jnp.expand_dims(self.angular_velocity, axis=1), @@ -208,8 +208,8 @@ def _make_physics( for _ in range(n_max_foods): builder.add_circle( radius=food_radius, - friction=0.1, - elasticity=0.1, + friction=0.2, + elasticity=0.4, color=FOOD_COLOR, is_static=True, ) @@ -670,7 +670,7 @@ def place_newborn_neighbor( self.obs_space = NamedTupleSpace( CFObs, sensor=BoxSpace(low=0.0, high=1.0, shape=(n_agent_sensors, self._n_obj)), - collision=BoxSpace(low=0.0, high=1.0, shape=(self._n_obj,)), + collision=BoxSpace(low=0.0, high=1.0, shape=(self._n_obj, n_tactile_bins)), velocity=BoxSpace(low=-MAX_VELOCITY, high=MAX_VELOCITY, shape=(2,)), angle=BoxSpace(low=-2 * np.pi, high=2 * np.pi, shape=()), angular_velocity=BoxSpace(low=-np.pi / 10, high=np.pi / 10, shape=()), @@ -962,7 +962,7 @@ def reset(self, key: chex.PRNGKey) -> tuple[CFState, TimeStep[CFObs]]: sensor_obs = self._sensor_obs(stated=physics) obs = CFObs( sensor=sensor_obs.reshape(-1, self._n_sensors, self._n_obj), - collision=jnp.zeros((N, self._n_obj), dtype=bool), + collision=jnp.zeros((N, self._n_obj, self._n_tactile_bins), dtype=bool), angle=physics.circle.p.angle, velocity=physics.circle.v.xy, angular_velocity=physics.circle.v.angle,