From 9598ab127f1404a8cb477905c7e55ca4b2ab76a8 Mon Sep 17 00:00:00 2001 From: dmarzoug Date: Wed, 13 Nov 2024 13:22:43 +0100 Subject: [PATCH] black formatting --- .../environment/shared/observables.py | 54 +++---- .../mjcf/morphology/morphology.py | 8 +- .../mjcf/morphology/parts/arm_segment.py | 146 +++++++++++------- .../mjcf/morphology/parts/disk.py | 44 ++++-- .../morphology/specification/specification.py | 4 +- .../directed_locomotion_single.py | 6 +- 6 files changed, 155 insertions(+), 107 deletions(-) diff --git a/biorobot/brittle_star/environment/shared/observables.py b/biorobot/brittle_star/environment/shared/observables.py index cffc380..94bcf05 100644 --- a/biorobot/brittle_star/environment/shared/observables.py +++ b/biorobot/brittle_star/environment/shared/observables.py @@ -26,7 +26,7 @@ def jquat2euler(quat): def get_num_contacts_and_segment_contacts_fn( - mj_model: mujoco.MjModel, backend: str + mj_model: mujoco.MjModel, backend: str ) -> Tuple[int, Callable[[BaseEnvState], chex.Array]]: if backend == "mjx": segment_capsule_geom_ids = np.array( @@ -34,7 +34,7 @@ def get_num_contacts_and_segment_contacts_fn( geom_id for geom_id in range(mj_model.ngeom) if "segment" in mj_model.geom(geom_id).name - and "capsule" in mj_model.geom(geom_id).name + and "capsule" in mj_model.geom(geom_id).name ] ) @@ -44,8 +44,8 @@ def get_segment_contacts(state: MJXEnvState) -> jnp.ndarray: def solve_contact(geom_id: int) -> jnp.ndarray: return ( - jnp.sum(contacts * jnp.any(geom_id == contact_data.geom, axis=-1)) - > 0 + jnp.sum(contacts * jnp.any(geom_id == contact_data.geom, axis=-1)) + > 0 ).astype(int) return jax.vmap(solve_contact)(segment_capsule_geom_ids) @@ -83,7 +83,7 @@ def get_segment_contacts(state: MJCEnvState) -> np.ndarray: def get_base_brittle_star_observables( - mj_model: mujoco.MjModel, backend: str + mj_model: mujoco.MjModel, backend: str ) -> List[BaseObservable]: if backend == "mjx": observable_class = MJXObservable @@ -111,7 +111,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + sensor.adr[0] : sensor.adr[0] + sensor.dim[0] ] for sensor in joint_pos_sensors ] @@ -131,7 +131,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + sensor.adr[0] : sensor.adr[0] + sensor.dim[0] ] for sensor in joint_vel_sensors ] @@ -151,7 +151,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + sensor.adr[0] : sensor.adr[0] + sensor.dim[0] ] for sensor in joint_actuator_frc_sensors ] @@ -171,7 +171,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + sensor.adr[0] : sensor.adr[0] + sensor.dim[0] ] for sensor in actuator_frc_sensors ] @@ -183,23 +183,23 @@ def get_base_brittle_star_observables( mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEPOS - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_position_observable = observable_class( name="disk_position", low=-bnp.inf * bnp.ones(3), high=bnp.inf * bnp.ones(3), retriever=lambda state: get_data(state).sensordata[ - disk_framepos_sensor.adr[0]: disk_framepos_sensor.adr[0] - + disk_framepos_sensor.dim[0] - ], + disk_framepos_sensor.adr[0] : disk_framepos_sensor.adr[0] + + disk_framepos_sensor.dim[0] + ], ) # disk rotation disk_framequat_sensor = [ mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEQUAT - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_rotation_observable = observable_class( name="disk_rotation", @@ -207,8 +207,8 @@ def get_base_brittle_star_observables( high=bnp.pi * bnp.ones(3), retriever=lambda state: get_quat2euler_fn(backend=backend)( get_data(state).sensordata[ - disk_framequat_sensor.adr[0]: disk_framequat_sensor.adr[0] - + disk_framequat_sensor.dim[0] + disk_framequat_sensor.adr[0] : disk_framequat_sensor.adr[0] + + disk_framequat_sensor.dim[0] ] ), ) @@ -218,16 +218,16 @@ def get_base_brittle_star_observables( mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMELINVEL - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_linvel_observable = observable_class( name="disk_linear_velocity", low=-bnp.inf * bnp.ones(3), high=bnp.inf * bnp.ones(3), retriever=lambda state: get_data(state).sensordata[ - disk_framelinvel_sensor.adr[0]: disk_framelinvel_sensor.adr[0] - + disk_framelinvel_sensor.dim[0] - ], + disk_framelinvel_sensor.adr[0] : disk_framelinvel_sensor.adr[0] + + disk_framelinvel_sensor.dim[0] + ], ) # disk angvel @@ -235,16 +235,16 @@ def get_base_brittle_star_observables( mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEANGVEL - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_angvel_observable = observable_class( name="disk_angular_velocity", low=-bnp.inf * bnp.ones(3), high=bnp.inf * bnp.ones(3), retriever=lambda state: get_data(state).sensordata[ - disk_frameangvel_sensor.adr[0]: disk_frameangvel_sensor.adr[0] - + disk_frameangvel_sensor.dim[0] - ], + disk_frameangvel_sensor.adr[0] : disk_frameangvel_sensor.adr[0] + + disk_frameangvel_sensor.dim[0] + ], ) # tendons @@ -260,7 +260,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + sensor.adr[0] : sensor.adr[0] + sensor.dim[0] ] for sensor in tendon_pos_sensors ] @@ -279,7 +279,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + sensor.adr[0] : sensor.adr[0] + sensor.dim[0] ] for sensor in tendon_vel_sensors ] @@ -308,5 +308,5 @@ def get_base_brittle_star_observables( disk_angvel_observable, tendon_pos_observable, tendon_vel_observable, - segment_contact_observable + segment_contact_observable, ] diff --git a/biorobot/brittle_star/mjcf/morphology/morphology.py b/biorobot/brittle_star/mjcf/morphology/morphology.py index 72a271d..3566fac 100644 --- a/biorobot/brittle_star/mjcf/morphology/morphology.py +++ b/biorobot/brittle_star/mjcf/morphology/morphology.py @@ -73,7 +73,11 @@ def _configure_camera(self) -> None: if __name__ == "__main__": spec = default_brittle_star_morphology_specification( - num_arms=5, num_segments_per_arm=5, use_p_control=False, use_torque_control=True, use_tendons=True, - radius_to_strength_factor=200 + num_arms=5, + num_segments_per_arm=5, + use_p_control=False, + use_torque_control=True, + use_tendons=True, + radius_to_strength_factor=200, ) MJCFBrittleStarMorphology(spec).export_to_xml_with_assets("./mjcf") diff --git a/biorobot/brittle_star/mjcf/morphology/parts/arm_segment.py b/biorobot/brittle_star/mjcf/morphology/parts/arm_segment.py index f5085a0..751f0f5 100644 --- a/biorobot/brittle_star/mjcf/morphology/parts/arm_segment.py +++ b/biorobot/brittle_star/mjcf/morphology/parts/arm_segment.py @@ -17,13 +17,13 @@ class MJCFBrittleStarArmSegment(MJCFMorphologyPart): def __init__( - self, - parent: Union[MJCFMorphology, MJCFMorphologyPart], - name: str, - pos: np.array, - euler: np.array, - *args, - **kwargs, + self, + parent: Union[MJCFMorphology, MJCFMorphologyPart], + name: str, + pos: np.array, + euler: np.array, + *args, + **kwargs, ) -> None: super().__init__(parent, name, pos, euler, *args, **kwargs) @@ -64,7 +64,7 @@ def _build_capsule(self) -> None: pos=self.center_of_capsule, euler=[0, np.pi / 2, 0], size=[radius, length / 2], - rgba=colors.rgba_green + rgba=colors.rgba_green, ) def _build_connector(self) -> None: @@ -88,10 +88,10 @@ def center_of_capsule(self) -> np.ndarray: return np.array([x_offset, 0, 0]) def _configure_joint( - self, - name: str, - axis: np.ndarray, - joint_specification: BrittleStarJointSpecification, + self, + name: str, + axis: np.ndarray, + joint_specification: BrittleStarJointSpecification, ) -> _ElementImpl: joint = self.mjcf_body.add( "joint", @@ -111,11 +111,14 @@ def _configure_joints(self) -> None: self._configure_joint( name=f"{self.base_name}_in_plane_joint", axis=[0, 0, 1], - joint_specification=self._segment_specification.in_plane_joint_specification), + joint_specification=self._segment_specification.in_plane_joint_specification, + ), self._configure_joint( name=f"{self.base_name}_out_of_plane_joint", axis=[0, -1, 0], - joint_specification=self._segment_specification.out_of_plane_joint_specification)] + joint_specification=self._segment_specification.out_of_plane_joint_specification, + ), + ] def _configure_tendon_attachment_points(self) -> None: angles = np.linspace(np.pi / 4, 7 * np.pi / 4, 4) @@ -123,23 +126,38 @@ def _configure_tendon_attachment_points(self) -> None: self.distal_taps = [] for i, angle in enumerate(angles): # proximal - pos = 0.8 * self._segment_specification.radius.value * np.array([0, np.cos(angle), np.sin(angle)]) + pos = ( + 0.8 + * self._segment_specification.radius.value + * np.array([0, np.cos(angle), np.sin(angle)]) + ) pos[0] = self._segment_specification.radius.value - self._proximal_taps.append(self.mjcf_body.add("site", - name=f"{self.base_name}_proximal_tap_{i}", - type="sphere", - rgba=rgba_red, - pos=pos, - size=[0.001])) + self._proximal_taps.append( + self.mjcf_body.add( + "site", + name=f"{self.base_name}_proximal_tap_{i}", + type="sphere", + rgba=rgba_red, + pos=pos, + size=[0.001], + ) + ) # distal - pos[0] = self._segment_specification.radius.value + self._segment_specification.length.value - self.distal_taps.append(self.mjcf_body.add("site", - name=f"{self.base_name}_distal_tap_{i}", - type="sphere", - rgba=rgba_red, - pos=pos, - size=[0.001])) + pos[0] = ( + self._segment_specification.radius.value + + self._segment_specification.length.value + ) + self.distal_taps.append( + self.mjcf_body.add( + "site", + name=f"{self.base_name}_distal_tap_{i}", + type="sphere", + rgba=rgba_red, + pos=pos, + size=[0.001], + ) + ) def _build_tendons(self) -> None: if self._segment_index == 0: @@ -149,12 +167,17 @@ def _build_tendons(self) -> None: distal_taps = self.parent.distal_taps self._tendons = [] - for tendon_index, (parent_tap, segment_tap) in enumerate(zip(distal_taps, self._proximal_taps)): - tendon = self.mjcf_model.tendon.add('spatial', name=f"{self.base_name}_tendon_{tendon_index}", - rgba=rgba_tendon_relaxed, - width=self._segment_specification.radius.value * 0.1) - tendon.add('site', site=parent_tap) - tendon.add('site', site=segment_tap) + for tendon_index, (parent_tap, segment_tap) in enumerate( + zip(distal_taps, self._proximal_taps) + ): + tendon = self.mjcf_model.tendon.add( + "spatial", + name=f"{self.base_name}_tendon_{tendon_index}", + rgba=rgba_tendon_relaxed, + width=self._segment_specification.radius.value * 0.1, + ) + tendon.add("site", site=parent_tap) + tendon.add("site", site=segment_tap) self._tendons.append(tendon) def _configure_tendons(self) -> None: @@ -172,8 +195,8 @@ def _is_last_segment(self) -> bool: @property def _actuator_strength(self) -> float: strength = ( - self._segment_specification.radius.value - * self.morphology_specification.actuation_specification.radius_to_strength_factor.value + self._segment_specification.radius.value + * self.morphology_specification.actuation_specification.radius_to_strength_factor.value ) return strength @@ -192,25 +215,27 @@ def _configure_p_control_actuator(self, transmission: _ElementImpl) -> _ElementI "ctrlrange": transmission.range, "forcelimited": True, "forcerange": [-self._actuator_strength, self._actuator_strength], - "joint": transmission + "joint": transmission, } - return self.mjcf_model.actuator.add( - "position", - **actuator_attributes - ) + return self.mjcf_model.actuator.add("position", **actuator_attributes) def _configure_p_control_actuators(self) -> None: if self.morphology_specification.actuation_specification.use_p_control.value: - self._actuators = [self._configure_p_control_actuator(transmission) for transmission in self._transmissions] + self._actuators = [ + self._configure_p_control_actuator(transmission) + for transmission in self._transmissions + ] - def _configure_torque_control_actuator(self, transmission: _ElementImpl) -> _ElementImpl: + def _configure_torque_control_actuator( + self, transmission: _ElementImpl + ) -> _ElementImpl: actuator_attributes = { "name": f"{transmission.name}_torque_control", "ctrllimited": True, "forcelimited": True, "ctrlrange": [-self._actuator_strength, self._actuator_strength], - "forcerange": [-self._actuator_strength, self._actuator_strength] + "forcerange": [-self._actuator_strength, self._actuator_strength], } if self.morphology_specification.actuation_specification.use_tendons.value: @@ -221,18 +246,25 @@ def _configure_torque_control_actuator(self, transmission: _ElementImpl) -> _Ele actuator_attributes["forcerange"] = [-self._actuator_strength * gear, 0] else: actuator_attributes["joint"] = transmission - actuator_attributes["ctrlrange"] = [-self._actuator_strength, self._actuator_strength] - actuator_attributes["forcerange"] = [-self._actuator_strength, self._actuator_strength] + actuator_attributes["ctrlrange"] = [ + -self._actuator_strength, + self._actuator_strength, + ] + actuator_attributes["forcerange"] = [ + -self._actuator_strength, + self._actuator_strength, + ] - return self.mjcf_model.actuator.add( - "motor", - **actuator_attributes - ) + return self.mjcf_model.actuator.add("motor", **actuator_attributes) def _configure_torque_control_actuators(self) -> None: - if self.morphology_specification.actuation_specification.use_torque_control.value: - self._actuators = [self._configure_torque_control_actuator(transmission) for transmission in - self._transmissions] + if ( + self.morphology_specification.actuation_specification.use_torque_control.value + ): + self._actuators = [ + self._configure_torque_control_actuator(transmission) + for transmission in self._transmissions + ] def _configure_actuators(self) -> None: self._configure_p_control_actuators() @@ -270,14 +302,10 @@ def _configure_tendon_sensors(self) -> None: if self.morphology_specification.actuation_specification.use_tendons.value: for tendon in self._tendons: self.mjcf_model.sensor.add( - "tendonpos", - name=f"{tendon.name}_tendonpos_sensor", - tendon=tendon + "tendonpos", name=f"{tendon.name}_tendonpos_sensor", tendon=tendon ) self.mjcf_model.sensor.add( - "tendonvel", - name=f"{tendon.name}_tendonvel_sensor", - tendon=tendon + "tendonvel", name=f"{tendon.name}_tendonvel_sensor", tendon=tendon ) def _configure_sensors(self) -> None: diff --git a/biorobot/brittle_star/mjcf/morphology/parts/disk.py b/biorobot/brittle_star/mjcf/morphology/parts/disk.py index 1c5a70e..5ffc870 100644 --- a/biorobot/brittle_star/mjcf/morphology/parts/disk.py +++ b/biorobot/brittle_star/mjcf/morphology/parts/disk.py @@ -13,13 +13,13 @@ class MJCFBrittleStarDisk(MJCFMorphologyPart): def __init__( - self, - parent: Union[MJCFMorphology, MJCFMorphologyPart], - name: str, - pos: np.array, - euler: np.array, - *args, - **kwargs, + self, + parent: Union[MJCFMorphology, MJCFMorphologyPart], + name: str, + pos: np.array, + euler: np.array, + *args, + **kwargs, ) -> None: super().__init__(parent, name, pos, euler, *args, **kwargs) @@ -108,32 +108,42 @@ def _configure_tendon_attachment_points(self) -> None: tap_angles = np.linspace(np.pi / 4, 7 * np.pi / 4, 4) for arm_index, arm_angle in enumerate(arm_angles): - arm_specification = self.morphology_specification.arm_specifications[arm_index] + arm_specification = self.morphology_specification.arm_specifications[ + arm_index + ] if arm_specification.number_of_segments == 0: continue - base_segment_radius = arm_specification.segment_specifications[0].radius.value + base_segment_radius = arm_specification.segment_specifications[ + 0 + ].radius.value arm_taps = [] positions = [] for angle in tap_angles: - pos = center_pos + 0.8 * base_segment_radius * np.array([0, np.cos(angle), np.sin(angle)]) + pos = center_pos + 0.8 * base_segment_radius * np.array( + [0, np.cos(angle), np.sin(angle)] + ) positions.append(pos) for tap_index, position in enumerate(positions): # rotate position around arm_angle degress # Define the rotation - rotation = R.from_euler('z', arm_angle, degrees=False) + rotation = R.from_euler("z", arm_angle, degrees=False) # Rotate point A around point B rotated_point = rotation.apply(position) - arm_taps.append(self.mjcf_body.add("site", - name=f"{self.base_name}_arm_{arm_index}_tap_{tap_index}", - type="sphere", - rgba=rgba_red, - pos=rotated_point, - size=[0.001])) + arm_taps.append( + self.mjcf_body.add( + "site", + name=f"{self.base_name}_arm_{arm_index}_tap_{tap_index}", + type="sphere", + rgba=rgba_red, + pos=rotated_point, + size=[0.001], + ) + ) self.distal_taps.append(arm_taps) def _configure_sensors(self) -> None: diff --git a/biorobot/brittle_star/mjcf/morphology/specification/specification.py b/biorobot/brittle_star/mjcf/morphology/specification/specification.py index e974871..7f5f64b 100644 --- a/biorobot/brittle_star/mjcf/morphology/specification/specification.py +++ b/biorobot/brittle_star/mjcf/morphology/specification/specification.py @@ -70,7 +70,9 @@ def __init__( assert ( use_p_control + use_torque_control == 1 ), "Only one actuation method can be used." - assert (not use_tendons or use_torque_control), "Only torque control is supported with tendons." + assert ( + not use_tendons or use_torque_control + ), "Only torque control is supported with tendons." self.use_tendons = FixedParameter(use_tendons) self.use_p_control = FixedParameter(use_p_control) self.use_torque_control = FixedParameter(use_torque_control) diff --git a/biorobot/brittle_star/usage_examples/directed_locomotion_single.py b/biorobot/brittle_star/usage_examples/directed_locomotion_single.py index 0e554e8..d0b8c8d 100644 --- a/biorobot/brittle_star/usage_examples/directed_locomotion_single.py +++ b/biorobot/brittle_star/usage_examples/directed_locomotion_single.py @@ -49,7 +49,11 @@ def create_env( backend: str, render_mode: str ) -> BrittleStarDirectedLocomotionEnvironment: morphology_spec = default_brittle_star_morphology_specification( - num_arms=5, num_segments_per_arm=5, use_p_control=False, use_torque_control=True, use_tendons=False + num_arms=5, + num_segments_per_arm=5, + use_p_control=False, + use_torque_control=True, + use_tendons=False, ) morphology = MJCFBrittleStarMorphology(morphology_spec) arena_config = AquariumArenaConfiguration(attach_target=True)