-
-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LOO Difference Plot #178
base: master
Are you sure you want to change the base?
Changes from 6 commits
dca229b
8786be3
8e329d0
90ecb40
48b97b6
5399cb8
4d085d6
d4b8641
a0417dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
#' Compare models across domains | ||
#' | ||
#' The LOO difference plot shows how the ELPD of two different models | ||
#' changes when a predictor is varied. This can is useful for identifying | ||
#' opportunities for model stacking or expansion. | ||
#' | ||
#' @param y A vector of observations. See Details. | ||
#' @param psis_object_1,psis_object_2 If using loo version 2.0.0 or greater, | ||
#' an object returned by the `[loo::psis()]` function (or by the | ||
#' `[loo::loo()]` function with argument `save_psis` set to `TRUE`). | ||
#' @param ... Currently unused. | ||
#' @param group A grouping variable (a vector or factor) the same length | ||
#' as `y`. Each value in group is interpreted as the group level pertaining | ||
#' to the corresponding value of `y`. If `FALSE`, ignored. | ||
#' @param outlier_thresh Flag values when the difference in the ELPD exceeds | ||
#' this threshold. Defaults to `NULL`, in which case no values are flagged. | ||
#' @param size,alpha,jitter Passed to `[ggplot2::geom_point()]` to control | ||
#' aesthetics. `size` and `alpha` are passed to to the `size` and `alpha` | ||
#' arguments of `[ggplot2::geom_jitter()]` to control the appearance of | ||
#' points. `jitter` can be either a number or a vector of numbers. | ||
#' Passing a single number will jitter variables along the x axis only, while | ||
#' passing a vector will jitter along both axes. | ||
#' @param sort_by_group Sort observations by `group`, then plot against an | ||
#' arbitrary index. Plotting by index can be useful when categories have | ||
#' very different sample sizes. | ||
#' | ||
#' | ||
#' @template return-ggplot | ||
#' | ||
#' @template reference-vis-paper | ||
#' | ||
#' @examples | ||
#' | ||
#' library(loo) | ||
#' | ||
#' cbPalette <- c("#636363", "#E69F00", "#56B4E9", "#009E73", | ||
#' "#F0E442", "#0072B2","#CC79A7") | ||
#' | ||
#' # Plot using groups from WHO | ||
#' | ||
#' plot_loo_dif(factor(GM@data$super_region_name), loo3, loo2, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now none of these examples run because the data There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this is the only problem that's left unresolved; as I mentioned in the email, I'd like to see if it's possible to add this data to the LOO package, since I think the example here is really great. |
||
#' group = GM@data$super_region_name, alpha = .5, | ||
#' jitter = c(.45, .2) | ||
#' ) + | ||
#' xlab("Region") + scale_colour_manual(values=cbPalette) | ||
#' | ||
#' # Plot using groups identified with clustering | ||
#' | ||
#' plot_loo_dif(factor(GM@data$cluster_region), loo3, loo2, | ||
#' group = GM@data$super_region_name, alpha = .5, | ||
#' jitter = c(.45, .2) | ||
#' ) + | ||
#' xlab("Cluster Group") + scale_colour_manual(values=cbPalette) | ||
#' | ||
#' # Plot using an index variable to reduce crowding | ||
#' | ||
#' plot_loo_dif(1:2980, loo3, loo2, group = GM@data$super_region_name, | ||
#' alpha = .5, sort_by_group = TRUE, | ||
#' ) + | ||
#' xlab("Index") + scale_colour_manual(values=cbPalette) | ||
#' | ||
#' | ||
plot_loo_dif <- | ||
function(y, | ||
psis_object_1, | ||
psis_object_2, | ||
..., | ||
group = NULL, | ||
outlier_thresh = NULL, | ||
size = 1, | ||
alpha = 1, | ||
jitter = 0, | ||
sort_by_group = FALSE | ||
){ | ||
|
||
# Adding a 0 at the end lets users provide a single number as input. | ||
# In this case, only horizontal jitter is applied. | ||
jitter <- c(jitter, 0) | ||
|
||
elpdDif <- psis_object_1$pointwise[, "elpd_loo"] - | ||
psis_object_2$pointwise[, "elpd_loo"] | ||
|
||
|
||
if (sort_by_group){ | ||
if (identical(group, NULL) || !identical(y, 1:length(y))){ | ||
stop("ERROR: sort_by_group should only be used for grouping categorical | ||
variables, then plotting them with an arbitrary index. You can | ||
create such an index using `1:length(data)`. | ||
") | ||
} | ||
|
||
ordering <- order(group) | ||
elpdDif <- elpdDif[ordering] | ||
group <- group[ordering] | ||
|
||
} | ||
|
||
|
||
plot <- ggplot2::ggplot(mapping=aes(y, elpdDif)) + | ||
ggplot2::geom_hline(yintercept=0) + | ||
ggplot2::xlab(ifelse(sort_by_group, "y", "Index")) + | ||
ggplot2::ylab(expression(ELPD[i][1] - ELPD[i][2])) + | ||
ggplot2::labs(color = "Groups") | ||
|
||
|
||
|
||
if (identical(group, FALSE)){ | ||
# Don't color by group if no groups are passed | ||
plot <- plot + | ||
ggplot2::geom_jitter(width = jitter[1], height = jitter[2], | ||
alpha = alpha, size = size | ||
) | ||
} | ||
else{ | ||
# If group is passed, use color | ||
plot <- plot + | ||
ggplot2::geom_jitter(aes(color = factor(group)), | ||
width = jitter[1], height = jitter[2], | ||
alpha = alpha, size = size | ||
) | ||
} | ||
|
||
if (!identical(outlier_thresh, NULL)){ | ||
# Flag outliers | ||
is_outlier <- elpdDif > outlier_thresh | ||
index <- 1:length(y) | ||
outlier_labs <- index[is_outlier] | ||
|
||
plot <- plot + ggplot2::annotate("text", | ||
x = y[is_outlier], | ||
y = elpdDif[outlier_labs], | ||
label = outlier_labs, | ||
size = 4 | ||
) | ||
|
||
|
||
} | ||
|
||
return(plot) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like here in the doc it's
plot_loo_dif
but in the code it'splot_loo_variation
.