diff --git a/README.md b/README.md index 183762094..9b842eced 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ MJPC allows the user to easily author and solve complex robotics tasks, and curr ## Overview -To read the paper describing this software package, please see out [preprint](https://arxiv.org/abs/2212.00541). +To read the paper describing this software package, please see our [preprint](https://arxiv.org/abs/2212.00541). For a quick video overview of MJPC, click below. diff --git a/grpc/agent.proto b/grpc/agent.proto index 514f415f7..4fb9f02d2 100644 --- a/grpc/agent.proto +++ b/grpc/agent.proto @@ -93,6 +93,11 @@ message GetActionRequest { // and actions will be averaged over that period. // During the rollout, the task's Transition will not be called. optional float averaging_duration = 2; + + // For planners that use feedback terms (iLQG), if true, return the nominal + // action for the given time rather than applying feedback terms on the + // current state. For the sampling planner this has no effect. + optional bool nominal_action = 3; } message GetActionResponse { diff --git a/grpc/agent_service_test.cc b/grpc/agent_service_test.cc index 1fd03f51f..1ad9445c0 100644 --- a/grpc/agent_service_test.cc +++ b/grpc/agent_service_test.cc @@ -22,6 +22,7 @@ #include "testing/base/public/gmock.h" #include "testing/base/public/gunit.h" #include +#include #include #include #include @@ -31,6 +32,7 @@ #include #include "grpc/agent.grpc.pb.h" #include "grpc/agent.pb.h" +#include "grpc/agent.proto.h" #include "mjpc/tasks/tasks.h" namespace agent_grpc { @@ -52,8 +54,6 @@ class AgentServiceTest : public ::testing::Test { void TearDown() override { server->Shutdown(); } void RunAndCheckInit(std::string_view task_id, mjModel* model) { - grpc::ClientContext init_context; - agent::InitRequest init_request; init_request.set_task_id(task_id); @@ -65,11 +65,26 @@ class AgentServiceTest : public ::testing::Test { init_request.set_allocated_model(nullptr); } - agent::InitResponse init_response; - grpc::Status init_status = - stub->Init(&init_context, init_request, &init_response); + SendRequest(&Agent::Stub::Init, init_request); + } + + // Sends a request, validates the status code and returns the response + template + Res SendRequest(grpc::Status (Agent::Stub::*method)(grpc::ClientContext*, + const Req&, Res*), + const Req& request) { + grpc::ClientContext context; + Res response; + grpc::Status status = (stub.get()->*method)(&context, request, &response); + EXPECT_TRUE(status.ok()) << status.error_message(); + return response; + } - EXPECT_TRUE(init_status.ok()) << init_status.error_message(); + // an overload which constructs an empty request + template + Res SendRequest(grpc::Status (Agent::Stub::*method)(grpc::ClientContext*, + const Req&, Res*)) { + return SendRequest(method, Req()); } std::unique_ptr agent_service; @@ -88,17 +103,11 @@ TEST_F(AgentServiceTest, Init_WithModel) { TEST_F(AgentServiceTest, SetState_Works) { RunAndCheckInit("Cartpole", nullptr); - grpc::ClientContext set_state_context; - agent::SetStateRequest set_state_request; agent::State* state = new agent::State(); state->set_time(0.0); set_state_request.set_allocated_state(state); - agent::SetStateResponse set_state_response; - grpc::Status set_state_status = stub->SetState( - &set_state_context, set_state_request, &set_state_response); - - EXPECT_TRUE(set_state_status.ok()) << set_state_status.error_message(); + SendRequest(&Agent::Stub::SetState, set_state_request); } TEST_F(AgentServiceTest, SetState_WrongSize) { @@ -124,56 +133,124 @@ TEST_F(AgentServiceTest, PlannerStep_ProducesNonzeroAction) { RunAndCheckInit("Cartpole", nullptr); { - grpc::ClientContext context; agent::SetTaskParametersRequest request; (*request.mutable_parameters())["Goal"].set_numeric(-1.0); - agent::SetTaskParametersResponse response; - grpc::Status status = stub->SetTaskParameters(&context, request, &response); + SendRequest(&Agent::Stub::SetTaskParameters, request); + } - EXPECT_TRUE(status.ok()); + SendRequest(&Agent::Stub::PlannerStep); + + { + agent::GetActionResponse response = SendRequest(&Agent::Stub::GetAction); + + ASSERT_EQ(response.action().size(), 1); + EXPECT_TRUE(response.action()[0] != 0.0); } +} + +TEST_F(AgentServiceTest, ActionAveragingGivesDifferentResult) { + RunAndCheckInit("Cartpole", nullptr); { - grpc::ClientContext context; + agent::SetTaskParametersRequest request; + (*request.mutable_parameters())["Goal"].set_numeric(-1.0); + SendRequest(&Agent::Stub::SetTaskParameters, request); + } - agent::PlannerStepRequest request; - agent::PlannerStepResponse response; - grpc::Status status = stub->PlannerStep(&context, request, &response); + SendRequest(&Agent::Stub::PlannerStep); - EXPECT_TRUE(status.ok()) << status.error_message(); + double action_without_averaging; + { + agent::GetActionResponse response = SendRequest(&Agent::Stub::GetAction); + ASSERT_EQ(response.action().size(), 1); + EXPECT_TRUE(response.action()[0] != 0.0); + action_without_averaging = response.action()[0]; } + double action_with_averaging; { grpc::ClientContext context; agent::GetActionRequest request; - agent::GetActionResponse response; - grpc::Status status = stub->GetAction(&context, request, &response); - - EXPECT_TRUE(status.ok()) << status.error_message(); - EXPECT_EQ(response.action().size(), 1); + request.set_averaging_duration(1.0); + agent::GetActionResponse response = + SendRequest(&Agent::Stub::GetAction, request); + ASSERT_EQ(response.action().size(), 1); EXPECT_TRUE(response.action()[0] != 0.0); + action_with_averaging = response.action()[0]; } + EXPECT_NE(action_with_averaging, action_without_averaging); } -TEST_F(AgentServiceTest, Step_AdvancesTime) { - RunAndCheckInit("Cartpole", nullptr); +TEST_F(AgentServiceTest, NominalActionIndependentOfState) { + // Pick a task that uses iLQG, where there is normally a feedback term on the + // policy. + RunAndCheckInit("Swimmer", nullptr); + + SendRequest(&Agent::Stub::PlannerStep); - agent::State initial_state; + double nominal_action1; { - grpc::ClientContext context; - agent::GetStateRequest request; - agent::GetStateResponse response; - EXPECT_TRUE(stub->GetState(&context, request, &response).ok()); - initial_state = response.state(); + agent::GetActionRequest request; + request.set_averaging_duration(1.0); + request.set_nominal_action(true); + request.set_time(0.01); + agent::GetActionResponse response = + SendRequest(&Agent::Stub::GetAction, request); + EXPECT_GE(response.action().size(), 1); + nominal_action1 = response.action()[0]; + EXPECT_NE(nominal_action1, 0.0); } + // Set a new state + { + agent::SetStateRequest request; + static constexpr int kSwimmerDofs = 8; + for (int i = 0; i < kSwimmerDofs; i++) { + request.mutable_state()->mutable_qpos()->Add(0.1); + } + SendRequest(&Agent::Stub::SetState, request); + } + + double nominal_action2; + { + agent::GetActionRequest request; + request.set_averaging_duration(1.0); + request.set_nominal_action(true); + request.set_time(0.01); + agent::GetActionResponse response = + SendRequest(&Agent::Stub::GetAction, request); + EXPECT_GE(response.action().size(), 1); + nominal_action2 = response.action()[0]; + } + + double feedback_action; + { + agent::GetActionRequest request; + request.set_averaging_duration(1.0); + request.set_nominal_action(false); + request.set_time(0.01); + agent::GetActionResponse response = + SendRequest(&Agent::Stub::GetAction, request); + EXPECT_GE(response.action().size(), 1); + feedback_action = response.action()[0]; + } + + EXPECT_EQ(nominal_action1, nominal_action2) + << "nominal action should be the same"; + EXPECT_NE(nominal_action1, feedback_action) + << "feedback action should be different from the nominal"; +} + +TEST_F(AgentServiceTest, Step_AdvancesTime) { + RunAndCheckInit("Cartpole", nullptr); + + agent::State initial_state = SendRequest(&Agent::Stub::GetState).state(); + { - grpc::ClientContext context; agent::SetTaskParametersRequest request; (*request.mutable_parameters())["Goal"].set_numeric(-1.0); - agent::SetTaskParametersResponse response; - EXPECT_TRUE(stub->SetTaskParameters(&context, request, &response).ok()); + SendRequest(&Agent::Stub::SetTaskParameters, request); } { @@ -185,33 +262,12 @@ TEST_F(AgentServiceTest, Step_AdvancesTime) { EXPECT_EQ(response.parameters().at("Goal").numeric(), -1.0); } - { - grpc::ClientContext context; - agent::PlannerStepRequest request; - agent::PlannerStepResponse response; - grpc::Status status = stub->PlannerStep(&context, request, &response); - - EXPECT_TRUE(status.ok()) << status.error_message(); - } - + SendRequest(&Agent::Stub::PlannerStep); for (int i = 0; i < 3; i++) { - grpc::ClientContext context; - agent::StepRequest request; - agent::StepResponse response; - grpc::Status status = stub->Step(&context, request, &response); - - EXPECT_TRUE(status.ok()) << status.error_message(); + SendRequest(&Agent::Stub::Step); } - agent::State final_state; - { - grpc::ClientContext context; - agent::GetStateRequest request; - agent::GetStateResponse response; - grpc::Status status = stub->GetState(&context, request, &response); - EXPECT_TRUE(status.ok()) << status.error_message(); - final_state = response.state(); - } + agent::State final_state = SendRequest(&Agent::Stub::GetState).state(); double cartpole_timestep = 0.001; EXPECT_DOUBLE_EQ(final_state.time() - initial_state.time(), 3 * cartpole_timestep); @@ -223,34 +279,11 @@ TEST_F(AgentServiceTest, Step_CallsTransition) { RunAndCheckInit("Particle", nullptr); - agent::State initial_state; - { - grpc::ClientContext context; - agent::GetStateRequest request; - agent::GetStateResponse response; - grpc::Status status = stub->GetState(&context, request, &response); - EXPECT_TRUE(status.ok()) << status.error_message(); - initial_state = response.state(); - } + agent::State initial_state = SendRequest(&Agent::Stub::GetState).state(); - { - grpc::ClientContext context; - agent::StepRequest request; - agent::StepResponse response; - grpc::Status status = stub->Step(&context, request, &response); + SendRequest(&Agent::Stub::Step); - EXPECT_TRUE(status.ok()) << status.error_message(); - } - - agent::State final_state; - { - grpc::ClientContext context; - agent::GetStateRequest request; - agent::GetStateResponse response; - grpc::Status status = stub->GetState(&context, request, &response); - EXPECT_TRUE(status.ok()) << status.error_message(); - final_state = response.state(); - } + agent::State final_state = SendRequest(&Agent::Stub::GetState).state(); EXPECT_NE(final_state.mocap_pos()[0], initial_state.mocap_pos()[0]) << "mocap_pos stayed constant. Was Transition called?"; } @@ -258,27 +291,16 @@ TEST_F(AgentServiceTest, Step_CallsTransition) { TEST_F(AgentServiceTest, SetTaskParameters_Numeric) { RunAndCheckInit("Cartpole", nullptr); - grpc::ClientContext context; - agent::SetTaskParametersRequest request; (*request.mutable_parameters())["Goal"].set_numeric(16.0); - agent::SetTaskParametersResponse response; - grpc::Status status = stub->SetTaskParameters(&context, request, &response); - - EXPECT_TRUE(status.ok()); + SendRequest(&Agent::Stub::SetTaskParameters, request); } TEST_F(AgentServiceTest, SetTaskParameters_Select) { RunAndCheckInit("Quadruped Flat", nullptr); - - grpc::ClientContext context; - agent::SetTaskParametersRequest request; (*request.mutable_parameters())["Gait"].set_selection("Trot"); - agent::SetTaskParametersResponse response; - grpc::Status status = stub->SetTaskParameters(&context, request, &response); - - EXPECT_TRUE(status.ok()) << status.error_message(); + SendRequest(&Agent::Stub::SetTaskParameters, request); } TEST_F(AgentServiceTest, SetCostWeights_Works) { diff --git a/grpc/estimator_server.cc b/grpc/estimator_server.cc index 98e834988..dd8e9b056 100644 --- a/grpc/estimator_server.cc +++ b/grpc/estimator_server.cc @@ -20,7 +20,6 @@ #include #include -#include #include #include @@ -33,7 +32,6 @@ ABSL_FLAG(int32_t, mjpc_port, 10000, "port to listen on"); int main(int argc, char** argv) { - absl::ParseCommandLine(argc, argv); absl::ParseCommandLine(argc, argv); int port = absl::GetFlag(FLAGS_mjpc_port); diff --git a/grpc/grpc_agent_util.cc b/grpc/grpc_agent_util.cc index af11f2e5d..286585eae 100644 --- a/grpc/grpc_agent_util.cc +++ b/grpc/grpc_agent_util.cc @@ -142,39 +142,71 @@ grpc::Status SetState(const SetStateRequest* request, mjpc::Agent* agent, #undef CHECK_SIZE +namespace { +// TODO(nimrod): make planner a const reference +std::vector AverageAction(mjpc::Planner& planner, const mjModel* model, + bool nominal_action, mjData* rollout_data, + mjpc::State* rollout_state, double time, + double averaging_duration) { + int nu = model->nu; + std::vector ret(nu, 0); + int nactions = 0; + double end_time = time + averaging_duration; + + if (nominal_action) { + std::vector action(nu, 0); + while (time < end_time) { + planner.ActionFromPolicy(action.data(), /*state=*/nullptr, time); + mju_addTo(ret.data(), action.data(), nu); + time += model->opt.timestep; + nactions++; + } + } else { + rollout_data->time = time; + while (rollout_data->time <= end_time) { + rollout_state->Set(model, rollout_data); + const double* state = rollout_state->state().data(); + planner.ActionFromPolicy(rollout_data->ctrl, state, + rollout_data->time); + mju_addTo(ret.data(), rollout_data->ctrl, nu); + mj_step(model, rollout_data); + nactions++; + } + } + mju_scl(ret.data(), ret.data(), 1.0 / nactions, nu); + return ret; +} + +} // namespace grpc::Status GetAction(const GetActionRequest* request, const mjpc::Agent* agent, const mjModel* model, mjData* rollout_data, mjpc::State* rollout_state, GetActionResponse* response) { - int nu = agent->GetActionDim(); - std::vector ret = std::vector(nu, 0); - double time = request->has_time() ? request->time() : agent->state.time(); if (request->averaging_duration() > 0) { - agent->state.CopyTo(model, rollout_data); - rollout_data->time = time; - double end_time = time + request->averaging_duration(); - int nactions = 0; - while (rollout_data->time <= end_time) { + if (request->nominal_action()) { + rollout_data = nullptr; + rollout_state = nullptr; + } else { + agent->state.CopyTo(model, rollout_data); rollout_state->Set(model, rollout_data); - agent->ActivePlanner().ActionFromPolicy(rollout_data->ctrl, - rollout_state->state().data(), - rollout_data->time); - mju_addTo(ret.data(), rollout_data->ctrl, nu); - mj_step(model, rollout_data); - nactions++; } - mju_scl(ret.data(), ret.data(), 1.0 / nactions, nu); + std::vector ret = AverageAction(agent->ActivePlanner(), model, + request->nominal_action(), rollout_data, rollout_state, + time, request->averaging_duration()); + response->mutable_action()->Assign(ret.begin(), ret.end()); } else { - agent->ActivePlanner().ActionFromPolicy( - ret.data(), &agent->state.state()[0], time); + std::vector ret(model->nu, 0); + const double* state = request->nominal_action() + ? nullptr + : agent->state.state().data(); + agent->ActivePlanner().ActionFromPolicy(ret.data(), state, time); + response->mutable_action()->Assign(ret.begin(), ret.end()); } - response->mutable_action()->Assign(ret.begin(), ret.end()); - return grpc::Status::OK; } @@ -187,7 +219,7 @@ grpc::Status GetCostValuesAndWeights( std::vector residuals(task->num_residual, 0); // scratch space double terms[mjpc::kMaxCostTerms]; task->Residual(model, data, residuals.data()); - task->CostTerms(terms, residuals.data(), /*weighted=*/false); + task->UnweightedCostTerms(terms, residuals.data()); for (int i = 0; i < task->num_term; i++) { CHECK_EQ(agent_model->sensor_type[i], mjSENS_USER); std::string_view sensor_name(agent_model->names + diff --git a/mjpc/interface.cc b/mjpc/interface.cc index 69f8142ce..e8b9b0197 100644 --- a/mjpc/interface.cc +++ b/mjpc/interface.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "third_party/mujoco_mpc/mjpc/interface.h" +#include "mjpc/interface.h" #include #include diff --git a/mjpc/planners/gradient/policy.h b/mjpc/planners/gradient/policy.h index 11413529b..390169b38 100644 --- a/mjpc/planners/gradient/policy.h +++ b/mjpc/planners/gradient/policy.h @@ -40,6 +40,7 @@ class GradientPolicy : public Policy { void Reset(int horizon) override; // compute action from policy + // state is not used void Action(double* action, const double* state, double time) const override; // copy policy diff --git a/mjpc/planners/ilqg/planner.h b/mjpc/planners/ilqg/planner.h index 92a48d85d..d60a6f3ec 100644 --- a/mjpc/planners/ilqg/planner.h +++ b/mjpc/planners/ilqg/planner.h @@ -53,6 +53,7 @@ class iLQGPlanner : public Planner { void NominalTrajectory(int horizon, ThreadPool& pool) override; // set action from policy + // if state == nullptr, return the nominal action without a feedback term void ActionFromPolicy(double* action, const double* state, double time, bool use_previous = false) override; diff --git a/mjpc/planners/ilqg/policy.cc b/mjpc/planners/ilqg/policy.cc index b18b9ee9d..facc5161b 100644 --- a/mjpc/planners/ilqg/policy.cc +++ b/mjpc/planners/ilqg/policy.cc @@ -96,47 +96,56 @@ void iLQGPolicy::Action(double* action, const double* state, ZeroInterpolation(action, time, trajectory.times, trajectory.actions.data(), model->nu, trajectory.horizon - 1); - // state reference - ZeroInterpolation(state_interp.data(), time, trajectory.times, - trajectory.states.data(), dim_state, trajectory.horizon); - - // gains - ZeroInterpolation(feedback_gain_scratch.data(), time, trajectory.times, - feedback_gain.data(), dim_action * dim_state_derivative, - trajectory.horizon - 1); + if (state) { + // state reference + ZeroInterpolation(state_interp.data(), time, trajectory.times, + trajectory.states.data(), dim_state, + trajectory.horizon); + + // gains + ZeroInterpolation(feedback_gain_scratch.data(), time, trajectory.times, + feedback_gain.data(), dim_action * dim_state_derivative, + trajectory.horizon - 1); + } } else if (representation == 1) { // action LinearInterpolation(action, time, trajectory.times, trajectory.actions.data(), model->nu, trajectory.horizon - 1); - // state - LinearInterpolation(state_interp.data(), time, trajectory.times, - trajectory.states.data(), dim_state, - trajectory.horizon); + if (state) { + // state + LinearInterpolation(state_interp.data(), time, trajectory.times, + trajectory.states.data(), dim_state, + trajectory.horizon); - // normalize quaternions - mj_normalizeQuat(model, state_interp.data()); + // normalize quaternions + mj_normalizeQuat(model, state_interp.data()); - LinearInterpolation(feedback_gain_scratch.data(), time, trajectory.times, - feedback_gain.data(), dim_action * dim_state_derivative, - trajectory.horizon - 1); + LinearInterpolation(feedback_gain_scratch.data(), time, trajectory.times, + feedback_gain.data(), + dim_action * dim_state_derivative, + trajectory.horizon - 1); + } } else if (representation == 2) { // action CubicInterpolation(action, time, trajectory.times, trajectory.actions.data(), model->nu, trajectory.horizon - 1); - // state - CubicInterpolation(state_interp.data(), time, trajectory.times, - trajectory.states.data(), dim_state, trajectory.horizon); + if (state) { + // state + CubicInterpolation(state_interp.data(), time, trajectory.times, + trajectory.states.data(), dim_state, + trajectory.horizon); - // normalize quaternions - mj_normalizeQuat(model, state_interp.data()); + // normalize quaternions + mj_normalizeQuat(model, state_interp.data()); - CubicInterpolation(feedback_gain_scratch.data(), time, trajectory.times, - feedback_gain.data(), dim_action * dim_state_derivative, - trajectory.horizon - 1); + CubicInterpolation(feedback_gain_scratch.data(), time, trajectory.times, + feedback_gain.data(), dim_action * dim_state_derivative, + trajectory.horizon - 1); + } } // add feedback diff --git a/mjpc/planners/ilqg/policy.h b/mjpc/planners/ilqg/policy.h index e6260c6d7..942e1dfad 100644 --- a/mjpc/planners/ilqg/policy.h +++ b/mjpc/planners/ilqg/policy.h @@ -39,6 +39,7 @@ class iLQGPolicy : public Policy { void Reset(int horizon) override; // set action from policy + // if state == nullptr, return the nominal action without a feedback term void Action(double* action, const double* state, double time) const override; // copy policy diff --git a/mjpc/planners/ilqs/planner.cc b/mjpc/planners/ilqs/planner.cc index c47670967..01b9dcf98 100644 --- a/mjpc/planners/ilqs/planner.cc +++ b/mjpc/planners/ilqs/planner.cc @@ -14,15 +14,11 @@ #include "mjpc/planners/ilqs/planner.h" -#include #include -#include -#include #include #include "mjpc/array_safety.h" #include "mjpc/planners/ilqg/planner.h" -#include "mjpc/planners/planner.h" #include "mjpc/planners/sampling/planner.h" #include "mjpc/states/state.h" #include "mjpc/trajectory.h" diff --git a/mjpc/planners/policy.h b/mjpc/planners/policy.h index a1d06c83d..887358ff8 100644 --- a/mjpc/planners/policy.h +++ b/mjpc/planners/policy.h @@ -41,6 +41,8 @@ class Policy { virtual void Reset(int horizon) = 0; // set action from policy + // for policies that have a feedback term, passing nullptr for state turns + // the feedback term off and returns the nominal action for that time virtual void Action(double* action, const double* state, double time) const = 0; }; diff --git a/mjpc/task.cc b/mjpc/task.cc index a01980934..c48f47bb3 100644 --- a/mjpc/task.cc +++ b/mjpc/task.cc @@ -23,8 +23,6 @@ namespace mjpc { -using ForwardingResidualFn = internal::ForwardingResidualFn; - namespace { void MissingParameterError(const mjModel* m, int sensorid) { mju_error( @@ -34,9 +32,6 @@ void MissingParameterError(const mjModel* m, int sensorid) { } } // namespace - -Task::Task() : default_residual_(this) {} - // called at: construction, load, and GUI reset void Task::Reset(const mjModel* model) { // ----- defaults ----- // @@ -221,55 +216,6 @@ void BaseResidualFn::Update() { parameters_ = task_->parameters; } -// default implementation calls down to Task::Residual, for backwards compat. -// this is not thread safe, but it's what most existing tasks do. -void ForwardingResidualFn::Residual(const mjModel* model, const mjData* data, - double* residual) const { - task_->Residual(model, data, residual); -} - -// compute weighted cost terms -void ForwardingResidualFn::CostTerms(double* terms, const double* residual, - bool weighted) const { - int f_shift = 0; - int p_shift = 0; - for (int k = 0; k < task_->num_term; k++) { - // running cost - terms[k] = (weighted ? task_->weight[k] : 1) * - Norm(nullptr, nullptr, residual + f_shift, - DataAt(task_->norm_parameter, p_shift), - task_->dim_norm_residual[k], task_->norm[k]); - - // shift residual - f_shift += task_->dim_norm_residual[k]; - - // shift parameters - p_shift += task_->num_norm_parameter[k]; - } -} - -// compute weighted cost from terms -double ForwardingResidualFn::CostValue(const double* residual) const { - // cost terms - double terms[kMaxCostTerms]; - - // evaluate - this->CostTerms(terms, residual, /*weighted=*/true); - - // summation of cost terms - double cost = 0.0; - for (int i = 0; i < task_->num_term; i++) { - cost += terms[i]; - } - - // exponential risk transformation - if (mju_abs(task_->risk) < kRiskNeutralTolerance) { - return cost; - } else { - return (mju_exp(task_->risk * cost) - 1.0) / task_->risk; - } -} - std::unique_ptr ThreadSafeTask::Residual() const { std::lock_guard lock(mutex_); return ResidualLocked(); @@ -288,7 +234,7 @@ void ThreadSafeTask::UpdateResidual() { void ThreadSafeTask::Transition(mjModel* model, mjData* data) { std::lock_guard lock(mutex_); - TransitionLocked(model, data, &mutex_); + TransitionLocked(model, data); InternalResidual()->Update(); } @@ -299,10 +245,15 @@ void ThreadSafeTask::Reset(const mjModel* model) { InternalResidual()->Update(); } -void ThreadSafeTask::CostTerms(double* terms, const double* residual, - bool weighted) const { +void ThreadSafeTask::CostTerms(double* terms, const double* residual) const { + std::lock_guard lock(mutex_); + return InternalResidual()->CostTerms(terms, residual, /*weighted=*/true); +} + +void ThreadSafeTask::UnweightedCostTerms(double* terms, + const double* residual) const { std::lock_guard lock(mutex_); - return InternalResidual()->CostTerms(terms, residual, weighted); + return InternalResidual()->CostTerms(terms, residual, /*weighted=*/false); } double ThreadSafeTask::CostValue(const double* residual) const { diff --git a/mjpc/task.h b/mjpc/task.h index 3bf104ab6..81117fc6e 100644 --- a/mjpc/task.h +++ b/mjpc/task.h @@ -41,7 +41,7 @@ class ResidualFn { virtual void Residual(const mjModel* model, const mjData* data, double* residual) const = 0; virtual void CostTerms(double* terms, const double* residual, - bool weighted = true) const = 0; + bool weighted) const = 0; virtual double CostValue(const double* residual) const = 0; // copies weights and parameters from the Task instance. This should be @@ -56,7 +56,7 @@ class BaseResidualFn : public ResidualFn { virtual ~BaseResidualFn() = default; void CostTerms(double* terms, const double* residual, - bool weighted = true) const override; + bool weighted) const override; double CostValue(const double* residual) const override; void Update() override; @@ -74,43 +74,20 @@ class BaseResidualFn : public ResidualFn { const Task* task_; }; -namespace internal { -// a ResidualFn which simply uses weights from the Task instance, for backwards -// compatibility. -// this isn't thread safe, because weights and parameters in the task can change -// at any time. -class ForwardingResidualFn : public ResidualFn { - public: - explicit ForwardingResidualFn(const Task* task) : task_(task) {} - virtual ~ForwardingResidualFn() = default; - - void Residual(const mjModel* model, const mjData* data, - double* residual) const override; - - void CostTerms(double* terms, const double* residual, - bool weighted = true) const override; - double CostValue(const double* residual) const override; - void Update() override {} - - private: - const Task* task_; -}; -} // namespace internal - +// interface for classes that implement MJPC task specifications +// +// NOTE: Rather than deriving from this class, derive from ThreadSafeTask +// TODO(nimrod): Rename ThreadSafeTask and clean up by assuming it's the only +// implementation class Task { public: // constructor - Task(); + Task() = default; virtual ~Task() = default; // ----- methods ----- // // returns an object which can compute the residual function. - // the default implementation delegates to - // Residual(mjModel*, mjData*, double*), for backwards compability, but - // new implementations should return a custom ResidualFn object. - virtual std::unique_ptr Residual() const { - return std::make_unique(this); - } + virtual std::unique_ptr Residual() const = 0; // should be overridden by subclasses to use internal ResidualFn virtual void Residual(const mjModel* model, const mjData* data, @@ -123,7 +100,7 @@ class Task { // Changes to data will affect the planner at the next set_state. Changes to // model will only affect the physics and render threads, and will not affect // the planner. This is useful for studying planning under model discrepancy, - virtual void Transition(mjModel* model, mjData* data) {} + virtual void Transition(mjModel* model, mjData* data) = 0; // get information from model virtual void Reset(const mjModel* model); @@ -132,15 +109,12 @@ class Task { mjvScene* scene) const {} // compute cost terms - virtual void CostTerms(double* terms, const double* residual, - bool weighted = true) const { - return default_residual_.CostTerms(terms, residual, weighted); - } + virtual void CostTerms(double* terms, const double* residual) const = 0; + virtual void UnweightedCostTerms(double* terms, + const double* residual) const = 0; // compute weighted cost - virtual double CostValue(const double* residual) const { - return default_residual_.CostValue(residual); - } + virtual double CostValue(const double* residual) const = 0; virtual std::string Name() const = 0; virtual std::string XmlPath() const = 0; @@ -169,14 +143,11 @@ class Task { private: // initial residual parameters from model void SetFeatureParameters(const mjModel* model); - internal::ForwardingResidualFn default_residual_; }; // A version of Task which provides a Residual that can be run independently // of the class, and where the parameters and weights used in the residual // computations are guarded with a lock. -// TODO(nimrod): Migrate all tasks to this API, and deprecate the -// not-thread-safe Task. class ThreadSafeTask : public Task { public: virtual ~ThreadSafeTask() override = default; @@ -201,8 +172,11 @@ class ThreadSafeTask : public Task { // calls CostTerms on the pointer returned from InternalResidual(), while // holding a lock - void CostTerms(double* terms, const double* residual, - bool weighted = true) const final; + void CostTerms(double* terms, const double* residual) const final; + + // calls CostTerms on the pointer returned from InternalResidual(), while + // holding a lock + void UnweightedCostTerms(double* terms, const double* residual) const final; // calls CostValue on the pointer returned from InternalResidual(), while // holding a lock @@ -221,11 +195,9 @@ class ThreadSafeTask : public Task { // implementation of Task::Transition() which can assume a lock is held. // in some cases the transition logic requires calling mj_forward (e.g., for // measuring contact forces), which will call the sensor callback, which calls - // ResidualLocked. In order to avoid such resource contention, we give the - // user the ability to temporarily unlock the mutex, but it must be locked - // again before returning. - virtual void TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) {} + // ResidualLocked. In order to avoid such resource contention, mutex_ might be + // temporarily unlocked, but it must be locked again before returning. + virtual void TransitionLocked(mjModel* model, mjData* data) {} // implementation of Task::Reset() which can assume a lock is held virtual void ResetLocked(const mjModel* model) {} // mutex which should be held on changes to InternalResidual. diff --git a/mjpc/tasks/hand/hand.cc b/mjpc/tasks/hand/hand.cc index bfa37a775..dc257ecba 100644 --- a/mjpc/tasks/hand/hand.cc +++ b/mjpc/tasks/hand/hand.cc @@ -79,8 +79,7 @@ void Hand::ResidualFn::Residual(const mjModel* model, const mjData* data, // If cube is within tolerance or floor -> // reset cube into hand. // ----------------------------------------------- -void Hand::TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) { +void Hand::TransitionLocked(mjModel* model, mjData* data) { // find cube and floor int cube = mj_name2id(model, mjOBJ_GEOM, "cube"); int floor = mj_name2id(model, mjOBJ_GEOM, "floor"); @@ -105,9 +104,9 @@ void Hand::TransitionLocked(mjModel* model, mjData* data, mju_copy(data->qpos + jnt_qposadr, model->qpos0 + jnt_qposadr, 7); mju_zero(data->qvel + jnt_veladr, 6); } - mutex->unlock(); // step calls sensor that calls Residual. + mutex_.unlock(); // step calls sensor that calls Residual. mj_forward(model, data); // mj_step1 would suffice, we just need contact - mutex->lock(); + mutex_.lock(); } } diff --git a/mjpc/tasks/hand/hand.h b/mjpc/tasks/hand/hand.h index 4e0d060c4..ad113113b 100644 --- a/mjpc/tasks/hand/hand.h +++ b/mjpc/tasks/hand/hand.h @@ -15,7 +15,6 @@ #ifndef MJPC_TASKS_HAND_HAND_H_ #define MJPC_TASKS_HAND_HAND_H_ -#include #include #include #include "mjpc/task.h" @@ -46,8 +45,7 @@ class Hand : public ThreadSafeTask { // If cube is within tolerance or floor -> // reset cube into hand. // ----------------------------------------------- - void TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) override; + void TransitionLocked(mjModel* model, mjData* data) override; protected: std::unique_ptr ResidualLocked() const override { diff --git a/mjpc/tasks/humanoid/tracking/tracking.cc b/mjpc/tasks/humanoid/tracking/tracking.cc index f63626605..539b67b4d 100644 --- a/mjpc/tasks/humanoid/tracking/tracking.cc +++ b/mjpc/tasks/humanoid/tracking/tracking.cc @@ -222,7 +222,7 @@ void Tracking::ResidualFn::Residual(const mjModel *model, const mjData *data, // Linearly interpolate between two consecutive key frames in order to // smooth the transitions between keyframes. // ---------------------------------------------------------------------------- -void Tracking::TransitionLocked(mjModel *model, mjData *d, std::mutex *mutex) { +void Tracking::TransitionLocked(mjModel *model, mjData *d) { // get motion start index int start = MotionStartIndex(mode); // get motion trajectory length diff --git a/mjpc/tasks/humanoid/tracking/tracking.h b/mjpc/tasks/humanoid/tracking/tracking.h index 0de873d71..08c02081d 100644 --- a/mjpc/tasks/humanoid/tracking/tracking.h +++ b/mjpc/tasks/humanoid/tracking/tracking.h @@ -56,8 +56,7 @@ class Tracking : public ThreadSafeTask { // Linearly interpolate between two consecutive key frames in order to // smooth the transitions between keyframes. // --------------------------------------------------------------------------- - void TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) override; + void TransitionLocked(mjModel* model, mjData* data) override; std::string Name() const override; std::string XmlPath() const override; diff --git a/mjpc/tasks/manipulation/manipulation.cc b/mjpc/tasks/manipulation/manipulation.cc index 40fc5bfc6..ef128b2ad 100644 --- a/mjpc/tasks/manipulation/manipulation.cc +++ b/mjpc/tasks/manipulation/manipulation.cc @@ -63,8 +63,7 @@ void manipulation::Bring::ResidualFn::Residual(const mjModel* model, CheckSensorDim(model, counter); } -void manipulation::Bring::TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) { +void manipulation::Bring::TransitionLocked(mjModel* model, mjData* data) { double residuals[100]; double terms[10]; residual_.Residual(model, data, residuals); diff --git a/mjpc/tasks/manipulation/manipulation.h b/mjpc/tasks/manipulation/manipulation.h index cd6a1a902..8db47fa27 100644 --- a/mjpc/tasks/manipulation/manipulation.h +++ b/mjpc/tasks/manipulation/manipulation.h @@ -38,8 +38,7 @@ class Bring : public ThreadSafeTask { }; Bring() : residual_(this, ModelValues()) {} - void TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) override; + void TransitionLocked(mjModel* model, mjData* data) override; void ResetLocked(const mjModel* model) override; protected: diff --git a/mjpc/tasks/panda/panda.cc b/mjpc/tasks/panda/panda.cc index 15e74b321..fed90e5a4 100644 --- a/mjpc/tasks/panda/panda.cc +++ b/mjpc/tasks/panda/panda.cc @@ -71,12 +71,9 @@ void Panda::ResidualFn::Residual(const mjModel* model, const mjData* data, } } -void Panda::TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) { +void Panda::TransitionLocked(mjModel* model, mjData* data) { double residuals[100]; - double terms[10]; residual_.Residual(model, data, residuals); - residual_.CostTerms(terms, residuals); double bring_dist = (mju_norm3(residuals+3) + mju_norm3(residuals+6)) / 2; // reset: diff --git a/mjpc/tasks/panda/panda.h b/mjpc/tasks/panda/panda.h index 01555c499..7d3ff4290 100644 --- a/mjpc/tasks/panda/panda.h +++ b/mjpc/tasks/panda/panda.h @@ -31,8 +31,7 @@ class Panda : public ThreadSafeTask { double* residual) const override; }; Panda() : residual_(this) {} - void TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) override; + void TransitionLocked(mjModel* model, mjData* data) override; protected: std::unique_ptr ResidualLocked() const override { diff --git a/mjpc/tasks/particle/particle.cc b/mjpc/tasks/particle/particle.cc index 692424cfe..e1651ac3d 100644 --- a/mjpc/tasks/particle/particle.cc +++ b/mjpc/tasks/particle/particle.cc @@ -49,8 +49,7 @@ void Particle::ResidualFn::Residual(const mjModel* model, const mjData* data, mju_copy(residual + 4, data->ctrl, model->nu); } -void Particle::TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) { +void Particle::TransitionLocked(mjModel* model, mjData* data) { // some Lissajous curve double goal[2]{0.25 * mju_sin(data->time), 0.25 * mju_cos(data->time / mjPI)}; diff --git a/mjpc/tasks/particle/particle.h b/mjpc/tasks/particle/particle.h index 19c36f0e7..4851d53f8 100644 --- a/mjpc/tasks/particle/particle.h +++ b/mjpc/tasks/particle/particle.h @@ -37,8 +37,7 @@ class Particle : public ThreadSafeTask { double* residual) const override; }; Particle() : residual_(this) {} - void TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) override; + void TransitionLocked(mjModel* model, mjData* data) override; protected: std::unique_ptr ResidualLocked() const override { diff --git a/mjpc/tasks/quadrotor/quadrotor.cc b/mjpc/tasks/quadrotor/quadrotor.cc index 9d9e5515c..7aa530d15 100644 --- a/mjpc/tasks/quadrotor/quadrotor.cc +++ b/mjpc/tasks/quadrotor/quadrotor.cc @@ -64,8 +64,7 @@ void Quadrotor::ResidualFn::Residual(const mjModel* model, const mjData* data, } // ----- Transition for quadrotor task ----- -void Quadrotor::TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) { +void Quadrotor::TransitionLocked(mjModel* model, mjData* data) { // set mode to GUI selection if (mode > 0) { current_mode_ = mode - 1; diff --git a/mjpc/tasks/quadrotor/quadrotor.h b/mjpc/tasks/quadrotor/quadrotor.h index 472ee86fb..cea28fc34 100644 --- a/mjpc/tasks/quadrotor/quadrotor.h +++ b/mjpc/tasks/quadrotor/quadrotor.h @@ -42,8 +42,7 @@ class Quadrotor : public ThreadSafeTask { }; Quadrotor() : residual_(this) {} - void TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) override; + void TransitionLocked(mjModel* model, mjData* data) override; protected: std::unique_ptr ResidualLocked() const override { diff --git a/mjpc/tasks/quadruped/quadruped.cc b/mjpc/tasks/quadruped/quadruped.cc index 6b6f8663a..56ce76f58 100644 --- a/mjpc/tasks/quadruped/quadruped.cc +++ b/mjpc/tasks/quadruped/quadruped.cc @@ -221,8 +221,7 @@ void QuadrupedFlat::ResidualFn::Residual(const mjModel* model, } // ============ transition ============ -void QuadrupedFlat::TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) { +void QuadrupedFlat::TransitionLocked(mjModel* model, mjData* data) { // ---------- handle mjData reset ---------- if (data->time < residual_.last_transition_time_ || residual_.last_transition_time_ == -1) { @@ -777,8 +776,7 @@ void QuadrupedHill::ResidualFn::Residual(const mjModel* model, // If quadruped is within tolerance of goal -> // set goal to next from keyframes. // ----------------------------------------------- -void QuadrupedHill::TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) { +void QuadrupedHill::TransitionLocked(mjModel* model, mjData* data) { // set mode to GUI selection if (mode > 0) { residual_.current_mode_ = mode - 1; diff --git a/mjpc/tasks/quadruped/quadruped.h b/mjpc/tasks/quadruped/quadruped.h index 4b660e403..5e3f5c617 100644 --- a/mjpc/tasks/quadruped/quadruped.h +++ b/mjpc/tasks/quadruped/quadruped.h @@ -227,8 +227,7 @@ class QuadrupedFlat : public ThreadSafeTask { }; QuadrupedFlat() : residual_(this) {} - void TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) override; + void TransitionLocked(mjModel* model, mjData* data) override; // call base-class Reset, save task-related ids void ResetLocked(const mjModel* model) override; @@ -274,8 +273,7 @@ class QuadrupedHill : public ThreadSafeTask { int current_mode_; }; QuadrupedHill() : residual_(this) {} - void TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) override; + void TransitionLocked(mjModel* model, mjData* data) override; protected: std::unique_ptr ResidualLocked() const override { diff --git a/mjpc/tasks/swimmer/swimmer.cc b/mjpc/tasks/swimmer/swimmer.cc index 015af3f48..59350dfa7 100644 --- a/mjpc/tasks/swimmer/swimmer.cc +++ b/mjpc/tasks/swimmer/swimmer.cc @@ -48,8 +48,7 @@ void Swimmer::ResidualFn::Residual(const mjModel* model, const mjData* data, // If swimmer is within tolerance of goal -> // move goal randomly. // --------------------------------------------- -void Swimmer::TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) { +void Swimmer::TransitionLocked(mjModel* model, mjData* data) { double* target = SensorByName(model, data, "target"); double* nose = SensorByName(model, data, "nose"); double nose_to_target[2]; diff --git a/mjpc/tasks/swimmer/swimmer.h b/mjpc/tasks/swimmer/swimmer.h index 25b8fcf44..490da794f 100644 --- a/mjpc/tasks/swimmer/swimmer.h +++ b/mjpc/tasks/swimmer/swimmer.h @@ -15,7 +15,6 @@ #ifndef MJPC_TASKS_SWIMMER_SWIMMER_H_ #define MJPC_TASKS_SWIMMER_SWIMMER_H_ -#include #include #include #include "mjpc/task.h" @@ -43,8 +42,7 @@ class Swimmer : public ThreadSafeTask { // If swimmer is within tolerance of goal -> // move goal randomly. // --------------------------------------------- - void TransitionLocked(mjModel* model, mjData* data, - std::mutex* mutex) override; + void TransitionLocked(mjModel* model, mjData* data) override; protected: std::unique_ptr ResidualLocked() const override { diff --git a/mjpc/test/agent/rollout_test.cc b/mjpc/test/agent/rollout_test.cc index 2f09c6cc1..2dc613374 100644 --- a/mjpc/test/agent/rollout_test.cc +++ b/mjpc/test/agent/rollout_test.cc @@ -22,16 +22,34 @@ namespace mjpc { namespace { -struct ParticleCopyTestTask : public mjpc::Task { +class ParticleCopyTestTask : public mjpc::ThreadSafeTask { + public: + ParticleCopyTestTask() : residual_(this) {} + std::string Name() const override {return ""; } std::string XmlPath() const override { return ""; } - void Residual(const mjModel* model, const mjData* data, - double* residual) const override { - mju_copy(residual, data->qpos, model->nq); - mju_copy(residual + model->nq, data->qvel, model->nv); + + private: + class ResidualFn : public mjpc::BaseResidualFn { + public: + explicit ResidualFn(const ParticleCopyTestTask* task) + : mjpc::BaseResidualFn(task) {} + void Residual(const mjModel* model, const mjData* data, + double* residual) const override { + mju_copy(residual, data->qpos, model->nq); + mju_copy(residual + model->nq, data->qvel, model->nv); + } + }; + + std::unique_ptr ResidualLocked() const override { + return std::make_unique(residual_); } + ResidualFn* InternalResidual() override { return &residual_; } + + ResidualFn residual_; }; + ParticleCopyTestTask task; extern "C" { diff --git a/mjpc/test/tasks/task_test.cc b/mjpc/test/tasks/task_test.cc index 9c5aeff59..c0980bd38 100644 --- a/mjpc/test/tasks/task_test.cc +++ b/mjpc/test/tasks/task_test.cc @@ -21,10 +21,23 @@ namespace mjpc { namespace { -struct TestTask : public Task { - std::string Name() const override {return ""; } +class TestTask : public ThreadSafeTask { + public: + TestTask() : residual_(this) {} + std::string Name() const override { return ""; } std::string XmlPath() const override { return ""; } - void Residual(const mjModel*, const mjData*, double*) const override {}; + + class ResidualFn : public BaseResidualFn { + public: + ResidualFn(TestTask* task) : BaseResidualFn(task) {} + void Residual(const mjModel*, const mjData*, double*) const override {} + }; + + std::unique_ptr ResidualLocked() const override { + return std::make_unique(residual_); + } + ResidualFn* InternalResidual() override { return &residual_; } + ResidualFn residual_; }; // test task construction diff --git a/mjpc/test/testdata/particle_residual.h b/mjpc/test/testdata/particle_residual.h index 7b869d9b9..627fab124 100644 --- a/mjpc/test/testdata/particle_residual.h +++ b/mjpc/test/testdata/particle_residual.h @@ -19,19 +19,36 @@ #include "mjpc/task.h" #include -class ParticleTestTask : public mjpc::Task { +class ParticleTestTask : public mjpc::ThreadSafeTask { public: + ParticleTestTask() : residual_(this) {} + std::string Name() const override {return ""; } std::string XmlPath() const override { return ""; } - void Residual(const mjModel* model, const mjData* data, - double* residual) const override { - // goal position - mju_copy(residual, data->qpos, model->nq); - residual[0] -= data->mocap_pos[0]; - residual[1] -= data->mocap_pos[1]; - - // goal velocity error - mju_copy(residual + 2, data->qvel, model->nv); + + private: + class ResidualFn : public mjpc::BaseResidualFn { + public: + explicit ResidualFn(const ParticleTestTask* task) + : mjpc::BaseResidualFn(task) {} + void Residual(const mjModel* model, const mjData* data, + double* residual) const override { + // goal position + mju_copy(residual, data->qpos, model->nq); + residual[0] -= data->mocap_pos[0]; + residual[1] -= data->mocap_pos[1]; + + // goal velocity error + mju_copy(residual + 2, data->qvel, model->nv); + } + }; + + std::unique_ptr ResidualLocked() const override { + return std::make_unique(residual_); } + ResidualFn* InternalResidual() override { return &residual_; } + + ResidualFn residual_; }; + #endif // MJPC_TEST_TESTDATA_PARTICLE_RESIDUAL_H_ diff --git a/mjpc/utilities.cc b/mjpc/utilities.cc index 6bad96781..79e383825 100644 --- a/mjpc/utilities.cc +++ b/mjpc/utilities.cc @@ -571,15 +571,17 @@ void StateDiff(const mjModel* m, mjtNum* ds, const mjtNum* s1, const mjtNum* s2, } // return global height of nearest group 0 geom under given position -mjtNum Ground(const mjModel* model, const mjData* data, const mjtNum pos[3]) { - const mjtByte geomgroup[6] = {1, 0, 0, 0, 0, 0}; // only detect group 0 +mjtNum Ground(const mjModel* model, const mjData* data, const mjtNum pos[3], + const mjtByte* geomgroup) { mjtNum down[3] = {0, 0, -1}; // aim ray straight down const mjtNum height_offset = .5; // add some height in case of penetration const mjtByte flg_static = 1; // include static geoms const int bodyexclude = -1; // don't exclude any bodies int geomid; // id of intersecting geom mjtNum query[3] = {pos[0], pos[1], pos[2] + height_offset}; - mjtNum dist = mj_ray(model, data, query, down, geomgroup, flg_static, + const mjtByte default_geomgroup[6] = {1, 0, 0, 0, 0, 0}; + const mjtByte* query_geomgroup = geomgroup ? geomgroup : default_geomgroup; + mjtNum dist = mj_ray(model, data, query, down, query_geomgroup, flg_static, bodyexclude, &geomid); if (dist < 0) { // SHOULD NOT OCCUR diff --git a/mjpc/utilities.h b/mjpc/utilities.h index 244432987..31783ff13 100644 --- a/mjpc/utilities.h +++ b/mjpc/utilities.h @@ -157,8 +157,9 @@ void Diff(mjtNum* dx, const mjtNum* x1, const mjtNum* x2, mjtNum h, int n); void StateDiff(const mjModel* m, mjtNum* ds, const mjtNum* s1, const mjtNum* s2, mjtNum h); -// return global height of nearest group 0 geom under given position -mjtNum Ground(const mjModel* model, const mjData* data, const mjtNum pos[3]); +// return global height of nearest geom in geomgroup under given position +mjtNum Ground(const mjModel* model, const mjData* data, const mjtNum pos[3], + const mjtByte* geomgroup = nullptr); // set x to be the point on the segment [p0 p1] that is nearest to x void ProjectToSegment(double x[3], const double p0[3], const double p1[3]); diff --git a/python/mujoco_mpc/agent.py b/python/mujoco_mpc/agent.py index 7b8c52226..36fa64583 100644 --- a/python/mujoco_mpc/agent.py +++ b/python/mujoco_mpc/agent.py @@ -193,7 +193,10 @@ def get_state(self) -> agent_pb2.State: return self.stub.GetState(agent_pb2.GetStateRequest()).state def get_action( - self, time: Optional[float] = None, averaging_duration: float = 0 + self, + time: Optional[float] = None, + averaging_duration: float = 0, + nominal_action: bool = False, ) -> np.ndarray: """Return latest `action` from the `Agent`'s planner. @@ -201,12 +204,15 @@ def get_action( time: `data.time`, i.e. the simulation time. averaging_duration: the duration over which actions should be averaged (e.g. the control timestep). + nominal_action: if True, don't apply feedback terms in the policy Returns: action: `Agent`'s planner's latest action. """ get_action_request = agent_pb2.GetActionRequest( - time=time, averaging_duration=averaging_duration + time=time, + averaging_duration=averaging_duration, + nominal_action=nominal_action, ) get_action_response = self.stub.GetAction(get_action_request) return np.array(get_action_response.action) diff --git a/python/mujoco_mpc/agent_test.py b/python/mujoco_mpc/agent_test.py index c6b99ecc6..96a21492d 100644 --- a/python/mujoco_mpc/agent_test.py +++ b/python/mujoco_mpc/agent_test.py @@ -15,6 +15,7 @@ from absl.testing import absltest +from absl.testing import parameterized import grpc import mujoco from mujoco_mpc import agent as agent_lib @@ -39,7 +40,7 @@ def environment_reset(model, data): return get_observation(model, data) -class AgentTest(absltest.TestCase): +class AgentTest(parameterized.TestCase): def test_set_task_parameters(self): model_path = ( @@ -84,7 +85,8 @@ def test_step_env_with_planner(self): self.assertFalse((observations == 0).all()) self.assertFalse((actions == 0).all()) - def test_action_averaging_doesnt_change_state(self): + @parameterized.parameters({"nominal": False}, {"nominal": True}) + def test_action_averaging_doesnt_change_state(self, nominal): # when calling get_action with action averaging, the Agent needs to roll # out physics, but the API should be implemented not to mutate the state model_path = ( @@ -108,7 +110,9 @@ def test_action_averaging_doesnt_change_state(self): mocap_quat=data.mocap_quat, userdata=data.userdata, ) - agent.get_action(averaging_duration=control_timestep) + agent.get_action( + averaging_duration=control_timestep, nominal_action=nominal + ) state_after = agent.get_state() self.assertEqual(data.time, state_after.time) np.testing.assert_allclose(data.qpos, state_after.qpos)