From 8b9c7b24e977454c335a77fac0837feaf9d7d727 Mon Sep 17 00:00:00 2001 From: Tom Erez Date: Mon, 26 Feb 2024 08:35:04 -0800 Subject: [PATCH] Add multi-plan optimization. PiperOrigin-RevId: 610417779 Change-Id: I0ee7274760c4a28b56c8f2f86d36beb1fd45143b --- python/mujoco_mpc/mjx/predictive_sampling.py | 143 +++++++++++------- .../mujoco_mpc/mjx/tasks/bimanual/handover.py | 42 ++--- python/mujoco_mpc/mjx/visualize.py | 111 +++++++++----- 3 files changed, 190 insertions(+), 106 deletions(-) diff --git a/python/mujoco_mpc/mjx/predictive_sampling.py b/python/mujoco_mpc/mjx/predictive_sampling.py index 764b677e2..419f85c93 100644 --- a/python/mujoco_mpc/mjx/predictive_sampling.py +++ b/python/mujoco_mpc/mjx/predictive_sampling.py @@ -23,6 +23,7 @@ import mujoco from mujoco import mjx from mujoco.mjx._src import dataclasses +import numpy as np CostFn = Callable[[mjx.Model, mjx.Data], jax.Array] @@ -48,6 +49,7 @@ class Planner(dataclasses.PyTreeNode): interp: str +@jax.jit def _rollout(p: Planner, d: mjx.Data, policy: jax.Array) -> jax.Array: """Expand the policy into actions and roll out dynamics and cost.""" actions = get_actions(p, policy) @@ -59,7 +61,6 @@ def step(d, action): return d, cost _, costs = jax.lax.scan(step, d, actions, length=p.horizon) - return jnp.sum(costs) @@ -77,103 +78,141 @@ def get_actions(p: Planner, policy: jax.Array) -> jax.Array: raise ValueError(f'unimplemented interp: {p.interp}') return actions +v_get_actions = jax.vmap(get_actions, in_axes=[None, 0]) def improve_policy( p: Planner, d: mjx.Data, policy: jax.Array, rng: jax.Array ) -> Tuple[jax.Array, jax.Array]: """Improves policy.""" - limit = p.model.actuator_ctrlrange # create noisy policies, with nominal policy at index 0 - noise = jax.random.normal(rng, (p.nsample, p.nspline, p.model.nu)) - noise = noise * p.noise_scale * (limit[:, 1] - limit[:, 0]) + noise = ( + jax.random.normal(rng, (p.nsample, p.nspline, p.model.nu)) * p.noise_scale + ) policies = jnp.concatenate((policy[None], policy + noise)) # clamp actions to ctrlrange + limit = p.model.actuator_ctrlrange policies = jnp.clip(policies, limit[:, 0], limit[:, 1]) - # perform nsample + 1 parallel rollouts costs = jax.vmap(_rollout, in_axes=(None, None, 0))(p, d, policies) costs = jnp.nan_to_num(costs, nan=jnp.inf) - best_id = jnp.argmin(costs) + winners = jnp.argmin(costs) - return policies[best_id], costs[best_id] + return policies[winners], winners def resample(p: Planner, policy: jax.Array, steps_per_plan: int) -> jax.Array: """Resample policy to new advanced time.""" - if p.horizon % p.nspline != 0: - raise ValueError("horizon must be divisible by nspline") - splinesteps = p.horizon // p.nspline - if splinesteps % steps_per_plan != 0: - raise ValueError( - f'splinesteps ({splinesteps}) must be divisible by steps_per_plan' - f' ({steps_per_plan})' - ) - roll = splinesteps // steps_per_plan - policy = jnp.roll(policy, -roll, axis=0) - policy = policy.at[-roll:].set(policy[-roll - 1]) + if p.interp == 'zero': + return policy # assuming steps_per_plan < splinesteps + elif p.interp == 'linear': + actions = v_get_actions(p, policy) + roll = steps_per_plan + actions = jnp.roll(actions, -roll, axis=1) + actions = actions.at[:, -roll:, :].set(actions[:, [-1], :]) + idx = jnp.floor(jnp.linspace(0, p.horizon, p.nspline)).astype(int) + return actions[:, idx, :] return policy -def set_state(d_out, d_in): - return d_out.replace( - time=d_in.time, qpos=d_in.qpos, qvel=d_in.qvel, act=d_in.act, - ctrl=d_in.ctrl) - +def set_state(d, state): + return d.replace( + time=state.time, qpos=state.qpos, qvel=state.qvel, act=state.act, + ctrl=state.ctrl) +set_states = jax.vmap(set_state, in_axes=[0, 0]) def receding_horizon_optimization( p: Planner, plan_model_cpu, sim_model_cpu, nsteps, + nplans, steps_per_plan, frame_skip, + verbose=False, ): - d = mujoco.MjData(plan_model_cpu) - d = mjx.put_data(plan_model_cpu, d) + """Receding horizon optimization, all nplans start from same keyframe.""" + plan_data = mujoco.MjData(plan_model_cpu) + plan_data = mjx.put_data(plan_model_cpu, plan_data) m = mjx.put_model(plan_model_cpu) p = p.replace(model=m) - jitted_cost = jax.jit(p.cost) - - policy = jnp.zeros((p.nspline, m.nu)) - rng = jax.random.key(0) - improve_fn = ( - jax.jit(improve_policy) - .lower(p, d, policy, rng) - .compile() - ) - step_fn = jax.jit(mjx.step).lower(m, d).compile() - trajectory, costs = [], [] - plan_time = 0 sim_data = mujoco.MjData(sim_model_cpu) mujoco.mj_resetDataKeyframe(sim_model_cpu, sim_data, 0) # without kinematics, the first cost is off: mujoco.mj_forward(sim_model_cpu, sim_data) sim_data = mjx.put_data(sim_model_cpu, sim_data) sim_model = mjx.put_model(sim_model_cpu) - actions = get_actions(p, policy) + + policy = jnp.tile(sim_data.ctrl, (nplans, p.nspline, 1)) + multi_actions = v_get_actions(p, policy) + first_actions = multi_actions[:, 0, :] # just the first actions + # duplicate data for each plan + def set_action(data, action): + return data.replace(ctrl=action) + + duplicate_data = jax.vmap(set_action, in_axes=[None, 0], out_axes=0) + sim_datas = duplicate_data(sim_data, first_actions) + plan_datas = duplicate_data(plan_data, first_actions) + + def step_and_cost(model, data, action): + data = data.replace(ctrl=action) + cost = p.cost(model, data) + data = mjx.step(model, data) + return data, cost + + multi_step = ( + jax.jit( + jax.vmap(step_and_cost, in_axes=[None, 0, 0]) + ) + .lower(sim_model, sim_datas, first_actions) + .compile() + ) + + rng = jax.random.key(0) + keys = jax.random.split(rng, nplans) + improve_fn = ( + jax.jit( + jax.vmap(improve_policy, in_axes=(None, 0, 0, 0)) + ) + .lower(p, plan_datas, policy, keys) + .compile() + ) + trajectories = np.zeros( + (nplans, nsteps // frame_skip, sim_data.qpos.shape[0]) + ) + costs = np.zeros((nplans, nsteps)) + plan_time = 0 + multi_actions = v_get_actions(p, policy) for step in range(nsteps): if step % steps_per_plan == 0: + if verbose: + print('re-planning') # resample policy to new advanced time - print('re-planning') policy = resample(p, policy, steps_per_plan) beg = time.perf_counter() - d = set_state(d, sim_data) - policy, _ = improve_fn(p, d, policy, jax.random.key(step)) - plan_time += time.perf_counter() - beg - actions = get_actions(p, policy) - - sim_data = sim_data.replace(ctrl=actions[0]) - cost = jitted_cost(sim_model, sim_data) - sim_data = step_fn(sim_model, sim_data) - costs.append(cost) - print(f'step: {step}') - print(f'cost: {cost}') + plan_datas = set_states(plan_datas, sim_datas) + policy, winners = improve_fn( + p, plan_datas, policy, jax.random.split(jax.random.key(step), nplans) + ) + this_plan_time = time.perf_counter() - beg + plan_time += this_plan_time + if verbose: + print(f'winners: {winners}') + multi_actions = v_get_actions(p, policy) + + step_index = step % steps_per_plan + sim_datas, cost = multi_step( + sim_model, sim_datas, multi_actions[:, step_index, :] + ) + costs[:, step] = jax.device_get(cost) if step % frame_skip == 0: - trajectory.append(jax.device_get(sim_data.qpos)) + trajectories[:, step // frame_skip, :] = jax.device_get(sim_datas.qpos) + if verbose: + print(f'step: {step}') + print(f'avg cost: {np.mean(costs[:, step])}') - return trajectory, costs, plan_time + return trajectories, costs, plan_time diff --git a/python/mujoco_mpc/mjx/tasks/bimanual/handover.py b/python/mujoco_mpc/mjx/tasks/bimanual/handover.py index ccbcfd24e..cbf8b835f 100644 --- a/python/mujoco_mpc/mjx/tasks/bimanual/handover.py +++ b/python/mujoco_mpc/mjx/tasks/bimanual/handover.py @@ -13,27 +13,31 @@ # limitations under the License. # ============================================================================== +from typing import Callable + from etils import epath +# internal import import jax -from jax import numpy as jp +from jax import numpy as jnp import mujoco from mujoco import mjx -from mujoco_mpc.mjx import predictive_sampling def bring_to_target(m: mjx.Model, d: mjx.Data) -> jax.Array: """Returns cost for bimanual bring to target task.""" # reach - left_gripper = d.site_xpos[3] - right_gripper = d.site_xpos[6] - box = d.xpos[m.nbody - 1] + left_gripper_site_index = 3 + right_gripper_site_index = 6 + box_body_index = m.nbody - 1 + left_gripper_pos = d.site_xpos[..., left_gripper_site_index, :] + right_gripper_pos = d.site_xpos[..., right_gripper_site_index, :] + box_pos = d.xpos[..., box_body_index, :] - reach_l = left_gripper - box - reach_r = right_gripper - box + reach_l = left_gripper_pos - box_pos + reach_r = right_gripper_pos - box_pos - # bring - target = jp.array([-0.4, -0.2, 0.3]) - bring = box - target + target = jnp.array([-0.4, -0.2, 0.3]) + bring = box_pos - target residuals = [reach_l, reach_r, bring] weights = [0.1, 0.1, 1] @@ -41,18 +45,22 @@ def bring_to_target(m: mjx.Model, d: mjx.Data) -> jax.Array: # NormType::kL2: y = sqrt(x*x' + p^2) - p terms = [] - for r, w, p in zip(residuals, weights, norm_p): - terms.append(w * (jp.sqrt(jp.dot(r, r) + p**2) - p)) + for t, w, p in zip(residuals, weights, norm_p): + terms.append(w * jnp.sqrt(jnp.sum(t**2, axis=-1) + p**2) - p) + costs = jnp.sum(jnp.array(terms), axis=-1) - return jp.sum(jp.array(terms)) + return costs -def get_models_and_cost_fn() -> ( - tuple[mujoco.MjModel, mujoco.MjModel, predictive_sampling.CostFn] -): +def get_models_and_cost_fn() -> tuple[ + mujoco.MjModel, + mujoco.MjModel, + Callable[[mjx.Model, mjx.Data], jax.Array], +]: """Returns a tuple of the model and the cost function.""" path = epath.Path( - 'build/mujoco_menagerie/aloha/' + 'build' + / 'mujoco_menagerie/aloha/' ) model_file_name = 'mjx_scene.xml' xml = (path / model_file_name).read_text() diff --git a/python/mujoco_mpc/mjx/visualize.py b/python/mujoco_mpc/mjx/visualize.py index deb493d65..1c66313f2 100644 --- a/python/mujoco_mpc/mjx/visualize.py +++ b/python/mujoco_mpc/mjx/visualize.py @@ -18,56 +18,93 @@ import mujoco from mujoco_mpc.mjx import predictive_sampling from mujoco_mpc.mjx.tasks.bimanual import handover + +import numpy as np # %% -nsteps = 500 -steps_per_plan = 4 -frame_skip = 5 # how many steps between each rendered frame +costs_to_compare = {} +for it in [0.3, 0.5, 0.8]: + nsteps = 300 + steps_per_plan = 10 + frame_skip = 5 # how many steps between each rendered frame + batch_size = 8192 + nsamples = 512 + nplans = batch_size // nsamples + print(f'nplans: {nplans}') + + sim_model_cpu, plan_model_cpu, cost_fn = handover.get_models_and_cost_fn() + p = predictive_sampling.Planner( + model=plan_model_cpu, # dummy + cost=cost_fn, + noise_scale=it, + horizon=128, + nspline=4, + nsample=nsamples - 1, + interp='zero', + ) + + trajectories, costs, plan_time = ( + predictive_sampling.receding_horizon_optimization( + p, + plan_model_cpu, + sim_model_cpu, + nsteps, + nplans, + steps_per_plan, + frame_skip, + ) + ) + print(f'plan_time: {plan_time}') + plt.figure() + plt.xlim([0, nsteps * sim_model_cpu.opt.timestep]) + plt.ylim([0, max(costs.flatten())]) + plt.xlabel('time') + plt.ylabel('cost') + x_time = [i * sim_model_cpu.opt.timestep for i in range(nsteps)] + for i in range(nplans): + plt.plot(x_time, costs[i], alpha=0.1) + avg = np.mean(costs, axis=0) + plt.plot(x_time, avg, linewidth=2.0) + var = np.var(costs, axis=0) + plt.errorbar( + x_time, + avg, + yerr=var, + fmt='none', + ecolor='b', + elinewidth=1, + alpha=0.2, + capsize=0, + ) -sim_model, plan_model, cost_fn = handover.get_models_and_cost_fn() -p = predictive_sampling.Planner( - model=plan_model, - cost=cost_fn, - noise_scale=0.3, - horizon=128, - nspline=4, - nsample=128 - 1, - interp='zero', -) + plt.show() + costs_to_compare[it] = costs -trajectory, costs, plan_time = ( - predictive_sampling.receding_horizon_optimization( - p, - plan_model, - sim_model, - nsteps, - steps_per_plan, - frame_skip, - ) -) +trajectory = trajectories[0] # %% -plt.xlim([0, nsteps * sim_model.opt.timestep]) -plt.ylim([0, max(costs)]) +plt.figure() +plt.xlim([0, nsteps * sim_model_cpu.opt.timestep]) +plt.ylim([0, max(costs.flatten())]) plt.xlabel('time') plt.ylabel('cost') -plt.plot([i * sim_model.opt.timestep for i in range(nsteps)], costs) -plt.show() +x_time = [i * sim_model_cpu.opt.timestep for i in range(nsteps)] +for val, costs in costs_to_compare.items(): + avg = np.mean(costs, axis=0) + plt.plot(x_time, avg, label=str(val)) + var = np.var(costs, axis=0) + plt.errorbar(x_time, avg, yerr=var, fmt='none', elinewidth=1, alpha=0.2, capsize=0) -sim_time = nsteps * sim_model.opt.timestep -plan_steps = nsteps // steps_per_plan -real_factor = sim_time / plan_time -print(f'Total wall time ({plan_steps} planning steps): {plan_time} s' - f' ({real_factor:.2f}x realtime)') +plt.legend() +plt.show() # %% frames = [] -renderer = mujoco.Renderer(sim_model) -d = mujoco.MjData(sim_model) +renderer = mujoco.Renderer(sim_model_cpu) +d = mujoco.MjData(sim_model_cpu) for qpos in trajectory: d.qpos = qpos - mujoco.mj_forward(sim_model, d) + mujoco.mj_forward(sim_model_cpu, d) renderer.update_scene(d) frames.append(renderer.render()) # %% -mediapy.show_video(frames, fps=1/sim_model.opt.timestep/frame_skip) -# %% +mediapy.show_video(frames, fps=1/sim_model_cpu.opt.timestep/frame_skip)