Skip to content

Commit

Permalink
Updated GBM ALE objects to respond to ggplot2 3.5.0 (#2).
Browse files Browse the repository at this point in the history
  • Loading branch information
tripartio committed Feb 5, 2024
1 parent bed304a commit dafe2d7
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 16 deletions.
Binary file modified download/gbm.data_model.rds
Binary file not shown.
Binary file modified download/gbm_ale_ixn_link.rds
Binary file not shown.
Binary file modified download/gbm_ale_ixn_prob.rds
Binary file not shown.
Binary file modified download/gbm_ale_link.rds
Binary file not shown.
Binary file modified download/gbm_ale_prob.rds
Binary file not shown.
33 changes: 17 additions & 16 deletions vignettes/ale-ALEPlot.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ data <-
Although gradient boosted trees generally perform quite well, they are rather slow. Rather than having you wait for it to run, the code here downloads a pretrained GBM model. However, the code used to generate it is provided in comments so that you can see it and run it yourself if you want to. Note that the model calls is based on `data[,-c(3,4)]`, which drops the third and fourth variables (`fnlwgt` and `education`, respectively).

```{r gbm model}
# To generate the code, uncomment the following lines.
# But it is slow, so this vignette loads a pre-created model object.
# # To generate the code, uncomment the following lines.
# # But they are slow, so this vignette loads a pre-created model object.
# set.seed(0)
# gbm.data <- gbm(higher_income ~ ., data= data[,-c(3,4)],
# distribution = "bernoulli", n.trees=6000, shrinkage=0.02,
Expand Down Expand Up @@ -274,19 +274,18 @@ We display all the plots because it is easy to do so with the `{ale}` package bu

```{r ale one-way link, fig.width=7, fig.height=20}
# Custom predict function that returns log odds
yhat <- function(object, newdata) {
as.numeric(
predict(object, newdata, n.trees = 6000,
type="link") # return log odds
)
yhat <- function(object, newdata, type) {
predict(object, newdata, type='link', n.trees = 6000) |> # return log odds
as.numeric()
}
# Generate ALE data for all variables
# # To generate the code, uncomment the following lines.
# # But it is slow, so this vignette loads a pre-created model object.
# gbm_ale_link <- ale(
# data[,-c(3,4)], gbm.data,
# # data[,-c(3,4)], gbm.data,
# data, gbm.data,
# pred_fun = yhat,
# x_intervals = 500,
# rug_sample_size = 600, # technical issue: rug_sample_size must be > x_intervals + 1
Expand All @@ -307,13 +306,13 @@ Now we generate ALE data for all two-way interactions and then plot them. Again,
# # To generate the code, uncomment the following lines.
# # But it is slow, so this vignette loads a pre-created model object.
# gbm_ale_ixn_link <- ale_ixn(
# data[,-c(3,4)], gbm.data,
# # data[,-c(3,4)], gbm.data,
# data, gbm.data,
# pred_fun = yhat,
# x_intervals = 500,
# rug_sample_size = 600, # technical issue: rug_sample_size must be > x_intervals + 1
# relative_y = 'zero', # compatibility with ALEPlot
# model_packages = 'gbm' # required for parallel processing
# )
# saveRDS(gbm_ale_ixn_link, file.choose())
gbm_ale_ixn_link <- url('https://github.com/Tripartio/ale/raw/main/download/gbm_ale_ixn_link.rds') |>
Expand Down Expand Up @@ -342,7 +341,7 @@ As we can see, the shapes of the plots are similar, but the y axes are more easi

```{r ale one-way prob, fig.width=7, fig.height=20}
# Custom predict function that returns predicted probabilities
yhat <- function(object, newdata) {
yhat <- function(object, newdata, type) {
as.numeric(
predict(object, newdata, n.trees = 6000,
type="response") # return predicted probabilities
Expand All @@ -354,11 +353,12 @@ yhat <- function(object, newdata) {
# # To generate the code, uncomment the following lines.
# # But it is slow, so this vignette loads a pre-created model object.
# gbm_ale_prob <- ale(
# data[,-c(3,4)], gbm.data,
# # data[,-c(3,4)], gbm.data,
# data, gbm.data,
# pred_fun = yhat,
# x_intervals = 500,
# rug_sample_size = 600, # technical issue: rug_sample_size must be > x_intervals + 1
# model_packages = 'nnet' # required for parallel processing
# model_packages = 'gbm' # required for parallel processing
# )
# saveRDS(gbm_ale_prob, file.choose())
gbm_ale_prob <- url('https://github.com/Tripartio/ale/raw/main/download/gbm_ale_prob.rds') |>
Expand All @@ -371,10 +371,11 @@ gridExtra::grid.arrange(grobs = gbm_ale_prob$plots, ncol = 2)
Finally, we again generate two-way interactions, this time based on probabilities instead of on log odds. However, probabilities might not be the best choice for indicating interactions because, as we see from the rugs in the one-way ALE plots, the GBM model heavily concentrates its probabilities in the extremes near 0 and 1. Thus, the plots' suggestions of strong interactions are likely exaggerated. In this case, the log odds ALEs shown above are probably more relevant.

```{r ale ixn prob, fig.width=7, fig.height=5}
# # To generate the code, uncomment the following lines.
# # But it is slow, so this vignette loads a pre-created model object.
# To generate the code, uncomment the following lines.
# But it is slow, so this vignette loads a pre-created model object.
# gbm_ale_ixn_prob <- ale_ixn(
# data[,-c(3,4)], gbm.data,
# # data[,-c(3,4)], gbm.data,
# data, gbm.data,
# pred_fun = yhat,
# x_intervals = 500,
# rug_sample_size = 600, # technical issue: rug_sample_size must be > x_intervals + 1
Expand Down

0 comments on commit dafe2d7

Please sign in to comment.