-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge from ale-data-structure into main
Merge branch 'ale-data-structure' # Conflicts: # R/validation.R # tests/testthat/_snaps/model_bootstrap.md # tests/testthat/test-ALEPlot.R
- Loading branch information
Showing
82 changed files
with
103,190 additions
and
10,510 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ ale.Rproj | |
docs | ||
inst/doc | ||
scholar | ||
R/refALEPlot.R |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,21 @@ | ||
Package: ale | ||
Title: Interpretable Machine Learning and Statistical Inference with Accumulated Local Effects (ALE) | ||
Version: 0.3.0.20240426 | ||
Version: 0.3.0.20240823 | ||
Authors@R: c( | ||
person("Chitu", "Okoli", , "[email protected]", role = c("aut", "cre"), | ||
comment = c(ORCID = "0000-0001-5574-7572")), | ||
person("Dan", "Apley", role = "cph", comment = "The current code for calculating ALE interaction values is copied with few changes from Dan Apley's ALEPlot package. We gratefully acknowledge his open-source contribution. However, he was not directly involved in the development of this ale package.") | ||
comment = c(ORCID = "0000-0001-5574-7572")) | ||
) | ||
Description: Accumulated Local Effects (ALE) were initially developed as a model-agnostic approach for global explanations of the results of black-box machine learning algorithms. ALE has a key advantage over other approaches like partial dependency plots (PDP) and SHapley Additive exPlanations (SHAP): its values represent a clean functional decomposition of the model. As such, ALE values are not affected by the presence or absence of interactions among variables in a mode. Moreover, its computation is relatively rapid. This package rewrites the original code from the 'ALEPlot' package for calculating ALE data and it completely reimplements the plotting of ALE values. It also extends the original ALE concept to add bootstrap-based confidence intervals and ALE-based statistics that can be used for statistical inference. For more details, see Okoli, Chitu. 2023. “Statistical Inference Using Machine Learning and Classical Techniques Based on Accumulated Local Effects (ALE).” arXiv. <arXiv:2310.09877>. <doi:10.48550/arXiv.2310.09877>. | ||
License: GPL-2 | ||
Description: Accumulated Local Effects (ALE) were initially developed as a model-agnostic approach for global explanations of the results of black-box machine learning algorithms. ALE has a key advantage over other approaches like partial dependency plots (PDP) and SHapley Additive exPlanations (SHAP): its values represent a clean functional decomposition of the model. As such, ALE values are not affected by the presence or absence of interactions among variables in a mode. Moreover, its computation is relatively rapid. This package reimplements the algorithms for calculating ALE data and develops highly interpretable visualizations for plotting these ALE values. It also extends the original ALE concept to add bootstrap-based confidence intervals and ALE-based statistics that can be used for statistical inference. For more details, see Okoli, Chitu. 2023. “Statistical Inference Using Machine Learning and Classical Techniques Based on Accumulated Local Effects (ALE).” arXiv. <arXiv:2310.09877>. <doi:10.48550/arXiv.2310.09877>. | ||
License: MIT + file LICENSE | ||
Language: en-ca | ||
Encoding: UTF-8 | ||
Roxygen: list(markdown = TRUE) | ||
RoxygenNote: 7.3.1 | ||
RoxygenNote: 7.3.2 | ||
Suggests: | ||
ALEPlot, | ||
knitr, | ||
mgcv, | ||
patchwork, | ||
nnet, | ||
readr, | ||
rmarkdown, | ||
testthat (>= 3.0.0) | ||
|
@@ -32,6 +31,7 @@ Imports: | |
grDevices, | ||
insight, | ||
labeling, | ||
patchwork, | ||
progressr, | ||
purrr, | ||
rlang, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
YEAR: 2024 | ||
COPYRIGHT HOLDER: ale authors |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
# ALEPlot.R | ||
# Functions for compatibility with the ALEPlot package | ||
|
||
|
||
# For missing (NA) cells in 2D interactions, replace delta_pred (dp) with the nearest valid neighbour | ||
nn_na_delta_pred <- function(dp, xd) { | ||
# Hack to silence R-CMD-CHECK | ||
knnIndexDist <- NULL | ||
|
||
# nn_na_delta_pred <- function(dp, numeric_x1) { | ||
x1_ceilings <- xd[[1]]$ceilings | ||
x2_ceilings <- xd[[2]]$ceilings | ||
|
||
# na_delta: xd[[1]]$n_bins by xd[[2]]$n_bins matrix with missing values TRUE | ||
# na_delta_idx: long matrix with row, col columns indicating indices of missing delta values | ||
na_delta <- is.na(dp) | ||
na_delta_idx <- which(na_delta, arr.ind = TRUE, useNames = TRUE) | ||
|
||
if (nrow(na_delta_idx) > 0) { | ||
# not_na_delta_idx: long matrix with row, col columns indicating indices WITHOUT missing delta values | ||
not_na_delta_idx <- which(!na_delta, arr.ind = TRUE, useNames = TRUE) | ||
|
||
range_x1 <- if (xd[[1]]$x_type == 'numeric') { | ||
# range_x1 <- if (numeric_x1) { | ||
max(x1_ceilings) - min(x1_ceilings) | ||
} else { | ||
xd[[1]]$n_bins - 1 | ||
} | ||
range_x2 <- max(x2_ceilings) - min(x2_ceilings) | ||
|
||
# Data Values of na_delta_idx and not_na_delta_idx, but normalized according to ALEPlot formulas | ||
if (xd[[1]]$x_type == 'numeric') { | ||
# if (numeric_x1) { | ||
norm_na_delta <- cbind( | ||
(x1_ceilings[na_delta_idx[, 1]] + x1_ceilings[na_delta_idx[, 1]+1]) / 2 / range_x1, | ||
(x2_ceilings[na_delta_idx[, 2]] + x2_ceilings[na_delta_idx[, 2]+1]) / 2 / range_x2 | ||
) | ||
norm_not_na_delta <- cbind( | ||
(x1_ceilings[not_na_delta_idx[, 1]] + x1_ceilings[not_na_delta_idx[, 1]+1]) / 2 / range_x1, | ||
(x2_ceilings[not_na_delta_idx[, 2]] + x2_ceilings[not_na_delta_idx[, 2]+1]) / 2 / range_x2 | ||
) | ||
} else { | ||
norm_na_delta <- cbind( | ||
na_delta_idx[, 1] / range_x1, | ||
na_delta_idx[, 2] / range_x2 | ||
) | ||
norm_not_na_delta <- cbind( | ||
not_na_delta_idx[, 1] / range_x1, | ||
not_na_delta_idx[, 2] / range_x2 | ||
) | ||
} | ||
|
||
|
||
# # Data Values of na_delta_idx and not_na_delta_idx, but normalized according to ALEPlot formulas | ||
# norm_na_delta <- cbind( | ||
# if (numeric_x1) { | ||
# (x1_ceilings[na_delta_idx[, 1]] + x1_ceilings[na_delta_idx[, 1]+1]) / 2 / range_x1 | ||
# } else { | ||
# na_delta_idx[, 1] / range_x1 | ||
# }, | ||
# (x2_ceilings[na_delta_idx[, 2]] + x2_ceilings[na_delta_idx[, 2]+1]) / 2 / range_x2 | ||
# ) | ||
# norm_not_na_delta <- cbind( | ||
# if (numeric_x1) { | ||
# (x1_ceilings[not_na_delta_idx[, 1]] + x1_ceilings[not_na_delta_idx[, 1]+1]) / 2 / range_x1 | ||
# } else { | ||
# not_na_delta_idx[, 1] / range_x1 | ||
# }, | ||
# (x2_ceilings[not_na_delta_idx[, 2]] + x2_ceilings[not_na_delta_idx[, 2]+1]) / 2 / range_x2 | ||
# ) | ||
|
||
|
||
if (any(is.na(norm_not_na_delta)) || any(is.na(norm_na_delta))) { | ||
closeAllConnections() | ||
browser() | ||
} | ||
|
||
# Use yaImpute::ann() for fast nearest non-NA neighbours of NA cells (consistency with ALEPlot) | ||
na_nbrs <- yaImpute::ann( | ||
norm_not_na_delta, | ||
norm_na_delta, | ||
k = 1, | ||
verbose = FALSE | ||
) |> | ||
(`$`)(knnIndexDist) |> | ||
(`[`)(, 1) | ||
|
||
# drop = FALSE needed to prevent occasionally collapsing into a vector | ||
dp[na_delta_idx] <- dp[not_na_delta_idx[na_nbrs,], drop = FALSE] | ||
|
||
# # Adapted note from ALEPlot: "The matrix() command is needed, because if there is only one empty cell, not_na_delta_idx[na_nbrs] is created as a 2-length vector instead of a 1x2 matrix, which does not index dp properly" | ||
# dp[na_delta_idx] <- dp[matrix(not_na_delta_idx[na_nbrs,], ncol = 2)] | ||
} # end if (nrow(na_delta_idx) > 0) | ||
|
||
dp | ||
} | ||
|
||
|
||
#' Sorted categorical indices based on Kolmogorov-Smirnov distances for empirically ordering categorical categories. | ||
#' | ||
#' @param X X data | ||
#' @param x_col character | ||
#' @param n_bins integer | ||
#' @param x_int_counts bin sizes | ||
idxs_kolmogorov_smirnov <- function( | ||
X, | ||
x_col, | ||
n_bins, | ||
x_int_counts | ||
) { | ||
|
||
# Initialize distance matrices between pairs of intervals of X[[x_col]] | ||
dist_mx <- matrix(0, n_bins, n_bins) | ||
cdm <- matrix(0, n_bins, n_bins) # cumulative distance matrix | ||
|
||
# Calculate distance matrix for each of the other X columns | ||
for (j_col in setdiff(names(X), x_col)) { | ||
if (var_type(X[[j_col]]) == 'numeric') { # distance matrix for numeric j_col | ||
# list of ECDFs for X[[j_col]] by intervals of X[[x_col]] | ||
x_by_j_ecdf <- tapply( | ||
X[[j_col]], | ||
X[[x_col]], | ||
stats::ecdf | ||
) | ||
|
||
# quantiles of X[[j_col]] for all intervals of X[[x_col]] combined | ||
j_quantiles <- stats::quantile( | ||
X[[j_col]], | ||
probs = seq(0, 1, length.out = 100), | ||
na.rm = TRUE, | ||
names = FALSE | ||
) | ||
|
||
for (i in 1:(n_bins - 1)) { | ||
for (k in (i + 1):n_bins) { | ||
# Kolmogorov-Smirnov distance between X[[j_col]] for intervals i and k of X[[x_col]]; always within [0, 1] | ||
dist_mx[i, k] <- (x_by_j_ecdf[[i]](j_quantiles) - | ||
x_by_j_ecdf[[k]](j_quantiles)) |> | ||
abs() |> | ||
max() | ||
# dist_mx[i, k] <- max(abs(x_by_j_ecdf[[i]](j_quantiles) - | ||
# x_by_j_ecdf[[k]](j_quantiles))) | ||
dist_mx[k, i] <- dist_mx[i, k] | ||
} | ||
} | ||
} | ||
else { # distance matrix for non-numeric j_col | ||
x_j_freq <- table(X[[x_col]], X[[j_col]]) #frequency table, rows of which will be compared | ||
x_j_freq <- x_j_freq / as.numeric(x_int_counts) | ||
for (i in 1:(n_bins-1)) { | ||
for (k in (i+1):n_bins) { | ||
# Dissimilarity measure always within [0, 1] | ||
dist_mx[i, k] <- sum(abs(x_j_freq[i, ] - | ||
x_j_freq[k, ])) / 2 | ||
dist_mx[k, i] <- dist_mx[i, k] | ||
} | ||
} | ||
} | ||
|
||
cdm <- cdm + dist_mx | ||
|
||
} | ||
|
||
# Replace any NA with the maximum distance | ||
cdm[is.na(cdm)] <- max(cdm, na.rm = TRUE) | ||
|
||
# Convert cumulative distance matrix to sorted indices | ||
idxs <- cdm |> | ||
stats::cmdscale(k = 1) |> # one-dimensional MDS representation of dist_mx | ||
sort(index.return = TRUE) |> | ||
(`[[`)('ix') | ||
|
||
|
||
return(idxs) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.