Skip to content

Commit

Permalink
Make planning model differentiable.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694207411
Change-Id: Id1b9d2edf83a0db96595b7dfce4ffaca136136f6
  • Loading branch information
thowell authored and copybara-github committed Nov 7, 2024
1 parent c08d406 commit 49783b8
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 15 deletions.
82 changes: 67 additions & 15 deletions mjpc/agent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,19 @@ void Agent::Initialize(const mjModel* model) {
planner_threads_ =
std::max(1, NumAvailableHardwareThreads() - 3 - 2 * estimator_threads_);

// differentiable planning model
// by default gradient-based planners use a differentiable model
int gradient_planner = false;
if (planner_ == kGradientPlanner || planner_ == kILQGPlanner ||
planner_ == kILQSPlanner) {
gradient_planner = true;
}
differentiable_ =
GetNumberOrDefault(gradient_planner, model, "agent_differentiable");
jnt_solimp_.resize(model->njnt);
geom_solimp_.resize(model->ngeom);
pair_solimp_.resize(model->npair);

// delete the previous model after all the planners have been updated to use
// the new one.
if (old_model) {
Expand Down Expand Up @@ -279,6 +292,22 @@ void Agent::PlanIteration(ThreadPool* pool) {
steps_ =
mju_max(mju_min(horizon_ / timestep_ + 1, kMaxTrajectoryHorizon), 1);

// make model differentiable
int differentiable = differentiable_;
if (differentiable) {
// cache solimp defaults
for (int i = 0; i < model_->njnt; i++) {
jnt_solimp_[i] = model_->jnt_solimp[mjNIMP * i];
}
for (int i = 0; i < model_->ngeom; i++) {
geom_solimp_[i] = model_->geom_solimp[mjNIMP * i];
}
for (int i = 0; i < model_->npair; i++) {
pair_solimp_[i] = model_->pair_solimp[mjNIMP * i];
}
MakeDifferentiable(model_);
}

// plan
if (!allocate_enabled) {
// set state
Expand Down Expand Up @@ -312,6 +341,19 @@ void Agent::PlanIteration(ThreadPool* pool) {
// release the planning residual function
residual_fn_.reset();
}

// restore solimp defaults
if (differentiable) {
for (int i = 0; i < model_->njnt; i++) {
model_->jnt_solimp[mjNIMP * i] = jnt_solimp_[i];
}
for (int i = 0; i < model_->ngeom; i++) {
model_->geom_solimp[mjNIMP * i] = geom_solimp_[i];
}
for (int i = 0; i < model_->npair; i++) {
model_->pair_solimp[mjNIMP * i] = pair_solimp_[i];
}
}
}

// call planner to update nominal policy
Expand Down Expand Up @@ -644,21 +686,23 @@ void Agent::GUI(mjUI& ui) {
}

// ----- agent ----- //
mjuiDef defAgent[] = {{mjITEM_SECTION, "Agent", 1, nullptr, "AP"},
{mjITEM_BUTTON, "Reset", 2, nullptr, " #459"},
{mjITEM_SELECT, "Planner", 2, &planner_, ""},
{mjITEM_SELECT, "Estimator", 2, &estimator_, ""},
{mjITEM_CHECKINT, "Plan", 2, &plan_enabled, ""},
{mjITEM_CHECKINT, "Action", 2, &action_enabled, ""},
{mjITEM_CHECKINT, "Plots", 2, &plot_enabled, ""},
{mjITEM_CHECKINT, "Traces", 2, &visualize_enabled, ""},
{mjITEM_SEPARATOR, "Agent Settings", 1},
{mjITEM_SLIDERNUM, "Horizon", 2, &horizon_, "0 1"},
{mjITEM_SLIDERNUM, "Timestep", 2, &timestep_, "0 1"},
{mjITEM_SELECT, "Integrator", 2, &integrator_,
"Euler\nRK4\nImplicit\nImplicitFast"},
{mjITEM_SEPARATOR, "Planner Settings", 1},
{mjITEM_END}};
mjuiDef defAgent[] = {
{mjITEM_SECTION, "Agent", 1, nullptr, "AP"},
{mjITEM_BUTTON, "Reset", 2, nullptr, " #459"},
{mjITEM_SELECT, "Planner", 2, &planner_, ""},
{mjITEM_SELECT, "Estimator", 2, &estimator_, ""},
{mjITEM_CHECKINT, "Plan", 2, &plan_enabled, ""},
{mjITEM_CHECKINT, "Action", 2, &action_enabled, ""},
{mjITEM_CHECKINT, "Plots", 2, &plot_enabled, ""},
{mjITEM_CHECKINT, "Traces", 2, &visualize_enabled, ""},
{mjITEM_SEPARATOR, "Agent Settings", 1},
{mjITEM_SLIDERNUM, "Horizon", 2, &horizon_, "0 1"},
{mjITEM_SLIDERNUM, "Timestep", 2, &timestep_, "0 1"},
{mjITEM_SELECT, "Integrator", 2, &integrator_,
"Euler\nRK4\nImplicit\nImplicitFast"},
{mjITEM_CHECKINT, "Differentiable", 2, &differentiable_, ""},
{mjITEM_SEPARATOR, "Planner Settings", 1},
{mjITEM_END}};

// planner names
mju::strcpy_arr(defAgent[2].other, planner_names_);
Expand Down Expand Up @@ -730,6 +774,14 @@ void Agent::AgentEvent(mjuiItem* it, mjData* data,
this->PlotInitialize();
this->PlotReset();

// by default gradient-based planners use a differentiable model
if (planner_ == kGradientPlanner || planner_ == kILQGPlanner ||
planner_ == kILQSPlanner) {
differentiable_ = true;
} else {
differentiable_ = false;
}

// reset agent
uiloadrequest.fetch_sub(1);
}
Expand Down
6 changes: 6 additions & 0 deletions mjpc/agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ class Agent {

// max threads for estimation
int estimator_threads_;

// differentiable planning model
bool differentiable_;
std::vector<double> jnt_solimp_;
std::vector<double> geom_solimp_;
std::vector<double> pair_solimp_;
};

} // namespace mjpc
Expand Down
11 changes: 11 additions & 0 deletions mjpc/planners/include.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@

namespace mjpc {

// planner types
enum PlannerType : int {
kSamplingPlanner = 0,
kGradientPlanner,
kILQGPlanner,
kILQSPlanner,
kRobustPlanner,
kCrossEntropyPlanner,
kSampleGradientPlanner,
};

// Planner names, separated by '\n'.
extern const char kPlannerNames[];

Expand Down
18 changes: 18 additions & 0 deletions mjpc/utilities.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,24 @@ extern "C" {
}

namespace mjpc {
// make model differentiable by setting solimp[0] to zero
void MakeDifferentiable(mjModel* model) {
// joints
for (int i = 0; i < model->njnt; i++) {
model->jnt_solimp[mjNIMP * i] = 0.0;
}

// geoms
for (int i = 0; i < model->ngeom; i++) {
model->geom_solimp[mjNIMP * i] = 0.0;
}

// contact pairs
for (int i = 0; i < model->npair; i++) {
model->pair_solimp[mjNIMP * i] = 0.0;
}
}

// set mjData state
void SetState(const mjModel* model, mjData* data, const double* state) {
mju_copy(data->qpos, state, model->nq);
Expand Down
3 changes: 3 additions & 0 deletions mjpc/utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ namespace mjpc {
// maximum number of traces that are visualized
inline constexpr int kMaxTraces = 99;

// make model differentiable by setting solimp[0] to zero
void MakeDifferentiable(mjModel* model);

// set mjData state
void SetState(const mjModel* model, mjData* data, const double* state);

Expand Down

0 comments on commit 49783b8

Please sign in to comment.