From ca5eb1dd315d58ce51820017b4209d25d31a48bf Mon Sep 17 00:00:00 2001 From: dmarzoug Date: Tue, 3 Dec 2024 15:55:01 +0100 Subject: [PATCH] [brittle-star] continuous contact observations, and contact granularity in specification --- biorobot/brittle_star/README.md | 2 +- .../directed_locomotion/mjc_env.py | 2 +- .../directed_locomotion/mjx_env.py | 2 +- .../environment/light_escape/mjc_env.py | 2 +- .../environment/light_escape/mjx_env.py | 2 +- .../brittle_star/environment/shared/base.py | 14 +- .../environment/shared/observables.py | 132 ++++++------------ .../undirected_locomotion/mjc_env.py | 2 +- .../undirected_locomotion/mjx_env.py | 2 +- .../mjcf/morphology/morphology.py | 1 + .../mjcf/morphology/parts/arm_segment.py | 82 ++++++++--- .../mjcf/morphology/specification/default.py | 22 +-- .../morphology/specification/specification.py | 58 ++++---- .../directed_locomotion_single.py | 1 + 14 files changed, 166 insertions(+), 158 deletions(-) diff --git a/biorobot/brittle_star/README.md b/biorobot/brittle_star/README.md index 10eaffa..b649c40 100644 --- a/biorobot/brittle_star/README.md +++ b/biorobot/brittle_star/README.md @@ -186,7 +186,7 @@ brittle star environment returns as observations (further discussed below). - Central disk's velocity (w.r.t. world frame, in m/s) - Central disk's angular velocity (w.r.t. world frame, in radians/s) - Exteroception - - Touch (per segment, boolean) + - Contact (X continuous values per segment, in Newtons) (X is defined in the morphology specification) In terms of actuation, the following actuators are implemented (two per segment, one for the in-plane DoF and one for the out-of-plane DoF). The brittle star's morphology specification defines which if either position-based or diff --git a/biorobot/brittle_star/environment/directed_locomotion/mjc_env.py b/biorobot/brittle_star/environment/directed_locomotion/mjc_env.py index c9bcd1f..a32ffdc 100644 --- a/biorobot/brittle_star/environment/directed_locomotion/mjc_env.py +++ b/biorobot/brittle_star/environment/directed_locomotion/mjc_env.py @@ -119,7 +119,7 @@ def _get_mj_models_and_datas_to_render( mj_models, mj_datas = super()._get_mj_models_and_datas_to_render(state=state) if self.environment_configuration.color_contacts: self._color_segment_capsule_contacts( - mj_models=mj_models, contact_bools=state.observations["segment_contact"] + mj_models=mj_models, contacts=state.observations["segment_contact"] ) return mj_models, mj_datas diff --git a/biorobot/brittle_star/environment/directed_locomotion/mjx_env.py b/biorobot/brittle_star/environment/directed_locomotion/mjx_env.py index 5b8747d..ffceb48 100644 --- a/biorobot/brittle_star/environment/directed_locomotion/mjx_env.py +++ b/biorobot/brittle_star/environment/directed_locomotion/mjx_env.py @@ -121,7 +121,7 @@ def _get_mj_models_and_datas_to_render( mj_models, mj_datas = super()._get_mj_models_and_datas_to_render(state=state) if self.environment_configuration.color_contacts: self._color_segment_capsule_contacts( - mj_models=mj_models, contact_bools=state.observations["segment_contact"] + mj_models=mj_models, contacts=state.observations["segment_contact"] ) return mj_models, mj_datas diff --git a/biorobot/brittle_star/environment/light_escape/mjc_env.py b/biorobot/brittle_star/environment/light_escape/mjc_env.py index c40dbdd..c465f82 100644 --- a/biorobot/brittle_star/environment/light_escape/mjc_env.py +++ b/biorobot/brittle_star/environment/light_escape/mjc_env.py @@ -165,7 +165,7 @@ def _get_mj_models_and_datas_to_render( self._update_mj_models_tex_data(mj_models=mj_models, state=state) if self.environment_configuration.color_contacts: self._color_segment_capsule_contacts( - mj_models=mj_models, contact_bools=state.observations["segment_contact"] + mj_models=mj_models, contacts=state.observations["segment_contact"] ) return mj_models, mj_datas diff --git a/biorobot/brittle_star/environment/light_escape/mjx_env.py b/biorobot/brittle_star/environment/light_escape/mjx_env.py index 489f933..e00f834 100644 --- a/biorobot/brittle_star/environment/light_escape/mjx_env.py +++ b/biorobot/brittle_star/environment/light_escape/mjx_env.py @@ -159,7 +159,7 @@ def _get_mj_models_and_datas_to_render( self._update_mj_models_tex_data(mj_models=mj_models, state=state) if self.environment_configuration.color_contacts: self._color_segment_capsule_contacts( - mj_models=mj_models, contact_bools=state.observations["segment_contact"] + mj_models=mj_models, contacts=state.observations["segment_contact"] ) return mj_models, mj_datas diff --git a/biorobot/brittle_star/environment/shared/base.py b/biorobot/brittle_star/environment/shared/base.py index 8578995..44148d4 100644 --- a/biorobot/brittle_star/environment/shared/base.py +++ b/biorobot/brittle_star/environment/shared/base.py @@ -3,6 +3,7 @@ import chex import jax.numpy as jnp import mujoco +import numpy as np from moojoco.environment.base import MuJoCoEnvironmentConfiguration from biorobot.utils import colors @@ -75,16 +76,19 @@ def _get_segment_capsule_geom_ids(self, mj_model: mujoco.MjModel) -> jnp.ndarray return self._segment_capsule_geom_ids def _color_segment_capsule_contacts( - self, mj_models: List[mujoco.MjModel], contact_bools: chex.Array + self, mj_models: List[mujoco.MjModel], contacts: chex.Array ) -> None: for i, mj_model in enumerate(mj_models): - if len(contact_bools.shape) > 1: - contacts = contact_bools[i] + if len(contacts.shape) > 1: + c = contacts[i] else: - contacts = contact_bools + c = contacts + + c = c.reshape((len(self._segment_capsule_geom_ids), -1)) > 0 + c = np.any(c, axis=-1).flatten() for capsule_geom_id, contact in zip( - self._segment_capsule_geom_ids, contacts + self._segment_capsule_geom_ids, c ): if contact: mj_model.geom(capsule_geom_id).rgba = colors.rgba_red diff --git a/biorobot/brittle_star/environment/shared/observables.py b/biorobot/brittle_star/environment/shared/observables.py index 94bcf05..ab5ed26 100644 --- a/biorobot/brittle_star/environment/shared/observables.py +++ b/biorobot/brittle_star/environment/shared/observables.py @@ -1,15 +1,13 @@ -from itertools import count -from typing import List, Callable, Tuple +from typing import List, Callable import chex -import jax import jax.numpy as jnp import mujoco import numpy as np from jax.scipy.spatial.transform import Rotation -from moojoco.environment.base import BaseObservable, BaseEnvState -from moojoco.environment.mjc_env import MJCObservable, MJCEnvState -from moojoco.environment.mjx_env import MJXObservable, MJXEnvState +from moojoco.environment.base import BaseObservable +from moojoco.environment.mjc_env import MJCObservable +from moojoco.environment.mjx_env import MJXObservable from transforms3d.euler import quat2euler @@ -25,65 +23,8 @@ def jquat2euler(quat): return quat2euler -def get_num_contacts_and_segment_contacts_fn( - mj_model: mujoco.MjModel, backend: str -) -> Tuple[int, Callable[[BaseEnvState], chex.Array]]: - if backend == "mjx": - segment_capsule_geom_ids = np.array( - [ - geom_id - for geom_id in range(mj_model.ngeom) - if "segment" in mj_model.geom(geom_id).name - and "capsule" in mj_model.geom(geom_id).name - ] - ) - - def get_segment_contacts(state: MJXEnvState) -> jnp.ndarray: - contact_data = state.mjx_data.contact - contacts = contact_data.dist <= 0 - - def solve_contact(geom_id: int) -> jnp.ndarray: - return ( - jnp.sum(contacts * jnp.any(geom_id == contact_data.geom, axis=-1)) - > 0 - ).astype(int) - - return jax.vmap(solve_contact)(segment_capsule_geom_ids) - - num_contacts = len(segment_capsule_geom_ids) - else: - # segment touch values - # Start by mapping geom indices of segment capsules to a contact output index - indexer = count(0) - segment_capsule_geom_id_to_contact_idx = {} - for geom_id in range(mj_model.ngeom): - geom_name = mj_model.geom(geom_id).name - if "segment" in geom_name and "capsule" in geom_name: - segment_capsule_geom_id_to_contact_idx[geom_id] = next(indexer) - - def get_segment_contacts(state: MJCEnvState) -> np.ndarray: - contacts = np.zeros(len(segment_capsule_geom_id_to_contact_idx), dtype=int) - # based on https://gist.github.com/WuXinyang2012/b6649817101dfcb061eff901e9942057 - for contact_id in range(state.mj_data.ncon): - contact = state.mj_data.contact[contact_id] - if contact.dist < 0: - if contact.geom1 in segment_capsule_geom_id_to_contact_idx: - contacts[ - segment_capsule_geom_id_to_contact_idx[contact.geom1] - ] = 1 - if contact.geom2 in segment_capsule_geom_id_to_contact_idx: - contacts[ - segment_capsule_geom_id_to_contact_idx[contact.geom2] - ] = 1 - - return contacts - - num_contacts = len(segment_capsule_geom_id_to_contact_idx) - return num_contacts, get_segment_contacts - - def get_base_brittle_star_observables( - mj_model: mujoco.MjModel, backend: str + mj_model: mujoco.MjModel, backend: str ) -> List[BaseObservable]: if backend == "mjx": observable_class = MJXObservable @@ -111,7 +52,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0] : sensor.adr[0] + sensor.dim[0] + sensor.adr[0]: sensor.adr[0] + sensor.dim[0] ] for sensor in joint_pos_sensors ] @@ -131,7 +72,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0] : sensor.adr[0] + sensor.dim[0] + sensor.adr[0]: sensor.adr[0] + sensor.dim[0] ] for sensor in joint_vel_sensors ] @@ -151,7 +92,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0] : sensor.adr[0] + sensor.dim[0] + sensor.adr[0]: sensor.adr[0] + sensor.dim[0] ] for sensor in joint_actuator_frc_sensors ] @@ -171,7 +112,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0] : sensor.adr[0] + sensor.dim[0] + sensor.adr[0]: sensor.adr[0] + sensor.dim[0] ] for sensor in actuator_frc_sensors ] @@ -183,23 +124,23 @@ def get_base_brittle_star_observables( mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEPOS - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_position_observable = observable_class( name="disk_position", low=-bnp.inf * bnp.ones(3), high=bnp.inf * bnp.ones(3), retriever=lambda state: get_data(state).sensordata[ - disk_framepos_sensor.adr[0] : disk_framepos_sensor.adr[0] - + disk_framepos_sensor.dim[0] - ], + disk_framepos_sensor.adr[0]: disk_framepos_sensor.adr[0] + + disk_framepos_sensor.dim[0] + ], ) # disk rotation disk_framequat_sensor = [ mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEQUAT - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_rotation_observable = observable_class( name="disk_rotation", @@ -207,8 +148,8 @@ def get_base_brittle_star_observables( high=bnp.pi * bnp.ones(3), retriever=lambda state: get_quat2euler_fn(backend=backend)( get_data(state).sensordata[ - disk_framequat_sensor.adr[0] : disk_framequat_sensor.adr[0] - + disk_framequat_sensor.dim[0] + disk_framequat_sensor.adr[0]: disk_framequat_sensor.adr[0] + + disk_framequat_sensor.dim[0] ] ), ) @@ -218,16 +159,16 @@ def get_base_brittle_star_observables( mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMELINVEL - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_linvel_observable = observable_class( name="disk_linear_velocity", low=-bnp.inf * bnp.ones(3), high=bnp.inf * bnp.ones(3), retriever=lambda state: get_data(state).sensordata[ - disk_framelinvel_sensor.adr[0] : disk_framelinvel_sensor.adr[0] - + disk_framelinvel_sensor.dim[0] - ], + disk_framelinvel_sensor.adr[0]: disk_framelinvel_sensor.adr[0] + + disk_framelinvel_sensor.dim[0] + ], ) # disk angvel @@ -235,16 +176,16 @@ def get_base_brittle_star_observables( mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEANGVEL - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_angvel_observable = observable_class( name="disk_angular_velocity", low=-bnp.inf * bnp.ones(3), high=bnp.inf * bnp.ones(3), retriever=lambda state: get_data(state).sensordata[ - disk_frameangvel_sensor.adr[0] : disk_frameangvel_sensor.adr[0] - + disk_frameangvel_sensor.dim[0] - ], + disk_frameangvel_sensor.adr[0]: disk_frameangvel_sensor.adr[0] + + disk_frameangvel_sensor.dim[0] + ], ) # tendons @@ -260,7 +201,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0] : sensor.adr[0] + sensor.dim[0] + sensor.adr[0]: sensor.adr[0] + sensor.dim[0] ] for sensor in tendon_pos_sensors ] @@ -279,7 +220,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0] : sensor.adr[0] + sensor.dim[0] + sensor.adr[0]: sensor.adr[0] + sensor.dim[0] ] for sensor in tendon_vel_sensors ] @@ -287,14 +228,23 @@ def get_base_brittle_star_observables( ) # contacts - num_contacts, get_segment_contacts_fn = get_num_contacts_and_segment_contacts_fn( - mj_model=mj_model, backend=backend - ) + contact_sensors = [ + sensor + for sensor in sensors + if sensor.type[0] == mujoco.mjtSensor.mjSENS_TOUCH + ] segment_contact_observable = observable_class( name="segment_contact", - low=np.zeros(num_contacts), - high=np.ones(num_contacts), - retriever=get_segment_contacts_fn, + low=bnp.zeros(len(contact_sensors)), + high=bnp.inf * bnp.ones(len(contact_sensors)), + retriever=lambda state: bnp.array( + [ + get_data(state).sensordata[ + sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + ] + for sensor in contact_sensors + ] + ).flatten(), ) return [ diff --git a/biorobot/brittle_star/environment/undirected_locomotion/mjc_env.py b/biorobot/brittle_star/environment/undirected_locomotion/mjc_env.py index cf4dd87..74aba57 100644 --- a/biorobot/brittle_star/environment/undirected_locomotion/mjc_env.py +++ b/biorobot/brittle_star/environment/undirected_locomotion/mjc_env.py @@ -76,7 +76,7 @@ def _get_mj_models_and_datas_to_render( mj_models, mj_datas = super()._get_mj_models_and_datas_to_render(state=state) if self.environment_configuration.color_contacts: self._color_segment_capsule_contacts( - mj_models=mj_models, contact_bools=state.observations["segment_contact"] + mj_models=mj_models, contacts=state.observations["segment_contact"] ) return mj_models, mj_datas diff --git a/biorobot/brittle_star/environment/undirected_locomotion/mjx_env.py b/biorobot/brittle_star/environment/undirected_locomotion/mjx_env.py index 8cecccc..4dac0fe 100644 --- a/biorobot/brittle_star/environment/undirected_locomotion/mjx_env.py +++ b/biorobot/brittle_star/environment/undirected_locomotion/mjx_env.py @@ -80,7 +80,7 @@ def _get_mj_models_and_datas_to_render( mj_models, mj_datas = super()._get_mj_models_and_datas_to_render(state=state) if self.environment_configuration.color_contacts: self._color_segment_capsule_contacts( - mj_models=mj_models, contact_bools=state.observations["segment_contact"] + mj_models=mj_models, contacts=state.observations["segment_contact"] ) return mj_models, mj_datas diff --git a/biorobot/brittle_star/mjcf/morphology/morphology.py b/biorobot/brittle_star/mjcf/morphology/morphology.py index 3566fac..e5570f3 100644 --- a/biorobot/brittle_star/mjcf/morphology/morphology.py +++ b/biorobot/brittle_star/mjcf/morphology/morphology.py @@ -79,5 +79,6 @@ def _configure_camera(self) -> None: use_torque_control=True, use_tendons=True, radius_to_strength_factor=200, + num_contact_sensors_per_segment=8 ) MJCFBrittleStarMorphology(spec).export_to_xml_with_assets("./mjcf") diff --git a/biorobot/brittle_star/mjcf/morphology/parts/arm_segment.py b/biorobot/brittle_star/mjcf/morphology/parts/arm_segment.py index 751f0f5..2804c7e 100644 --- a/biorobot/brittle_star/mjcf/morphology/parts/arm_segment.py +++ b/biorobot/brittle_star/mjcf/morphology/parts/arm_segment.py @@ -12,18 +12,18 @@ BrittleStarMorphologySpecification, ) from biorobot.utils import colors -from biorobot.utils.colors import rgba_red, rgba_tendon_contracted, rgba_tendon_relaxed +from biorobot.utils.colors import rgba_red, rgba_tendon_relaxed class MJCFBrittleStarArmSegment(MJCFMorphologyPart): def __init__( - self, - parent: Union[MJCFMorphology, MJCFMorphologyPart], - name: str, - pos: np.array, - euler: np.array, - *args, - **kwargs, + self, + parent: Union[MJCFMorphology, MJCFMorphologyPart], + name: str, + pos: np.array, + euler: np.array, + *args, + **kwargs, ) -> None: super().__init__(parent, name, pos, euler, *args, **kwargs) @@ -88,10 +88,10 @@ def center_of_capsule(self) -> np.ndarray: return np.array([x_offset, 0, 0]) def _configure_joint( - self, - name: str, - axis: np.ndarray, - joint_specification: BrittleStarJointSpecification, + self, + name: str, + axis: np.ndarray, + joint_specification: BrittleStarJointSpecification, ) -> _ElementImpl: joint = self.mjcf_body.add( "joint", @@ -127,9 +127,9 @@ def _configure_tendon_attachment_points(self) -> None: for i, angle in enumerate(angles): # proximal pos = ( - 0.8 - * self._segment_specification.radius.value - * np.array([0, np.cos(angle), np.sin(angle)]) + 0.8 + * self._segment_specification.radius.value + * np.array([0, np.cos(angle), np.sin(angle)]) ) pos[0] = self._segment_specification.radius.value self._proximal_taps.append( @@ -145,8 +145,8 @@ def _configure_tendon_attachment_points(self) -> None: # distal pos[0] = ( - self._segment_specification.radius.value - + self._segment_specification.length.value + self._segment_specification.radius.value + + self._segment_specification.length.value ) self.distal_taps.append( self.mjcf_body.add( @@ -168,7 +168,7 @@ def _build_tendons(self) -> None: self._tendons = [] for tendon_index, (parent_tap, segment_tap) in enumerate( - zip(distal_taps, self._proximal_taps) + zip(distal_taps, self._proximal_taps) ): tendon = self.mjcf_model.tendon.add( "spatial", @@ -195,8 +195,8 @@ def _is_last_segment(self) -> bool: @property def _actuator_strength(self) -> float: strength = ( - self._segment_specification.radius.value - * self.morphology_specification.actuation_specification.radius_to_strength_factor.value + self._segment_specification.radius.value + * self.morphology_specification.actuation_specification.radius_to_strength_factor.value ) return strength @@ -228,7 +228,7 @@ def _configure_p_control_actuators(self) -> None: ] def _configure_torque_control_actuator( - self, transmission: _ElementImpl + self, transmission: _ElementImpl ) -> _ElementImpl: actuator_attributes = { "name": f"{transmission.name}_torque_control", @@ -259,7 +259,7 @@ def _configure_torque_control_actuator( def _configure_torque_control_actuators(self) -> None: if ( - self.morphology_specification.actuation_specification.use_torque_control.value + self.morphology_specification.actuation_specification.use_torque_control.value ): self._actuators = [ self._configure_torque_control_actuator(transmission) @@ -308,8 +308,46 @@ def _configure_tendon_sensors(self) -> None: "tendonvel", name=f"{tendon.name}_tendonvel_sensor", tendon=tendon ) + def _configure_contact_sensors(self) -> None: + num_contact_sensors = self.morphology_specification.sensor_specification.num_contact_sensors_per_segment.value + + contact_sites = [] + if num_contact_sensors == 1: + contact_sites.append( + self.mjcf_body.add("site", + type="capsule", + pos=self._capsule.pos, + size=self._capsule.size * np.array([1.01, 1]), + rgba=rgba_red * np.array([1, 1, 1, 0.5]), + euler=self._capsule.euler, + name=f"{self.base_name}_contact_site", + group=3), + ) + else: + angles = np.linspace(-np.pi / 2, 1.5 * np.pi, num_contact_sensors + 1)[:num_contact_sensors] + radius = self._segment_specification.radius.value + + for i, angle in enumerate(angles): + pos = self.center_of_capsule + 0.95 * radius * np.array([0, np.cos(angle), np.sin(angle)]) + contact_sites.append(self.mjcf_body.add( + "site", + pos=pos, + euler=[angle, 0, 0], + type="box", + size=[self._segment_specification.length.value / 2 + radius / 2, 0.05 * radius, 0.4 * radius], + rgba=rgba_red * np.array([1, 1, 1, 0.5]), + group=3, + name=f"{self.base_name}_contact_site_{i}" + )) + + for site in contact_sites: + self.mjcf_model.sensor.add( + "touch", name=f"{site.name}_contact_sensor", site=site + ) + def _configure_sensors(self) -> None: self._configure_position_sensor() self._configure_joints_sensors() self._configure_actuator_sensors() self._configure_tendon_sensors() + self._configure_contact_sensors() diff --git a/biorobot/brittle_star/mjcf/morphology/specification/default.py b/biorobot/brittle_star/mjcf/morphology/specification/default.py index fa5b158..63980e2 100644 --- a/biorobot/brittle_star/mjcf/morphology/specification/default.py +++ b/biorobot/brittle_star/mjcf/morphology/specification/default.py @@ -8,7 +8,7 @@ BrittleStarArmSpecification, BrittleStarDiskSpecification, BrittleStarJointSpecification, - BrittleStarMorphologySpecification, + BrittleStarMorphologySpecification, BrittleStarSensorSpecification, ) START_SEGMENT_RADIUS = 0.025 @@ -32,7 +32,7 @@ def default_joint_specification(range: float) -> BrittleStarJointSpecification: def default_arm_segment_specification( - alpha: float, + alpha: float, ) -> BrittleStarArmSegmentSpecification: in_plane_joint_specification = default_joint_specification( range=30 / 180 * np.pi @@ -72,12 +72,13 @@ def default_arm_specification(num_segments_per_arm: int) -> BrittleStarArmSpecif def default_brittle_star_morphology_specification( - num_arms: int = 5, - num_segments_per_arm: Union[int, List[int]] = 5, - use_tendons: bool = False, - use_p_control: bool = False, - use_torque_control: bool = False, - radius_to_strength_factor: float = 200, + num_arms: int = 5, + num_segments_per_arm: Union[int, List[int]] = 5, + use_tendons: bool = False, + use_p_control: bool = False, + use_torque_control: bool = False, + radius_to_strength_factor: float = 200, + num_contact_sensors_per_segment: int = 1 ) -> BrittleStarMorphologySpecification: disk_specification = BrittleStarDiskSpecification( diameter=DISK_DIAMETER, height=DISK_HEIGHT @@ -102,10 +103,15 @@ def default_brittle_star_morphology_specification( use_torque_control=use_torque_control, radius_to_strength_factor=radius_to_strength_factor, ) + sensor_specification = BrittleStarSensorSpecification( + num_contact_sensors_per_segment=num_contact_sensors_per_segment + ) + specification = BrittleStarMorphologySpecification( disk_specification=disk_specification, arm_specifications=arm_specifications, actuation_specification=actuation_specification, + sensor_specification=sensor_specification ) return specification diff --git a/biorobot/brittle_star/mjcf/morphology/specification/specification.py b/biorobot/brittle_star/mjcf/morphology/specification/specification.py index 7f5f64b..241fa99 100644 --- a/biorobot/brittle_star/mjcf/morphology/specification/specification.py +++ b/biorobot/brittle_star/mjcf/morphology/specification/specification.py @@ -7,11 +7,11 @@ class BrittleStarJointSpecification(Specification): def __init__( - self, - range: float, - stiffness: float, - damping: float, - armature: float, + self, + range: float, + stiffness: float, + damping: float, + armature: float, ) -> None: super().__init__() self.stiffness = FixedParameter(value=stiffness) @@ -22,11 +22,11 @@ def __init__( class BrittleStarArmSegmentSpecification(Specification): def __init__( - self, - radius: float, - length: float, - in_plane_joint_specification: BrittleStarJointSpecification, - out_of_plane_joint_specification: BrittleStarJointSpecification, + self, + radius: float, + length: float, + in_plane_joint_specification: BrittleStarJointSpecification, + out_of_plane_joint_specification: BrittleStarJointSpecification, ) -> None: super().__init__() self.radius = FixedParameter(radius) @@ -37,7 +37,7 @@ def __init__( class BrittleStarArmSpecification(Specification): def __init__( - self, segment_specifications: List[BrittleStarArmSegmentSpecification] + self, segment_specifications: List[BrittleStarArmSegmentSpecification] ) -> None: super().__init__() self.segment_specifications = segment_specifications @@ -49,9 +49,9 @@ def number_of_segments(self) -> int: class BrittleStarDiskSpecification(Specification): def __init__( - self, - diameter: float, - height: float, + self, + diameter: float, + height: float, ) -> None: super().__init__() self.radius = FixedParameter(diameter / 2) @@ -60,18 +60,18 @@ def __init__( class BrittleStarActuationSpecification(Specification): def __init__( - self, - use_tendons: bool, - use_p_control: bool, - use_torque_control: bool, - radius_to_strength_factor: float, + self, + use_tendons: bool, + use_p_control: bool, + use_torque_control: bool, + radius_to_strength_factor: float, ) -> None: super().__init__() assert ( - use_p_control + use_torque_control == 1 + use_p_control + use_torque_control == 1 ), "Only one actuation method can be used." assert ( - not use_tendons or use_torque_control + not use_tendons or use_torque_control ), "Only torque control is supported with tendons." self.use_tendons = FixedParameter(use_tendons) self.use_p_control = FixedParameter(use_p_control) @@ -79,17 +79,25 @@ def __init__( self.radius_to_strength_factor = FixedParameter(radius_to_strength_factor) +class BrittleStarSensorSpecification(Specification): + def __init__(self, num_contact_sensors_per_segment: int) -> int: + super().__init__() + self.num_contact_sensors_per_segment = FixedParameter(num_contact_sensors_per_segment) + + class BrittleStarMorphologySpecification(MorphologySpecification): def __init__( - self, - disk_specification: BrittleStarDiskSpecification, - arm_specifications: List[BrittleStarArmSpecification], - actuation_specification: BrittleStarActuationSpecification, + self, + disk_specification: BrittleStarDiskSpecification, + arm_specifications: List[BrittleStarArmSpecification], + actuation_specification: BrittleStarActuationSpecification, + sensor_specification: BrittleStarSensorSpecification ) -> None: super(BrittleStarMorphologySpecification, self).__init__() self.disk_specification = disk_specification self.arm_specifications = arm_specifications self.actuation_specification = actuation_specification + self.sensor_specification = sensor_specification @property def number_of_arms(self) -> int: diff --git a/biorobot/brittle_star/usage_examples/directed_locomotion_single.py b/biorobot/brittle_star/usage_examples/directed_locomotion_single.py index d0b8c8d..7900aad 100644 --- a/biorobot/brittle_star/usage_examples/directed_locomotion_single.py +++ b/biorobot/brittle_star/usage_examples/directed_locomotion_single.py @@ -54,6 +54,7 @@ def create_env( use_p_control=False, use_torque_control=True, use_tendons=False, + num_contact_sensors_per_segment=1 ) morphology = MJCFBrittleStarMorphology(morphology_spec) arena_config = AquariumArenaConfiguration(attach_target=True)