Skip to content

Commit

Permalink
Merge pull request #40 from Co-Evolve/brittle-star-continuous-contact
Browse files Browse the repository at this point in the history
[brittle-star] continuous contact and specifiable contact sensing granularity
  • Loading branch information
driesmarzougui authored Dec 3, 2024
2 parents 3c96ce9 + ccd1ff3 commit 6f73d7b
Show file tree
Hide file tree
Showing 14 changed files with 105 additions and 84 deletions.
2 changes: 1 addition & 1 deletion biorobot/brittle_star/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ brittle star environment returns as observations (further discussed below).
- Central disk's velocity (w.r.t. world frame, in m/s)
- Central disk's angular velocity (w.r.t. world frame, in radians/s)
- Exteroception
- Touch (per segment, boolean)
- Contact (X continuous values per segment, in Newtons) (X is defined in the morphology specification)

In terms of actuation, the following actuators are implemented (two per segment, one for the in-plane DoF and one for
the out-of-plane DoF). The brittle star's morphology specification defines which if either position-based or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _get_mj_models_and_datas_to_render(
mj_models, mj_datas = super()._get_mj_models_and_datas_to_render(state=state)
if self.environment_configuration.color_contacts:
self._color_segment_capsule_contacts(
mj_models=mj_models, contact_bools=state.observations["segment_contact"]
mj_models=mj_models, contacts=state.observations["segment_contact"]
)
return mj_models, mj_datas

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _get_mj_models_and_datas_to_render(
mj_models, mj_datas = super()._get_mj_models_and_datas_to_render(state=state)
if self.environment_configuration.color_contacts:
self._color_segment_capsule_contacts(
mj_models=mj_models, contact_bools=state.observations["segment_contact"]
mj_models=mj_models, contacts=state.observations["segment_contact"]
)
return mj_models, mj_datas

Expand Down
2 changes: 1 addition & 1 deletion biorobot/brittle_star/environment/light_escape/mjc_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _get_mj_models_and_datas_to_render(
self._update_mj_models_tex_data(mj_models=mj_models, state=state)
if self.environment_configuration.color_contacts:
self._color_segment_capsule_contacts(
mj_models=mj_models, contact_bools=state.observations["segment_contact"]
mj_models=mj_models, contacts=state.observations["segment_contact"]
)
return mj_models, mj_datas

Expand Down
2 changes: 1 addition & 1 deletion biorobot/brittle_star/environment/light_escape/mjx_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _get_mj_models_and_datas_to_render(
self._update_mj_models_tex_data(mj_models=mj_models, state=state)
if self.environment_configuration.color_contacts:
self._color_segment_capsule_contacts(
mj_models=mj_models, contact_bools=state.observations["segment_contact"]
mj_models=mj_models, contacts=state.observations["segment_contact"]
)
return mj_models, mj_datas

Expand Down
16 changes: 9 additions & 7 deletions biorobot/brittle_star/environment/shared/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import chex
import jax.numpy as jnp
import mujoco
import numpy as np
from moojoco.environment.base import MuJoCoEnvironmentConfiguration

from biorobot.utils import colors
Expand Down Expand Up @@ -75,17 +76,18 @@ def _get_segment_capsule_geom_ids(self, mj_model: mujoco.MjModel) -> jnp.ndarray
return self._segment_capsule_geom_ids

def _color_segment_capsule_contacts(
self, mj_models: List[mujoco.MjModel], contact_bools: chex.Array
self, mj_models: List[mujoco.MjModel], contacts: chex.Array
) -> None:
for i, mj_model in enumerate(mj_models):
if len(contact_bools.shape) > 1:
contacts = contact_bools[i]
if len(contacts.shape) > 1:
c = contacts[i]
else:
contacts = contact_bools
c = contacts

for capsule_geom_id, contact in zip(
self._segment_capsule_geom_ids, contacts
):
c = c.reshape((len(self._segment_capsule_geom_ids), -1)) > 0
c = np.any(c, axis=-1).flatten()

for capsule_geom_id, contact in zip(self._segment_capsule_geom_ids, c):
if contact:
mj_model.geom(capsule_geom_id).rgba = colors.rgba_red
else:
Expand Down
86 changes: 17 additions & 69 deletions biorobot/brittle_star/environment/shared/observables.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from itertools import count
from typing import List, Callable, Tuple
from typing import List, Callable

import chex
import jax
import jax.numpy as jnp
import mujoco
import numpy as np
from jax.scipy.spatial.transform import Rotation
from moojoco.environment.base import BaseObservable, BaseEnvState
from moojoco.environment.mjc_env import MJCObservable, MJCEnvState
from moojoco.environment.mjx_env import MJXObservable, MJXEnvState
from moojoco.environment.base import BaseObservable
from moojoco.environment.mjc_env import MJCObservable
from moojoco.environment.mjx_env import MJXObservable
from transforms3d.euler import quat2euler


Expand All @@ -25,63 +23,6 @@ def jquat2euler(quat):
return quat2euler


def get_num_contacts_and_segment_contacts_fn(
mj_model: mujoco.MjModel, backend: str
) -> Tuple[int, Callable[[BaseEnvState], chex.Array]]:
if backend == "mjx":
segment_capsule_geom_ids = np.array(
[
geom_id
for geom_id in range(mj_model.ngeom)
if "segment" in mj_model.geom(geom_id).name
and "capsule" in mj_model.geom(geom_id).name
]
)

def get_segment_contacts(state: MJXEnvState) -> jnp.ndarray:
contact_data = state.mjx_data.contact
contacts = contact_data.dist <= 0

def solve_contact(geom_id: int) -> jnp.ndarray:
return (
jnp.sum(contacts * jnp.any(geom_id == contact_data.geom, axis=-1))
> 0
).astype(int)

return jax.vmap(solve_contact)(segment_capsule_geom_ids)

num_contacts = len(segment_capsule_geom_ids)
else:
# segment touch values
# Start by mapping geom indices of segment capsules to a contact output index
indexer = count(0)
segment_capsule_geom_id_to_contact_idx = {}
for geom_id in range(mj_model.ngeom):
geom_name = mj_model.geom(geom_id).name
if "segment" in geom_name and "capsule" in geom_name:
segment_capsule_geom_id_to_contact_idx[geom_id] = next(indexer)

def get_segment_contacts(state: MJCEnvState) -> np.ndarray:
contacts = np.zeros(len(segment_capsule_geom_id_to_contact_idx), dtype=int)
# based on https://gist.github.com/WuXinyang2012/b6649817101dfcb061eff901e9942057
for contact_id in range(state.mj_data.ncon):
contact = state.mj_data.contact[contact_id]
if contact.dist < 0:
if contact.geom1 in segment_capsule_geom_id_to_contact_idx:
contacts[
segment_capsule_geom_id_to_contact_idx[contact.geom1]
] = 1
if contact.geom2 in segment_capsule_geom_id_to_contact_idx:
contacts[
segment_capsule_geom_id_to_contact_idx[contact.geom2]
] = 1

return contacts

num_contacts = len(segment_capsule_geom_id_to_contact_idx)
return num_contacts, get_segment_contacts


def get_base_brittle_star_observables(
mj_model: mujoco.MjModel, backend: str
) -> List[BaseObservable]:
Expand Down Expand Up @@ -287,14 +228,21 @@ def get_base_brittle_star_observables(
)

# contacts
num_contacts, get_segment_contacts_fn = get_num_contacts_and_segment_contacts_fn(
mj_model=mj_model, backend=backend
)
contact_sensors = [
sensor for sensor in sensors if sensor.type[0] == mujoco.mjtSensor.mjSENS_TOUCH
]
segment_contact_observable = observable_class(
name="segment_contact",
low=np.zeros(num_contacts),
high=np.ones(num_contacts),
retriever=get_segment_contacts_fn,
low=bnp.zeros(len(contact_sensors)),
high=bnp.inf * bnp.ones(len(contact_sensors)),
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
]
for sensor in contact_sensors
]
).flatten(),
)

return [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _get_mj_models_and_datas_to_render(
mj_models, mj_datas = super()._get_mj_models_and_datas_to_render(state=state)
if self.environment_configuration.color_contacts:
self._color_segment_capsule_contacts(
mj_models=mj_models, contact_bools=state.observations["segment_contact"]
mj_models=mj_models, contacts=state.observations["segment_contact"]
)
return mj_models, mj_datas

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _get_mj_models_and_datas_to_render(
mj_models, mj_datas = super()._get_mj_models_and_datas_to_render(state=state)
if self.environment_configuration.color_contacts:
self._color_segment_capsule_contacts(
mj_models=mj_models, contact_bools=state.observations["segment_contact"]
mj_models=mj_models, contacts=state.observations["segment_contact"]
)
return mj_models, mj_datas

Expand Down
1 change: 1 addition & 0 deletions biorobot/brittle_star/mjcf/morphology/morphology.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,6 @@ def _configure_camera(self) -> None:
use_torque_control=True,
use_tendons=True,
radius_to_strength_factor=200,
num_contact_sensors_per_segment=8,
)
MJCFBrittleStarMorphology(spec).export_to_xml_with_assets("./mjcf")
54 changes: 53 additions & 1 deletion biorobot/brittle_star/mjcf/morphology/parts/arm_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
BrittleStarMorphologySpecification,
)
from biorobot.utils import colors
from biorobot.utils.colors import rgba_red, rgba_tendon_contracted, rgba_tendon_relaxed
from biorobot.utils.colors import rgba_red, rgba_tendon_relaxed


class MJCFBrittleStarArmSegment(MJCFMorphologyPart):
Expand Down Expand Up @@ -308,8 +308,60 @@ def _configure_tendon_sensors(self) -> None:
"tendonvel", name=f"{tendon.name}_tendonvel_sensor", tendon=tendon
)

def _configure_contact_sensors(self) -> None:
num_contact_sensors = (
self.morphology_specification.sensor_specification.num_contact_sensors_per_segment.value
)

contact_sites = []
if num_contact_sensors == 1:
contact_sites.append(
self.mjcf_body.add(
"site",
type="capsule",
pos=self._capsule.pos,
size=self._capsule.size * np.array([1.01, 1]),
rgba=rgba_red * np.array([1, 1, 1, 0.5]),
euler=self._capsule.euler,
name=f"{self.base_name}_contact_site",
group=3,
),
)
else:
angles = np.linspace(-np.pi / 2, 1.5 * np.pi, num_contact_sensors + 1)[
:num_contact_sensors
]
radius = self._segment_specification.radius.value

for i, angle in enumerate(angles):
pos = self.center_of_capsule + 0.95 * radius * np.array(
[0, np.cos(angle), np.sin(angle)]
)
contact_sites.append(
self.mjcf_body.add(
"site",
pos=pos,
euler=[angle, 0, 0],
type="box",
size=[
self._segment_specification.length.value / 2 + radius / 2,
0.05 * radius,
0.4 * radius,
],
rgba=rgba_red * np.array([1, 1, 1, 0.5]),
group=3,
name=f"{self.base_name}_contact_site_{i}",
)
)

for site in contact_sites:
self.mjcf_model.sensor.add(
"touch", name=f"{site.name}_contact_sensor", site=site
)

def _configure_sensors(self) -> None:
self._configure_position_sensor()
self._configure_joints_sensors()
self._configure_actuator_sensors()
self._configure_tendon_sensors()
self._configure_contact_sensors()
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
BrittleStarDiskSpecification,
BrittleStarJointSpecification,
BrittleStarMorphologySpecification,
BrittleStarSensorSpecification,
)

START_SEGMENT_RADIUS = 0.025
Expand Down Expand Up @@ -78,6 +79,7 @@ def default_brittle_star_morphology_specification(
use_p_control: bool = False,
use_torque_control: bool = False,
radius_to_strength_factor: float = 200,
num_contact_sensors_per_segment: int = 1,
) -> BrittleStarMorphologySpecification:
disk_specification = BrittleStarDiskSpecification(
diameter=DISK_DIAMETER, height=DISK_HEIGHT
Expand All @@ -102,10 +104,15 @@ def default_brittle_star_morphology_specification(
use_torque_control=use_torque_control,
radius_to_strength_factor=radius_to_strength_factor,
)
sensor_specification = BrittleStarSensorSpecification(
num_contact_sensors_per_segment=num_contact_sensors_per_segment
)

specification = BrittleStarMorphologySpecification(
disk_specification=disk_specification,
arm_specifications=arm_specifications,
actuation_specification=actuation_specification,
sensor_specification=sensor_specification,
)

return specification
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,27 @@ def __init__(
self.radius_to_strength_factor = FixedParameter(radius_to_strength_factor)


class BrittleStarSensorSpecification(Specification):
def __init__(self, num_contact_sensors_per_segment: int) -> int:
super().__init__()
self.num_contact_sensors_per_segment = FixedParameter(
num_contact_sensors_per_segment
)


class BrittleStarMorphologySpecification(MorphologySpecification):
def __init__(
self,
disk_specification: BrittleStarDiskSpecification,
arm_specifications: List[BrittleStarArmSpecification],
actuation_specification: BrittleStarActuationSpecification,
sensor_specification: BrittleStarSensorSpecification,
) -> None:
super(BrittleStarMorphologySpecification, self).__init__()
self.disk_specification = disk_specification
self.arm_specifications = arm_specifications
self.actuation_specification = actuation_specification
self.sensor_specification = sensor_specification

@property
def number_of_arms(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def create_env(
use_p_control=False,
use_torque_control=True,
use_tendons=False,
num_contact_sensors_per_segment=1,
)
morphology = MJCFBrittleStarMorphology(morphology_spec)
arena_config = AquariumArenaConfiguration(attach_target=True)
Expand Down

0 comments on commit 6f73d7b

Please sign in to comment.