Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LOO Difference Plot #178

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions R/loo_difference_plot.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#' Compare models across domains
#'
#' The LOO difference plot shows how the ELPD of two different models
#' changes when a predictor is varied. This can is useful for identifying
#' opportunities for model stacking or expansion.
#'
#' @param y A vector of observations. See Details.
#' @param psis_object_1,psis_object_2 If using loo version 2.0.0 or greater,
#' an object returned by the `[loo::psis()]` function (or by the
#' `[loo::loo()]` function with argument `save_psis` set to `TRUE`).
#' @param ... Currently unused.
#' @param group A grouping variable (a vector or factor) the same length
#' as `y`. Each value in group is interpreted as the group level pertaining
#' to the corresponding value of `y`. If `FALSE`, ignored.
#' @param outlier_thresh Flag values when the difference in the ELPD exceeds
#' this threshold. Defaults to `NULL`, in which case no values are flagged.
#' @param size,alpha,jitter Passed to `[ggplot2::geom_point()]` to control
#' aesthetics. `size` and `alpha` are passed to to the `size` and `alpha`
#' arguments of `[ggplot2::geom_jitter()]` to control the appearance of
#' points. `jitter` can be either a number or a vector of numbers.
#' Passing a single number will jitter variables along the x axis only, while
#' passing a vector will jitter along both axes.
#' @param sort_by_group Sort observations by `group`, then plot against an
#' arbitrary index. Plotting by index can be useful when categories have
#' very different sample sizes.
#'
#'
#' @template return-ggplot
#'
#' @template reference-vis-paper
#'
#' @examples
#'
#' library(loo)
#'
#' cbPalette <- c("#636363", "#E69F00", "#56B4E9", "#009E73",
#' "#F0E442", "#0072B2","#CC79A7")
#'
#' # Plot using groups from WHO
#'
#' plot_loo_dif(factor(GM@data$super_region_name), loo3, loo2,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like here in the doc it's plot_loo_dif but in the code it's plot_loo_variation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now none of these examples run because the data GM and loo objects (loo2, loo3) don't exist. For the Examples section in the doc we'll need to change this to a self contained example or add all this data to the loo package itself so it can be used in the example. One possibility is to use a toy example here in the doc and then potentially add a more real example (e.g. this one from the paper) in one of the package vignettes. I'll think a bit more about this too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is the only problem that's left unresolved; as I mentioned in the email, I'd like to see if it's possible to add this data to the LOO package, since I think the example here is really great.

#' group = GM@data$super_region_name, alpha = .5,
#' jitter = c(.45, .2)
#' ) +
#' xlab("Region") + scale_colour_manual(values=cbPalette)
#'
#' # Plot using groups identified with clustering
#'
#' plot_loo_dif(factor(GM@data$cluster_region), loo3, loo2,
#' group = GM@data$super_region_name, alpha = .5,
#' jitter = c(.45, .2)
#' ) +
#' xlab("Cluster Group") + scale_colour_manual(values=cbPalette)
#'
#' # Plot using an index variable to reduce crowding
#'
#' plot_loo_dif(1:2980, loo3, loo2, group = GM@data$super_region_name,
#' alpha = .5, sort_by_group = TRUE,
#' ) +
#' xlab("Index") + scale_colour_manual(values=cbPalette)
#'
#'
plot_loo_dif <-
function(y,
psis_object_1,
psis_object_2,
...,
group = NULL,
outlier_thresh = NULL,
size = 1,
alpha = 1,
jitter = 0,
sort_by_group = FALSE
){

# Adding a 0 at the end lets users provide a single number as input.
# In this case, only horizontal jitter is applied.
jitter <- c(jitter, 0)

elpdDif <- psis_object_1$pointwise[, "elpd_loo"] -
psis_object_2$pointwise[, "elpd_loo"]


if (sort_by_group){
if (identical(group, NULL) || !identical(y, 1:length(y))){
stop("ERROR: sort_by_group should only be used for grouping categorical
variables, then plotting them with an arbitrary index. You can
create such an index using `1:length(data)`.
")
}

ordering <- order(group)
elpdDif <- elpdDif[ordering]
group <- group[ordering]

}


plot <- ggplot2::ggplot(mapping=aes(y, elpdDif)) +
ggplot2::geom_hline(yintercept=0) +
ggplot2::xlab(ifelse(sort_by_group, "y", "Index")) +
ggplot2::ylab(expression(ELPD[i][1] - ELPD[i][2])) +
ggplot2::labs(color = "Groups")



if (identical(group, FALSE)){
# Don't color by group if no groups are passed
plot <- plot +
ggplot2::geom_jitter(width = jitter[1], height = jitter[2],
alpha = alpha, size = size
)
}
else{
# If group is passed, use color
plot <- plot +
ggplot2::geom_jitter(aes(color = factor(group)),
width = jitter[1], height = jitter[2],
alpha = alpha, size = size
)
}

if (!identical(outlier_thresh, NULL)){
# Flag outliers
is_outlier <- elpdDif > outlier_thresh
index <- 1:length(y)
outlier_labs <- index[is_outlier]

plot <- plot + ggplot2::annotate("text",
x = y[is_outlier],
y = elpdDif[outlier_labs],
label = outlier_labs,
size = 4
)


}

return(plot)
}