diff --git a/experimental/beacon_sim/BUILD b/experimental/beacon_sim/BUILD index 445c4637..c9771ed6 100644 --- a/experimental/beacon_sim/BUILD +++ b/experimental/beacon_sim/BUILD @@ -431,6 +431,7 @@ cc_library( srcs = ["correlated_beacons.cc"], deps = [ "@eigen//:eigen", + "//common:check", "//common:drake", "//common/math:combinations", "//common/math:n_choose_k", diff --git a/experimental/beacon_sim/correlated_beacons.cc b/experimental/beacon_sim/correlated_beacons.cc index 8cf741f6..da6ac85a 100644 --- a/experimental/beacon_sim/correlated_beacons.cc +++ b/experimental/beacon_sim/correlated_beacons.cc @@ -5,11 +5,10 @@ #include #include #include -#include -#include #include #include +#include "common/check.hh" #include "common/math/combinations.hh" #include "common/math/n_choose_k.hh" #include "drake/solvers/mathematical_program.h" @@ -144,34 +143,56 @@ double BeaconPotential::log_prob(const std::vector &present_beacons) const return log_prob(assignment); } -double BeaconPotential::log_prob(const std::unordered_map &assignment) const { +double BeaconPotential::log_prob(const std::unordered_map &assignment, + const bool allow_partial_assignment) const { const std::vector sorted_members = sorted_vector(members_); const std::vector keys = sorted_keys(assignment); std::vector missing_keys; std::set_difference(sorted_members.begin(), sorted_members.end(), keys.begin(), keys.end(), std::back_inserter(missing_keys)); - if (!missing_keys.empty()) { - std::ostringstream out; - out << "Missing keys from assignment {"; - bool is_first = true; - for (const auto &key : missing_keys) { - if (is_first) { - is_first = false; - } else { - out << ", "; + CHECK(allow_partial_assignment || missing_keys.empty(), + "partial assignment specified when not enabled", assignment, missing_keys, members()); + + const std::vector to_marginalize = missing_keys; + + std::unordered_map index_from_id; + for (int i = 0; i < static_cast(members_.size()); i++) { + index_from_id[members_.at(i)] = i; + } + + const auto sum_over_marginalized = [&to_marginalize, &index_from_id, + this](const Eigen::VectorXd &x) { + const int n = to_marginalize.size(); + std::vector terms; + terms.reserve(1 << n); + for (int num_present = 0; num_present <= n; num_present++) { + // For each number of present beacons + for (const auto &config : math::combinations(n, num_present)) { + // We have a different way of that many beacons being present + + // Set the element for the current config + Eigen::VectorXd curr_config = x; + for (const int to_marginalize_idx : config) { + const int marginal_id = to_marginalize.at(to_marginalize_idx); + const int x_idx = index_from_id[marginal_id]; + curr_config(x_idx) = 1; + } + + // Evaluate the log probability + terms.push_back(curr_config.transpose() * precision_ * curr_config - + log_normalizer_); } - out << key; } - out << "}"; - throw std::runtime_error(out.str()); - } + return logsumexp(terms); + }; - Eigen::VectorXd x(members_.size()); - for (int i = 0; i < static_cast(members_.size()); i++) { - x(i) = assignment.at(members_.at(i)); + Eigen::VectorXd config = Eigen::VectorXd::Zero(members_.size()); + for (const auto &[beacon_id, is_present] : assignment) { + config(index_from_id.at(beacon_id)) = is_present; } - return x.transpose() * precision_ * x - log_normalizer_; + + return sum_over_marginalized(config); } std::vector BeaconPotential::compute_log_marginals( diff --git a/experimental/beacon_sim/correlated_beacons.hh b/experimental/beacon_sim/correlated_beacons.hh index 6d5343b3..e575da31 100644 --- a/experimental/beacon_sim/correlated_beacons.hh +++ b/experimental/beacon_sim/correlated_beacons.hh @@ -32,7 +32,8 @@ class BeaconPotential { BeaconPotential(const Eigen::MatrixXd &information, const double log_normalizer, const std::vector &members); - double log_prob(const std::unordered_map &assignments) const; + double log_prob(const std::unordered_map &assignments, + const bool allow_partial_assignment = false) const; double log_prob(const std::vector &present_beacons) const; BeaconPotential operator*(const BeaconPotential &other) const; diff --git a/experimental/beacon_sim/correlated_beacons_python.cc b/experimental/beacon_sim/correlated_beacons_python.cc index 08d6379f..cf03140f 100644 --- a/experimental/beacon_sim/correlated_beacons_python.cc +++ b/experimental/beacon_sim/correlated_beacons_python.cc @@ -39,8 +39,10 @@ PYBIND11_MODULE(correlated_beacons_python, m) { py::class_(m, "BeaconPotential") .def(py::init>()) - .def("log_prob", py::overload_cast &>( - &BeaconPotential::log_prob, py::const_)) + .def("log_prob", + py::overload_cast &, bool>( + &BeaconPotential::log_prob, py::const_), + py::arg("assignment"), py::arg("allow_partial_assignment") = false) .def("log_prob", py::overload_cast &>(&BeaconPotential::log_prob, py::const_)) .def("__mul__", &BeaconPotential::operator*) diff --git a/experimental/beacon_sim/correlated_beacons_test.py b/experimental/beacon_sim/correlated_beacons_test.py index b46a63cb..0d778152 100644 --- a/experimental/beacon_sim/correlated_beacons_test.py +++ b/experimental/beacon_sim/correlated_beacons_test.py @@ -167,6 +167,23 @@ def test_three_independent_beacons_marginalization(self): self.assertAlmostEqual(np.exp(with_beacon.log_marginal), p_beacon, places=6) self.assertAlmostEqual(np.exp(no_beacon.log_marginal), 1 - p_beacon, places=6) + def test_three_independent_beacons_partial_prob(self): + # Setup + p_beacon = 0.75 + p_no_beacons = (1 - p_beacon) ** 3 + bc = cb.BeaconClique( + p_beacon=p_beacon, + p_no_beacons=p_no_beacons, + members=[1, 2, 3], + ) + + beacon_pot = cb.create_correlated_beacons(bc) + + # Action + log_prob = beacon_pot.log_prob({1:True, 2:False}, allow_partial_assignment=True) + + # Verification + self.assertAlmostEqual(np.exp(log_prob), p_beacon * (1-p_beacon), places=6) if __name__ == "__main__": unittest.main() diff --git a/toolchain/gcc_toolchain_config.bzl b/toolchain/gcc_toolchain_config.bzl index aefea79a..26aee4e2 100644 --- a/toolchain/gcc_toolchain_config.bzl +++ b/toolchain/gcc_toolchain_config.bzl @@ -98,6 +98,20 @@ def _impl(ctx): ) ] ), + feature( + name="c_compile_flags", + enabled=True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile], + flag_groups = [ + flag_group( + flags=["-fPIC"], + ) + ] + ) + ] + ), feature( name="dbg", enabled=False,