From cfc95cecf126bedaf9c90dac2e61d746bb9108da Mon Sep 17 00:00:00 2001 From: dmarzoug Date: Wed, 13 Nov 2024 13:06:32 +0100 Subject: [PATCH] [brittle-star] Added tendon observables --- .../environment/shared/observables.py | 92 ++++++++++++++----- .../directed_locomotion_single.py | 4 +- 2 files changed, 69 insertions(+), 27 deletions(-) diff --git a/biorobot/brittle_star/environment/shared/observables.py b/biorobot/brittle_star/environment/shared/observables.py index 2d3a991..cffc380 100644 --- a/biorobot/brittle_star/environment/shared/observables.py +++ b/biorobot/brittle_star/environment/shared/observables.py @@ -26,7 +26,7 @@ def jquat2euler(quat): def get_num_contacts_and_segment_contacts_fn( - mj_model: mujoco.MjModel, backend: str + mj_model: mujoco.MjModel, backend: str ) -> Tuple[int, Callable[[BaseEnvState], chex.Array]]: if backend == "mjx": segment_capsule_geom_ids = np.array( @@ -34,7 +34,7 @@ def get_num_contacts_and_segment_contacts_fn( geom_id for geom_id in range(mj_model.ngeom) if "segment" in mj_model.geom(geom_id).name - and "capsule" in mj_model.geom(geom_id).name + and "capsule" in mj_model.geom(geom_id).name ] ) @@ -44,8 +44,8 @@ def get_segment_contacts(state: MJXEnvState) -> jnp.ndarray: def solve_contact(geom_id: int) -> jnp.ndarray: return ( - jnp.sum(contacts * jnp.any(geom_id == contact_data.geom, axis=-1)) - > 0 + jnp.sum(contacts * jnp.any(geom_id == contact_data.geom, axis=-1)) + > 0 ).astype(int) return jax.vmap(solve_contact)(segment_capsule_geom_ids) @@ -83,7 +83,7 @@ def get_segment_contacts(state: MJCEnvState) -> np.ndarray: def get_base_brittle_star_observables( - mj_model: mujoco.MjModel, backend: str + mj_model: mujoco.MjModel, backend: str ) -> List[BaseObservable]: if backend == "mjx": observable_class = MJXObservable @@ -111,7 +111,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0] : sensor.adr[0] + sensor.dim[0] + sensor.adr[0]: sensor.adr[0] + sensor.dim[0] ] for sensor in joint_pos_sensors ] @@ -131,7 +131,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0] : sensor.adr[0] + sensor.dim[0] + sensor.adr[0]: sensor.adr[0] + sensor.dim[0] ] for sensor in joint_vel_sensors ] @@ -151,7 +151,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0] : sensor.adr[0] + sensor.dim[0] + sensor.adr[0]: sensor.adr[0] + sensor.dim[0] ] for sensor in joint_actuator_frc_sensors ] @@ -171,7 +171,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0] : sensor.adr[0] + sensor.dim[0] + sensor.adr[0]: sensor.adr[0] + sensor.dim[0] ] for sensor in actuator_frc_sensors ] @@ -183,23 +183,23 @@ def get_base_brittle_star_observables( mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEPOS - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_position_observable = observable_class( name="disk_position", low=-bnp.inf * bnp.ones(3), high=bnp.inf * bnp.ones(3), retriever=lambda state: get_data(state).sensordata[ - disk_framepos_sensor.adr[0] : disk_framepos_sensor.adr[0] - + disk_framepos_sensor.dim[0] - ], + disk_framepos_sensor.adr[0]: disk_framepos_sensor.adr[0] + + disk_framepos_sensor.dim[0] + ], ) # disk rotation disk_framequat_sensor = [ mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEQUAT - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_rotation_observable = observable_class( name="disk_rotation", @@ -207,8 +207,8 @@ def get_base_brittle_star_observables( high=bnp.pi * bnp.ones(3), retriever=lambda state: get_quat2euler_fn(backend=backend)( get_data(state).sensordata[ - disk_framequat_sensor.adr[0] : disk_framequat_sensor.adr[0] - + disk_framequat_sensor.dim[0] + disk_framequat_sensor.adr[0]: disk_framequat_sensor.adr[0] + + disk_framequat_sensor.dim[0] ] ), ) @@ -218,16 +218,16 @@ def get_base_brittle_star_observables( mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMELINVEL - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_linvel_observable = observable_class( name="disk_linear_velocity", low=-bnp.inf * bnp.ones(3), high=bnp.inf * bnp.ones(3), retriever=lambda state: get_data(state).sensordata[ - disk_framelinvel_sensor.adr[0] : disk_framelinvel_sensor.adr[0] - + disk_framelinvel_sensor.dim[0] - ], + disk_framelinvel_sensor.adr[0]: disk_framelinvel_sensor.adr[0] + + disk_framelinvel_sensor.dim[0] + ], ) # disk angvel @@ -235,18 +235,58 @@ def get_base_brittle_star_observables( mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEANGVEL - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_angvel_observable = observable_class( name="disk_angular_velocity", low=-bnp.inf * bnp.ones(3), high=bnp.inf * bnp.ones(3), retriever=lambda state: get_data(state).sensordata[ - disk_frameangvel_sensor.adr[0] : disk_frameangvel_sensor.adr[0] - + disk_frameangvel_sensor.dim[0] - ], + disk_frameangvel_sensor.adr[0]: disk_frameangvel_sensor.adr[0] + + disk_frameangvel_sensor.dim[0] + ], ) + # tendons + tendon_pos_sensors = [ + sensor + for sensor in sensors + if sensor.type[0] == mujoco.mjtSensor.mjSENS_TENDONPOS + ] + tendon_pos_observable = observable_class( + name="tendon_position", + low=-bnp.inf * bnp.ones(len(tendon_pos_sensors)), + high=bnp.inf * bnp.ones(len(tendon_pos_sensors)), + retriever=lambda state: bnp.array( + [ + get_data(state).sensordata[ + sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + ] + for sensor in tendon_pos_sensors + ] + ).flatten(), + ) + + tendon_vel_sensors = [ + sensor + for sensor in sensors + if sensor.type[0] == mujoco.mjtSensor.mjSENS_TENDONVEL + ] + tendon_vel_observable = observable_class( + name="tendon_velocity", + low=-bnp.inf * bnp.ones(len(tendon_vel_sensors)), + high=bnp.inf * bnp.ones(len(tendon_vel_sensors)), + retriever=lambda state: bnp.array( + [ + get_data(state).sensordata[ + sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + ] + for sensor in tendon_vel_sensors + ] + ).flatten(), + ) + + # contacts num_contacts, get_segment_contacts_fn = get_num_contacts_and_segment_contacts_fn( mj_model=mj_model, backend=backend ) @@ -266,5 +306,7 @@ def get_base_brittle_star_observables( disk_rotation_observable, disk_linvel_observable, disk_angvel_observable, - segment_contact_observable, + tendon_pos_observable, + tendon_vel_observable, + segment_contact_observable ] diff --git a/biorobot/brittle_star/usage_examples/directed_locomotion_single.py b/biorobot/brittle_star/usage_examples/directed_locomotion_single.py index 3bcefd3..0e554e8 100644 --- a/biorobot/brittle_star/usage_examples/directed_locomotion_single.py +++ b/biorobot/brittle_star/usage_examples/directed_locomotion_single.py @@ -49,7 +49,7 @@ def create_env( backend: str, render_mode: str ) -> BrittleStarDirectedLocomotionEnvironment: morphology_spec = default_brittle_star_morphology_specification( - num_arms=5, num_segments_per_arm=5, use_p_control=False, use_torque_control=True + num_arms=5, num_segments_per_arm=5, use_p_control=False, use_torque_control=True, use_tendons=False ) morphology = MJCFBrittleStarMorphology(morphology_spec) arena_config = AquariumArenaConfiguration(attach_target=True) @@ -100,7 +100,7 @@ def action_sample_fn(rng: chex.PRNGKey) -> Tuple[jnp.ndarray, chex.PRNGKey]: action, action_rng = action_sample_fn(action_rng) state = step_fn(state=state, action=action) post_render(env.render(state=state), env.environment_configuration) - print(state.observations["actuator_force"]) + print(state.observations["tendon_position"]) if state.terminated | state.truncated: state = reset_fn(env_rng) env.close()