Skip to content

Commit

Permalink
[brittle-star] continuous contact observations, and contact granulari…
Browse files Browse the repository at this point in the history
…ty in specification
  • Loading branch information
driesmarzougui committed Dec 3, 2024
1 parent 3c96ce9 commit ca5eb1d
Show file tree
Hide file tree
Showing 14 changed files with 166 additions and 158 deletions.
2 changes: 1 addition & 1 deletion biorobot/brittle_star/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ brittle star environment returns as observations (further discussed below).
- Central disk's velocity (w.r.t. world frame, in m/s)
- Central disk's angular velocity (w.r.t. world frame, in radians/s)
- Exteroception
- Touch (per segment, boolean)
- Contact (X continuous values per segment, in Newtons) (X is defined in the morphology specification)

In terms of actuation, the following actuators are implemented (two per segment, one for the in-plane DoF and one for
the out-of-plane DoF). The brittle star's morphology specification defines which if either position-based or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _get_mj_models_and_datas_to_render(
mj_models, mj_datas = super()._get_mj_models_and_datas_to_render(state=state)
if self.environment_configuration.color_contacts:
self._color_segment_capsule_contacts(
mj_models=mj_models, contact_bools=state.observations["segment_contact"]
mj_models=mj_models, contacts=state.observations["segment_contact"]
)
return mj_models, mj_datas

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _get_mj_models_and_datas_to_render(
mj_models, mj_datas = super()._get_mj_models_and_datas_to_render(state=state)
if self.environment_configuration.color_contacts:
self._color_segment_capsule_contacts(
mj_models=mj_models, contact_bools=state.observations["segment_contact"]
mj_models=mj_models, contacts=state.observations["segment_contact"]
)
return mj_models, mj_datas

Expand Down
2 changes: 1 addition & 1 deletion biorobot/brittle_star/environment/light_escape/mjc_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _get_mj_models_and_datas_to_render(
self._update_mj_models_tex_data(mj_models=mj_models, state=state)
if self.environment_configuration.color_contacts:
self._color_segment_capsule_contacts(
mj_models=mj_models, contact_bools=state.observations["segment_contact"]
mj_models=mj_models, contacts=state.observations["segment_contact"]
)
return mj_models, mj_datas

Expand Down
2 changes: 1 addition & 1 deletion biorobot/brittle_star/environment/light_escape/mjx_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _get_mj_models_and_datas_to_render(
self._update_mj_models_tex_data(mj_models=mj_models, state=state)
if self.environment_configuration.color_contacts:
self._color_segment_capsule_contacts(
mj_models=mj_models, contact_bools=state.observations["segment_contact"]
mj_models=mj_models, contacts=state.observations["segment_contact"]
)
return mj_models, mj_datas

Expand Down
14 changes: 9 additions & 5 deletions biorobot/brittle_star/environment/shared/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import chex
import jax.numpy as jnp
import mujoco
import numpy as np
from moojoco.environment.base import MuJoCoEnvironmentConfiguration

from biorobot.utils import colors
Expand Down Expand Up @@ -75,16 +76,19 @@ def _get_segment_capsule_geom_ids(self, mj_model: mujoco.MjModel) -> jnp.ndarray
return self._segment_capsule_geom_ids

def _color_segment_capsule_contacts(
self, mj_models: List[mujoco.MjModel], contact_bools: chex.Array
self, mj_models: List[mujoco.MjModel], contacts: chex.Array
) -> None:
for i, mj_model in enumerate(mj_models):
if len(contact_bools.shape) > 1:
contacts = contact_bools[i]
if len(contacts.shape) > 1:
c = contacts[i]
else:
contacts = contact_bools
c = contacts

c = c.reshape((len(self._segment_capsule_geom_ids), -1)) > 0
c = np.any(c, axis=-1).flatten()

for capsule_geom_id, contact in zip(
self._segment_capsule_geom_ids, contacts
self._segment_capsule_geom_ids, c
):
if contact:
mj_model.geom(capsule_geom_id).rgba = colors.rgba_red
Expand Down
132 changes: 41 additions & 91 deletions biorobot/brittle_star/environment/shared/observables.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from itertools import count
from typing import List, Callable, Tuple
from typing import List, Callable

import chex
import jax
import jax.numpy as jnp
import mujoco
import numpy as np
from jax.scipy.spatial.transform import Rotation
from moojoco.environment.base import BaseObservable, BaseEnvState
from moojoco.environment.mjc_env import MJCObservable, MJCEnvState
from moojoco.environment.mjx_env import MJXObservable, MJXEnvState
from moojoco.environment.base import BaseObservable
from moojoco.environment.mjc_env import MJCObservable
from moojoco.environment.mjx_env import MJXObservable
from transforms3d.euler import quat2euler


Expand All @@ -25,65 +23,8 @@ def jquat2euler(quat):
return quat2euler


def get_num_contacts_and_segment_contacts_fn(
mj_model: mujoco.MjModel, backend: str
) -> Tuple[int, Callable[[BaseEnvState], chex.Array]]:
if backend == "mjx":
segment_capsule_geom_ids = np.array(
[
geom_id
for geom_id in range(mj_model.ngeom)
if "segment" in mj_model.geom(geom_id).name
and "capsule" in mj_model.geom(geom_id).name
]
)

def get_segment_contacts(state: MJXEnvState) -> jnp.ndarray:
contact_data = state.mjx_data.contact
contacts = contact_data.dist <= 0

def solve_contact(geom_id: int) -> jnp.ndarray:
return (
jnp.sum(contacts * jnp.any(geom_id == contact_data.geom, axis=-1))
> 0
).astype(int)

return jax.vmap(solve_contact)(segment_capsule_geom_ids)

num_contacts = len(segment_capsule_geom_ids)
else:
# segment touch values
# Start by mapping geom indices of segment capsules to a contact output index
indexer = count(0)
segment_capsule_geom_id_to_contact_idx = {}
for geom_id in range(mj_model.ngeom):
geom_name = mj_model.geom(geom_id).name
if "segment" in geom_name and "capsule" in geom_name:
segment_capsule_geom_id_to_contact_idx[geom_id] = next(indexer)

def get_segment_contacts(state: MJCEnvState) -> np.ndarray:
contacts = np.zeros(len(segment_capsule_geom_id_to_contact_idx), dtype=int)
# based on https://gist.github.com/WuXinyang2012/b6649817101dfcb061eff901e9942057
for contact_id in range(state.mj_data.ncon):
contact = state.mj_data.contact[contact_id]
if contact.dist < 0:
if contact.geom1 in segment_capsule_geom_id_to_contact_idx:
contacts[
segment_capsule_geom_id_to_contact_idx[contact.geom1]
] = 1
if contact.geom2 in segment_capsule_geom_id_to_contact_idx:
contacts[
segment_capsule_geom_id_to_contact_idx[contact.geom2]
] = 1

return contacts

num_contacts = len(segment_capsule_geom_id_to_contact_idx)
return num_contacts, get_segment_contacts


def get_base_brittle_star_observables(
mj_model: mujoco.MjModel, backend: str
mj_model: mujoco.MjModel, backend: str
) -> List[BaseObservable]:
if backend == "mjx":
observable_class = MJXObservable
Expand Down Expand Up @@ -111,7 +52,7 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
]
for sensor in joint_pos_sensors
]
Expand All @@ -131,7 +72,7 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
]
for sensor in joint_vel_sensors
]
Expand All @@ -151,7 +92,7 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
]
for sensor in joint_actuator_frc_sensors
]
Expand All @@ -171,7 +112,7 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
]
for sensor in actuator_frc_sensors
]
Expand All @@ -183,32 +124,32 @@ def get_base_brittle_star_observables(
mj_model.sensor(i)
for i in range(mj_model.nsensor)
if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEPOS
and "disk" in mj_model.sensor(i).name
and "disk" in mj_model.sensor(i).name
][0]
disk_position_observable = observable_class(
name="disk_position",
low=-bnp.inf * bnp.ones(3),
high=bnp.inf * bnp.ones(3),
retriever=lambda state: get_data(state).sensordata[
disk_framepos_sensor.adr[0] : disk_framepos_sensor.adr[0]
+ disk_framepos_sensor.dim[0]
],
disk_framepos_sensor.adr[0]: disk_framepos_sensor.adr[0]
+ disk_framepos_sensor.dim[0]
],
)
# disk rotation
disk_framequat_sensor = [
mj_model.sensor(i)
for i in range(mj_model.nsensor)
if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEQUAT
and "disk" in mj_model.sensor(i).name
and "disk" in mj_model.sensor(i).name
][0]
disk_rotation_observable = observable_class(
name="disk_rotation",
low=-bnp.pi * bnp.ones(3),
high=bnp.pi * bnp.ones(3),
retriever=lambda state: get_quat2euler_fn(backend=backend)(
get_data(state).sensordata[
disk_framequat_sensor.adr[0] : disk_framequat_sensor.adr[0]
+ disk_framequat_sensor.dim[0]
disk_framequat_sensor.adr[0]: disk_framequat_sensor.adr[0]
+ disk_framequat_sensor.dim[0]
]
),
)
Expand All @@ -218,33 +159,33 @@ def get_base_brittle_star_observables(
mj_model.sensor(i)
for i in range(mj_model.nsensor)
if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMELINVEL
and "disk" in mj_model.sensor(i).name
and "disk" in mj_model.sensor(i).name
][0]
disk_linvel_observable = observable_class(
name="disk_linear_velocity",
low=-bnp.inf * bnp.ones(3),
high=bnp.inf * bnp.ones(3),
retriever=lambda state: get_data(state).sensordata[
disk_framelinvel_sensor.adr[0] : disk_framelinvel_sensor.adr[0]
+ disk_framelinvel_sensor.dim[0]
],
disk_framelinvel_sensor.adr[0]: disk_framelinvel_sensor.adr[0]
+ disk_framelinvel_sensor.dim[0]
],
)

# disk angvel
disk_frameangvel_sensor = [
mj_model.sensor(i)
for i in range(mj_model.nsensor)
if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEANGVEL
and "disk" in mj_model.sensor(i).name
and "disk" in mj_model.sensor(i).name
][0]
disk_angvel_observable = observable_class(
name="disk_angular_velocity",
low=-bnp.inf * bnp.ones(3),
high=bnp.inf * bnp.ones(3),
retriever=lambda state: get_data(state).sensordata[
disk_frameangvel_sensor.adr[0] : disk_frameangvel_sensor.adr[0]
+ disk_frameangvel_sensor.dim[0]
],
disk_frameangvel_sensor.adr[0]: disk_frameangvel_sensor.adr[0]
+ disk_frameangvel_sensor.dim[0]
],
)

# tendons
Expand All @@ -260,7 +201,7 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
]
for sensor in tendon_pos_sensors
]
Expand All @@ -279,22 +220,31 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
]
for sensor in tendon_vel_sensors
]
).flatten(),
)

# contacts
num_contacts, get_segment_contacts_fn = get_num_contacts_and_segment_contacts_fn(
mj_model=mj_model, backend=backend
)
contact_sensors = [
sensor
for sensor in sensors
if sensor.type[0] == mujoco.mjtSensor.mjSENS_TOUCH
]
segment_contact_observable = observable_class(
name="segment_contact",
low=np.zeros(num_contacts),
high=np.ones(num_contacts),
retriever=get_segment_contacts_fn,
low=bnp.zeros(len(contact_sensors)),
high=bnp.inf * bnp.ones(len(contact_sensors)),
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
]
for sensor in contact_sensors
]
).flatten(),
)

return [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _get_mj_models_and_datas_to_render(
mj_models, mj_datas = super()._get_mj_models_and_datas_to_render(state=state)
if self.environment_configuration.color_contacts:
self._color_segment_capsule_contacts(
mj_models=mj_models, contact_bools=state.observations["segment_contact"]
mj_models=mj_models, contacts=state.observations["segment_contact"]
)
return mj_models, mj_datas

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _get_mj_models_and_datas_to_render(
mj_models, mj_datas = super()._get_mj_models_and_datas_to_render(state=state)
if self.environment_configuration.color_contacts:
self._color_segment_capsule_contacts(
mj_models=mj_models, contact_bools=state.observations["segment_contact"]
mj_models=mj_models, contacts=state.observations["segment_contact"]
)
return mj_models, mj_datas

Expand Down
1 change: 1 addition & 0 deletions biorobot/brittle_star/mjcf/morphology/morphology.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,6 @@ def _configure_camera(self) -> None:
use_torque_control=True,
use_tendons=True,
radius_to_strength_factor=200,
num_contact_sensors_per_segment=8
)
MJCFBrittleStarMorphology(spec).export_to_xml_with_assets("./mjcf")
Loading

0 comments on commit ca5eb1d

Please sign in to comment.