Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
driesmarzougui committed Nov 13, 2024
1 parent 97e459e commit 9598ab1
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 107 deletions.
54 changes: 27 additions & 27 deletions biorobot/brittle_star/environment/shared/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def jquat2euler(quat):


def get_num_contacts_and_segment_contacts_fn(
mj_model: mujoco.MjModel, backend: str
mj_model: mujoco.MjModel, backend: str
) -> Tuple[int, Callable[[BaseEnvState], chex.Array]]:
if backend == "mjx":
segment_capsule_geom_ids = np.array(
[
geom_id
for geom_id in range(mj_model.ngeom)
if "segment" in mj_model.geom(geom_id).name
and "capsule" in mj_model.geom(geom_id).name
and "capsule" in mj_model.geom(geom_id).name
]
)

Expand All @@ -44,8 +44,8 @@ def get_segment_contacts(state: MJXEnvState) -> jnp.ndarray:

def solve_contact(geom_id: int) -> jnp.ndarray:
return (
jnp.sum(contacts * jnp.any(geom_id == contact_data.geom, axis=-1))
> 0
jnp.sum(contacts * jnp.any(geom_id == contact_data.geom, axis=-1))
> 0
).astype(int)

return jax.vmap(solve_contact)(segment_capsule_geom_ids)
Expand Down Expand Up @@ -83,7 +83,7 @@ def get_segment_contacts(state: MJCEnvState) -> np.ndarray:


def get_base_brittle_star_observables(
mj_model: mujoco.MjModel, backend: str
mj_model: mujoco.MjModel, backend: str
) -> List[BaseObservable]:
if backend == "mjx":
observable_class = MJXObservable
Expand Down Expand Up @@ -111,7 +111,7 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
]
for sensor in joint_pos_sensors
]
Expand All @@ -131,7 +131,7 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
]
for sensor in joint_vel_sensors
]
Expand All @@ -151,7 +151,7 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
]
for sensor in joint_actuator_frc_sensors
]
Expand All @@ -171,7 +171,7 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
]
for sensor in actuator_frc_sensors
]
Expand All @@ -183,32 +183,32 @@ def get_base_brittle_star_observables(
mj_model.sensor(i)
for i in range(mj_model.nsensor)
if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEPOS
and "disk" in mj_model.sensor(i).name
and "disk" in mj_model.sensor(i).name
][0]
disk_position_observable = observable_class(
name="disk_position",
low=-bnp.inf * bnp.ones(3),
high=bnp.inf * bnp.ones(3),
retriever=lambda state: get_data(state).sensordata[
disk_framepos_sensor.adr[0]: disk_framepos_sensor.adr[0]
+ disk_framepos_sensor.dim[0]
],
disk_framepos_sensor.adr[0] : disk_framepos_sensor.adr[0]
+ disk_framepos_sensor.dim[0]
],
)
# disk rotation
disk_framequat_sensor = [
mj_model.sensor(i)
for i in range(mj_model.nsensor)
if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEQUAT
and "disk" in mj_model.sensor(i).name
and "disk" in mj_model.sensor(i).name
][0]
disk_rotation_observable = observable_class(
name="disk_rotation",
low=-bnp.pi * bnp.ones(3),
high=bnp.pi * bnp.ones(3),
retriever=lambda state: get_quat2euler_fn(backend=backend)(
get_data(state).sensordata[
disk_framequat_sensor.adr[0]: disk_framequat_sensor.adr[0]
+ disk_framequat_sensor.dim[0]
disk_framequat_sensor.adr[0] : disk_framequat_sensor.adr[0]
+ disk_framequat_sensor.dim[0]
]
),
)
Expand All @@ -218,33 +218,33 @@ def get_base_brittle_star_observables(
mj_model.sensor(i)
for i in range(mj_model.nsensor)
if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMELINVEL
and "disk" in mj_model.sensor(i).name
and "disk" in mj_model.sensor(i).name
][0]
disk_linvel_observable = observable_class(
name="disk_linear_velocity",
low=-bnp.inf * bnp.ones(3),
high=bnp.inf * bnp.ones(3),
retriever=lambda state: get_data(state).sensordata[
disk_framelinvel_sensor.adr[0]: disk_framelinvel_sensor.adr[0]
+ disk_framelinvel_sensor.dim[0]
],
disk_framelinvel_sensor.adr[0] : disk_framelinvel_sensor.adr[0]
+ disk_framelinvel_sensor.dim[0]
],
)

# disk angvel
disk_frameangvel_sensor = [
mj_model.sensor(i)
for i in range(mj_model.nsensor)
if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEANGVEL
and "disk" in mj_model.sensor(i).name
and "disk" in mj_model.sensor(i).name
][0]
disk_angvel_observable = observable_class(
name="disk_angular_velocity",
low=-bnp.inf * bnp.ones(3),
high=bnp.inf * bnp.ones(3),
retriever=lambda state: get_data(state).sensordata[
disk_frameangvel_sensor.adr[0]: disk_frameangvel_sensor.adr[0]
+ disk_frameangvel_sensor.dim[0]
],
disk_frameangvel_sensor.adr[0] : disk_frameangvel_sensor.adr[0]
+ disk_frameangvel_sensor.dim[0]
],
)

# tendons
Expand All @@ -260,7 +260,7 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
]
for sensor in tendon_pos_sensors
]
Expand All @@ -279,7 +279,7 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
]
for sensor in tendon_vel_sensors
]
Expand Down Expand Up @@ -308,5 +308,5 @@ def get_base_brittle_star_observables(
disk_angvel_observable,
tendon_pos_observable,
tendon_vel_observable,
segment_contact_observable
segment_contact_observable,
]
8 changes: 6 additions & 2 deletions biorobot/brittle_star/mjcf/morphology/morphology.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def _configure_camera(self) -> None:

if __name__ == "__main__":
spec = default_brittle_star_morphology_specification(
num_arms=5, num_segments_per_arm=5, use_p_control=False, use_torque_control=True, use_tendons=True,
radius_to_strength_factor=200
num_arms=5,
num_segments_per_arm=5,
use_p_control=False,
use_torque_control=True,
use_tendons=True,
radius_to_strength_factor=200,
)
MJCFBrittleStarMorphology(spec).export_to_xml_with_assets("./mjcf")
Loading

0 comments on commit 9598ab1

Please sign in to comment.