From 12ca16a64b376b20acc4a45a11bdc2edf8288446 Mon Sep 17 00:00:00 2001 From: Nimrod Gileadi Date: Mon, 31 Jul 2023 08:06:26 -0700 Subject: [PATCH] Add an option to GetAction to return the average nominal action instead of rolling out. For the sampling and gradient planners, this will make no difference. For iLQG, this will ignore the feedback terms in the policy. Also, remove some repetitive code from agent_service_test.cc PiperOrigin-RevId: 552483307 Change-Id: Ib3f9420b3541a5f77a6abb7c577737169f3780d5 --- grpc/agent.proto | 5 + grpc/agent_service_test.cc | 224 ++++++++++++++++++-------------- grpc/grpc_agent_util.cc | 70 +++++++--- mjpc/planners/gradient/policy.h | 1 + mjpc/planners/ilqg/planner.h | 1 + mjpc/planners/ilqg/policy.cc | 59 +++++---- mjpc/planners/ilqg/policy.h | 1 + mjpc/planners/policy.h | 2 + python/mujoco_mpc/agent.py | 10 +- python/mujoco_mpc/agent_test.py | 10 +- 10 files changed, 233 insertions(+), 150 deletions(-) diff --git a/grpc/agent.proto b/grpc/agent.proto index 514f415f7..4fb9f02d2 100644 --- a/grpc/agent.proto +++ b/grpc/agent.proto @@ -93,6 +93,11 @@ message GetActionRequest { // and actions will be averaged over that period. // During the rollout, the task's Transition will not be called. optional float averaging_duration = 2; + + // For planners that use feedback terms (iLQG), if true, return the nominal + // action for the given time rather than applying feedback terms on the + // current state. For the sampling planner this has no effect. + optional bool nominal_action = 3; } message GetActionResponse { diff --git a/grpc/agent_service_test.cc b/grpc/agent_service_test.cc index 1fd03f51f..f9d0cbbe0 100644 --- a/grpc/agent_service_test.cc +++ b/grpc/agent_service_test.cc @@ -22,6 +22,7 @@ #include "testing/base/public/gmock.h" #include "testing/base/public/gunit.h" #include +#include #include #include #include @@ -31,6 +32,7 @@ #include #include "grpc/agent.grpc.pb.h" #include "grpc/agent.pb.h" +#include "third_party/mujoco_mpc/grpc/agent.proto.h" #include "mjpc/tasks/tasks.h" namespace agent_grpc { @@ -52,8 +54,6 @@ class AgentServiceTest : public ::testing::Test { void TearDown() override { server->Shutdown(); } void RunAndCheckInit(std::string_view task_id, mjModel* model) { - grpc::ClientContext init_context; - agent::InitRequest init_request; init_request.set_task_id(task_id); @@ -65,11 +65,26 @@ class AgentServiceTest : public ::testing::Test { init_request.set_allocated_model(nullptr); } - agent::InitResponse init_response; - grpc::Status init_status = - stub->Init(&init_context, init_request, &init_response); + SendRequest(&Agent::Stub::Init, init_request); + } + + // Sends a request, validates the status code and returns the response + template + Res SendRequest(grpc::Status (Agent::Stub::*method)(grpc::ClientContext*, + const Req&, Res*), + const Req& request) { + grpc::ClientContext context; + Res response; + grpc::Status status = (stub.get()->*method)(&context, request, &response); + EXPECT_TRUE(status.ok()) << status.error_message(); + return response; + } - EXPECT_TRUE(init_status.ok()) << init_status.error_message(); + // an overload which constructs an empty request + template + Res SendRequest(grpc::Status (Agent::Stub::*method)(grpc::ClientContext*, + const Req&, Res*)) { + return SendRequest(method, Req()); } std::unique_ptr agent_service; @@ -88,17 +103,11 @@ TEST_F(AgentServiceTest, Init_WithModel) { TEST_F(AgentServiceTest, SetState_Works) { RunAndCheckInit("Cartpole", nullptr); - grpc::ClientContext set_state_context; - agent::SetStateRequest set_state_request; agent::State* state = new agent::State(); state->set_time(0.0); set_state_request.set_allocated_state(state); - agent::SetStateResponse set_state_response; - grpc::Status set_state_status = stub->SetState( - &set_state_context, set_state_request, &set_state_response); - - EXPECT_TRUE(set_state_status.ok()) << set_state_status.error_message(); + SendRequest(&Agent::Stub::SetState, set_state_request); } TEST_F(AgentServiceTest, SetState_WrongSize) { @@ -124,56 +133,124 @@ TEST_F(AgentServiceTest, PlannerStep_ProducesNonzeroAction) { RunAndCheckInit("Cartpole", nullptr); { - grpc::ClientContext context; agent::SetTaskParametersRequest request; (*request.mutable_parameters())["Goal"].set_numeric(-1.0); - agent::SetTaskParametersResponse response; - grpc::Status status = stub->SetTaskParameters(&context, request, &response); + SendRequest(&Agent::Stub::SetTaskParameters, request); + } - EXPECT_TRUE(status.ok()); + SendRequest(&Agent::Stub::PlannerStep); + + { + agent::GetActionResponse response = SendRequest(&Agent::Stub::GetAction); + + ASSERT_EQ(response.action().size(), 1); + EXPECT_TRUE(response.action()[0] != 0.0); } +} + +TEST_F(AgentServiceTest, ActionAveragingGivesDifferentResult) { + RunAndCheckInit("Cartpole", nullptr); { - grpc::ClientContext context; + agent::SetTaskParametersRequest request; + (*request.mutable_parameters())["Goal"].set_numeric(-1.0); + SendRequest(&Agent::Stub::SetTaskParameters, request); + } - agent::PlannerStepRequest request; - agent::PlannerStepResponse response; - grpc::Status status = stub->PlannerStep(&context, request, &response); + SendRequest(&Agent::Stub::PlannerStep); - EXPECT_TRUE(status.ok()) << status.error_message(); + double action_without_averaging; + { + agent::GetActionResponse response = SendRequest(&Agent::Stub::GetAction); + ASSERT_EQ(response.action().size(), 1); + EXPECT_TRUE(response.action()[0] != 0.0); + action_without_averaging = response.action()[0]; } + double action_with_averaging; { grpc::ClientContext context; agent::GetActionRequest request; - agent::GetActionResponse response; - grpc::Status status = stub->GetAction(&context, request, &response); - - EXPECT_TRUE(status.ok()) << status.error_message(); - EXPECT_EQ(response.action().size(), 1); + request.set_averaging_duration(1.0); + agent::GetActionResponse response = + SendRequest(&Agent::Stub::GetAction, request); + ASSERT_EQ(response.action().size(), 1); EXPECT_TRUE(response.action()[0] != 0.0); + action_with_averaging = response.action()[0]; } + EXPECT_NE(action_with_averaging, action_without_averaging); } -TEST_F(AgentServiceTest, Step_AdvancesTime) { - RunAndCheckInit("Cartpole", nullptr); +TEST_F(AgentServiceTest, NominalActionIndependentOfState) { + // Pick a task that uses iLQG, where there is normally a feedback term on the + // policy. + RunAndCheckInit("Swimmer", nullptr); + + SendRequest(&Agent::Stub::PlannerStep); - agent::State initial_state; + double nominal_action1; { - grpc::ClientContext context; - agent::GetStateRequest request; - agent::GetStateResponse response; - EXPECT_TRUE(stub->GetState(&context, request, &response).ok()); - initial_state = response.state(); + agent::GetActionRequest request; + request.set_averaging_duration(1.0); + request.set_nominal_action(true); + request.set_time(0.01); + agent::GetActionResponse response = + SendRequest(&Agent::Stub::GetAction, request); + EXPECT_GE(response.action().size(), 1); + nominal_action1 = response.action()[0]; + EXPECT_NE(nominal_action1, 0.0); } + // Set a new state + { + agent::SetStateRequest request; + static constexpr int kSwimmerDofs = 8; + for (int i = 0; i < kSwimmerDofs; i++) { + request.mutable_state()->mutable_qpos()->Add(0.1); + } + SendRequest(&Agent::Stub::SetState, request); + } + + double nominal_action2; + { + agent::GetActionRequest request; + request.set_averaging_duration(1.0); + request.set_nominal_action(true); + request.set_time(0.01); + agent::GetActionResponse response = + SendRequest(&Agent::Stub::GetAction, request); + EXPECT_GE(response.action().size(), 1); + nominal_action2 = response.action()[0]; + } + + double feedback_action; + { + agent::GetActionRequest request; + request.set_averaging_duration(1.0); + request.set_nominal_action(false); + request.set_time(0.01); + agent::GetActionResponse response = + SendRequest(&Agent::Stub::GetAction, request); + EXPECT_GE(response.action().size(), 1); + feedback_action = response.action()[0]; + } + + EXPECT_EQ(nominal_action1, nominal_action2) + << "nominal action should be the same"; + EXPECT_NE(nominal_action1, feedback_action) + << "feedback action should be different from the nominal"; +} + +TEST_F(AgentServiceTest, Step_AdvancesTime) { + RunAndCheckInit("Cartpole", nullptr); + + agent::State initial_state = SendRequest(&Agent::Stub::GetState).state(); + { - grpc::ClientContext context; agent::SetTaskParametersRequest request; (*request.mutable_parameters())["Goal"].set_numeric(-1.0); - agent::SetTaskParametersResponse response; - EXPECT_TRUE(stub->SetTaskParameters(&context, request, &response).ok()); + SendRequest(&Agent::Stub::SetTaskParameters, request); } { @@ -185,33 +262,12 @@ TEST_F(AgentServiceTest, Step_AdvancesTime) { EXPECT_EQ(response.parameters().at("Goal").numeric(), -1.0); } - { - grpc::ClientContext context; - agent::PlannerStepRequest request; - agent::PlannerStepResponse response; - grpc::Status status = stub->PlannerStep(&context, request, &response); - - EXPECT_TRUE(status.ok()) << status.error_message(); - } - + SendRequest(&Agent::Stub::PlannerStep); for (int i = 0; i < 3; i++) { - grpc::ClientContext context; - agent::StepRequest request; - agent::StepResponse response; - grpc::Status status = stub->Step(&context, request, &response); - - EXPECT_TRUE(status.ok()) << status.error_message(); + SendRequest(&Agent::Stub::Step); } - agent::State final_state; - { - grpc::ClientContext context; - agent::GetStateRequest request; - agent::GetStateResponse response; - grpc::Status status = stub->GetState(&context, request, &response); - EXPECT_TRUE(status.ok()) << status.error_message(); - final_state = response.state(); - } + agent::State final_state = SendRequest(&Agent::Stub::GetState).state(); double cartpole_timestep = 0.001; EXPECT_DOUBLE_EQ(final_state.time() - initial_state.time(), 3 * cartpole_timestep); @@ -223,34 +279,11 @@ TEST_F(AgentServiceTest, Step_CallsTransition) { RunAndCheckInit("Particle", nullptr); - agent::State initial_state; - { - grpc::ClientContext context; - agent::GetStateRequest request; - agent::GetStateResponse response; - grpc::Status status = stub->GetState(&context, request, &response); - EXPECT_TRUE(status.ok()) << status.error_message(); - initial_state = response.state(); - } + agent::State initial_state = SendRequest(&Agent::Stub::GetState).state(); - { - grpc::ClientContext context; - agent::StepRequest request; - agent::StepResponse response; - grpc::Status status = stub->Step(&context, request, &response); + SendRequest(&Agent::Stub::Step); - EXPECT_TRUE(status.ok()) << status.error_message(); - } - - agent::State final_state; - { - grpc::ClientContext context; - agent::GetStateRequest request; - agent::GetStateResponse response; - grpc::Status status = stub->GetState(&context, request, &response); - EXPECT_TRUE(status.ok()) << status.error_message(); - final_state = response.state(); - } + agent::State final_state = SendRequest(&Agent::Stub::GetState).state(); EXPECT_NE(final_state.mocap_pos()[0], initial_state.mocap_pos()[0]) << "mocap_pos stayed constant. Was Transition called?"; } @@ -258,27 +291,16 @@ TEST_F(AgentServiceTest, Step_CallsTransition) { TEST_F(AgentServiceTest, SetTaskParameters_Numeric) { RunAndCheckInit("Cartpole", nullptr); - grpc::ClientContext context; - agent::SetTaskParametersRequest request; (*request.mutable_parameters())["Goal"].set_numeric(16.0); - agent::SetTaskParametersResponse response; - grpc::Status status = stub->SetTaskParameters(&context, request, &response); - - EXPECT_TRUE(status.ok()); + SendRequest(&Agent::Stub::SetTaskParameters, request); } TEST_F(AgentServiceTest, SetTaskParameters_Select) { RunAndCheckInit("Quadruped Flat", nullptr); - - grpc::ClientContext context; - agent::SetTaskParametersRequest request; (*request.mutable_parameters())["Gait"].set_selection("Trot"); - agent::SetTaskParametersResponse response; - grpc::Status status = stub->SetTaskParameters(&context, request, &response); - - EXPECT_TRUE(status.ok()) << status.error_message(); + SendRequest(&Agent::Stub::SetTaskParameters, request); } TEST_F(AgentServiceTest, SetCostWeights_Works) { diff --git a/grpc/grpc_agent_util.cc b/grpc/grpc_agent_util.cc index 2d7734691..88350a367 100644 --- a/grpc/grpc_agent_util.cc +++ b/grpc/grpc_agent_util.cc @@ -142,39 +142,71 @@ grpc::Status SetState(const SetStateRequest* request, mjpc::Agent* agent, #undef CHECK_SIZE +namespace { +// TODO(nimrod): make planner a const reference +std::vector AverageAction(mjpc::Planner& planner, const mjModel* model, + bool nominal_action, mjData* rollout_data, + mjpc::State* rollout_state, double time, + double averaging_duration) { + int nu = model->nu; + std::vector ret(nu, 0); + int nactions = 0; + double end_time = time + averaging_duration; + + if (nominal_action) { + std::vector action(nu, 0); + while (time < end_time) { + planner.ActionFromPolicy(action.data(), /*state=*/nullptr, time); + mju_addTo(ret.data(), action.data(), nu); + time += model->opt.timestep; + nactions++; + } + } else { + rollout_data->time = time; + while (rollout_data->time <= end_time) { + rollout_state->Set(model, rollout_data); + const double* state = rollout_state->state().data(); + planner.ActionFromPolicy(rollout_data->ctrl, state, + rollout_data->time); + mju_addTo(ret.data(), rollout_data->ctrl, nu); + mj_step(model, rollout_data); + nactions++; + } + } + mju_scl(ret.data(), ret.data(), 1.0 / nactions, nu); + return ret; +} + +} // namespace grpc::Status GetAction(const GetActionRequest* request, const mjpc::Agent* agent, const mjModel* model, mjData* rollout_data, mjpc::State* rollout_state, GetActionResponse* response) { - int nu = agent->GetActionDim(); - std::vector ret = std::vector(nu, 0); - double time = request->has_time() ? request->time() : agent->ActiveState().time(); if (request->averaging_duration() > 0) { - agent->ActiveState().CopyTo(model, rollout_data); - rollout_data->time = time; - double end_time = time + request->averaging_duration(); - int nactions = 0; - while (rollout_data->time <= end_time) { + if (request->nominal_action()) { + rollout_data = nullptr; + rollout_state = nullptr; + } else { + agent->ActiveState().CopyTo(model, rollout_data); rollout_state->Set(model, rollout_data); - agent->ActivePlanner().ActionFromPolicy(rollout_data->ctrl, - rollout_state->state().data(), - rollout_data->time); - mju_addTo(ret.data(), rollout_data->ctrl, nu); - mj_step(model, rollout_data); - nactions++; } - mju_scl(ret.data(), ret.data(), 1.0 / nactions, nu); + std::vector ret = AverageAction(agent->ActivePlanner(), model, + request->nominal_action(), rollout_data, rollout_state, + time, request->averaging_duration()); + response->mutable_action()->Assign(ret.begin(), ret.end()); } else { - agent->ActivePlanner().ActionFromPolicy( - ret.data(), &agent->ActiveState().state()[0], time); + std::vector ret(model->nu, 0); + const double* state = request->nominal_action() + ? nullptr + : agent->ActiveState().state().data(); + agent->ActivePlanner().ActionFromPolicy(ret.data(), state, time); + response->mutable_action()->Assign(ret.begin(), ret.end()); } - response->mutable_action()->Assign(ret.begin(), ret.end()); - return grpc::Status::OK; } diff --git a/mjpc/planners/gradient/policy.h b/mjpc/planners/gradient/policy.h index 11413529b..390169b38 100644 --- a/mjpc/planners/gradient/policy.h +++ b/mjpc/planners/gradient/policy.h @@ -40,6 +40,7 @@ class GradientPolicy : public Policy { void Reset(int horizon) override; // compute action from policy + // state is not used void Action(double* action, const double* state, double time) const override; // copy policy diff --git a/mjpc/planners/ilqg/planner.h b/mjpc/planners/ilqg/planner.h index 92a48d85d..d60a6f3ec 100644 --- a/mjpc/planners/ilqg/planner.h +++ b/mjpc/planners/ilqg/planner.h @@ -53,6 +53,7 @@ class iLQGPlanner : public Planner { void NominalTrajectory(int horizon, ThreadPool& pool) override; // set action from policy + // if state == nullptr, return the nominal action without a feedback term void ActionFromPolicy(double* action, const double* state, double time, bool use_previous = false) override; diff --git a/mjpc/planners/ilqg/policy.cc b/mjpc/planners/ilqg/policy.cc index b18b9ee9d..facc5161b 100644 --- a/mjpc/planners/ilqg/policy.cc +++ b/mjpc/planners/ilqg/policy.cc @@ -96,47 +96,56 @@ void iLQGPolicy::Action(double* action, const double* state, ZeroInterpolation(action, time, trajectory.times, trajectory.actions.data(), model->nu, trajectory.horizon - 1); - // state reference - ZeroInterpolation(state_interp.data(), time, trajectory.times, - trajectory.states.data(), dim_state, trajectory.horizon); - - // gains - ZeroInterpolation(feedback_gain_scratch.data(), time, trajectory.times, - feedback_gain.data(), dim_action * dim_state_derivative, - trajectory.horizon - 1); + if (state) { + // state reference + ZeroInterpolation(state_interp.data(), time, trajectory.times, + trajectory.states.data(), dim_state, + trajectory.horizon); + + // gains + ZeroInterpolation(feedback_gain_scratch.data(), time, trajectory.times, + feedback_gain.data(), dim_action * dim_state_derivative, + trajectory.horizon - 1); + } } else if (representation == 1) { // action LinearInterpolation(action, time, trajectory.times, trajectory.actions.data(), model->nu, trajectory.horizon - 1); - // state - LinearInterpolation(state_interp.data(), time, trajectory.times, - trajectory.states.data(), dim_state, - trajectory.horizon); + if (state) { + // state + LinearInterpolation(state_interp.data(), time, trajectory.times, + trajectory.states.data(), dim_state, + trajectory.horizon); - // normalize quaternions - mj_normalizeQuat(model, state_interp.data()); + // normalize quaternions + mj_normalizeQuat(model, state_interp.data()); - LinearInterpolation(feedback_gain_scratch.data(), time, trajectory.times, - feedback_gain.data(), dim_action * dim_state_derivative, - trajectory.horizon - 1); + LinearInterpolation(feedback_gain_scratch.data(), time, trajectory.times, + feedback_gain.data(), + dim_action * dim_state_derivative, + trajectory.horizon - 1); + } } else if (representation == 2) { // action CubicInterpolation(action, time, trajectory.times, trajectory.actions.data(), model->nu, trajectory.horizon - 1); - // state - CubicInterpolation(state_interp.data(), time, trajectory.times, - trajectory.states.data(), dim_state, trajectory.horizon); + if (state) { + // state + CubicInterpolation(state_interp.data(), time, trajectory.times, + trajectory.states.data(), dim_state, + trajectory.horizon); - // normalize quaternions - mj_normalizeQuat(model, state_interp.data()); + // normalize quaternions + mj_normalizeQuat(model, state_interp.data()); - CubicInterpolation(feedback_gain_scratch.data(), time, trajectory.times, - feedback_gain.data(), dim_action * dim_state_derivative, - trajectory.horizon - 1); + CubicInterpolation(feedback_gain_scratch.data(), time, trajectory.times, + feedback_gain.data(), dim_action * dim_state_derivative, + trajectory.horizon - 1); + } } // add feedback diff --git a/mjpc/planners/ilqg/policy.h b/mjpc/planners/ilqg/policy.h index e6260c6d7..942e1dfad 100644 --- a/mjpc/planners/ilqg/policy.h +++ b/mjpc/planners/ilqg/policy.h @@ -39,6 +39,7 @@ class iLQGPolicy : public Policy { void Reset(int horizon) override; // set action from policy + // if state == nullptr, return the nominal action without a feedback term void Action(double* action, const double* state, double time) const override; // copy policy diff --git a/mjpc/planners/policy.h b/mjpc/planners/policy.h index a1d06c83d..887358ff8 100644 --- a/mjpc/planners/policy.h +++ b/mjpc/planners/policy.h @@ -41,6 +41,8 @@ class Policy { virtual void Reset(int horizon) = 0; // set action from policy + // for policies that have a feedback term, passing nullptr for state turns + // the feedback term off and returns the nominal action for that time virtual void Action(double* action, const double* state, double time) const = 0; }; diff --git a/python/mujoco_mpc/agent.py b/python/mujoco_mpc/agent.py index 7b8c52226..36fa64583 100644 --- a/python/mujoco_mpc/agent.py +++ b/python/mujoco_mpc/agent.py @@ -193,7 +193,10 @@ def get_state(self) -> agent_pb2.State: return self.stub.GetState(agent_pb2.GetStateRequest()).state def get_action( - self, time: Optional[float] = None, averaging_duration: float = 0 + self, + time: Optional[float] = None, + averaging_duration: float = 0, + nominal_action: bool = False, ) -> np.ndarray: """Return latest `action` from the `Agent`'s planner. @@ -201,12 +204,15 @@ def get_action( time: `data.time`, i.e. the simulation time. averaging_duration: the duration over which actions should be averaged (e.g. the control timestep). + nominal_action: if True, don't apply feedback terms in the policy Returns: action: `Agent`'s planner's latest action. """ get_action_request = agent_pb2.GetActionRequest( - time=time, averaging_duration=averaging_duration + time=time, + averaging_duration=averaging_duration, + nominal_action=nominal_action, ) get_action_response = self.stub.GetAction(get_action_request) return np.array(get_action_response.action) diff --git a/python/mujoco_mpc/agent_test.py b/python/mujoco_mpc/agent_test.py index c6b99ecc6..96a21492d 100644 --- a/python/mujoco_mpc/agent_test.py +++ b/python/mujoco_mpc/agent_test.py @@ -15,6 +15,7 @@ from absl.testing import absltest +from absl.testing import parameterized import grpc import mujoco from mujoco_mpc import agent as agent_lib @@ -39,7 +40,7 @@ def environment_reset(model, data): return get_observation(model, data) -class AgentTest(absltest.TestCase): +class AgentTest(parameterized.TestCase): def test_set_task_parameters(self): model_path = ( @@ -84,7 +85,8 @@ def test_step_env_with_planner(self): self.assertFalse((observations == 0).all()) self.assertFalse((actions == 0).all()) - def test_action_averaging_doesnt_change_state(self): + @parameterized.parameters({"nominal": False}, {"nominal": True}) + def test_action_averaging_doesnt_change_state(self, nominal): # when calling get_action with action averaging, the Agent needs to roll # out physics, but the API should be implemented not to mutate the state model_path = ( @@ -108,7 +110,9 @@ def test_action_averaging_doesnt_change_state(self): mocap_quat=data.mocap_quat, userdata=data.userdata, ) - agent.get_action(averaging_duration=control_timestep) + agent.get_action( + averaging_duration=control_timestep, nominal_action=nominal + ) state_after = agent.get_state() self.assertEqual(data.time, state_after.time) np.testing.assert_allclose(data.qpos, state_after.qpos)