diff --git a/grpc/agent.proto b/grpc/agent.proto index 514f415f7..4fb9f02d2 100644 --- a/grpc/agent.proto +++ b/grpc/agent.proto @@ -93,6 +93,11 @@ message GetActionRequest { // and actions will be averaged over that period. // During the rollout, the task's Transition will not be called. optional float averaging_duration = 2; + + // For planners that use feedback terms (iLQG), if true, return the nominal + // action for the given time rather than applying feedback terms on the + // current state. For the sampling planner this has no effect. + optional bool nominal_action = 3; } message GetActionResponse { diff --git a/grpc/agent_service_test.cc b/grpc/agent_service_test.cc index 1fd03f51f..f9d0cbbe0 100644 --- a/grpc/agent_service_test.cc +++ b/grpc/agent_service_test.cc @@ -22,6 +22,7 @@ #include "testing/base/public/gmock.h" #include "testing/base/public/gunit.h" #include +#include #include #include #include @@ -31,6 +32,7 @@ #include #include "grpc/agent.grpc.pb.h" #include "grpc/agent.pb.h" +#include "third_party/mujoco_mpc/grpc/agent.proto.h" #include "mjpc/tasks/tasks.h" namespace agent_grpc { @@ -52,8 +54,6 @@ class AgentServiceTest : public ::testing::Test { void TearDown() override { server->Shutdown(); } void RunAndCheckInit(std::string_view task_id, mjModel* model) { - grpc::ClientContext init_context; - agent::InitRequest init_request; init_request.set_task_id(task_id); @@ -65,11 +65,26 @@ class AgentServiceTest : public ::testing::Test { init_request.set_allocated_model(nullptr); } - agent::InitResponse init_response; - grpc::Status init_status = - stub->Init(&init_context, init_request, &init_response); + SendRequest(&Agent::Stub::Init, init_request); + } + + // Sends a request, validates the status code and returns the response + template + Res SendRequest(grpc::Status (Agent::Stub::*method)(grpc::ClientContext*, + const Req&, Res*), + const Req& request) { + grpc::ClientContext context; + Res response; + grpc::Status status = (stub.get()->*method)(&context, request, &response); + EXPECT_TRUE(status.ok()) << status.error_message(); + return response; + } - EXPECT_TRUE(init_status.ok()) << init_status.error_message(); + // an overload which constructs an empty request + template + Res SendRequest(grpc::Status (Agent::Stub::*method)(grpc::ClientContext*, + const Req&, Res*)) { + return SendRequest(method, Req()); } std::unique_ptr agent_service; @@ -88,17 +103,11 @@ TEST_F(AgentServiceTest, Init_WithModel) { TEST_F(AgentServiceTest, SetState_Works) { RunAndCheckInit("Cartpole", nullptr); - grpc::ClientContext set_state_context; - agent::SetStateRequest set_state_request; agent::State* state = new agent::State(); state->set_time(0.0); set_state_request.set_allocated_state(state); - agent::SetStateResponse set_state_response; - grpc::Status set_state_status = stub->SetState( - &set_state_context, set_state_request, &set_state_response); - - EXPECT_TRUE(set_state_status.ok()) << set_state_status.error_message(); + SendRequest(&Agent::Stub::SetState, set_state_request); } TEST_F(AgentServiceTest, SetState_WrongSize) { @@ -124,56 +133,124 @@ TEST_F(AgentServiceTest, PlannerStep_ProducesNonzeroAction) { RunAndCheckInit("Cartpole", nullptr); { - grpc::ClientContext context; agent::SetTaskParametersRequest request; (*request.mutable_parameters())["Goal"].set_numeric(-1.0); - agent::SetTaskParametersResponse response; - grpc::Status status = stub->SetTaskParameters(&context, request, &response); + SendRequest(&Agent::Stub::SetTaskParameters, request); + } - EXPECT_TRUE(status.ok()); + SendRequest(&Agent::Stub::PlannerStep); + + { + agent::GetActionResponse response = SendRequest(&Agent::Stub::GetAction); + + ASSERT_EQ(response.action().size(), 1); + EXPECT_TRUE(response.action()[0] != 0.0); } +} + +TEST_F(AgentServiceTest, ActionAveragingGivesDifferentResult) { + RunAndCheckInit("Cartpole", nullptr); { - grpc::ClientContext context; + agent::SetTaskParametersRequest request; + (*request.mutable_parameters())["Goal"].set_numeric(-1.0); + SendRequest(&Agent::Stub::SetTaskParameters, request); + } - agent::PlannerStepRequest request; - agent::PlannerStepResponse response; - grpc::Status status = stub->PlannerStep(&context, request, &response); + SendRequest(&Agent::Stub::PlannerStep); - EXPECT_TRUE(status.ok()) << status.error_message(); + double action_without_averaging; + { + agent::GetActionResponse response = SendRequest(&Agent::Stub::GetAction); + ASSERT_EQ(response.action().size(), 1); + EXPECT_TRUE(response.action()[0] != 0.0); + action_without_averaging = response.action()[0]; } + double action_with_averaging; { grpc::ClientContext context; agent::GetActionRequest request; - agent::GetActionResponse response; - grpc::Status status = stub->GetAction(&context, request, &response); - - EXPECT_TRUE(status.ok()) << status.error_message(); - EXPECT_EQ(response.action().size(), 1); + request.set_averaging_duration(1.0); + agent::GetActionResponse response = + SendRequest(&Agent::Stub::GetAction, request); + ASSERT_EQ(response.action().size(), 1); EXPECT_TRUE(response.action()[0] != 0.0); + action_with_averaging = response.action()[0]; } + EXPECT_NE(action_with_averaging, action_without_averaging); } -TEST_F(AgentServiceTest, Step_AdvancesTime) { - RunAndCheckInit("Cartpole", nullptr); +TEST_F(AgentServiceTest, NominalActionIndependentOfState) { + // Pick a task that uses iLQG, where there is normally a feedback term on the + // policy. + RunAndCheckInit("Swimmer", nullptr); + + SendRequest(&Agent::Stub::PlannerStep); - agent::State initial_state; + double nominal_action1; { - grpc::ClientContext context; - agent::GetStateRequest request; - agent::GetStateResponse response; - EXPECT_TRUE(stub->GetState(&context, request, &response).ok()); - initial_state = response.state(); + agent::GetActionRequest request; + request.set_averaging_duration(1.0); + request.set_nominal_action(true); + request.set_time(0.01); + agent::GetActionResponse response = + SendRequest(&Agent::Stub::GetAction, request); + EXPECT_GE(response.action().size(), 1); + nominal_action1 = response.action()[0]; + EXPECT_NE(nominal_action1, 0.0); } + // Set a new state + { + agent::SetStateRequest request; + static constexpr int kSwimmerDofs = 8; + for (int i = 0; i < kSwimmerDofs; i++) { + request.mutable_state()->mutable_qpos()->Add(0.1); + } + SendRequest(&Agent::Stub::SetState, request); + } + + double nominal_action2; + { + agent::GetActionRequest request; + request.set_averaging_duration(1.0); + request.set_nominal_action(true); + request.set_time(0.01); + agent::GetActionResponse response = + SendRequest(&Agent::Stub::GetAction, request); + EXPECT_GE(response.action().size(), 1); + nominal_action2 = response.action()[0]; + } + + double feedback_action; + { + agent::GetActionRequest request; + request.set_averaging_duration(1.0); + request.set_nominal_action(false); + request.set_time(0.01); + agent::GetActionResponse response = + SendRequest(&Agent::Stub::GetAction, request); + EXPECT_GE(response.action().size(), 1); + feedback_action = response.action()[0]; + } + + EXPECT_EQ(nominal_action1, nominal_action2) + << "nominal action should be the same"; + EXPECT_NE(nominal_action1, feedback_action) + << "feedback action should be different from the nominal"; +} + +TEST_F(AgentServiceTest, Step_AdvancesTime) { + RunAndCheckInit("Cartpole", nullptr); + + agent::State initial_state = SendRequest(&Agent::Stub::GetState).state(); + { - grpc::ClientContext context; agent::SetTaskParametersRequest request; (*request.mutable_parameters())["Goal"].set_numeric(-1.0); - agent::SetTaskParametersResponse response; - EXPECT_TRUE(stub->SetTaskParameters(&context, request, &response).ok()); + SendRequest(&Agent::Stub::SetTaskParameters, request); } { @@ -185,33 +262,12 @@ TEST_F(AgentServiceTest, Step_AdvancesTime) { EXPECT_EQ(response.parameters().at("Goal").numeric(), -1.0); } - { - grpc::ClientContext context; - agent::PlannerStepRequest request; - agent::PlannerStepResponse response; - grpc::Status status = stub->PlannerStep(&context, request, &response); - - EXPECT_TRUE(status.ok()) << status.error_message(); - } - + SendRequest(&Agent::Stub::PlannerStep); for (int i = 0; i < 3; i++) { - grpc::ClientContext context; - agent::StepRequest request; - agent::StepResponse response; - grpc::Status status = stub->Step(&context, request, &response); - - EXPECT_TRUE(status.ok()) << status.error_message(); + SendRequest(&Agent::Stub::Step); } - agent::State final_state; - { - grpc::ClientContext context; - agent::GetStateRequest request; - agent::GetStateResponse response; - grpc::Status status = stub->GetState(&context, request, &response); - EXPECT_TRUE(status.ok()) << status.error_message(); - final_state = response.state(); - } + agent::State final_state = SendRequest(&Agent::Stub::GetState).state(); double cartpole_timestep = 0.001; EXPECT_DOUBLE_EQ(final_state.time() - initial_state.time(), 3 * cartpole_timestep); @@ -223,34 +279,11 @@ TEST_F(AgentServiceTest, Step_CallsTransition) { RunAndCheckInit("Particle", nullptr); - agent::State initial_state; - { - grpc::ClientContext context; - agent::GetStateRequest request; - agent::GetStateResponse response; - grpc::Status status = stub->GetState(&context, request, &response); - EXPECT_TRUE(status.ok()) << status.error_message(); - initial_state = response.state(); - } + agent::State initial_state = SendRequest(&Agent::Stub::GetState).state(); - { - grpc::ClientContext context; - agent::StepRequest request; - agent::StepResponse response; - grpc::Status status = stub->Step(&context, request, &response); + SendRequest(&Agent::Stub::Step); - EXPECT_TRUE(status.ok()) << status.error_message(); - } - - agent::State final_state; - { - grpc::ClientContext context; - agent::GetStateRequest request; - agent::GetStateResponse response; - grpc::Status status = stub->GetState(&context, request, &response); - EXPECT_TRUE(status.ok()) << status.error_message(); - final_state = response.state(); - } + agent::State final_state = SendRequest(&Agent::Stub::GetState).state(); EXPECT_NE(final_state.mocap_pos()[0], initial_state.mocap_pos()[0]) << "mocap_pos stayed constant. Was Transition called?"; } @@ -258,27 +291,16 @@ TEST_F(AgentServiceTest, Step_CallsTransition) { TEST_F(AgentServiceTest, SetTaskParameters_Numeric) { RunAndCheckInit("Cartpole", nullptr); - grpc::ClientContext context; - agent::SetTaskParametersRequest request; (*request.mutable_parameters())["Goal"].set_numeric(16.0); - agent::SetTaskParametersResponse response; - grpc::Status status = stub->SetTaskParameters(&context, request, &response); - - EXPECT_TRUE(status.ok()); + SendRequest(&Agent::Stub::SetTaskParameters, request); } TEST_F(AgentServiceTest, SetTaskParameters_Select) { RunAndCheckInit("Quadruped Flat", nullptr); - - grpc::ClientContext context; - agent::SetTaskParametersRequest request; (*request.mutable_parameters())["Gait"].set_selection("Trot"); - agent::SetTaskParametersResponse response; - grpc::Status status = stub->SetTaskParameters(&context, request, &response); - - EXPECT_TRUE(status.ok()) << status.error_message(); + SendRequest(&Agent::Stub::SetTaskParameters, request); } TEST_F(AgentServiceTest, SetCostWeights_Works) { diff --git a/grpc/grpc_agent_util.cc b/grpc/grpc_agent_util.cc index 2d7734691..88350a367 100644 --- a/grpc/grpc_agent_util.cc +++ b/grpc/grpc_agent_util.cc @@ -142,39 +142,71 @@ grpc::Status SetState(const SetStateRequest* request, mjpc::Agent* agent, #undef CHECK_SIZE +namespace { +// TODO(nimrod): make planner a const reference +std::vector AverageAction(mjpc::Planner& planner, const mjModel* model, + bool nominal_action, mjData* rollout_data, + mjpc::State* rollout_state, double time, + double averaging_duration) { + int nu = model->nu; + std::vector ret(nu, 0); + int nactions = 0; + double end_time = time + averaging_duration; + + if (nominal_action) { + std::vector action(nu, 0); + while (time < end_time) { + planner.ActionFromPolicy(action.data(), /*state=*/nullptr, time); + mju_addTo(ret.data(), action.data(), nu); + time += model->opt.timestep; + nactions++; + } + } else { + rollout_data->time = time; + while (rollout_data->time <= end_time) { + rollout_state->Set(model, rollout_data); + const double* state = rollout_state->state().data(); + planner.ActionFromPolicy(rollout_data->ctrl, state, + rollout_data->time); + mju_addTo(ret.data(), rollout_data->ctrl, nu); + mj_step(model, rollout_data); + nactions++; + } + } + mju_scl(ret.data(), ret.data(), 1.0 / nactions, nu); + return ret; +} + +} // namespace grpc::Status GetAction(const GetActionRequest* request, const mjpc::Agent* agent, const mjModel* model, mjData* rollout_data, mjpc::State* rollout_state, GetActionResponse* response) { - int nu = agent->GetActionDim(); - std::vector ret = std::vector(nu, 0); - double time = request->has_time() ? request->time() : agent->ActiveState().time(); if (request->averaging_duration() > 0) { - agent->ActiveState().CopyTo(model, rollout_data); - rollout_data->time = time; - double end_time = time + request->averaging_duration(); - int nactions = 0; - while (rollout_data->time <= end_time) { + if (request->nominal_action()) { + rollout_data = nullptr; + rollout_state = nullptr; + } else { + agent->ActiveState().CopyTo(model, rollout_data); rollout_state->Set(model, rollout_data); - agent->ActivePlanner().ActionFromPolicy(rollout_data->ctrl, - rollout_state->state().data(), - rollout_data->time); - mju_addTo(ret.data(), rollout_data->ctrl, nu); - mj_step(model, rollout_data); - nactions++; } - mju_scl(ret.data(), ret.data(), 1.0 / nactions, nu); + std::vector ret = AverageAction(agent->ActivePlanner(), model, + request->nominal_action(), rollout_data, rollout_state, + time, request->averaging_duration()); + response->mutable_action()->Assign(ret.begin(), ret.end()); } else { - agent->ActivePlanner().ActionFromPolicy( - ret.data(), &agent->ActiveState().state()[0], time); + std::vector ret(model->nu, 0); + const double* state = request->nominal_action() + ? nullptr + : agent->ActiveState().state().data(); + agent->ActivePlanner().ActionFromPolicy(ret.data(), state, time); + response->mutable_action()->Assign(ret.begin(), ret.end()); } - response->mutable_action()->Assign(ret.begin(), ret.end()); - return grpc::Status::OK; } diff --git a/mjpc/planners/gradient/policy.h b/mjpc/planners/gradient/policy.h index 11413529b..390169b38 100644 --- a/mjpc/planners/gradient/policy.h +++ b/mjpc/planners/gradient/policy.h @@ -40,6 +40,7 @@ class GradientPolicy : public Policy { void Reset(int horizon) override; // compute action from policy + // state is not used void Action(double* action, const double* state, double time) const override; // copy policy diff --git a/mjpc/planners/ilqg/planner.h b/mjpc/planners/ilqg/planner.h index 92a48d85d..d60a6f3ec 100644 --- a/mjpc/planners/ilqg/planner.h +++ b/mjpc/planners/ilqg/planner.h @@ -53,6 +53,7 @@ class iLQGPlanner : public Planner { void NominalTrajectory(int horizon, ThreadPool& pool) override; // set action from policy + // if state == nullptr, return the nominal action without a feedback term void ActionFromPolicy(double* action, const double* state, double time, bool use_previous = false) override; diff --git a/mjpc/planners/ilqg/policy.cc b/mjpc/planners/ilqg/policy.cc index b18b9ee9d..facc5161b 100644 --- a/mjpc/planners/ilqg/policy.cc +++ b/mjpc/planners/ilqg/policy.cc @@ -96,47 +96,56 @@ void iLQGPolicy::Action(double* action, const double* state, ZeroInterpolation(action, time, trajectory.times, trajectory.actions.data(), model->nu, trajectory.horizon - 1); - // state reference - ZeroInterpolation(state_interp.data(), time, trajectory.times, - trajectory.states.data(), dim_state, trajectory.horizon); - - // gains - ZeroInterpolation(feedback_gain_scratch.data(), time, trajectory.times, - feedback_gain.data(), dim_action * dim_state_derivative, - trajectory.horizon - 1); + if (state) { + // state reference + ZeroInterpolation(state_interp.data(), time, trajectory.times, + trajectory.states.data(), dim_state, + trajectory.horizon); + + // gains + ZeroInterpolation(feedback_gain_scratch.data(), time, trajectory.times, + feedback_gain.data(), dim_action * dim_state_derivative, + trajectory.horizon - 1); + } } else if (representation == 1) { // action LinearInterpolation(action, time, trajectory.times, trajectory.actions.data(), model->nu, trajectory.horizon - 1); - // state - LinearInterpolation(state_interp.data(), time, trajectory.times, - trajectory.states.data(), dim_state, - trajectory.horizon); + if (state) { + // state + LinearInterpolation(state_interp.data(), time, trajectory.times, + trajectory.states.data(), dim_state, + trajectory.horizon); - // normalize quaternions - mj_normalizeQuat(model, state_interp.data()); + // normalize quaternions + mj_normalizeQuat(model, state_interp.data()); - LinearInterpolation(feedback_gain_scratch.data(), time, trajectory.times, - feedback_gain.data(), dim_action * dim_state_derivative, - trajectory.horizon - 1); + LinearInterpolation(feedback_gain_scratch.data(), time, trajectory.times, + feedback_gain.data(), + dim_action * dim_state_derivative, + trajectory.horizon - 1); + } } else if (representation == 2) { // action CubicInterpolation(action, time, trajectory.times, trajectory.actions.data(), model->nu, trajectory.horizon - 1); - // state - CubicInterpolation(state_interp.data(), time, trajectory.times, - trajectory.states.data(), dim_state, trajectory.horizon); + if (state) { + // state + CubicInterpolation(state_interp.data(), time, trajectory.times, + trajectory.states.data(), dim_state, + trajectory.horizon); - // normalize quaternions - mj_normalizeQuat(model, state_interp.data()); + // normalize quaternions + mj_normalizeQuat(model, state_interp.data()); - CubicInterpolation(feedback_gain_scratch.data(), time, trajectory.times, - feedback_gain.data(), dim_action * dim_state_derivative, - trajectory.horizon - 1); + CubicInterpolation(feedback_gain_scratch.data(), time, trajectory.times, + feedback_gain.data(), dim_action * dim_state_derivative, + trajectory.horizon - 1); + } } // add feedback diff --git a/mjpc/planners/ilqg/policy.h b/mjpc/planners/ilqg/policy.h index e6260c6d7..942e1dfad 100644 --- a/mjpc/planners/ilqg/policy.h +++ b/mjpc/planners/ilqg/policy.h @@ -39,6 +39,7 @@ class iLQGPolicy : public Policy { void Reset(int horizon) override; // set action from policy + // if state == nullptr, return the nominal action without a feedback term void Action(double* action, const double* state, double time) const override; // copy policy diff --git a/mjpc/planners/policy.h b/mjpc/planners/policy.h index a1d06c83d..887358ff8 100644 --- a/mjpc/planners/policy.h +++ b/mjpc/planners/policy.h @@ -41,6 +41,8 @@ class Policy { virtual void Reset(int horizon) = 0; // set action from policy + // for policies that have a feedback term, passing nullptr for state turns + // the feedback term off and returns the nominal action for that time virtual void Action(double* action, const double* state, double time) const = 0; }; diff --git a/python/mujoco_mpc/agent.py b/python/mujoco_mpc/agent.py index 7b8c52226..36fa64583 100644 --- a/python/mujoco_mpc/agent.py +++ b/python/mujoco_mpc/agent.py @@ -193,7 +193,10 @@ def get_state(self) -> agent_pb2.State: return self.stub.GetState(agent_pb2.GetStateRequest()).state def get_action( - self, time: Optional[float] = None, averaging_duration: float = 0 + self, + time: Optional[float] = None, + averaging_duration: float = 0, + nominal_action: bool = False, ) -> np.ndarray: """Return latest `action` from the `Agent`'s planner. @@ -201,12 +204,15 @@ def get_action( time: `data.time`, i.e. the simulation time. averaging_duration: the duration over which actions should be averaged (e.g. the control timestep). + nominal_action: if True, don't apply feedback terms in the policy Returns: action: `Agent`'s planner's latest action. """ get_action_request = agent_pb2.GetActionRequest( - time=time, averaging_duration=averaging_duration + time=time, + averaging_duration=averaging_duration, + nominal_action=nominal_action, ) get_action_response = self.stub.GetAction(get_action_request) return np.array(get_action_response.action) diff --git a/python/mujoco_mpc/agent_test.py b/python/mujoco_mpc/agent_test.py index c6b99ecc6..96a21492d 100644 --- a/python/mujoco_mpc/agent_test.py +++ b/python/mujoco_mpc/agent_test.py @@ -15,6 +15,7 @@ from absl.testing import absltest +from absl.testing import parameterized import grpc import mujoco from mujoco_mpc import agent as agent_lib @@ -39,7 +40,7 @@ def environment_reset(model, data): return get_observation(model, data) -class AgentTest(absltest.TestCase): +class AgentTest(parameterized.TestCase): def test_set_task_parameters(self): model_path = ( @@ -84,7 +85,8 @@ def test_step_env_with_planner(self): self.assertFalse((observations == 0).all()) self.assertFalse((actions == 0).all()) - def test_action_averaging_doesnt_change_state(self): + @parameterized.parameters({"nominal": False}, {"nominal": True}) + def test_action_averaging_doesnt_change_state(self, nominal): # when calling get_action with action averaging, the Agent needs to roll # out physics, but the API should be implemented not to mutate the state model_path = ( @@ -108,7 +110,9 @@ def test_action_averaging_doesnt_change_state(self): mocap_quat=data.mocap_quat, userdata=data.userdata, ) - agent.get_action(averaging_duration=control_timestep) + agent.get_action( + averaging_duration=control_timestep, nominal_action=nominal + ) state_after = agent.get_state() self.assertEqual(data.time, state_after.time) np.testing.assert_allclose(data.qpos, state_after.qpos)