Skip to content

Commit

Permalink
Add multi-plan optimization.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 610417779
Change-Id: I0ee7274760c4a28b56c8f2f86d36beb1fd45143b
  • Loading branch information
erez-tom authored and copybara-github committed Feb 26, 2024
1 parent cc6c1f3 commit 8b9c7b2
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 106 deletions.
143 changes: 91 additions & 52 deletions python/mujoco_mpc/mjx/predictive_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import mujoco
from mujoco import mjx
from mujoco.mjx._src import dataclasses
import numpy as np

CostFn = Callable[[mjx.Model, mjx.Data], jax.Array]

Expand All @@ -48,6 +49,7 @@ class Planner(dataclasses.PyTreeNode):
interp: str


@jax.jit
def _rollout(p: Planner, d: mjx.Data, policy: jax.Array) -> jax.Array:
"""Expand the policy into actions and roll out dynamics and cost."""
actions = get_actions(p, policy)
Expand All @@ -59,7 +61,6 @@ def step(d, action):
return d, cost

_, costs = jax.lax.scan(step, d, actions, length=p.horizon)

return jnp.sum(costs)


Expand All @@ -77,103 +78,141 @@ def get_actions(p: Planner, policy: jax.Array) -> jax.Array:
raise ValueError(f'unimplemented interp: {p.interp}')

return actions
v_get_actions = jax.vmap(get_actions, in_axes=[None, 0])


def improve_policy(
p: Planner, d: mjx.Data, policy: jax.Array, rng: jax.Array
) -> Tuple[jax.Array, jax.Array]:
"""Improves policy."""
limit = p.model.actuator_ctrlrange

# create noisy policies, with nominal policy at index 0
noise = jax.random.normal(rng, (p.nsample, p.nspline, p.model.nu))
noise = noise * p.noise_scale * (limit[:, 1] - limit[:, 0])
noise = (
jax.random.normal(rng, (p.nsample, p.nspline, p.model.nu)) * p.noise_scale
)
policies = jnp.concatenate((policy[None], policy + noise))
# clamp actions to ctrlrange
limit = p.model.actuator_ctrlrange
policies = jnp.clip(policies, limit[:, 0], limit[:, 1])

# perform nsample + 1 parallel rollouts
costs = jax.vmap(_rollout, in_axes=(None, None, 0))(p, d, policies)
costs = jnp.nan_to_num(costs, nan=jnp.inf)
best_id = jnp.argmin(costs)
winners = jnp.argmin(costs)

return policies[best_id], costs[best_id]
return policies[winners], winners


def resample(p: Planner, policy: jax.Array, steps_per_plan: int) -> jax.Array:
"""Resample policy to new advanced time."""
if p.horizon % p.nspline != 0:
raise ValueError("horizon must be divisible by nspline")
splinesteps = p.horizon // p.nspline
if splinesteps % steps_per_plan != 0:
raise ValueError(
f'splinesteps ({splinesteps}) must be divisible by steps_per_plan'
f' ({steps_per_plan})'
)
roll = splinesteps // steps_per_plan
policy = jnp.roll(policy, -roll, axis=0)
policy = policy.at[-roll:].set(policy[-roll - 1])
if p.interp == 'zero':
return policy # assuming steps_per_plan < splinesteps
elif p.interp == 'linear':
actions = v_get_actions(p, policy)
roll = steps_per_plan
actions = jnp.roll(actions, -roll, axis=1)
actions = actions.at[:, -roll:, :].set(actions[:, [-1], :])
idx = jnp.floor(jnp.linspace(0, p.horizon, p.nspline)).astype(int)
return actions[:, idx, :]

return policy


def set_state(d_out, d_in):
return d_out.replace(
time=d_in.time, qpos=d_in.qpos, qvel=d_in.qvel, act=d_in.act,
ctrl=d_in.ctrl)

def set_state(d, state):
return d.replace(
time=state.time, qpos=state.qpos, qvel=state.qvel, act=state.act,
ctrl=state.ctrl)
set_states = jax.vmap(set_state, in_axes=[0, 0])

def receding_horizon_optimization(
p: Planner,
plan_model_cpu,
sim_model_cpu,
nsteps,
nplans,
steps_per_plan,
frame_skip,
verbose=False,
):
d = mujoco.MjData(plan_model_cpu)
d = mjx.put_data(plan_model_cpu, d)
"""Receding horizon optimization, all nplans start from same keyframe."""
plan_data = mujoco.MjData(plan_model_cpu)
plan_data = mjx.put_data(plan_model_cpu, plan_data)
m = mjx.put_model(plan_model_cpu)
p = p.replace(model=m)
jitted_cost = jax.jit(p.cost)

policy = jnp.zeros((p.nspline, m.nu))
rng = jax.random.key(0)
improve_fn = (
jax.jit(improve_policy)
.lower(p, d, policy, rng)
.compile()
)
step_fn = jax.jit(mjx.step).lower(m, d).compile()

trajectory, costs = [], []
plan_time = 0
sim_data = mujoco.MjData(sim_model_cpu)
mujoco.mj_resetDataKeyframe(sim_model_cpu, sim_data, 0)
# without kinematics, the first cost is off:
mujoco.mj_forward(sim_model_cpu, sim_data)
sim_data = mjx.put_data(sim_model_cpu, sim_data)
sim_model = mjx.put_model(sim_model_cpu)
actions = get_actions(p, policy)

policy = jnp.tile(sim_data.ctrl, (nplans, p.nspline, 1))
multi_actions = v_get_actions(p, policy)
first_actions = multi_actions[:, 0, :] # just the first actions
# duplicate data for each plan
def set_action(data, action):
return data.replace(ctrl=action)

duplicate_data = jax.vmap(set_action, in_axes=[None, 0], out_axes=0)
sim_datas = duplicate_data(sim_data, first_actions)
plan_datas = duplicate_data(plan_data, first_actions)

def step_and_cost(model, data, action):
data = data.replace(ctrl=action)
cost = p.cost(model, data)
data = mjx.step(model, data)
return data, cost

multi_step = (
jax.jit(
jax.vmap(step_and_cost, in_axes=[None, 0, 0])
)
.lower(sim_model, sim_datas, first_actions)
.compile()
)

rng = jax.random.key(0)
keys = jax.random.split(rng, nplans)
improve_fn = (
jax.jit(
jax.vmap(improve_policy, in_axes=(None, 0, 0, 0))
)
.lower(p, plan_datas, policy, keys)
.compile()
)
trajectories = np.zeros(
(nplans, nsteps // frame_skip, sim_data.qpos.shape[0])
)
costs = np.zeros((nplans, nsteps))
plan_time = 0
multi_actions = v_get_actions(p, policy)

for step in range(nsteps):
if step % steps_per_plan == 0:
if verbose:
print('re-planning')
# resample policy to new advanced time
print('re-planning')
policy = resample(p, policy, steps_per_plan)
beg = time.perf_counter()
d = set_state(d, sim_data)
policy, _ = improve_fn(p, d, policy, jax.random.key(step))
plan_time += time.perf_counter() - beg
actions = get_actions(p, policy)

sim_data = sim_data.replace(ctrl=actions[0])
cost = jitted_cost(sim_model, sim_data)
sim_data = step_fn(sim_model, sim_data)
costs.append(cost)
print(f'step: {step}')
print(f'cost: {cost}')
plan_datas = set_states(plan_datas, sim_datas)
policy, winners = improve_fn(
p, plan_datas, policy, jax.random.split(jax.random.key(step), nplans)
)
this_plan_time = time.perf_counter() - beg
plan_time += this_plan_time
if verbose:
print(f'winners: {winners}')
multi_actions = v_get_actions(p, policy)

step_index = step % steps_per_plan
sim_datas, cost = multi_step(
sim_model, sim_datas, multi_actions[:, step_index, :]
)
costs[:, step] = jax.device_get(cost)
if step % frame_skip == 0:
trajectory.append(jax.device_get(sim_data.qpos))
trajectories[:, step // frame_skip, :] = jax.device_get(sim_datas.qpos)
if verbose:
print(f'step: {step}')
print(f'avg cost: {np.mean(costs[:, step])}')

return trajectory, costs, plan_time
return trajectories, costs, plan_time
42 changes: 25 additions & 17 deletions python/mujoco_mpc/mjx/tasks/bimanual/handover.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,54 @@
# limitations under the License.
# ==============================================================================

from typing import Callable

from etils import epath
# internal import
import jax
from jax import numpy as jp
from jax import numpy as jnp
import mujoco
from mujoco import mjx
from mujoco_mpc.mjx import predictive_sampling


def bring_to_target(m: mjx.Model, d: mjx.Data) -> jax.Array:
"""Returns cost for bimanual bring to target task."""
# reach
left_gripper = d.site_xpos[3]
right_gripper = d.site_xpos[6]
box = d.xpos[m.nbody - 1]
left_gripper_site_index = 3
right_gripper_site_index = 6
box_body_index = m.nbody - 1
left_gripper_pos = d.site_xpos[..., left_gripper_site_index, :]
right_gripper_pos = d.site_xpos[..., right_gripper_site_index, :]
box_pos = d.xpos[..., box_body_index, :]

reach_l = left_gripper - box
reach_r = right_gripper - box
reach_l = left_gripper_pos - box_pos
reach_r = right_gripper_pos - box_pos

# bring
target = jp.array([-0.4, -0.2, 0.3])
bring = box - target
target = jnp.array([-0.4, -0.2, 0.3])
bring = box_pos - target

residuals = [reach_l, reach_r, bring]
weights = [0.1, 0.1, 1]
norm_p = [0.005, 0.005, 0.003]

# NormType::kL2: y = sqrt(x*x' + p^2) - p
terms = []
for r, w, p in zip(residuals, weights, norm_p):
terms.append(w * (jp.sqrt(jp.dot(r, r) + p**2) - p))
for t, w, p in zip(residuals, weights, norm_p):
terms.append(w * jnp.sqrt(jnp.sum(t**2, axis=-1) + p**2) - p)
costs = jnp.sum(jnp.array(terms), axis=-1)

return jp.sum(jp.array(terms))
return costs


def get_models_and_cost_fn() -> (
tuple[mujoco.MjModel, mujoco.MjModel, predictive_sampling.CostFn]
):
def get_models_and_cost_fn() -> tuple[
mujoco.MjModel,
mujoco.MjModel,
Callable[[mjx.Model, mjx.Data], jax.Array],
]:
"""Returns a tuple of the model and the cost function."""
path = epath.Path(
'build/mujoco_menagerie/aloha/'
'build'
/ 'mujoco_menagerie/aloha/'
)
model_file_name = 'mjx_scene.xml'
xml = (path / model_file_name).read_text()
Expand Down
Loading

0 comments on commit 8b9c7b2

Please sign in to comment.