Skip to content

Commit

Permalink
Merge pull request #302 from thowell/estimator_update
Browse files Browse the repository at this point in the history
Add mode input to Estimator::Update
  • Loading branch information
erez-tom authored Mar 3, 2024
2 parents c1b15e5 + 15221bd commit a01e100
Show file tree
Hide file tree
Showing 10 changed files with 20 additions and 12 deletions.
2 changes: 1 addition & 1 deletion mjpc/estimators/batch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ void Batch::Reset(const mjData* data) {
}

// update
void Batch::Update(const double* ctrl, const double* sensor) {
void Batch::Update(const double* ctrl, const double* sensor, int mode) {
// start timer
auto start = std::chrono::steady_clock::now();

Expand Down
2 changes: 1 addition & 1 deletion mjpc/estimators/batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Batch : public Direct, public Estimator {
void Reset(const mjData* data = nullptr) override;

// update
void Update(const double* ctrl, const double* sensor) override;
void Update(const double* ctrl, const double* sensor, int mode = 0) override;

// get state
double* State() override { return state.data(); };
Expand Down
5 changes: 3 additions & 2 deletions mjpc/estimators/estimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class Estimator {

// TODO(etom): time input
// update
virtual void Update(const double* ctrl, const double* sensor) = 0;
virtual void Update(const double* ctrl, const double* sensor,
int mode = 0) = 0;

// get state
virtual double* State() = 0;
Expand Down Expand Up @@ -197,7 +198,7 @@ class GroundTruth : public Estimator {
}

// update
void Update(const double* ctrl, const double* sensor) override {
void Update(const double* ctrl, const double* sensor, int mode = 0) override {
mju_copy(data_->qpos, state.data(), model->nq);
mju_copy(data_->qvel, state.data() + model->nq, model->nv);
mju_copy(data_->act, state.data() + model->nq + model->nv, model->na);
Expand Down
6 changes: 3 additions & 3 deletions mjpc/estimators/kalman.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ class Kalman : public Estimator {
void UpdatePrediction();

// update
void Update(const double* ctrl, const double* sensor) override {
void Update(const double* ctrl, const double* sensor, int mode = 0) override {
// correct state with latest measurement
UpdateMeasurement(ctrl, sensor);
if (mode == 0 || mode == 1) UpdateMeasurement(ctrl, sensor);

// propagate state forward in time with model
UpdatePrediction();
if (mode == 0 || mode == 2) UpdatePrediction();

// set time
time = data_->time;
Expand Down
2 changes: 1 addition & 1 deletion mjpc/estimators/unscented.cc
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ void Unscented::SigmaCovariances() {
}

// unscented filter update
void Unscented::Update(const double* ctrl, const double* sensor) {
void Unscented::Update(const double* ctrl, const double* sensor, int mode) {
// start timer
auto start = std::chrono::steady_clock::now();

Expand Down
2 changes: 1 addition & 1 deletion mjpc/estimators/unscented.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Unscented : public Estimator {
void SigmaCovariances();

// update
void Update(const double* ctrl, const double* sensor) override;
void Update(const double* ctrl, const double* sensor, int mode = 0) override;

// quaternion means
void QuaternionMeans();
Expand Down
2 changes: 2 additions & 0 deletions mjpc/grpc/filter.proto
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ message ResetResponse {}
message UpdateRequest {
repeated double ctrl = 1 [packed = true];
repeated double sensor = 2 [packed = true];
optional int32 mode = 3;

}

message UpdateResponse {}
Expand Down
5 changes: 3 additions & 2 deletions mjpc/grpc/filter_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,9 @@ grpc::Status FilterService::Update(grpc::ServerContext* context,
return {grpc::StatusCode::FAILED_PRECONDITION, "Init not called."};
}

// measurement update
filters_[filter_]->Update(request->ctrl().data(), request->sensor().data());
// update
filters_[filter_]->Update(request->ctrl().data(), request->sensor().data(),
request->mode());

return grpc::Status::OK;
}
Expand Down
2 changes: 2 additions & 0 deletions python/mujoco_mpc/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,13 @@ def update(
self,
ctrl: Optional[npt.ArrayLike] = [],
sensor: Optional[npt.ArrayLike] = [],
mode: Optional[int] = 0,
):
# request
request = filter_pb2.UpdateRequest(
ctrl=ctrl,
sensor=sensor,
mode=mode,
)

# response
Expand Down
4 changes: 3 additions & 1 deletion python/mujoco_mpc/filter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ def test_updates(self):
self.assertLess(np.linalg.norm(noise["process"] - process), 1.0e-5)
self.assertLess(np.linalg.norm(noise["sensor"] - sensor), 1.0e-5)

# measurement update
# update
ctrl = np.random.normal(scale=1.0, size=model.nu)
sensor = np.random.normal(scale=1.0, size=model.nsensordata)
filter.update(ctrl=ctrl, sensor=sensor)
filter.update(ctrl=ctrl, sensor=sensor, mode=0)
filter.update(ctrl=ctrl, sensor=sensor, mode=1)

# TODO(etom): more tests

Expand Down

0 comments on commit a01e100

Please sign in to comment.