Skip to content

Commit

Permalink
Merge pull request #315 from google-deepmind/deepmind
Browse files Browse the repository at this point in the history
Merge deepmind branch into main.
  • Loading branch information
erez-tom authored May 3, 2024
2 parents 6d65ce4 + cdeea6c commit c82f6a6
Show file tree
Hide file tree
Showing 51 changed files with 1,785 additions and 459 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ set(MUJOCO_BUILD_TESTS OFF)
set(MUJOCO_TEST_PYTHON_UTIL OFF)

set(MUJOCO_MPC_MUJOCO_GIT_TAG
24eb4c9f092da7dd245a116841a5325a0fb359b9
3.1.4
CACHE STRING "Git revision for MuJoCo."
)

set(MUJOCO_MPC_MENAGERIE_GIT_TAG
8a5f659ac3607dc5adb988e0187f683fe0f4edf4
aff360ad958cb0e45d8f740816b1aacad84e9282
CACHE STRING "Git revision for MuJoCo Menagerie."
)

Expand Down
2 changes: 2 additions & 0 deletions mjpc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ add_library(
tasks/acrobot/acrobot.h
tasks/allegro/allegro.cc
tasks/allegro/allegro.h
tasks/bimanual/insert/insert.cc
tasks/bimanual/insert/insert.h
tasks/bimanual/handover/handover.cc
tasks/bimanual/handover/handover.h
tasks/bimanual/reorient/reorient.cc
Expand Down
10 changes: 8 additions & 2 deletions mjpc/agent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Agent::Agent(const mjModel* model, std::shared_ptr<Task> task)
// initialize data, settings, planners, state
void Agent::Initialize(const mjModel* model) {
// ----- model ----- //
if (model_) mj_deleteModel(model_);
mjModel* old_model = model_;
model_ = mj_copyModel(nullptr, model); // agent's copy of model

// check for limits on all actuators
Expand Down Expand Up @@ -152,6 +152,12 @@ void Agent::Initialize(const mjModel* model) {
// planner threads
planner_threads_ =
std::max(1, NumAvailableHardwareThreads() - 3 - 2 * estimator_threads_);

// delete the previous model after all the planners have been updated to use
// the new one.
if (old_model) {
mj_deleteModel(old_model);
}
}

// allocate memory
Expand Down Expand Up @@ -654,7 +660,7 @@ void Agent::GUI(mjUI& ui) {
{mjITEM_SLIDERNUM, "Horizon", 2, &horizon_, "0 1"},
{mjITEM_SLIDERNUM, "Timestep", 2, &timestep_, "0 1"},
{mjITEM_SELECT, "Integrator", 2, &integrator_,
"Euler\nRK4\nImplicit\nFastImplicit"},
"Euler\nRK4\nImplicit\nImplicitFast"},
{mjITEM_SEPARATOR, "Planner Settings", 1},
{mjITEM_END}};

Expand Down
2 changes: 1 addition & 1 deletion mjpc/estimators/batch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ void Batch::GUI(mjUI& ui) {
{mjITEM_BUTTON, "Reset", 2, nullptr, ""},
{mjITEM_SLIDERNUM, "Timestep", 2, &gui_timestep_, "1.0e-3 0.1"},
{mjITEM_SELECT, "Integrator", 2, &gui_integrator_,
"Euler\nRK4\nImplicit\nFastImplicit"},
"Euler\nRK4\nImplicit\nImplicitFast"},
{mjITEM_SLIDERINT, "Horizon", 2, &gui_horizon_, "3 3"},
{mjITEM_SLIDERNUM, "Prior Scale", 2, &gui_scale_prior_, "1.0e-8 0.1"},
{mjITEM_END}};
Expand Down
2 changes: 1 addition & 1 deletion mjpc/estimators/kalman.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ void Kalman::GUI(mjUI& ui) {
{mjITEM_BUTTON, "Reset", 2, nullptr, ""},
{mjITEM_SLIDERNUM, "Timestep", 2, &gui_timestep_, "1.0e-3 0.1"},
{mjITEM_SELECT, "Integrator", 2, &gui_integrator_,
"Euler\nRK4\nImplicit\nFastImplicit"},
"Euler\nRK4\nImplicit\nImplicitFast"},
{mjITEM_END}};

// add estimator
Expand Down
2 changes: 1 addition & 1 deletion mjpc/estimators/unscented.cc
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ void Unscented::GUI(mjUI& ui) {
{mjITEM_BUTTON, "Reset", 2, nullptr, ""},
{mjITEM_SLIDERNUM, "Timestep", 2, &gui_timestep_, "1.0e-3 0.1"},
{mjITEM_SELECT, "Integrator", 2, &gui_integrator_,
"Euler\nRK4\nImplicit\nFastImplicit"},
"Euler\nRK4\nImplicit\nImplicitFast"},
{mjITEM_END}};

// add estimator
Expand Down
2 changes: 1 addition & 1 deletion mjpc/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@ int main(int argc, char** argv) {
mju_error("Invalid --task flag.");
}

mjpc::StartApp(tasks, 11); // start with quadruped flat
mjpc::StartApp(tasks, task_id); // start with quadruped flat
return 0;
}
70 changes: 41 additions & 29 deletions mjpc/planners/cross_entropy/planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
#include <shared_mutex>

#include <absl/random/random.h>
#include <absl/types/span.h>
#include <mujoco/mujoco.h>
#include "mjpc/array_safety.h"
#include "mjpc/planners/planner.h"
#include "mjpc/planners/sampling/planner.h"
#include "mjpc/spline/spline.h"
#include "mjpc/states/state.h"
#include "mjpc/task.h"
#include "mjpc/threadpool.h"
Expand All @@ -33,6 +35,7 @@
namespace mjpc {

namespace mju = ::mujoco::util_mjpc;
using mjpc::spline::TimeSpline;

// initialize data and settings
void CrossEntropyPlanner::Initialize(mjModel* model, const Task& task) {
Expand Down Expand Up @@ -159,6 +162,8 @@ void CrossEntropyPlanner::SetState(const State& state) {

// optimize nominal policy using random sampling
void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
resampled_policy.plan.SetInterpolation(interpolation_);

// if num_trajectory_ has changed, use it in this new iteration.
// num_trajectory_ might change while this function runs. Keep it constant
// for the duration of this function.
Expand All @@ -172,7 +177,6 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
ResizeMjData(model, pool.NumThreads());

// copy nominal policy
policy.num_parameters = model->nu * policy.num_spline_points;
{
const std::shared_lock<std::shared_mutex> lock(mtx_);
resampled_policy.CopyFrom(policy, policy.num_spline_points);
Expand Down Expand Up @@ -211,7 +215,7 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {

// dimensions
int num_spline_points = resampled_policy.num_spline_points;
int num_parameters = resampled_policy.num_parameters;
int num_parameters = num_spline_points * model->nu;

// averaged return over elites
double avg_return = 0.0;
Expand All @@ -225,8 +229,12 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
int idx = trajectory_order[i];

// add parameters
mju_addTo(parameters_scratch.data(),
candidate_policy[idx].parameters.data(), num_parameters);
for (int i = 0; i < num_spline_points; i++) {
TimeSpline::Node n = candidate_policy[idx].plan.NodeAt(i);
for (int j = 0; j < model->nu; j++) {
parameters_scratch[i * model->nu + j] += n.values()[j];
}
}

// add total return
avg_return += trajectory[idx].total_return;
Expand All @@ -240,13 +248,13 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
// loop over elites to compute variance
std::fill(variance.begin(), variance.end(), 0.0); // reset variance to zero
for (int t = 0; t < num_spline_points; t++) {
TimeSpline::Node n = candidate_policy[trajectory_order[0]].plan.NodeAt(t);
for (int j = 0; j < model->nu; j++) {
// average
double p_avg = parameters_scratch[t * model->nu + j];
for (int i = 0; i < n_elite; i++) {
// candidate parameter
double pi =
candidate_policy[trajectory_order[i]].parameters[t * model->nu + j];
double pi = n.values()[j];
double diff = pi - p_avg;
variance[t * model->nu + j] += diff * diff / (n_elite - 1);
}
Expand All @@ -256,7 +264,14 @@ void CrossEntropyPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
// update
{
const std::shared_lock<std::shared_mutex> lock(mtx_);
policy.CopyParametersFrom(parameters_scratch, times_scratch);
policy.plan.Clear();
policy.plan.SetInterpolation(interpolation_);
for (int t = 0; t < num_spline_points; t++) {
absl::Span<const double> values =
absl::MakeConstSpan(parameters_scratch.data() + t * model->nu,
parameters_scratch.data() + (t + 1) * model->nu);
policy.plan.AddNode(times_scratch[t], values);
}
}

// improvement: compare nominal to elite average
Expand Down Expand Up @@ -298,7 +313,6 @@ void CrossEntropyPlanner::ActionFromPolicy(double* action, const double* state,
// update policy via resampling
void CrossEntropyPlanner::ResamplePolicy(int horizon) {
// dimensions
int num_parameters = resampled_policy.num_parameters;
int num_spline_points = resampled_policy.num_spline_points;

// time
Expand All @@ -315,15 +329,14 @@ void CrossEntropyPlanner::ResamplePolicy(int horizon) {
}

// copy resampled policy parameters
mju_copy(resampled_policy.parameters.data(), parameters_scratch.data(),
num_parameters);
mju_copy(resampled_policy.times.data(), times_scratch.data(),
num_spline_points);

LinearRange(resampled_policy.times.data(), time_shift,
resampled_policy.times[0], num_spline_points);

resampled_policy.representation = policy.representation;
resampled_policy.plan.Clear();
for (int t = 0; t < num_spline_points; t++) {
absl::Span<const double> values =
absl::MakeConstSpan(parameters_scratch.data() + t * model->nu,
parameters_scratch.data() + (t + 1) * model->nu);
resampled_policy.plan.AddNode(times_scratch[t], values);
}
resampled_policy.plan.SetInterpolation(policy.plan.Interpolation());
}

// add random noise to nominal policy
Expand All @@ -333,7 +346,7 @@ void CrossEntropyPlanner::AddNoiseToPolicy(int i, double std_min) {

// dimensions
int num_spline_points = candidate_policy[i].num_spline_points;
int num_parameters = candidate_policy[i].num_parameters;
int num_parameters = num_spline_points * model->nu;

// sampling token
absl::BitGen gen_;
Expand All @@ -350,14 +363,13 @@ void CrossEntropyPlanner::AddNoiseToPolicy(int i, double std_min) {
gen_, 0.0, std::max(std::sqrt(variance[k]), std_min));
}

// add noise
mju_addTo(candidate_policy[i].parameters.data(), DataAt(noise, shift),
num_parameters);

// clamp parameters
for (int t = 0; t < num_spline_points; t++) {
Clamp(DataAt(candidate_policy[i].parameters, t * model->nu),
model->actuator_ctrlrange, model->nu);
for (int k = 0; k < candidate_policy[i].plan.Size(); k++) {
TimeSpline::Node n = candidate_policy[i].plan.NodeAt(k);
// add noise
mju_addTo(n.values().data(), DataAt(noise, shift + k * model->nu),
model->nu);
// clamp parameters
Clamp(n.values().data(), model->actuator_ctrlrange, model->nu);
}

// end timer
Expand Down Expand Up @@ -385,8 +397,8 @@ void CrossEntropyPlanner::Rollouts(int num_trajectory, int horizon,
const std::shared_lock<std::shared_mutex> lock(s.mtx_);
s.candidate_policy[i].CopyFrom(s.resampled_policy,
s.resampled_policy.num_spline_points);
s.candidate_policy[i].representation =
s.resampled_policy.representation;
s.candidate_policy[i].plan.SetInterpolation(
s.resampled_policy.plan.Interpolation());

// sample noise
s.AddNoiseToPolicy(i, std_min);
Expand Down Expand Up @@ -473,7 +485,7 @@ void CrossEntropyPlanner::Traces(mjvScene* scn) {
void CrossEntropyPlanner::GUI(mjUI& ui) {
mjuiDef defCrossEntropy[] = {
{mjITEM_SLIDERINT, "Rollouts", 2, &num_trajectory_, "0 1"},
{mjITEM_SELECT, "Spline", 2, &policy.representation,
{mjITEM_SELECT, "Spline", 2, &interpolation_,
"Zero\nLinear\nCubic"},
{mjITEM_SLIDERINT, "Spline Pts", 2, &policy.num_spline_points, "0 1"},
{mjITEM_SLIDERNUM, "Init. Std", 2, &std_initial_, "0 1"},
Expand Down
3 changes: 3 additions & 0 deletions mjpc/planners/cross_entropy/planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <mujoco/mujoco.h>
#include "mjpc/planners/planner.h"
#include "mjpc/planners/sampling/policy.h"
#include "mjpc/spline/spline.h"
#include "mjpc/states/state.h"
#include "mjpc/task.h"
#include "mjpc/threadpool.h"
Expand Down Expand Up @@ -135,6 +136,8 @@ class CrossEntropyPlanner : public Planner {
double rollouts_compute_time;
double policy_update_compute_time;

mjpc::spline::SplineInterpolation interpolation_ =
mjpc::spline::SplineInterpolation::kZeroSpline;
int num_trajectory_;
mutable std::shared_mutex mtx_;
};
Expand Down
16 changes: 10 additions & 6 deletions mjpc/planners/gradient/policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@
#include "mjpc/planners/gradient/policy.h"

#include <algorithm>
#include <vector>

#include <mujoco/mujoco.h>
#include "mjpc/planners/gradient/spline_mapping.h"
#include "mjpc/planners/policy.h"
#include "mjpc/spline/spline.h"
#include "mjpc/task.h"
#include "mjpc/trajectory.h"
#include "mjpc/utilities.h"

namespace mjpc {

using mjpc::spline::SplineInterpolation;


// allocate memory
void GradientPolicy::Allocate(const mjModel* model, const Task& task,
int horizon) {
Expand All @@ -48,7 +52,7 @@ void GradientPolicy::Allocate(const mjModel* model, const Task& task,
"gradient_spline_points");

// representation
representation = GetNumberOrDefault(PolicyRepresentation::kLinearSpline,
representation = GetNumberOrDefault(SplineInterpolation::kLinearSpline,
model, "gradient_representation");
}

Expand Down Expand Up @@ -83,13 +87,13 @@ void GradientPolicy::Action(double* action, const double* state,
// ----- get action ----- //

if (bounds[0] == bounds[1] ||
representation == PolicyRepresentation::kZeroSpline) {
representation == SplineInterpolation::kZeroSpline) {
ZeroInterpolation(action, time, times, parameters.data(), model->nu,
num_spline_points);
} else if (representation == PolicyRepresentation::kLinearSpline) {
} else if (representation == SplineInterpolation::kLinearSpline) {
LinearInterpolation(action, time, times, parameters.data(), model->nu,
num_spline_points);
} else if (representation == PolicyRepresentation::kCubicSpline) {
} else if (representation == SplineInterpolation::kCubicSpline) {
CubicInterpolation(action, time, times, parameters.data(), model->nu,
num_spline_points);
}
Expand Down
3 changes: 2 additions & 1 deletion mjpc/planners/gradient/policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <mujoco/mujoco.h>
#include "mjpc/planners/policy.h"
#include "mjpc/spline/spline.h"
#include "mjpc/task.h"

namespace mjpc {
Expand Down Expand Up @@ -62,7 +63,7 @@ class GradientPolicy : public Policy {
std::vector<double> times;
int num_parameters;
int num_spline_points;
PolicyRepresentation representation;
mjpc::spline::SplineInterpolation representation;
};

} // namespace mjpc
Expand Down
Loading

0 comments on commit c82f6a6

Please sign in to comment.