Skip to content

An implementation of Short Horizon Actor Critic writen in Jax. Core algorithm written in the style of Brax, with several bits taken from Xu's original paper.

Notifications You must be signed in to change notification settings

Andrew-Luo1/jax_shac

Repository files navigation

jax-shac

  • An implementation of Short Horizon Actor Critic (Xu; 2022) writen in Jax
  • Simulation using the Mujoco MJX simulator

Results

Inverted Pendulum

inverted_pend-ezgif com-video-to-gif-converter image

Run Time: 1 min jit, 2 min training

Known Issues: For some random seeds, you get drift in the cart position.

1 DOF Hopper

framed_hopper-ezgif com-video-to-gif-converter

image

Run Time: 1 min jit, 2 min training

Known Issues: As seen in the rewards figure, training can be unstable.

Warning: MJX + Exploding Gradients

  • Having great difficulty applying SHAC to get Anymal to walk with default 32-bit precision. (See Mujoco for an example with 64-bit precision)
  • Hypothesis: it's because quadruped gait is very contact-rich, leading to uninformative gradients.

anymal_vid-ezgif com-video-to-gif-converter

32-step rollout. Ground flashes red when step jacobian is greater than 10e2.

Setup

  • pip install -r requirements.txt
  • Add the parent folder of this repository to your PYTHONPATH environment variable.

About

An implementation of Short Horizon Actor Critic writen in Jax. Core algorithm written in the style of Brax, with several bits taken from Xu's original paper.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published