Skip to content

Commit

Permalink
Allow evaluation of partial beacon assignments (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
ewfuentes authored Oct 24, 2023
1 parent 04e9fcc commit c1c4ea9
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 23 deletions.
1 change: 1 addition & 0 deletions experimental/beacon_sim/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ cc_library(
srcs = ["correlated_beacons.cc"],
deps = [
"@eigen//:eigen",
"//common:check",
"//common:drake",
"//common/math:combinations",
"//common/math:n_choose_k",
Expand Down
61 changes: 41 additions & 20 deletions experimental/beacon_sim/correlated_beacons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
#include <cmath>
#include <iterator>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <type_traits>
#include <unordered_map>

#include "common/check.hh"
#include "common/math/combinations.hh"
#include "common/math/n_choose_k.hh"
#include "drake/solvers/mathematical_program.h"
Expand Down Expand Up @@ -144,34 +143,56 @@ double BeaconPotential::log_prob(const std::vector<int> &present_beacons) const
return log_prob(assignment);
}

double BeaconPotential::log_prob(const std::unordered_map<int, bool> &assignment) const {
double BeaconPotential::log_prob(const std::unordered_map<int, bool> &assignment,
const bool allow_partial_assignment) const {
const std::vector<int> sorted_members = sorted_vector(members_);
const std::vector<int> keys = sorted_keys(assignment);
std::vector<int> missing_keys;
std::set_difference(sorted_members.begin(), sorted_members.end(), keys.begin(), keys.end(),
std::back_inserter(missing_keys));

if (!missing_keys.empty()) {
std::ostringstream out;
out << "Missing keys from assignment {";
bool is_first = true;
for (const auto &key : missing_keys) {
if (is_first) {
is_first = false;
} else {
out << ", ";
CHECK(allow_partial_assignment || missing_keys.empty(),
"partial assignment specified when not enabled", assignment, missing_keys, members());

const std::vector<int> to_marginalize = missing_keys;

std::unordered_map<int, int> index_from_id;
for (int i = 0; i < static_cast<int>(members_.size()); i++) {
index_from_id[members_.at(i)] = i;
}

const auto sum_over_marginalized = [&to_marginalize, &index_from_id,
this](const Eigen::VectorXd &x) {
const int n = to_marginalize.size();
std::vector<double> terms;
terms.reserve(1 << n);
for (int num_present = 0; num_present <= n; num_present++) {
// For each number of present beacons
for (const auto &config : math::combinations(n, num_present)) {
// We have a different way of that many beacons being present

// Set the element for the current config
Eigen::VectorXd curr_config = x;
for (const int to_marginalize_idx : config) {
const int marginal_id = to_marginalize.at(to_marginalize_idx);
const int x_idx = index_from_id[marginal_id];
curr_config(x_idx) = 1;
}

// Evaluate the log probability
terms.push_back(curr_config.transpose() * precision_ * curr_config -
log_normalizer_);
}
out << key;
}
out << "}";
throw std::runtime_error(out.str());
}
return logsumexp(terms);
};

Eigen::VectorXd x(members_.size());
for (int i = 0; i < static_cast<int>(members_.size()); i++) {
x(i) = assignment.at(members_.at(i));
Eigen::VectorXd config = Eigen::VectorXd::Zero(members_.size());
for (const auto &[beacon_id, is_present] : assignment) {
config(index_from_id.at(beacon_id)) = is_present;
}
return x.transpose() * precision_ * x - log_normalizer_;

return sum_over_marginalized(config);
}

std::vector<LogMarginal> BeaconPotential::compute_log_marginals(
Expand Down
3 changes: 2 additions & 1 deletion experimental/beacon_sim/correlated_beacons.hh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class BeaconPotential {
BeaconPotential(const Eigen::MatrixXd &information, const double log_normalizer,
const std::vector<int> &members);

double log_prob(const std::unordered_map<int, bool> &assignments) const;
double log_prob(const std::unordered_map<int, bool> &assignments,
const bool allow_partial_assignment = false) const;
double log_prob(const std::vector<int> &present_beacons) const;

BeaconPotential operator*(const BeaconPotential &other) const;
Expand Down
6 changes: 4 additions & 2 deletions experimental/beacon_sim/correlated_beacons_python.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ PYBIND11_MODULE(correlated_beacons_python, m) {

py::class_<BeaconPotential>(m, "BeaconPotential")
.def(py::init<Eigen::MatrixXd, double, std::vector<int>>())
.def("log_prob", py::overload_cast<const std::unordered_map<int, bool> &>(
&BeaconPotential::log_prob, py::const_))
.def("log_prob",
py::overload_cast<const std::unordered_map<int, bool> &, bool>(
&BeaconPotential::log_prob, py::const_),
py::arg("assignment"), py::arg("allow_partial_assignment") = false)
.def("log_prob",
py::overload_cast<const std::vector<int> &>(&BeaconPotential::log_prob, py::const_))
.def("__mul__", &BeaconPotential::operator*)
Expand Down
17 changes: 17 additions & 0 deletions experimental/beacon_sim/correlated_beacons_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,23 @@ def test_three_independent_beacons_marginalization(self):
self.assertAlmostEqual(np.exp(with_beacon.log_marginal), p_beacon, places=6)
self.assertAlmostEqual(np.exp(no_beacon.log_marginal), 1 - p_beacon, places=6)

def test_three_independent_beacons_partial_prob(self):
# Setup
p_beacon = 0.75
p_no_beacons = (1 - p_beacon) ** 3
bc = cb.BeaconClique(
p_beacon=p_beacon,
p_no_beacons=p_no_beacons,
members=[1, 2, 3],
)

beacon_pot = cb.create_correlated_beacons(bc)

# Action
log_prob = beacon_pot.log_prob({1:True, 2:False}, allow_partial_assignment=True)

# Verification
self.assertAlmostEqual(np.exp(log_prob), p_beacon * (1-p_beacon), places=6)

if __name__ == "__main__":
unittest.main()
14 changes: 14 additions & 0 deletions toolchain/gcc_toolchain_config.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,20 @@ def _impl(ctx):
)
]
),
feature(
name="c_compile_flags",
enabled=True,
flag_sets = [
flag_set(
actions = [ACTION_NAMES.c_compile],
flag_groups = [
flag_group(
flags=["-fPIC"],
)
]
)
]
),
feature(
name="dbg",
enabled=False,
Expand Down

0 comments on commit c1c4ea9

Please sign in to comment.