Skip to content

Commit

Permalink
Compute reward from n_ate_food instead of observation
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Mar 21, 2024
1 parent 02a7d43 commit 4bf9d68
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
13 changes: 6 additions & 7 deletions experiments/cf_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,27 +59,26 @@ def __post_init__(self) -> None:

def normalize_action(self, action: jax.Array) -> jax.Array:
scaled = self.act_space.sigmoid_scale(action)
norm = jnp.sqrt(jnp.sum(scaled**2, axis=-1))
norm = jnp.sqrt(jnp.sum(scaled**2, axis=-1, keepdims=True))
return norm / self._max_norm

def extract_linear(
self,
collision: jax.Array,
ate_food: jax.Array,
action: jax.Array,
energy: jax.Array,
) -> jax.Array:
del energy
act_input = self.act_coef * self.normalize_action(action)
food_collision = collision[:, 1]
return jnp.stack((food_collision, act_input), axis=1)
return jnp.concatenate((ate_food.astype(jnp.float32), act_input), axis=1)

def extract_sigmoid(
self,
collision: jax.Array,
ate_food: jax.Array,
action: jax.Array,
energy: jax.Array,
) -> tuple[jax.Array, jax.Array]:
return self.extract_linear(collision, action, energy), energy
return self.extract_linear(ate_food, action, energy), energy


def serialize_weight(w: jax.Array) -> dict[str, jax.Array]:
Expand Down Expand Up @@ -115,7 +114,7 @@ def step_rollout(
)
obs_t1 = timestep.obs
energy = state_t.status.energy
rewards = reward_fn(obs_t1.collision, actions, energy).reshape(-1, 1)
rewards = reward_fn(timestep.info["n_ate_food"], actions, energy).reshape(-1, 1)
rollout = Rollout(
observations=obs_t_array,
actions=actions,
Expand Down
10 changes: 5 additions & 5 deletions src/emevo/environments/circle_foraging.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def as_array(self) -> jax.Array:
return jnp.concatenate(
(
self.sensor.reshape(self.sensor.shape[0], -1),
self.collision.reshape(self.collision.shape[0], -1),
self.collision.reshape(self.collision.shape[0], -1).astype(jnp.float32),
self.velocity,
jnp.expand_dims(self.angle, axis=1),
jnp.expand_dims(self.angular_velocity, axis=1),
Expand Down Expand Up @@ -208,8 +208,8 @@ def _make_physics(
for _ in range(n_max_foods):
builder.add_circle(
radius=food_radius,
friction=0.1,
elasticity=0.1,
friction=0.2,
elasticity=0.4,
color=FOOD_COLOR,
is_static=True,
)
Expand Down Expand Up @@ -670,7 +670,7 @@ def place_newborn_neighbor(
self.obs_space = NamedTupleSpace(
CFObs,
sensor=BoxSpace(low=0.0, high=1.0, shape=(n_agent_sensors, self._n_obj)),
collision=BoxSpace(low=0.0, high=1.0, shape=(self._n_obj,)),
collision=BoxSpace(low=0.0, high=1.0, shape=(self._n_obj, n_tactile_bins)),
velocity=BoxSpace(low=-MAX_VELOCITY, high=MAX_VELOCITY, shape=(2,)),
angle=BoxSpace(low=-2 * np.pi, high=2 * np.pi, shape=()),
angular_velocity=BoxSpace(low=-np.pi / 10, high=np.pi / 10, shape=()),
Expand Down Expand Up @@ -962,7 +962,7 @@ def reset(self, key: chex.PRNGKey) -> tuple[CFState, TimeStep[CFObs]]:
sensor_obs = self._sensor_obs(stated=physics)
obs = CFObs(
sensor=sensor_obs.reshape(-1, self._n_sensors, self._n_obj),
collision=jnp.zeros((N, self._n_obj), dtype=bool),
collision=jnp.zeros((N, self._n_obj, self._n_tactile_bins), dtype=bool),
angle=physics.circle.p.angle,
velocity=physics.circle.v.xy,
angular_velocity=physics.circle.v.angle,
Expand Down

0 comments on commit 4bf9d68

Please sign in to comment.