Skip to content

Commit

Permalink
compression: use memoization
Browse files Browse the repository at this point in the history
  • Loading branch information
elbersb committed Sep 17, 2023
1 parent 783c50a commit d31ae17
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 49 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ URL: https://elbersb.github.io/segregation/
BugReports: https://github.com/elbersb/segregation/issues
RoxygenNote: 7.2.3
VignetteBuilder: knitr
SystemRequirements: C++17
LinkingTo:
Rcpp,
RcppProgress
105 changes: 56 additions & 49 deletions src/compression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct CompressionResults
std::vector<std::string> new_unit;
};

std::tuple<double, double> calculate_m(std::map<std::string, std::vector<double>> data)
std::tuple<double, double> calculate_m(std::map<std::string, std::vector<double>> &data)
{
// create group sums
int n_groups = data.begin()->second.size();
Expand Down Expand Up @@ -44,10 +44,10 @@ std::tuple<double, double> calculate_m(std::map<std::string, std::vector<double>
return std::make_tuple(n_total, m_total);
}

std::tuple<double, double> calculate_twounit_m(std::vector<double> &unit1, std::vector<double> &unit2)
double calculate_reduction(double n, std::vector<double> &unit1, std::vector<double> &unit2)
{
// create group sums
int n_groups = unit1.size();
const int n_groups = unit1.size();
double n_total = 0.0;
std::vector<double> group_sums(n_groups, 0.0);
for (int i = 0; i < n_groups; i++)
Expand All @@ -59,8 +59,8 @@ std::tuple<double, double> calculate_twounit_m(std::vector<double> &unit1, std::
}

// create unit sums
double n_unit1 = std::accumulate(unit1.begin(), unit1.end(), 0);
double n_unit2 = std::accumulate(unit2.begin(), unit2.end(), 0);
const double n_unit1 = std::accumulate(unit1.begin(), unit1.end(), 0);
const double n_unit2 = std::accumulate(unit2.begin(), unit2.end(), 0);

// calculate M
double m_total = 0.0;
Expand All @@ -82,7 +82,7 @@ std::tuple<double, double> calculate_twounit_m(std::vector<double> &unit1, std::
}
}

return std::make_tuple(n_total, m_total);
return n_total / n * m_total;
}

// [[Rcpp::export]]
Expand All @@ -92,17 +92,6 @@ List compress_compute_cpp(
StringVector unit_names,
int max_iter)
{
// prepare neighbors data structure (list of sets)
std::vector<std::set<std::string>> neighbors;
for (int i = 0; i < m_neighbors.nrow(); i++)
{
neighbors.push_back({});
for (int j = 0; j < m_neighbors.ncol(); j++)
{
neighbors[neighbors.size() - 1].insert(Rcpp::as<std::string>(m_neighbors(i, j)));
}
}

// prepare main data structure: map, where the key is the unit name
// and the values are the ordered group counts
std::map<std::string, std::vector<double>> data;
Expand All @@ -111,17 +100,23 @@ List compress_compute_cpp(
std::string unit = Rcpp::as<std::string>(unit_names[i]);
data[unit] = {};
for (int j = 0; j < m_data.ncol(); j++)
{
data[unit].push_back(m_data(i, j));
}
}

int n_groups = m_data.ncol();

// compute total M index
double n_total, m_total;
std::tie(n_total, m_total) = calculate_m(data);

// prepare neighbors data structure (list of sets with 2 elements each)
std::map<std::set<std::string>, double> neighbors;
for (int i = 0; i < m_neighbors.nrow(); i++)
{
std::string unit1 = Rcpp::as<std::string>(m_neighbors(i, 0));
std::string unit2 = Rcpp::as<std::string>(m_neighbors(i, 1));
neighbors[{unit1, unit2}] = calculate_reduction(n_total, data[unit1], data[unit2]);
}

CompressionResults results;
results.iter.reserve(max_iter);
results.M_wgt.reserve(max_iter);
Expand All @@ -140,57 +135,69 @@ List compress_compute_cpp(
if (Progress::check_abort())
return List::create();

// analyze reductions for all neighbors
// find smallest reduction
double min_reduction = 10000;
int min_index;

for (int i = 0; i < neighbors.size(); i++)
std::set<std::string> min_key;
for (const auto &[key, reduction] : neighbors)
{
auto iter = neighbors[i].begin();
double n_total_pair, m_total_pair;
std::tie(n_total_pair, m_total_pair) = calculate_twounit_m(
data[*iter], data[*next(iter)]);
// reduction = p_AB * M_AB
double reduction = n_total_pair / n_total * m_total_pair;
if (reduction < min_reduction)
{
min_reduction = reduction;
min_index = i;
min_key = key;
if (reduction == 0)
break;
}
}

auto iter = neighbors[min_index].begin();
auto iter = min_key.begin();
const std::string unit_keep = *iter;
const std::string unit_delete = *next(iter);
// add counts of second to first, delete second
// add counts of 'delete' to 'keep', delete unit
for (int i = 0; i < n_groups; i++)
{
data[unit_keep][i] += data[unit_delete][i];
}
data.erase(unit_delete);
// update all neighbor references
int mark_for_deletion;
for (int i = 0; i < neighbors.size(); i++)

// update neighbors
std::vector<std::set<std::string>> delete_neighbors;
std::map<std::set<std::string>, double> new_neighbors;
for (const auto &[key, reduction] : neighbors)
{
const bool is_in = neighbors[i].find(unit_delete) != neighbors[i].end();
if (is_in)
const bool delete_found = key.find(unit_delete) != key.end();
const bool keep_found = key.find(unit_keep) != key.end();
if (keep_found && delete_found)
{
neighbors[i].erase(unit_delete);
neighbors[i].insert(unit_keep);
// remove the neighbor pair
delete_neighbors.push_back(key);
}
else if (delete_found)
{
// replace deleted unit with new unit
delete_neighbors.push_back(key);
std::set<std::string> new_key(key);
new_key.erase(unit_delete);
new_key.insert(unit_keep);

auto iter = new_key.begin();
const std::string unit1 = *iter;
const std::string unit2 = *next(iter);
new_neighbors[{unit1, unit2}] = calculate_reduction(n_total, data[unit1], data[unit2]);
}
if (neighbors[i].size() == 1)
else if (keep_found)
{
mark_for_deletion = i;
// recalculate if updated unit is involved
auto iter = key.begin();
const std::string unit1 = *iter;
const std::string unit2 = *next(iter);
new_neighbors[{unit1, unit2}] = calculate_reduction(n_total, data[unit1], data[unit2]);
}
}
// clean up neighbors: erase equal elements
neighbors.erase(neighbors.begin() + mark_for_deletion);
// delete neighbors that involve the old unit
for (int i = 0; i < delete_neighbors.size(); i++)
neighbors.erase(delete_neighbors[i]);

// clean up neighbors: erase duplicates
std::sort(neighbors.begin(), neighbors.end());
neighbors.erase(std::unique(neighbors.begin(), neighbors.end()), neighbors.end());
// update and add new neighbors
for (const auto &[key, reduction] : new_neighbors)
neighbors[key] = reduction;

// update results
m_current -= min_reduction;
Expand Down

0 comments on commit d31ae17

Please sign in to comment.