diff --git a/biorobot/brittle_star/environment/shared/base.py b/biorobot/brittle_star/environment/shared/base.py index 44148d4..723143f 100644 --- a/biorobot/brittle_star/environment/shared/base.py +++ b/biorobot/brittle_star/environment/shared/base.py @@ -87,9 +87,7 @@ def _color_segment_capsule_contacts( c = c.reshape((len(self._segment_capsule_geom_ids), -1)) > 0 c = np.any(c, axis=-1).flatten() - for capsule_geom_id, contact in zip( - self._segment_capsule_geom_ids, c - ): + for capsule_geom_id, contact in zip(self._segment_capsule_geom_ids, c): if contact: mj_model.geom(capsule_geom_id).rgba = colors.rgba_red else: diff --git a/biorobot/brittle_star/environment/shared/observables.py b/biorobot/brittle_star/environment/shared/observables.py index ab5ed26..5aaba8e 100644 --- a/biorobot/brittle_star/environment/shared/observables.py +++ b/biorobot/brittle_star/environment/shared/observables.py @@ -24,7 +24,7 @@ def jquat2euler(quat): def get_base_brittle_star_observables( - mj_model: mujoco.MjModel, backend: str + mj_model: mujoco.MjModel, backend: str ) -> List[BaseObservable]: if backend == "mjx": observable_class = MJXObservable @@ -52,7 +52,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + sensor.adr[0] : sensor.adr[0] + sensor.dim[0] ] for sensor in joint_pos_sensors ] @@ -72,7 +72,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + sensor.adr[0] : sensor.adr[0] + sensor.dim[0] ] for sensor in joint_vel_sensors ] @@ -92,7 +92,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + sensor.adr[0] : sensor.adr[0] + sensor.dim[0] ] for sensor in joint_actuator_frc_sensors ] @@ -112,7 +112,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + sensor.adr[0] : sensor.adr[0] + sensor.dim[0] ] for sensor in actuator_frc_sensors ] @@ -124,23 +124,23 @@ def get_base_brittle_star_observables( mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEPOS - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_position_observable = observable_class( name="disk_position", low=-bnp.inf * bnp.ones(3), high=bnp.inf * bnp.ones(3), retriever=lambda state: get_data(state).sensordata[ - disk_framepos_sensor.adr[0]: disk_framepos_sensor.adr[0] - + disk_framepos_sensor.dim[0] - ], + disk_framepos_sensor.adr[0] : disk_framepos_sensor.adr[0] + + disk_framepos_sensor.dim[0] + ], ) # disk rotation disk_framequat_sensor = [ mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEQUAT - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_rotation_observable = observable_class( name="disk_rotation", @@ -148,8 +148,8 @@ def get_base_brittle_star_observables( high=bnp.pi * bnp.ones(3), retriever=lambda state: get_quat2euler_fn(backend=backend)( get_data(state).sensordata[ - disk_framequat_sensor.adr[0]: disk_framequat_sensor.adr[0] - + disk_framequat_sensor.dim[0] + disk_framequat_sensor.adr[0] : disk_framequat_sensor.adr[0] + + disk_framequat_sensor.dim[0] ] ), ) @@ -159,16 +159,16 @@ def get_base_brittle_star_observables( mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMELINVEL - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_linvel_observable = observable_class( name="disk_linear_velocity", low=-bnp.inf * bnp.ones(3), high=bnp.inf * bnp.ones(3), retriever=lambda state: get_data(state).sensordata[ - disk_framelinvel_sensor.adr[0]: disk_framelinvel_sensor.adr[0] - + disk_framelinvel_sensor.dim[0] - ], + disk_framelinvel_sensor.adr[0] : disk_framelinvel_sensor.adr[0] + + disk_framelinvel_sensor.dim[0] + ], ) # disk angvel @@ -176,16 +176,16 @@ def get_base_brittle_star_observables( mj_model.sensor(i) for i in range(mj_model.nsensor) if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEANGVEL - and "disk" in mj_model.sensor(i).name + and "disk" in mj_model.sensor(i).name ][0] disk_angvel_observable = observable_class( name="disk_angular_velocity", low=-bnp.inf * bnp.ones(3), high=bnp.inf * bnp.ones(3), retriever=lambda state: get_data(state).sensordata[ - disk_frameangvel_sensor.adr[0]: disk_frameangvel_sensor.adr[0] - + disk_frameangvel_sensor.dim[0] - ], + disk_frameangvel_sensor.adr[0] : disk_frameangvel_sensor.adr[0] + + disk_frameangvel_sensor.dim[0] + ], ) # tendons @@ -201,7 +201,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + sensor.adr[0] : sensor.adr[0] + sensor.dim[0] ] for sensor in tendon_pos_sensors ] @@ -220,7 +220,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + sensor.adr[0] : sensor.adr[0] + sensor.dim[0] ] for sensor in tendon_vel_sensors ] @@ -229,9 +229,7 @@ def get_base_brittle_star_observables( # contacts contact_sensors = [ - sensor - for sensor in sensors - if sensor.type[0] == mujoco.mjtSensor.mjSENS_TOUCH + sensor for sensor in sensors if sensor.type[0] == mujoco.mjtSensor.mjSENS_TOUCH ] segment_contact_observable = observable_class( name="segment_contact", @@ -240,7 +238,7 @@ def get_base_brittle_star_observables( retriever=lambda state: bnp.array( [ get_data(state).sensordata[ - sensor.adr[0]: sensor.adr[0] + sensor.dim[0] + sensor.adr[0] : sensor.adr[0] + sensor.dim[0] ] for sensor in contact_sensors ] diff --git a/biorobot/brittle_star/mjcf/morphology/morphology.py b/biorobot/brittle_star/mjcf/morphology/morphology.py index e5570f3..b966157 100644 --- a/biorobot/brittle_star/mjcf/morphology/morphology.py +++ b/biorobot/brittle_star/mjcf/morphology/morphology.py @@ -79,6 +79,6 @@ def _configure_camera(self) -> None: use_torque_control=True, use_tendons=True, radius_to_strength_factor=200, - num_contact_sensors_per_segment=8 + num_contact_sensors_per_segment=8, ) MJCFBrittleStarMorphology(spec).export_to_xml_with_assets("./mjcf") diff --git a/biorobot/brittle_star/mjcf/morphology/parts/arm_segment.py b/biorobot/brittle_star/mjcf/morphology/parts/arm_segment.py index 2804c7e..5934496 100644 --- a/biorobot/brittle_star/mjcf/morphology/parts/arm_segment.py +++ b/biorobot/brittle_star/mjcf/morphology/parts/arm_segment.py @@ -17,13 +17,13 @@ class MJCFBrittleStarArmSegment(MJCFMorphologyPart): def __init__( - self, - parent: Union[MJCFMorphology, MJCFMorphologyPart], - name: str, - pos: np.array, - euler: np.array, - *args, - **kwargs, + self, + parent: Union[MJCFMorphology, MJCFMorphologyPart], + name: str, + pos: np.array, + euler: np.array, + *args, + **kwargs, ) -> None: super().__init__(parent, name, pos, euler, *args, **kwargs) @@ -88,10 +88,10 @@ def center_of_capsule(self) -> np.ndarray: return np.array([x_offset, 0, 0]) def _configure_joint( - self, - name: str, - axis: np.ndarray, - joint_specification: BrittleStarJointSpecification, + self, + name: str, + axis: np.ndarray, + joint_specification: BrittleStarJointSpecification, ) -> _ElementImpl: joint = self.mjcf_body.add( "joint", @@ -127,9 +127,9 @@ def _configure_tendon_attachment_points(self) -> None: for i, angle in enumerate(angles): # proximal pos = ( - 0.8 - * self._segment_specification.radius.value - * np.array([0, np.cos(angle), np.sin(angle)]) + 0.8 + * self._segment_specification.radius.value + * np.array([0, np.cos(angle), np.sin(angle)]) ) pos[0] = self._segment_specification.radius.value self._proximal_taps.append( @@ -145,8 +145,8 @@ def _configure_tendon_attachment_points(self) -> None: # distal pos[0] = ( - self._segment_specification.radius.value - + self._segment_specification.length.value + self._segment_specification.radius.value + + self._segment_specification.length.value ) self.distal_taps.append( self.mjcf_body.add( @@ -168,7 +168,7 @@ def _build_tendons(self) -> None: self._tendons = [] for tendon_index, (parent_tap, segment_tap) in enumerate( - zip(distal_taps, self._proximal_taps) + zip(distal_taps, self._proximal_taps) ): tendon = self.mjcf_model.tendon.add( "spatial", @@ -195,8 +195,8 @@ def _is_last_segment(self) -> bool: @property def _actuator_strength(self) -> float: strength = ( - self._segment_specification.radius.value - * self.morphology_specification.actuation_specification.radius_to_strength_factor.value + self._segment_specification.radius.value + * self.morphology_specification.actuation_specification.radius_to_strength_factor.value ) return strength @@ -228,7 +228,7 @@ def _configure_p_control_actuators(self) -> None: ] def _configure_torque_control_actuator( - self, transmission: _ElementImpl + self, transmission: _ElementImpl ) -> _ElementImpl: actuator_attributes = { "name": f"{transmission.name}_torque_control", @@ -259,7 +259,7 @@ def _configure_torque_control_actuator( def _configure_torque_control_actuators(self) -> None: if ( - self.morphology_specification.actuation_specification.use_torque_control.value + self.morphology_specification.actuation_specification.use_torque_control.value ): self._actuators = [ self._configure_torque_control_actuator(transmission) @@ -309,36 +309,50 @@ def _configure_tendon_sensors(self) -> None: ) def _configure_contact_sensors(self) -> None: - num_contact_sensors = self.morphology_specification.sensor_specification.num_contact_sensors_per_segment.value + num_contact_sensors = ( + self.morphology_specification.sensor_specification.num_contact_sensors_per_segment.value + ) contact_sites = [] if num_contact_sensors == 1: contact_sites.append( - self.mjcf_body.add("site", - type="capsule", - pos=self._capsule.pos, - size=self._capsule.size * np.array([1.01, 1]), - rgba=rgba_red * np.array([1, 1, 1, 0.5]), - euler=self._capsule.euler, - name=f"{self.base_name}_contact_site", - group=3), + self.mjcf_body.add( + "site", + type="capsule", + pos=self._capsule.pos, + size=self._capsule.size * np.array([1.01, 1]), + rgba=rgba_red * np.array([1, 1, 1, 0.5]), + euler=self._capsule.euler, + name=f"{self.base_name}_contact_site", + group=3, + ), ) else: - angles = np.linspace(-np.pi / 2, 1.5 * np.pi, num_contact_sensors + 1)[:num_contact_sensors] + angles = np.linspace(-np.pi / 2, 1.5 * np.pi, num_contact_sensors + 1)[ + :num_contact_sensors + ] radius = self._segment_specification.radius.value for i, angle in enumerate(angles): - pos = self.center_of_capsule + 0.95 * radius * np.array([0, np.cos(angle), np.sin(angle)]) - contact_sites.append(self.mjcf_body.add( - "site", - pos=pos, - euler=[angle, 0, 0], - type="box", - size=[self._segment_specification.length.value / 2 + radius / 2, 0.05 * radius, 0.4 * radius], - rgba=rgba_red * np.array([1, 1, 1, 0.5]), - group=3, - name=f"{self.base_name}_contact_site_{i}" - )) + pos = self.center_of_capsule + 0.95 * radius * np.array( + [0, np.cos(angle), np.sin(angle)] + ) + contact_sites.append( + self.mjcf_body.add( + "site", + pos=pos, + euler=[angle, 0, 0], + type="box", + size=[ + self._segment_specification.length.value / 2 + radius / 2, + 0.05 * radius, + 0.4 * radius, + ], + rgba=rgba_red * np.array([1, 1, 1, 0.5]), + group=3, + name=f"{self.base_name}_contact_site_{i}", + ) + ) for site in contact_sites: self.mjcf_model.sensor.add( diff --git a/biorobot/brittle_star/mjcf/morphology/specification/default.py b/biorobot/brittle_star/mjcf/morphology/specification/default.py index 63980e2..62639bb 100644 --- a/biorobot/brittle_star/mjcf/morphology/specification/default.py +++ b/biorobot/brittle_star/mjcf/morphology/specification/default.py @@ -8,7 +8,8 @@ BrittleStarArmSpecification, BrittleStarDiskSpecification, BrittleStarJointSpecification, - BrittleStarMorphologySpecification, BrittleStarSensorSpecification, + BrittleStarMorphologySpecification, + BrittleStarSensorSpecification, ) START_SEGMENT_RADIUS = 0.025 @@ -32,7 +33,7 @@ def default_joint_specification(range: float) -> BrittleStarJointSpecification: def default_arm_segment_specification( - alpha: float, + alpha: float, ) -> BrittleStarArmSegmentSpecification: in_plane_joint_specification = default_joint_specification( range=30 / 180 * np.pi @@ -72,13 +73,13 @@ def default_arm_specification(num_segments_per_arm: int) -> BrittleStarArmSpecif def default_brittle_star_morphology_specification( - num_arms: int = 5, - num_segments_per_arm: Union[int, List[int]] = 5, - use_tendons: bool = False, - use_p_control: bool = False, - use_torque_control: bool = False, - radius_to_strength_factor: float = 200, - num_contact_sensors_per_segment: int = 1 + num_arms: int = 5, + num_segments_per_arm: Union[int, List[int]] = 5, + use_tendons: bool = False, + use_p_control: bool = False, + use_torque_control: bool = False, + radius_to_strength_factor: float = 200, + num_contact_sensors_per_segment: int = 1, ) -> BrittleStarMorphologySpecification: disk_specification = BrittleStarDiskSpecification( diameter=DISK_DIAMETER, height=DISK_HEIGHT @@ -111,7 +112,7 @@ def default_brittle_star_morphology_specification( disk_specification=disk_specification, arm_specifications=arm_specifications, actuation_specification=actuation_specification, - sensor_specification=sensor_specification + sensor_specification=sensor_specification, ) return specification diff --git a/biorobot/brittle_star/mjcf/morphology/specification/specification.py b/biorobot/brittle_star/mjcf/morphology/specification/specification.py index 241fa99..430d4e1 100644 --- a/biorobot/brittle_star/mjcf/morphology/specification/specification.py +++ b/biorobot/brittle_star/mjcf/morphology/specification/specification.py @@ -7,11 +7,11 @@ class BrittleStarJointSpecification(Specification): def __init__( - self, - range: float, - stiffness: float, - damping: float, - armature: float, + self, + range: float, + stiffness: float, + damping: float, + armature: float, ) -> None: super().__init__() self.stiffness = FixedParameter(value=stiffness) @@ -22,11 +22,11 @@ def __init__( class BrittleStarArmSegmentSpecification(Specification): def __init__( - self, - radius: float, - length: float, - in_plane_joint_specification: BrittleStarJointSpecification, - out_of_plane_joint_specification: BrittleStarJointSpecification, + self, + radius: float, + length: float, + in_plane_joint_specification: BrittleStarJointSpecification, + out_of_plane_joint_specification: BrittleStarJointSpecification, ) -> None: super().__init__() self.radius = FixedParameter(radius) @@ -37,7 +37,7 @@ def __init__( class BrittleStarArmSpecification(Specification): def __init__( - self, segment_specifications: List[BrittleStarArmSegmentSpecification] + self, segment_specifications: List[BrittleStarArmSegmentSpecification] ) -> None: super().__init__() self.segment_specifications = segment_specifications @@ -49,9 +49,9 @@ def number_of_segments(self) -> int: class BrittleStarDiskSpecification(Specification): def __init__( - self, - diameter: float, - height: float, + self, + diameter: float, + height: float, ) -> None: super().__init__() self.radius = FixedParameter(diameter / 2) @@ -60,18 +60,18 @@ def __init__( class BrittleStarActuationSpecification(Specification): def __init__( - self, - use_tendons: bool, - use_p_control: bool, - use_torque_control: bool, - radius_to_strength_factor: float, + self, + use_tendons: bool, + use_p_control: bool, + use_torque_control: bool, + radius_to_strength_factor: float, ) -> None: super().__init__() assert ( - use_p_control + use_torque_control == 1 + use_p_control + use_torque_control == 1 ), "Only one actuation method can be used." assert ( - not use_tendons or use_torque_control + not use_tendons or use_torque_control ), "Only torque control is supported with tendons." self.use_tendons = FixedParameter(use_tendons) self.use_p_control = FixedParameter(use_p_control) @@ -82,16 +82,18 @@ def __init__( class BrittleStarSensorSpecification(Specification): def __init__(self, num_contact_sensors_per_segment: int) -> int: super().__init__() - self.num_contact_sensors_per_segment = FixedParameter(num_contact_sensors_per_segment) + self.num_contact_sensors_per_segment = FixedParameter( + num_contact_sensors_per_segment + ) class BrittleStarMorphologySpecification(MorphologySpecification): def __init__( - self, - disk_specification: BrittleStarDiskSpecification, - arm_specifications: List[BrittleStarArmSpecification], - actuation_specification: BrittleStarActuationSpecification, - sensor_specification: BrittleStarSensorSpecification + self, + disk_specification: BrittleStarDiskSpecification, + arm_specifications: List[BrittleStarArmSpecification], + actuation_specification: BrittleStarActuationSpecification, + sensor_specification: BrittleStarSensorSpecification, ) -> None: super(BrittleStarMorphologySpecification, self).__init__() self.disk_specification = disk_specification diff --git a/biorobot/brittle_star/usage_examples/directed_locomotion_single.py b/biorobot/brittle_star/usage_examples/directed_locomotion_single.py index 7900aad..920ca1d 100644 --- a/biorobot/brittle_star/usage_examples/directed_locomotion_single.py +++ b/biorobot/brittle_star/usage_examples/directed_locomotion_single.py @@ -54,7 +54,7 @@ def create_env( use_p_control=False, use_torque_control=True, use_tendons=False, - num_contact_sensors_per_segment=1 + num_contact_sensors_per_segment=1, ) morphology = MJCFBrittleStarMorphology(morphology_spec) arena_config = AquariumArenaConfiguration(attach_target=True)