Skip to content

Commit

Permalink
move local seg code to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
elbersb committed Sep 19, 2023
1 parent 1a6c3fb commit f0803fe
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 27 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

compress_compute_cpp <- function(neighbors_option, m_neighbors, m_data, unit_names, max_iter) {
.Call(`_segregation_compress_compute_cpp`, neighbors_option, m_neighbors, m_data, unit_names, max_iter)
compress_compute_cpp <- function(neighbors_option, m_neighbors, n_neighbors, m_data, unit_names, max_iter) {
.Call(`_segregation_compress_compute_cpp`, neighbors_option, m_neighbors, n_neighbors, m_data, unit_names, max_iter)
}

get_crosswalk_cpp <- function(old_unit, new_unit) {
Expand Down
24 changes: 6 additions & 18 deletions R/compression.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ compress <- function(data, group, unit, weight = NULL,
if (!is.factor(d[[unit]]) && !is.character(d[[unit]])) {
warning("coercing unit ids to character")
d[[unit]] <- as.character(d[[unit]])
neighbors <- as.character(neighbors)
}

wide <- dcast(d, paste0(unit, "~", group), value.var = "freq", fill = 0)
Expand All @@ -53,24 +54,11 @@ compress <- function(data, group, unit, weight = NULL,
if (is.infinite(max_iter)) max_iter <- -1

if (is.data.frame(neighbors)) {
res <- compress_compute_cpp("df", as.matrix(neighbors), as.matrix(wide), units, max_iter)
} else if (is.character(neighbors) && neighbors == "all") {
res <- compress_compute_cpp("all", matrix(""), as.matrix(wide), units, max_iter)
} else if (is.character(neighbors) && neighbors == "local") {
ls <- mutual_local(d, group, unit, weight = "freq", wide = TRUE)
entropy <- d[, .(entropy = entropy(.SD, group, weight = "freq")), by = unit]
ls <- merge(ls, entropy)

setorder(ls, entropy)
neighbors <- lapply(2:(nrow(ls) - 1), function(u) {
focal <- ls[[unit]][u]
nb_before <- ls[[unit]][max(c(1, u - n_neighbors)):(u - 1)]
nb_after <- ls[[unit]][(u + 1):min(c(nrow(ls), u + n_neighbors))]
data.table(a = focal, b = c(nb_before, nb_after))
})
neighbors <- rbindlist(neighbors)

res <- compress_compute_cpp("local", as.matrix(neighbors), as.matrix(wide), units, max_iter)
res <- compress_compute_cpp("df", as.matrix(neighbors), -1, as.matrix(wide), units, max_iter)
} else if (neighbors == "all") {
res <- compress_compute_cpp("all", matrix(""), -1, as.matrix(wide), units, max_iter)
} else if (neighbors == "local") {
res <- compress_compute_cpp("local", matrix(""), n_neighbors, as.matrix(wide), units, max_iter)
}

iterations <- as.data.table(res)
Expand Down
9 changes: 5 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// compress_compute_cpp
List compress_compute_cpp(std::string neighbors_option, StringMatrix m_neighbors, NumericMatrix m_data, std::vector<std::string> unit_names, int max_iter);
RcppExport SEXP _segregation_compress_compute_cpp(SEXP neighbors_optionSEXP, SEXP m_neighborsSEXP, SEXP m_dataSEXP, SEXP unit_namesSEXP, SEXP max_iterSEXP) {
List compress_compute_cpp(std::string neighbors_option, StringMatrix m_neighbors, int n_neighbors, NumericMatrix m_data, std::vector<std::string> unit_names, int max_iter);
RcppExport SEXP _segregation_compress_compute_cpp(SEXP neighbors_optionSEXP, SEXP m_neighborsSEXP, SEXP n_neighborsSEXP, SEXP m_dataSEXP, SEXP unit_namesSEXP, SEXP max_iterSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< std::string >::type neighbors_option(neighbors_optionSEXP);
Rcpp::traits::input_parameter< StringMatrix >::type m_neighbors(m_neighborsSEXP);
Rcpp::traits::input_parameter< int >::type n_neighbors(n_neighborsSEXP);
Rcpp::traits::input_parameter< NumericMatrix >::type m_data(m_dataSEXP);
Rcpp::traits::input_parameter< std::vector<std::string> >::type unit_names(unit_namesSEXP);
Rcpp::traits::input_parameter< int >::type max_iter(max_iterSEXP);
rcpp_result_gen = Rcpp::wrap(compress_compute_cpp(neighbors_option, m_neighbors, m_data, unit_names, max_iter));
rcpp_result_gen = Rcpp::wrap(compress_compute_cpp(neighbors_option, m_neighbors, n_neighbors, m_data, unit_names, max_iter));
return rcpp_result_gen;
END_RCPP
}
Expand All @@ -39,7 +40,7 @@ END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_segregation_compress_compute_cpp", (DL_FUNC) &_segregation_compress_compute_cpp, 5},
{"_segregation_compress_compute_cpp", (DL_FUNC) &_segregation_compress_compute_cpp, 6},
{"_segregation_get_crosswalk_cpp", (DL_FUNC) &_segregation_get_crosswalk_cpp, 2},
{NULL, NULL, 0}
};
Expand Down
62 changes: 59 additions & 3 deletions src/compression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,47 @@ std::tuple<double, double> calculate_m(std::map<std::string, std::vector<double>
return std::make_tuple(n_total, m_total);
}

std::vector<std::pair<std::string, double>> calculate_ls(std::map<std::string, std::vector<double>> &data)
{
// create group sums
int n_groups = data.begin()->second.size();
std::vector<double> group_sums(n_groups, 0.0);
for (auto &[unit, counts] : data)
{
for (int i = 0; i < n_groups; i++)
group_sums[i] += counts[i];
}

// create group proportions
double n_total = std::accumulate(group_sums.begin(), group_sums.end(), 0);
std::vector<double> group_p(n_groups, 0.0);
for (int i = 0; i < n_groups; i++)
{
group_p[i] = group_sums[i] / n_total;
}

// create local segregation scores for each unit
std::vector<std::pair<std::string, double>> ls;
for (auto &[unit, counts] : data)
{
double n_unit = std::accumulate(counts.begin(), counts.end(), 0);
double ls_unit = 0.0;
for (int i = 0; i < n_groups; i++)
{
double p_group_given_unit = counts[i] / n_unit;
if (p_group_given_unit == 0)
continue;
ls_unit += p_group_given_unit * std::log(p_group_given_unit / group_p[i]);
}
ls.push_back({unit, ls_unit});
}

std::sort(ls.begin(), ls.end(), [](auto &left, auto &right)
{ return left.second < right.second; });

return ls;
}

double calculate_reduction(double n, std::vector<double> &unit1, std::vector<double> &unit2)
{
// create group sums
Expand Down Expand Up @@ -89,6 +130,7 @@ double calculate_reduction(double n, std::vector<double> &unit1, std::vector<dou
List compress_compute_cpp(
std::string neighbors_option,
StringMatrix m_neighbors,
int n_neighbors,
NumericMatrix m_data,
std::vector<std::string> unit_names,
int max_iter)
Expand Down Expand Up @@ -121,7 +163,21 @@ List compress_compute_cpp(
}
}
}
else if (neighbors_option == "df" || neighbors_option == "local")
else if (neighbors_option == "local")
{
auto ls = calculate_ls(data);
for (int i = 0; i < ls.size(); i++)
{
for (int j = std::max(i - n_neighbors, 0);
j < std::min(i + n_neighbors + 1, static_cast<int>(ls.size()) - 1);
j++)
{
if (i != j)
neighbors[{ls[i].first, ls[j].first}] = 0;
}
}
}
else if (neighbors_option == "df")
{
for (int i = 0; i < m_neighbors.nrow(); i++)
{
Expand Down Expand Up @@ -209,15 +265,15 @@ List compress_compute_cpp(
new_key.erase(unit_delete);
new_key.insert(unit_keep);

auto iter = new_key.begin();
const auto iter = new_key.begin();
const std::string unit1 = *iter;
const std::string unit2 = *next(iter);
new_neighbors[{unit1, unit2}] = calculate_reduction(n_total, data[unit1], data[unit2]);
}
else if (keep_found)
{
// recalculate if updated unit is involved
auto iter = key.begin();
const auto iter = key.begin();
const std::string unit1 = *iter;
const std::string unit2 = *next(iter);
new_neighbors[{unit1, unit2}] = calculate_reduction(n_total, data[unit1], data[unit2]);
Expand Down

0 comments on commit f0803fe

Please sign in to comment.