diff --git a/src/compression.cpp b/src/compression.cpp index 464c16b..a1f0dae 100644 --- a/src/compression.cpp +++ b/src/compression.cpp @@ -126,6 +126,16 @@ double calculate_reduction(double n, std::vector &unit1, std::vector t_neighbor; + +t_neighbor neighbor_make_pair(std::string a, std::string b) +{ + if (a < b) + return t_neighbor(a, b); + else + return t_neighbor(b, a); +} + // [[Rcpp::export]] List compress_compute_cpp( std::string neighbors_option, @@ -151,20 +161,21 @@ List compress_compute_cpp( double n_total, m_total; std::tie(n_total, m_total) = calculate_m(data); - // prepare neighbors data structure (list of sets with 2 elements each) - std::map, double> neighbors; + // prepare neighbors data structure (list of pairs, where order is unimportant) + std::map neighbors; if (neighbors_option == "all") { for (int row = 0; row < unit_names.size(); row++) { for (int col = row + 1; col < unit_names.size(); col++) { - neighbors[{unit_names[row], unit_names[col]}] = 0; + neighbors[neighbor_make_pair(unit_names[row], unit_names[col])] = 0; } } } else if (neighbors_option == "local") { + auto ls = calculate_ls(data); for (int i = 0; i < ls.size(); i++) { @@ -173,7 +184,7 @@ List compress_compute_cpp( j++) { if (i != j) - neighbors[{ls[i].first, ls[j].first}] = 0; + neighbors[neighbor_make_pair(ls[i].first, ls[j].first)] = 0; } } } @@ -184,7 +195,7 @@ List compress_compute_cpp( std::string unit1 = Rcpp::as(m_neighbors(i, 0)); std::string unit2 = Rcpp::as(m_neighbors(i, 1)); if (unit1 != unit2) - neighbors[{unit1, unit2}] = 0; + neighbors[neighbor_make_pair(unit1, unit2)] = 0; } } @@ -193,10 +204,7 @@ List compress_compute_cpp( // might do a lot of duplicate calculations) for (const auto &[key, reduction] : neighbors) { - auto iter = key.begin(); - const std::string unit1 = *iter; - const std::string unit2 = *next(iter); - neighbors[{unit1, unit2}] = calculate_reduction(n_total, data[unit1], data[unit2]); + neighbors[neighbor_make_pair(key.first, key.second)] = calculate_reduction(n_total, data[key.first], data[key.second]); } // determine maximum number of iterations @@ -225,7 +233,7 @@ List compress_compute_cpp( // find smallest reduction double min_reduction = 10000; - std::set min_key; + t_neighbor min_key; for (const auto &[key, reduction] : neighbors) { if (reduction < min_reduction) @@ -237,48 +245,36 @@ List compress_compute_cpp( } } - auto iter = min_key.begin(); - const std::string unit_keep = *iter; - const std::string unit_delete = *next(iter); + const std::string unit_keep = min_key.first; + const std::string unit_delete = min_key.second; // add counts of 'delete' to 'keep', delete unit for (int i = 0; i < n_groups; i++) data[unit_keep][i] += data[unit_delete][i]; data.erase(unit_delete); // update neighbors - std::vector> delete_neighbors; - std::map, double> new_neighbors; + neighbors.erase(min_key); + + std::vector delete_neighbors; + std::map new_neighbors; for (const auto &[key, reduction] : neighbors) { - const bool delete_found = key.find(unit_delete) != key.end(); - const bool keep_found = key.find(unit_keep) != key.end(); - if (keep_found && delete_found) + // update pairs if deleted unit is involved + if (key.first == unit_delete || key.second == unit_delete) { - // remove the neighbor pair + // this is a pair some_unit - deleted_unit -- replace deleted_unit with unit_keep delete_neighbors.push_back(key); + std::string some_unit = (key.first == unit_delete) ? key.second : key.first; + new_neighbors[neighbor_make_pair(unit_keep, some_unit)] = + calculate_reduction(n_total, data[unit_keep], data[some_unit]); } - else if (delete_found) + // recalculate reduction if kept unit is involved + else if (key.first == unit_keep || key.second == unit_keep) { - // replace deleted unit with new unit - delete_neighbors.push_back(key); - std::set new_key(key); - new_key.erase(unit_delete); - new_key.insert(unit_keep); - - const auto iter = new_key.begin(); - const std::string unit1 = *iter; - const std::string unit2 = *next(iter); - new_neighbors[{unit1, unit2}] = calculate_reduction(n_total, data[unit1], data[unit2]); - } - else if (keep_found) - { - // recalculate if updated unit is involved - const auto iter = key.begin(); - const std::string unit1 = *iter; - const std::string unit2 = *next(iter); - new_neighbors[{unit1, unit2}] = calculate_reduction(n_total, data[unit1], data[unit2]); + new_neighbors[key] = calculate_reduction(n_total, data[key.first], data[key.second]); } } + // delete neighbors that involve the old unit for (int i = 0; i < delete_neighbors.size(); i++) neighbors.erase(delete_neighbors[i]);