Skip to content

Commit

Permalink
Fix get_relative_angle ad modify test_observe
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Mar 20, 2024
1 parent a5e60cb commit 02a7d43
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
8 changes: 4 additions & 4 deletions src/emevo/environments/circle_foraging.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,12 @@ def _get_tactile(
collision_mat: jax.Array,
) -> tuple[jax.Array, jax.Array]:
nm_shape = collision_mat.shape
rel_angle = get_relative_angle(s1, s2) # [0, 2π]
rel_angle = get_relative_angle(s1, s2) # [0, 2π] (N, M)
weights = (jnp.pi * 2 / n_bins) * jnp.arange(n_bins + 1) # [0, ..., 2π]
in_range = _search_bin(rel_angle.ravel(), weights).reshape(*nm_shape, n_bins)
tactile_raw = in_range * jnp.expand_dims(collision_mat, axis=2)
tactile_raw = in_range * jnp.expand_dims(collision_mat, axis=2) # (N, M, B)
tactile = jnp.sum(tactile_raw, axis=1, keepdims=True) # (N, 1, B)
return tactile, tactile_raw
return tactile, jnp.expand_dims(tactile_raw, axis=2) # (N, M, 1, B)


def _food_tactile_with_labels(
Expand Down Expand Up @@ -804,7 +804,7 @@ def step(
seg2c.transpose(),
)
collision = jnp.concatenate(
(food_tactile > 0, ag_tactile > 0, wall_tactile > 0),
(ag_tactile > 0, food_tactile > 0, wall_tactile > 0),
axis=1,
)
# Gather sensor obs
Expand Down
9 changes: 6 additions & 3 deletions src/emevo/environments/phyjax2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,12 @@ def batch_size(self) -> int:


def get_relative_angle(s_a: State, s_b: State) -> jax.Array:
a2b_x, a2b_y = _get_xy(s_b.p.xy - s_a.p.xy)
a2b_angle = jnp.arctan2(a2b_y, a2b_x)
return (a2b_angle - s_a.p.angle + 2.0 * TWO_PI) % TWO_PI
a2b = jax.vmap(jnp.subtract, in_axes=(None, 0))(s_b.p.xy, s_a.p.xy)
a2b_x, a2b_y = _get_xy(a2b)
a2b_angle = jnp.arctan2(a2b_y, a2b_x) # (N_A, N_B)
a_angle = jnp.expand_dims(s_a.p.angle, axis=1)
# Subtract 0.5𝛑 because our angle starts from 0.5𝛑 (90 degree)
return (a2b_angle - a_angle + TWO_PI * 3 - jnp.pi * 0.5) % TWO_PI


@chex.dataclass
Expand Down
14 changes: 7 additions & 7 deletions tests/test_observe.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,13 @@ def test_encount_and_collision(key: chex.PRNGKey) -> None:
if not p2p4_ok and jnp.linalg.norm(p2 - p4) <= 2 * AGENT_RADIUS:
assert bool(ts.encount[2, 4]), (p2, p3, p4)
assert bool(ts.encount[4, 2]), (p2, p3, p4)
assert bool(ts.obs.collision[2, 0]), (p2, p3, p4)
assert bool(ts.obs.collision[4, 0]), (p2, p3, p4)
assert bool(ts.obs.collision[2, 0, -1]), (p2, p3, p4)
assert bool(ts.obs.collision[4, 0, 0]), (p2, p3, p4)
p2p4_ok = True

p3_to_food = jnp.linalg.norm(p3 - jnp.array([80.0, 90.0]))
if not p3_ok and p3_to_food <= AGENT_RADIUS + FOOD_RADIUS:
assert bool(ts.obs.collision[3, 1]), (p2, p3, p4)
assert bool(ts.obs.collision[3, 1, 0]), (p2, p3, p4)
p3_ok = True

if p2p4_ok and p3_ok:
Expand Down Expand Up @@ -281,12 +281,12 @@ def test_collision_with_foodlabels(key: chex.PRNGKey) -> None:
if jnp.any(ts.obs.collision):
assert jnp.all(to_food <= AGENT_RADIUS + FOOD_RADIUS + 0.1)
chex.assert_trees_all_close(
ts.obs.collision[:3],
ts.obs.collision[:3, :, -1],
jnp.array(
[
[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
[False, True, False, False, False],
[False, False, True, False, False],
[False, False, False, True, False],
]
),
)
Expand Down

0 comments on commit 02a7d43

Please sign in to comment.