Skip to content

Commit

Permalink
[brittle-star] Added tendon observables
Browse files Browse the repository at this point in the history
  • Loading branch information
driesmarzougui committed Nov 13, 2024
1 parent 1b7c930 commit cfc95ce
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 27 deletions.
92 changes: 67 additions & 25 deletions biorobot/brittle_star/environment/shared/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def jquat2euler(quat):


def get_num_contacts_and_segment_contacts_fn(
mj_model: mujoco.MjModel, backend: str
mj_model: mujoco.MjModel, backend: str
) -> Tuple[int, Callable[[BaseEnvState], chex.Array]]:
if backend == "mjx":
segment_capsule_geom_ids = np.array(
[
geom_id
for geom_id in range(mj_model.ngeom)
if "segment" in mj_model.geom(geom_id).name
and "capsule" in mj_model.geom(geom_id).name
and "capsule" in mj_model.geom(geom_id).name
]
)

Expand All @@ -44,8 +44,8 @@ def get_segment_contacts(state: MJXEnvState) -> jnp.ndarray:

def solve_contact(geom_id: int) -> jnp.ndarray:
return (
jnp.sum(contacts * jnp.any(geom_id == contact_data.geom, axis=-1))
> 0
jnp.sum(contacts * jnp.any(geom_id == contact_data.geom, axis=-1))
> 0
).astype(int)

return jax.vmap(solve_contact)(segment_capsule_geom_ids)
Expand Down Expand Up @@ -83,7 +83,7 @@ def get_segment_contacts(state: MJCEnvState) -> np.ndarray:


def get_base_brittle_star_observables(
mj_model: mujoco.MjModel, backend: str
mj_model: mujoco.MjModel, backend: str
) -> List[BaseObservable]:
if backend == "mjx":
observable_class = MJXObservable
Expand Down Expand Up @@ -111,7 +111,7 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
]
for sensor in joint_pos_sensors
]
Expand All @@ -131,7 +131,7 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
]
for sensor in joint_vel_sensors
]
Expand All @@ -151,7 +151,7 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
]
for sensor in joint_actuator_frc_sensors
]
Expand All @@ -171,7 +171,7 @@ def get_base_brittle_star_observables(
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0] : sensor.adr[0] + sensor.dim[0]
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
]
for sensor in actuator_frc_sensors
]
Expand All @@ -183,32 +183,32 @@ def get_base_brittle_star_observables(
mj_model.sensor(i)
for i in range(mj_model.nsensor)
if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEPOS
and "disk" in mj_model.sensor(i).name
and "disk" in mj_model.sensor(i).name
][0]
disk_position_observable = observable_class(
name="disk_position",
low=-bnp.inf * bnp.ones(3),
high=bnp.inf * bnp.ones(3),
retriever=lambda state: get_data(state).sensordata[
disk_framepos_sensor.adr[0] : disk_framepos_sensor.adr[0]
+ disk_framepos_sensor.dim[0]
],
disk_framepos_sensor.adr[0]: disk_framepos_sensor.adr[0]
+ disk_framepos_sensor.dim[0]
],
)
# disk rotation
disk_framequat_sensor = [
mj_model.sensor(i)
for i in range(mj_model.nsensor)
if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEQUAT
and "disk" in mj_model.sensor(i).name
and "disk" in mj_model.sensor(i).name
][0]
disk_rotation_observable = observable_class(
name="disk_rotation",
low=-bnp.pi * bnp.ones(3),
high=bnp.pi * bnp.ones(3),
retriever=lambda state: get_quat2euler_fn(backend=backend)(
get_data(state).sensordata[
disk_framequat_sensor.adr[0] : disk_framequat_sensor.adr[0]
+ disk_framequat_sensor.dim[0]
disk_framequat_sensor.adr[0]: disk_framequat_sensor.adr[0]
+ disk_framequat_sensor.dim[0]
]
),
)
Expand All @@ -218,35 +218,75 @@ def get_base_brittle_star_observables(
mj_model.sensor(i)
for i in range(mj_model.nsensor)
if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMELINVEL
and "disk" in mj_model.sensor(i).name
and "disk" in mj_model.sensor(i).name
][0]
disk_linvel_observable = observable_class(
name="disk_linear_velocity",
low=-bnp.inf * bnp.ones(3),
high=bnp.inf * bnp.ones(3),
retriever=lambda state: get_data(state).sensordata[
disk_framelinvel_sensor.adr[0] : disk_framelinvel_sensor.adr[0]
+ disk_framelinvel_sensor.dim[0]
],
disk_framelinvel_sensor.adr[0]: disk_framelinvel_sensor.adr[0]
+ disk_framelinvel_sensor.dim[0]
],
)

# disk angvel
disk_frameangvel_sensor = [
mj_model.sensor(i)
for i in range(mj_model.nsensor)
if mj_model.sensor(i).type[0] == mujoco.mjtSensor.mjSENS_FRAMEANGVEL
and "disk" in mj_model.sensor(i).name
and "disk" in mj_model.sensor(i).name
][0]
disk_angvel_observable = observable_class(
name="disk_angular_velocity",
low=-bnp.inf * bnp.ones(3),
high=bnp.inf * bnp.ones(3),
retriever=lambda state: get_data(state).sensordata[
disk_frameangvel_sensor.adr[0] : disk_frameangvel_sensor.adr[0]
+ disk_frameangvel_sensor.dim[0]
],
disk_frameangvel_sensor.adr[0]: disk_frameangvel_sensor.adr[0]
+ disk_frameangvel_sensor.dim[0]
],
)

# tendons
tendon_pos_sensors = [
sensor
for sensor in sensors
if sensor.type[0] == mujoco.mjtSensor.mjSENS_TENDONPOS
]
tendon_pos_observable = observable_class(
name="tendon_position",
low=-bnp.inf * bnp.ones(len(tendon_pos_sensors)),
high=bnp.inf * bnp.ones(len(tendon_pos_sensors)),
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
]
for sensor in tendon_pos_sensors
]
).flatten(),
)

tendon_vel_sensors = [
sensor
for sensor in sensors
if sensor.type[0] == mujoco.mjtSensor.mjSENS_TENDONVEL
]
tendon_vel_observable = observable_class(
name="tendon_velocity",
low=-bnp.inf * bnp.ones(len(tendon_vel_sensors)),
high=bnp.inf * bnp.ones(len(tendon_vel_sensors)),
retriever=lambda state: bnp.array(
[
get_data(state).sensordata[
sensor.adr[0]: sensor.adr[0] + sensor.dim[0]
]
for sensor in tendon_vel_sensors
]
).flatten(),
)

# contacts
num_contacts, get_segment_contacts_fn = get_num_contacts_and_segment_contacts_fn(
mj_model=mj_model, backend=backend
)
Expand All @@ -266,5 +306,7 @@ def get_base_brittle_star_observables(
disk_rotation_observable,
disk_linvel_observable,
disk_angvel_observable,
segment_contact_observable,
tendon_pos_observable,
tendon_vel_observable,
segment_contact_observable
]
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def create_env(
backend: str, render_mode: str
) -> BrittleStarDirectedLocomotionEnvironment:
morphology_spec = default_brittle_star_morphology_specification(
num_arms=5, num_segments_per_arm=5, use_p_control=False, use_torque_control=True
num_arms=5, num_segments_per_arm=5, use_p_control=False, use_torque_control=True, use_tendons=False
)
morphology = MJCFBrittleStarMorphology(morphology_spec)
arena_config = AquariumArenaConfiguration(attach_target=True)
Expand Down Expand Up @@ -100,7 +100,7 @@ def action_sample_fn(rng: chex.PRNGKey) -> Tuple[jnp.ndarray, chex.PRNGKey]:
action, action_rng = action_sample_fn(action_rng)
state = step_fn(state=state, action=action)
post_render(env.render(state=state), env.environment_configuration)
print(state.observations["actuator_force"])
print(state.observations["tendon_position"])
if state.terminated | state.truncated:
state = reset_fn(env_rng)
env.close()

0 comments on commit cfc95ce

Please sign in to comment.