Skip to content

Commit

Permalink
Merge branch 'main' of github.com:Co-Evolve/brt
Browse files Browse the repository at this point in the history
  • Loading branch information
driesmarzougui committed Dec 3, 2024
2 parents 3dd11a0 + f724cc2 commit 3c96ce9
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 11 deletions.
5 changes: 4 additions & 1 deletion biorobot/brittle_star/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ brittle star environment returns as observations (further discussed below).
- Joint positions (two per segment, in-plane and out-of-plane, in radians)
- Joint velocities (two per segment, in-plane and out-of-plane, in radians / second)
- Joint actuator force (i.e. the total actuator force acting on a joint, in Newton meters) (two per segment)
- Actuator force (the scalar actuator force, in Newtons) (four per segment in case of tendon transmission, otherwise two)
- Actuator force (the scalar actuator force, in Newtons) (four per segment in case of tendon transmission, otherwise
two)
- Tendon position (in case tendon transmission is used, four per segment, in meters)
- Tendon velocity (in case tendon transmission is used, four per segment, in meters / second)
- Central disk's position (w.r.t. world frame)
Expand Down Expand Up @@ -222,6 +223,8 @@ difficulty:
be positive if the measured light income has decreased in the current timestep, and negative if the light income
has increased. The light income at a given timestep is calculated as a weighted average over all body geoms (
weight scales with the surface area of the geom).
- The light escape environment configuration accepts an additional argument `random_initial_rotation`. This sets a
random z-axis rotation of the brittle star upon environment resets.
- Requires an aquarium with `sand_ground_color=True`.
- Additional observations:
- The amount of light each segment takes in.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def reset(
rng, target_pos_rng, qpos_rng, qvel_rng = jax.random.split(key=rng, num=4)

target_body_id = mj_model.body("target").id
disk_body_id = mj_model.body("BrittleStarMorphology/central_disk").id

# Set random target position
target_pos = self._get_target_position(
Expand Down
23 changes: 18 additions & 5 deletions biorobot/brittle_star/environment/light_escape/mjc_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gymnasium.core import RenderFrame
from moojoco.environment.mjc_env import MJCEnv, MJCEnvState, MJCObservable
from moojoco.environment.renderer import MujocoRenderer
from transforms3d.euler import euler2quat

from biorobot.brittle_star.environment.light_escape.shared import (
BrittleStarLightEscapeEnvironmentBase,
Expand Down Expand Up @@ -263,11 +264,23 @@ def reset(self, rng: np.random.RandomState, *args, **kwargs) -> MJCEnvState:
mj_model, mj_data = self._prepare_reset()

# Set morphology position
mj_model.body("BrittleStarMorphology/central_disk").pos = [
self._get_x_start_position(mj_model=mj_model),
0.0,
0.11,
]
morphology_qpos_adr = mj_model.joint(
"BrittleStarMorphology/freejoint/"
).qposadr[0]
morphology_pos = np.array(
[
self._get_x_start_position(mj_model=mj_model),
0.0,
0.11,
]
)

mj_data.qpos[morphology_qpos_adr : morphology_qpos_adr + 3] = morphology_pos

if self.environment_configuration.random_initial_rotation:
z_axis_rotation = rng.uniform(-np.pi, np.pi)
quat = euler2quat(0, 0, z_axis_rotation, axes="sxyz")
mj_data.qpos[morphology_qpos_adr + 3 : morphology_qpos_adr + 7] = quat

# Add noise to initial qpos and qvel of segment joints
joint_qpos_adrs = self._get_segment_joints_qpos_adrs(mj_model=mj_model)
Expand Down
17 changes: 14 additions & 3 deletions biorobot/brittle_star/environment/light_escape/mjx_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from gymnasium.core import RenderFrame
from jax import numpy as jnp
from jax.scipy.spatial.transform import Rotation
from moojoco.environment.mjx_env import MJXEnv, MJXEnvState, MJXObservable
from moojoco.environment.renderer import MujocoRenderer
from mujoco import mjx
Expand Down Expand Up @@ -267,9 +268,7 @@ def reset(self, rng: chex.PRNGKey, *args, **kwargs) -> MJXEnvState:
).qposadr[0]
morphology_pos = jnp.array(
[
BrittleStarLightEscapeEnvironmentBase._get_x_start_position(
mj_model=mj_model
),
self._get_x_start_position(mj_model=mj_model),
0.0,
0.11,
]
Expand All @@ -278,6 +277,18 @@ def reset(self, rng: chex.PRNGKey, *args, **kwargs) -> MJXEnvState:
morphology_pos
)

if self.environment_configuration.random_initial_rotation:
rng, rotation_rng = jax.random.split(key=rng, num=2)
z_axis_rotation = jax.random.uniform(
key=rotation_rng, shape=(), minval=-jnp.pi, maxval=jnp.pi
)
quat = Rotation.from_euler(
seq="xyz", angles=jnp.array([0, 0, z_axis_rotation]), degrees=False
).as_quat()
qpos = qpos.at[morphology_qpos_adr + 3 : morphology_qpos_adr + 7].set(
jnp.roll(quat, shift=1)
)

# Add noise to initial qpos and qvel of segment joints
joint_qpos_adrs = self._get_segment_joints_qpos_adrs(mj_model=mj_model)
joint_qvel_adrs = self._get_segment_joints_qvel_adrs(mj_model=mj_model)
Expand Down
2 changes: 2 additions & 0 deletions biorobot/brittle_star/environment/light_escape/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
self,
light_perlin_noise_scale: int = 0,
joint_randomization_noise_scale: float = 0.0,
random_initial_rotation: bool = False,
color_contacts: bool = False,
*args,
**kwargs,
Expand All @@ -36,6 +37,7 @@ def __init__(
**kwargs,
)
self.light_perlin_noise_scale = int(light_perlin_noise_scale)
self.random_initial_rotation = random_initial_rotation


class BrittleStarLightEscapeEnvironmentBase(BrittleStarEnvironmentBase):
Expand Down
4 changes: 3 additions & 1 deletion biorobot/brittle_star/usage_examples/light_escape_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def create_env(backend: str, render_mode: str) -> BrittleStarLightEscapeEnvironm
time_scale=1,
camera_ids=[0, 1],
color_contacts=True,
random_initial_rotation=True,
)
env = BrittleStarLightEscapeEnvironment.from_morphology_and_arena(
morphology=morphology, arena=arena, configuration=env_config, backend=backend
Expand Down Expand Up @@ -71,10 +72,11 @@ def action_sample_fn(rng: chex.PRNGKey) -> Tuple[jnp.ndarray, chex.PRNGKey]:
state = reset_fn(env_rng)
while True:
action, action_rng = action_sample_fn(action_rng)
state = step_fn(state=state, action=action)
state = step_fn(state=state, action=action * 0)
print(state.observations["joint_position"])
print(state.observations["joint_velocity"])
print(state.observations["joint_actuator_force"])
print()
env.render(state=state)

env.close()

0 comments on commit 3c96ce9

Please sign in to comment.