diff --git a/biorobot/brittle_star/README.md b/biorobot/brittle_star/README.md index 7b2dd6d..10eaffa 100644 --- a/biorobot/brittle_star/README.md +++ b/biorobot/brittle_star/README.md @@ -177,7 +177,8 @@ brittle star environment returns as observations (further discussed below). - Joint positions (two per segment, in-plane and out-of-plane, in radians) - Joint velocities (two per segment, in-plane and out-of-plane, in radians / second) - Joint actuator force (i.e. the total actuator force acting on a joint, in Newton meters) (two per segment) - - Actuator force (the scalar actuator force, in Newtons) (four per segment in case of tendon transmission, otherwise two) + - Actuator force (the scalar actuator force, in Newtons) (four per segment in case of tendon transmission, otherwise + two) - Tendon position (in case tendon transmission is used, four per segment, in meters) - Tendon velocity (in case tendon transmission is used, four per segment, in meters / second) - Central disk's position (w.r.t. world frame) @@ -222,6 +223,8 @@ difficulty: be positive if the measured light income has decreased in the current timestep, and negative if the light income has increased. The light income at a given timestep is calculated as a weighted average over all body geoms ( weight scales with the surface area of the geom). + - The light escape environment configuration accepts an additional argument `random_initial_rotation`. This sets a + random z-axis rotation of the brittle star upon environment resets. - Requires an aquarium with `sand_ground_color=True`. - Additional observations: - The amount of light each segment takes in. diff --git a/biorobot/brittle_star/environment/directed_locomotion/mjx_env.py b/biorobot/brittle_star/environment/directed_locomotion/mjx_env.py index ab56e94..5b8747d 100644 --- a/biorobot/brittle_star/environment/directed_locomotion/mjx_env.py +++ b/biorobot/brittle_star/environment/directed_locomotion/mjx_env.py @@ -157,7 +157,6 @@ def reset( rng, target_pos_rng, qpos_rng, qvel_rng = jax.random.split(key=rng, num=4) target_body_id = mj_model.body("target").id - disk_body_id = mj_model.body("BrittleStarMorphology/central_disk").id # Set random target position target_pos = self._get_target_position( diff --git a/biorobot/brittle_star/environment/light_escape/mjc_env.py b/biorobot/brittle_star/environment/light_escape/mjc_env.py index 99fba01..c40dbdd 100644 --- a/biorobot/brittle_star/environment/light_escape/mjc_env.py +++ b/biorobot/brittle_star/environment/light_escape/mjc_env.py @@ -7,6 +7,7 @@ from gymnasium.core import RenderFrame from moojoco.environment.mjc_env import MJCEnv, MJCEnvState, MJCObservable from moojoco.environment.renderer import MujocoRenderer +from transforms3d.euler import euler2quat from biorobot.brittle_star.environment.light_escape.shared import ( BrittleStarLightEscapeEnvironmentBase, @@ -263,11 +264,23 @@ def reset(self, rng: np.random.RandomState, *args, **kwargs) -> MJCEnvState: mj_model, mj_data = self._prepare_reset() # Set morphology position - mj_model.body("BrittleStarMorphology/central_disk").pos = [ - self._get_x_start_position(mj_model=mj_model), - 0.0, - 0.11, - ] + morphology_qpos_adr = mj_model.joint( + "BrittleStarMorphology/freejoint/" + ).qposadr[0] + morphology_pos = np.array( + [ + self._get_x_start_position(mj_model=mj_model), + 0.0, + 0.11, + ] + ) + + mj_data.qpos[morphology_qpos_adr : morphology_qpos_adr + 3] = morphology_pos + + if self.environment_configuration.random_initial_rotation: + z_axis_rotation = rng.uniform(-np.pi, np.pi) + quat = euler2quat(0, 0, z_axis_rotation, axes="sxyz") + mj_data.qpos[morphology_qpos_adr + 3 : morphology_qpos_adr + 7] = quat # Add noise to initial qpos and qvel of segment joints joint_qpos_adrs = self._get_segment_joints_qpos_adrs(mj_model=mj_model) diff --git a/biorobot/brittle_star/environment/light_escape/mjx_env.py b/biorobot/brittle_star/environment/light_escape/mjx_env.py index 5338804..489f933 100644 --- a/biorobot/brittle_star/environment/light_escape/mjx_env.py +++ b/biorobot/brittle_star/environment/light_escape/mjx_env.py @@ -8,6 +8,7 @@ import numpy as np from gymnasium.core import RenderFrame from jax import numpy as jnp +from jax.scipy.spatial.transform import Rotation from moojoco.environment.mjx_env import MJXEnv, MJXEnvState, MJXObservable from moojoco.environment.renderer import MujocoRenderer from mujoco import mjx @@ -267,9 +268,7 @@ def reset(self, rng: chex.PRNGKey, *args, **kwargs) -> MJXEnvState: ).qposadr[0] morphology_pos = jnp.array( [ - BrittleStarLightEscapeEnvironmentBase._get_x_start_position( - mj_model=mj_model - ), + self._get_x_start_position(mj_model=mj_model), 0.0, 0.11, ] @@ -278,6 +277,18 @@ def reset(self, rng: chex.PRNGKey, *args, **kwargs) -> MJXEnvState: morphology_pos ) + if self.environment_configuration.random_initial_rotation: + rng, rotation_rng = jax.random.split(key=rng, num=2) + z_axis_rotation = jax.random.uniform( + key=rotation_rng, shape=(), minval=-jnp.pi, maxval=jnp.pi + ) + quat = Rotation.from_euler( + seq="xyz", angles=jnp.array([0, 0, z_axis_rotation]), degrees=False + ).as_quat() + qpos = qpos.at[morphology_qpos_adr + 3 : morphology_qpos_adr + 7].set( + jnp.roll(quat, shift=1) + ) + # Add noise to initial qpos and qvel of segment joints joint_qpos_adrs = self._get_segment_joints_qpos_adrs(mj_model=mj_model) joint_qvel_adrs = self._get_segment_joints_qvel_adrs(mj_model=mj_model) diff --git a/biorobot/brittle_star/environment/light_escape/shared.py b/biorobot/brittle_star/environment/light_escape/shared.py index 5a15cc8..d1353c0 100644 --- a/biorobot/brittle_star/environment/light_escape/shared.py +++ b/biorobot/brittle_star/environment/light_escape/shared.py @@ -21,6 +21,7 @@ def __init__( self, light_perlin_noise_scale: int = 0, joint_randomization_noise_scale: float = 0.0, + random_initial_rotation: bool = False, color_contacts: bool = False, *args, **kwargs, @@ -36,6 +37,7 @@ def __init__( **kwargs, ) self.light_perlin_noise_scale = int(light_perlin_noise_scale) + self.random_initial_rotation = random_initial_rotation class BrittleStarLightEscapeEnvironmentBase(BrittleStarEnvironmentBase): diff --git a/biorobot/brittle_star/usage_examples/light_escape_single.py b/biorobot/brittle_star/usage_examples/light_escape_single.py index 6cf62b2..4f50fde 100644 --- a/biorobot/brittle_star/usage_examples/light_escape_single.py +++ b/biorobot/brittle_star/usage_examples/light_escape_single.py @@ -37,6 +37,7 @@ def create_env(backend: str, render_mode: str) -> BrittleStarLightEscapeEnvironm time_scale=1, camera_ids=[0, 1], color_contacts=True, + random_initial_rotation=True, ) env = BrittleStarLightEscapeEnvironment.from_morphology_and_arena( morphology=morphology, arena=arena, configuration=env_config, backend=backend @@ -71,10 +72,11 @@ def action_sample_fn(rng: chex.PRNGKey) -> Tuple[jnp.ndarray, chex.PRNGKey]: state = reset_fn(env_rng) while True: action, action_rng = action_sample_fn(action_rng) - state = step_fn(state=state, action=action) + state = step_fn(state=state, action=action * 0) print(state.observations["joint_position"]) print(state.observations["joint_velocity"]) print(state.observations["joint_actuator_force"]) print() env.render(state=state) + env.close()