Skip to content

Commit

Permalink
Merge from ale-data-structure into main
Browse files Browse the repository at this point in the history
Merge branch 'ale-data-structure'

# Conflicts:
#	R/validation.R
#	tests/testthat/_snaps/model_bootstrap.md
#	tests/testthat/test-ALEPlot.R
  • Loading branch information
tripartio committed Nov 9, 2024
2 parents 80f946a + d5b3358 commit 2e8738f
Show file tree
Hide file tree
Showing 82 changed files with 103,190 additions and 10,510 deletions.
3 changes: 2 additions & 1 deletion .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
^pkgdown$
^\.github$
^README\.Rmd$
^R/refALEPlot\.R$
^tests/testthat/test-ALEPlot\.R$
^vignettes/articles$
^tests/testthat/test-ALEPlot.R$
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ ale.Rproj
docs
inst/doc
scholar
R/refALEPlot.R
14 changes: 7 additions & 7 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
Package: ale
Title: Interpretable Machine Learning and Statistical Inference with Accumulated Local Effects (ALE)
Version: 0.3.0.20240426
Version: 0.3.0.20240823
Authors@R: c(
person("Chitu", "Okoli", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0001-5574-7572")),
person("Dan", "Apley", role = "cph", comment = "The current code for calculating ALE interaction values is copied with few changes from Dan Apley's ALEPlot package. We gratefully acknowledge his open-source contribution. However, he was not directly involved in the development of this ale package.")
comment = c(ORCID = "0000-0001-5574-7572"))
)
Description: Accumulated Local Effects (ALE) were initially developed as a model-agnostic approach for global explanations of the results of black-box machine learning algorithms. ALE has a key advantage over other approaches like partial dependency plots (PDP) and SHapley Additive exPlanations (SHAP): its values represent a clean functional decomposition of the model. As such, ALE values are not affected by the presence or absence of interactions among variables in a mode. Moreover, its computation is relatively rapid. This package rewrites the original code from the 'ALEPlot' package for calculating ALE data and it completely reimplements the plotting of ALE values. It also extends the original ALE concept to add bootstrap-based confidence intervals and ALE-based statistics that can be used for statistical inference. For more details, see Okoli, Chitu. 2023. “Statistical Inference Using Machine Learning and Classical Techniques Based on Accumulated Local Effects (ALE).” arXiv. <arXiv:2310.09877>. <doi:10.48550/arXiv.2310.09877>.
License: GPL-2
Description: Accumulated Local Effects (ALE) were initially developed as a model-agnostic approach for global explanations of the results of black-box machine learning algorithms. ALE has a key advantage over other approaches like partial dependency plots (PDP) and SHapley Additive exPlanations (SHAP): its values represent a clean functional decomposition of the model. As such, ALE values are not affected by the presence or absence of interactions among variables in a mode. Moreover, its computation is relatively rapid. This package reimplements the algorithms for calculating ALE data and develops highly interpretable visualizations for plotting these ALE values. It also extends the original ALE concept to add bootstrap-based confidence intervals and ALE-based statistics that can be used for statistical inference. For more details, see Okoli, Chitu. 2023. “Statistical Inference Using Machine Learning and Classical Techniques Based on Accumulated Local Effects (ALE).” arXiv. <arXiv:2310.09877>. <doi:10.48550/arXiv.2310.09877>.
License: MIT + file LICENSE
Language: en-ca
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Suggests:
ALEPlot,
knitr,
mgcv,
patchwork,
nnet,
readr,
rmarkdown,
testthat (>= 3.0.0)
Expand All @@ -32,6 +31,7 @@ Imports:
grDevices,
insight,
labeling,
patchwork,
progressr,
purrr,
rlang,
Expand Down
2 changes: 2 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
YEAR: 2024
COPYRIGHT HOLDER: ale authors
357 changes: 21 additions & 336 deletions LICENSE.md

Large diffs are not rendered by default.

25 changes: 22 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,23 +1,42 @@
# Generated by roxygen2: do not edit by hand

S3method(plot,ale)
S3method(plot,ale_boot)
S3method(plot,ale_plots)
S3method(print,ale)
S3method(print,ale_plots)
export(ale)
export(ale_ixn)
export(create_p_funs)
export(aucroc)
export(create_p_dist)
export(mad)
export(mae)
export(model_bootstrap)
export(rmse)
export(standardized_accuracy)
export(win_mae)
export(win_rmse)
export(winsorize)
import(dplyr)
import(ggplot2)
importFrom(cli,cli_abort)
importFrom(cli,cli_alert_danger)
importFrom(cli,cli_alert_info)
importFrom(cli,cli_warn)
importFrom(purrr,compact)
importFrom(purrr,imap)
importFrom(purrr,list_transpose)
importFrom(purrr,map)
importFrom(purrr,map2)
importFrom(purrr,map2_dbl)
importFrom(purrr,map_chr)
importFrom(purrr,map_dbl)
importFrom(purrr,pluck)
importFrom(purrr,set_names)
importFrom(purrr,transpose)
importFrom(purrr,walk)
importFrom(rlang,.data)
importFrom(rlang,`:=`)
importFrom(rlang,is_bool)
importFrom(rlang,is_scalar_logical)
importFrom(rlang,is_string)
importFrom(stats,median)
importFrom(stats,quantile)
Expand Down
40 changes: 33 additions & 7 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,57 @@

## Breaking changes

* We have deeply rethought how best to structure the objects for this package. As a result, the underlying algorithm for calculating ALE has been completely rewritten to be more scalable.
* In addition to rewriting the code under the hood, the structure of all ale objects has been completely rewritten. The latest objects are not compatible with earlier versions. However, the new structure supports the roadmap of future functionality, so we hope that there will be minimal changes in the future that interrupt backward compatibility.
* We have created several S3 objects to represent different kinds of ale package objects:
* `ale`: the core `ale` package object that holds the results of the [ale()] function.
* `ale_boot`: results of the [model_bootstrap()] function.
* `ale_p`: p-value distribution information as the result of the [create_p_dist()] function.
* With the extensive rewrite, we no longer depend on {ALEPlot} code and so now claim full authorship of the code. One of the most significant implications of this is that we have decided to change the package license from the GPL 2 to MIT, which permits maximum dissemination of our algorithms.
* Renamed the `rug_sample_size` argument of ale() to `sample_size`. Now it reflects the size of `data` that should be sampled in the `ale` object, which can be used not only for rug plots but for other purposes.
* [ale_ixn()] has been eliminated and now both 1D and 2D ALE are calculated with the [ale()] function.
* [ale()] no longer produces plots. ALE plots are now created as `ale_plot` objects that create all possible plots from the ALE data from `ale` or `ale_boot` objects. Thus, serializing `ale` objects now avoids the problems of environment bloat of the included `ggplot` objects.


## Bug fixes

* Gracefully fails when the input data has missing values.

## Other user-visible changes

* `print()` and `plot()` methods have been added to the `ale_plots` object.
* A `print()` method has been added to the `ale` object.
* Interactions are now supported between pairs of categorical variables. (Before, only numerical pairs or pairs with one numerical and one categorical were supported.)
* Bootstrapping is now supported for ALE interactions.
* ALE statistics are now supported for interactions.
* Categorical y outcomes are now supported. The plots, though, only plot one category at a time.
* 'boot_data' is now an output option from ale(). It outputs the ALE values from each bootstrap iteration.
* create_p_funs now produces two types of p-value via the `p_val_type` argument: 'approx fast' for relatively faster but only approximate values (the default) or 'precise slow' for very slow but more exact values.
* model_bootstrap() has added various model performance measures that are validated using bootstrap validation with the .632 correction.
* The structure of `p_funs` has been completely changed; it has now been converted to an object named `ale_p` and the functions are separated from the object as internal functions. The function create_p_funs() has been renamed create_p_dist().
* create_p_dist() now produces two types of p-values via the `p_speed` argument: 'approx fast' for relatively faster but only approximate values (the default) or 'precise slow' for very slow but more exact values.
* Character input data is now accepted as a categorical datatype. It is handled the same as unordered factors.

## Under the hood

One of the most fundamental changes is not directly visible but affects how some ALE values are calculated. In certain very specific cases, the ALE values are now slightly different from those of the reference `ALEPlot` package. These are only for non-numerical variables for some prediction types other than predictions scaled on the response variable. (E.g., a binary or categorical variable for a logarithmic prediction not scaled to the same scale as the response variable.) We made this change for two reasons:
* We can understand our implementation and its interpretation for these edge cases much better than that of the reference `ALEPlot` implementation. These cases are not covered at all in the base ALE scientific article and they are poorly documented in the `ALEPlot` code. We cannot help users to interpret results that we do not understand ourselves.
* Our implementation lets us write code that scales smoothly for interactions of arbitrary depth. In contrast, the `ALEPlot` reference implementation is not scalable: custom code must be written for each type and each degree of interaction.
Other than for these edge cases, our implementation continues to give identical results to the reference `ALEPlot` package.

Other notable changes that might not be readily visible to users:
* Reduced dependencies by doing more with the `{rlang}` and `{cli}` packages. Reduced the imported functions to a minimum.
* Messages now use `{cli}`.
* Package messages, warnings, and errors now use `{cli}`.
* Replaced `{assertthat}` with custom validation functions that adapt some `{assertthat}` code.
* Use helper.R test files so that some testing objects are available to the loaded package.
* Configure `{future}` parallelization code to restore original values on exit.
* Increased memory efficiency of p_funs objects.
* Configured `{future}` parallelization code to restore original values on exit.
* Configured codes that use a random seed to restore the original system seed on exit.
* Improved memory efficiency of `ale_p` objects.
* Plotting code updated for compatibility with ggplot2 3.5.

## Known issues to be addressed in a future version

- Bootstrapping is not yet supported for ALE interactions (`ale_ixn()`).
- ALE statistics are not yet supported for ALE interactions (`ale_ixn()`).
- `ale()` does not yet support multi-output model prediction types (e.g., multi-class classification and multi-time survival probabilities).
- Plots that display categorical outcomes all on one plot are yet to be implemented. For now, each class or category must be plotted at a time.
- Effects plots for interactions have not yet been implemented.


# ale 0.3.0
Expand Down
175 changes: 175 additions & 0 deletions R/ALEPlot.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# ALEPlot.R
# Functions for compatibility with the ALEPlot package


# For missing (NA) cells in 2D interactions, replace delta_pred (dp) with the nearest valid neighbour
nn_na_delta_pred <- function(dp, xd) {
# Hack to silence R-CMD-CHECK
knnIndexDist <- NULL

# nn_na_delta_pred <- function(dp, numeric_x1) {
x1_ceilings <- xd[[1]]$ceilings
x2_ceilings <- xd[[2]]$ceilings

# na_delta: xd[[1]]$n_bins by xd[[2]]$n_bins matrix with missing values TRUE
# na_delta_idx: long matrix with row, col columns indicating indices of missing delta values
na_delta <- is.na(dp)
na_delta_idx <- which(na_delta, arr.ind = TRUE, useNames = TRUE)

if (nrow(na_delta_idx) > 0) {
# not_na_delta_idx: long matrix with row, col columns indicating indices WITHOUT missing delta values
not_na_delta_idx <- which(!na_delta, arr.ind = TRUE, useNames = TRUE)

range_x1 <- if (xd[[1]]$x_type == 'numeric') {
# range_x1 <- if (numeric_x1) {
max(x1_ceilings) - min(x1_ceilings)
} else {
xd[[1]]$n_bins - 1
}
range_x2 <- max(x2_ceilings) - min(x2_ceilings)

# Data Values of na_delta_idx and not_na_delta_idx, but normalized according to ALEPlot formulas
if (xd[[1]]$x_type == 'numeric') {
# if (numeric_x1) {
norm_na_delta <- cbind(
(x1_ceilings[na_delta_idx[, 1]] + x1_ceilings[na_delta_idx[, 1]+1]) / 2 / range_x1,
(x2_ceilings[na_delta_idx[, 2]] + x2_ceilings[na_delta_idx[, 2]+1]) / 2 / range_x2
)
norm_not_na_delta <- cbind(
(x1_ceilings[not_na_delta_idx[, 1]] + x1_ceilings[not_na_delta_idx[, 1]+1]) / 2 / range_x1,
(x2_ceilings[not_na_delta_idx[, 2]] + x2_ceilings[not_na_delta_idx[, 2]+1]) / 2 / range_x2
)
} else {
norm_na_delta <- cbind(
na_delta_idx[, 1] / range_x1,
na_delta_idx[, 2] / range_x2
)
norm_not_na_delta <- cbind(
not_na_delta_idx[, 1] / range_x1,
not_na_delta_idx[, 2] / range_x2
)
}


# # Data Values of na_delta_idx and not_na_delta_idx, but normalized according to ALEPlot formulas
# norm_na_delta <- cbind(
# if (numeric_x1) {
# (x1_ceilings[na_delta_idx[, 1]] + x1_ceilings[na_delta_idx[, 1]+1]) / 2 / range_x1
# } else {
# na_delta_idx[, 1] / range_x1
# },
# (x2_ceilings[na_delta_idx[, 2]] + x2_ceilings[na_delta_idx[, 2]+1]) / 2 / range_x2
# )
# norm_not_na_delta <- cbind(
# if (numeric_x1) {
# (x1_ceilings[not_na_delta_idx[, 1]] + x1_ceilings[not_na_delta_idx[, 1]+1]) / 2 / range_x1
# } else {
# not_na_delta_idx[, 1] / range_x1
# },
# (x2_ceilings[not_na_delta_idx[, 2]] + x2_ceilings[not_na_delta_idx[, 2]+1]) / 2 / range_x2
# )


if (any(is.na(norm_not_na_delta)) || any(is.na(norm_na_delta))) {
closeAllConnections()
browser()
}

# Use yaImpute::ann() for fast nearest non-NA neighbours of NA cells (consistency with ALEPlot)
na_nbrs <- yaImpute::ann(
norm_not_na_delta,
norm_na_delta,
k = 1,
verbose = FALSE
) |>
(`$`)(knnIndexDist) |>
(`[`)(, 1)

# drop = FALSE needed to prevent occasionally collapsing into a vector
dp[na_delta_idx] <- dp[not_na_delta_idx[na_nbrs,], drop = FALSE]

# # Adapted note from ALEPlot: "The matrix() command is needed, because if there is only one empty cell, not_na_delta_idx[na_nbrs] is created as a 2-length vector instead of a 1x2 matrix, which does not index dp properly"
# dp[na_delta_idx] <- dp[matrix(not_na_delta_idx[na_nbrs,], ncol = 2)]
} # end if (nrow(na_delta_idx) > 0)

dp
}


#' Sorted categorical indices based on Kolmogorov-Smirnov distances for empirically ordering categorical categories.
#'
#' @param X X data
#' @param x_col character
#' @param n_bins integer
#' @param x_int_counts bin sizes
idxs_kolmogorov_smirnov <- function(
X,
x_col,
n_bins,
x_int_counts
) {

# Initialize distance matrices between pairs of intervals of X[[x_col]]
dist_mx <- matrix(0, n_bins, n_bins)
cdm <- matrix(0, n_bins, n_bins) # cumulative distance matrix

# Calculate distance matrix for each of the other X columns
for (j_col in setdiff(names(X), x_col)) {
if (var_type(X[[j_col]]) == 'numeric') { # distance matrix for numeric j_col
# list of ECDFs for X[[j_col]] by intervals of X[[x_col]]
x_by_j_ecdf <- tapply(
X[[j_col]],
X[[x_col]],
stats::ecdf
)

# quantiles of X[[j_col]] for all intervals of X[[x_col]] combined
j_quantiles <- stats::quantile(
X[[j_col]],
probs = seq(0, 1, length.out = 100),
na.rm = TRUE,
names = FALSE
)

for (i in 1:(n_bins - 1)) {
for (k in (i + 1):n_bins) {
# Kolmogorov-Smirnov distance between X[[j_col]] for intervals i and k of X[[x_col]]; always within [0, 1]
dist_mx[i, k] <- (x_by_j_ecdf[[i]](j_quantiles) -
x_by_j_ecdf[[k]](j_quantiles)) |>
abs() |>
max()
# dist_mx[i, k] <- max(abs(x_by_j_ecdf[[i]](j_quantiles) -
# x_by_j_ecdf[[k]](j_quantiles)))
dist_mx[k, i] <- dist_mx[i, k]
}
}
}
else { # distance matrix for non-numeric j_col
x_j_freq <- table(X[[x_col]], X[[j_col]]) #frequency table, rows of which will be compared
x_j_freq <- x_j_freq / as.numeric(x_int_counts)
for (i in 1:(n_bins-1)) {
for (k in (i+1):n_bins) {
# Dissimilarity measure always within [0, 1]
dist_mx[i, k] <- sum(abs(x_j_freq[i, ] -
x_j_freq[k, ])) / 2
dist_mx[k, i] <- dist_mx[i, k]
}
}
}

cdm <- cdm + dist_mx

}

# Replace any NA with the maximum distance
cdm[is.na(cdm)] <- max(cdm, na.rm = TRUE)

# Convert cumulative distance matrix to sorted indices
idxs <- cdm |>
stats::cmdscale(k = 1) |> # one-dimensional MDS representation of dist_mx
sort(index.return = TRUE) |>
(`[[`)('ix')


return(idxs)
}
3 changes: 1 addition & 2 deletions R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
# Define package-wide environment variables

# Create a package-wide environment to hold objects shared across functions.
# Its current use is only for objects that
# need to be reused across random iterations for create_p_funs().
# Its current use is only for objects that # need to be reused across random iterations for create_p_dist().
# https://r-pkgs.org/data.html#sec-data-state
package_scope <- new.env(parent = emptyenv())

Loading

0 comments on commit 2e8738f

Please sign in to comment.