diff --git a/src/emevo/environments/circle_foraging_with_neurotoxin.py b/src/emevo/environments/circle_foraging_with_neurotoxin.py index 0265236..d122cb9 100644 --- a/src/emevo/environments/circle_foraging_with_neurotoxin.py +++ b/src/emevo/environments/circle_foraging_with_neurotoxin.py @@ -69,14 +69,15 @@ def step( # type: ignore ) toxin_decay = jnp.expand_dims(1.0 - toxin_decay_rate, axis=1) # Add force - act = jax.vmap(self.act_space.clip)(jnp.array(action)) * toxin_decay + act = jax.vmap(self.act_space.clip)(jnp.array(action)) f1_raw = jax.lax.slice_in_dim(act, 0, 1, axis=-1) f2_raw = jax.lax.slice_in_dim(act, 1, 2, axis=-1) f1 = jnp.concatenate((jnp.zeros_like(f1_raw), f1_raw), axis=1) f2 = jnp.concatenate((jnp.zeros_like(f2_raw), f2_raw), axis=1) circle = state.physics.circle - circle = circle.apply_force_local(self._act_p1, f1) - circle = circle.apply_force_local(self._act_p2, f2) + # Decay force by toxin + circle = circle.apply_force_local(self._act_p1, f1 * toxin_decay) + circle = circle.apply_force_local(self._act_p2, f2 * toxin_decay) stated = replace(state.physics, circle=circle) # Step physics simulator stated, solver, nstep_contacts = nstep(