Skip to content

Commit

Permalink
use pair instead of set
Browse files Browse the repository at this point in the history
  • Loading branch information
elbersb committed Sep 19, 2023
1 parent f0803fe commit cc52d5e
Showing 1 changed file with 34 additions and 38 deletions.
72 changes: 34 additions & 38 deletions src/compression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,16 @@ double calculate_reduction(double n, std::vector<double> &unit1, std::vector<dou
return n_total / n * m_total;
}

typedef std::pair<std::string, std::string> t_neighbor;

t_neighbor neighbor_make_pair(std::string a, std::string b)
{
if (a < b)
return t_neighbor(a, b);
else
return t_neighbor(b, a);
}

// [[Rcpp::export]]
List compress_compute_cpp(
std::string neighbors_option,
Expand All @@ -151,20 +161,21 @@ List compress_compute_cpp(
double n_total, m_total;
std::tie(n_total, m_total) = calculate_m(data);

// prepare neighbors data structure (list of sets with 2 elements each)
std::map<std::set<std::string>, double> neighbors;
// prepare neighbors data structure (list of pairs, where order is unimportant)
std::map<t_neighbor, double> neighbors;
if (neighbors_option == "all")
{
for (int row = 0; row < unit_names.size(); row++)
{
for (int col = row + 1; col < unit_names.size(); col++)
{
neighbors[{unit_names[row], unit_names[col]}] = 0;
neighbors[neighbor_make_pair(unit_names[row], unit_names[col])] = 0;
}
}
}
else if (neighbors_option == "local")
{

auto ls = calculate_ls(data);
for (int i = 0; i < ls.size(); i++)
{
Expand All @@ -173,7 +184,7 @@ List compress_compute_cpp(
j++)
{
if (i != j)
neighbors[{ls[i].first, ls[j].first}] = 0;
neighbors[neighbor_make_pair(ls[i].first, ls[j].first)] = 0;
}
}
}
Expand All @@ -184,7 +195,7 @@ List compress_compute_cpp(
std::string unit1 = Rcpp::as<std::string>(m_neighbors(i, 0));
std::string unit2 = Rcpp::as<std::string>(m_neighbors(i, 1));
if (unit1 != unit2)
neighbors[{unit1, unit2}] = 0;
neighbors[neighbor_make_pair(unit1, unit2)] = 0;
}
}

Expand All @@ -193,10 +204,7 @@ List compress_compute_cpp(
// might do a lot of duplicate calculations)
for (const auto &[key, reduction] : neighbors)
{
auto iter = key.begin();
const std::string unit1 = *iter;
const std::string unit2 = *next(iter);
neighbors[{unit1, unit2}] = calculate_reduction(n_total, data[unit1], data[unit2]);
neighbors[neighbor_make_pair(key.first, key.second)] = calculate_reduction(n_total, data[key.first], data[key.second]);
}

// determine maximum number of iterations
Expand Down Expand Up @@ -225,7 +233,7 @@ List compress_compute_cpp(

// find smallest reduction
double min_reduction = 10000;
std::set<std::string> min_key;
t_neighbor min_key;
for (const auto &[key, reduction] : neighbors)
{
if (reduction < min_reduction)
Expand All @@ -237,48 +245,36 @@ List compress_compute_cpp(
}
}

auto iter = min_key.begin();
const std::string unit_keep = *iter;
const std::string unit_delete = *next(iter);
const std::string unit_keep = min_key.first;
const std::string unit_delete = min_key.second;
// add counts of 'delete' to 'keep', delete unit
for (int i = 0; i < n_groups; i++)
data[unit_keep][i] += data[unit_delete][i];
data.erase(unit_delete);

// update neighbors
std::vector<std::set<std::string>> delete_neighbors;
std::map<std::set<std::string>, double> new_neighbors;
neighbors.erase(min_key);

std::vector<t_neighbor> delete_neighbors;
std::map<t_neighbor, double> new_neighbors;
for (const auto &[key, reduction] : neighbors)
{
const bool delete_found = key.find(unit_delete) != key.end();
const bool keep_found = key.find(unit_keep) != key.end();
if (keep_found && delete_found)
// update pairs if deleted unit is involved
if (key.first == unit_delete || key.second == unit_delete)
{
// remove the neighbor pair
// this is a pair some_unit - deleted_unit -- replace deleted_unit with unit_keep
delete_neighbors.push_back(key);
std::string some_unit = (key.first == unit_delete) ? key.second : key.first;
new_neighbors[neighbor_make_pair(unit_keep, some_unit)] =
calculate_reduction(n_total, data[unit_keep], data[some_unit]);
}
else if (delete_found)
// recalculate reduction if kept unit is involved
else if (key.first == unit_keep || key.second == unit_keep)
{
// replace deleted unit with new unit
delete_neighbors.push_back(key);
std::set<std::string> new_key(key);
new_key.erase(unit_delete);
new_key.insert(unit_keep);

const auto iter = new_key.begin();
const std::string unit1 = *iter;
const std::string unit2 = *next(iter);
new_neighbors[{unit1, unit2}] = calculate_reduction(n_total, data[unit1], data[unit2]);
}
else if (keep_found)
{
// recalculate if updated unit is involved
const auto iter = key.begin();
const std::string unit1 = *iter;
const std::string unit2 = *next(iter);
new_neighbors[{unit1, unit2}] = calculate_reduction(n_total, data[unit1], data[unit2]);
new_neighbors[key] = calculate_reduction(n_total, data[key.first], data[key.second]);
}
}

// delete neighbors that involve the old unit
for (int i = 0; i < delete_neighbors.size(); i++)
neighbors.erase(delete_neighbors[i]);
Expand Down

0 comments on commit cc52d5e

Please sign in to comment.