-
-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LOO Difference Plot #178
base: master
Are you sure you want to change the base?
Changes from 5 commits
dca229b
8786be3
8e329d0
90ecb40
48b97b6
5399cb8
4d085d6
d4b8641
a0417dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
#' Compare models across domains | ||
#' | ||
#' The LOO difference plot shows how the ELPD of two different models | ||
#' changes when a predictor is varied. This can is useful for identifying | ||
#' opportunities for model stacking or expansion. | ||
#' | ||
#' @param y A vector of observations. See Details. | ||
#' @param psis_object_1,psis_object_2 If using loo version 2.0.0 or greater, | ||
#' an object returned by the `[loo::psis()]` function (or by the | ||
#' `[loo::loo()]` function with argument `save_psis` set to `TRUE`). | ||
#' @param ... Currently unused. | ||
#' @param group A grouping variable (a vector or factor) the same length | ||
#' as `y`. Each value in group is interpreted as the group level pertaining | ||
#' to the corresponding value of `y`. If `FALSE`, ignored. | ||
#' @param outlier_thresh Flag values when the difference in the ELPD exceeds | ||
#' this threshold. Defaults to `NULL`, in which case no values are flagged. | ||
#' @param size,alpha,jitter Passed to `[ggplot2::geom_point()]` to control | ||
#' aesthetics. `size` and `alpha` are passed to to the `size` and `alpha` | ||
#' arguments of `[ggplot2::geom_jitter()]` to control the appearance of | ||
#' points. `jitter` can be either a number or a vector of numbers. | ||
#' Passing a single number will jitter variables along the x axis only, while | ||
#' passing a vector will jitter along both axes. | ||
#' @param quantiles Boolean that determines whether to plot the quantiles of | ||
#' `y` rather than `y` itself. Useful when `y` has a very irregular | ||
#' distribution. | ||
#' @param sort_by_group Sort observations by `group`, then plot against an | ||
#' arbitrary index. Plotting by index can be useful when categories have | ||
#' very different sample sizes. | ||
#' | ||
#' | ||
#' @template return-ggplot | ||
#' | ||
#' @template reference-vis-paper | ||
#' | ||
#' @examples | ||
#' | ||
#' library(loo) | ||
#' | ||
#' cbPalette <- c("#636363", "#E69F00", "#56B4E9", "#009E73", | ||
#' "#F0E442", "#0072B2","#CC79A7") | ||
#' | ||
#' # Plot using groups from WHO | ||
#' | ||
#' plot_loo_dif(factor(GM@data$super_region_name), loo3, loo2, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like here in the doc it's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now none of these examples run because the data There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this is the only problem that's left unresolved; as I mentioned in the email, I'd like to see if it's possible to add this data to the LOO package, since I think the example here is really great. |
||
#' group = GM@data$super_region_name, alpha = .5, | ||
#' jitter = c(.45, .2) | ||
#' ) + | ||
#' xlab("Region") + scale_colour_manual(values=cbPalette) | ||
#' | ||
#' # Plot using groups identified with clustering | ||
#' | ||
#' plot_loo_dif(factor(GM@data$cluster_region), loo3, loo2, | ||
#' group = GM@data$super_region_name, alpha = .5, | ||
#' jitter = c(.45, .2) | ||
#' ) + | ||
#' xlab("Cluster Group") + scale_colour_manual(values=cbPalette) | ||
#' | ||
#' # Plot using an index variable to reduce crowding | ||
#' | ||
#' plot_loo_dif(1:2980, loo3, loo2, group = GM@data$super_region_name, | ||
#' alpha = .5, sort_by_group = TRUE, | ||
#' ) + | ||
#' xlab("Index") + scale_colour_manual(values=cbPalette) | ||
#' | ||
#' | ||
plot_loo_dif <- | ||
function(y, | ||
psis_object_1, | ||
psis_object_2, | ||
..., | ||
group = NULL, | ||
outlier_thresh = NULL, | ||
size = 1, | ||
alpha = 1, | ||
jitter = 0, | ||
quantiles = FALSE, | ||
sort_by_group = FALSE | ||
){ | ||
|
||
# Adding a 0 at the end lets users provide a single number as input. | ||
# In this case, only horizontal jitter is applied. | ||
jitter <- c(jitter, 0) | ||
|
||
elpdDif <- psis_object_1$pointwise[, "elpd_loo"] - | ||
psis_object_2$pointwise[, "elpd_loo"] | ||
|
||
|
||
if (quantiles){ | ||
# If quantiles is set to true, replace all y values with their quantile | ||
y <- ecdf(y)(y) | ||
} | ||
|
||
|
||
if (sort_by_group){ | ||
if (identical(group, NULL) || !identical(y, 1:length(y))){ | ||
stop("ERROR: sort_by_group should only be used for grouping categorical | ||
variables, then plotting them with an arbitrary index. You can | ||
create such an index using `1:length(data)`. | ||
") | ||
} | ||
|
||
ordering <- order(group) | ||
elpdDif <- elpdDif[ordering] | ||
group <- group[ordering] | ||
|
||
} | ||
|
||
|
||
plot <- ggplot2::ggplot(mapping=aes(y, elpdDif)) + | ||
ggplot2::geom_hline(yintercept=0) + | ||
ggplot2::xlab(ifelse(sort_by_group, "y", "Index")) + | ||
ggplot2::ylab(expression(ELPD[i][1] - ELPD[i][2])) + | ||
ggplot2::labs(color = "Groups") | ||
|
||
|
||
|
||
if (identical(group, FALSE)){ | ||
# Don't color by group if no groups are passed | ||
plot <- plot + | ||
ggplot2::geom_jitter(width = jitter[1], height = jitter[2], | ||
alpha = alpha, size = size | ||
) | ||
} | ||
else{ | ||
# If group is passed, use color | ||
plot <- plot + | ||
ggplot2::geom_jitter(aes(color = factor(group)), | ||
width = jitter[1], height = jitter[2], | ||
alpha = alpha, size = size | ||
) | ||
} | ||
|
||
if (!identical(outlier_thresh, NULL)){ | ||
# Flag outliers | ||
is_outlier <- elpdDif > outlier_thresh | ||
index <- 1:length(y) | ||
outlier_labs <- index[is_outlier] | ||
|
||
plot <- plot + ggplot2::annotate("text", | ||
x = y[is_outlier], | ||
y = elpdDif[outlier_labs], | ||
label = outlier_labs, | ||
size = 4 | ||
) | ||
|
||
|
||
} | ||
|
||
return(plot) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me
quantiles
sounds like it would accept vector of quantiles (real valued) as, e.g., dnorm and pnorm. I suggest thatquantiles
parameter would actually allow to define which quantiles are plotted and the type of the plot would be selected with argument called, e.g.,type
. With options type='y' (default) and type='quantiles', which would allow then also extending to other possible types- Can you also include example of how these plots look like in the discussion thread of this PR?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does seem like it might be confusing; I think it makes sense to remove this, since users can provide transformations of the values themselves by replacing
y
.I'm not sure what you mean here; could you elaborate?
Will do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If users can provide transformations of the values themselves by replacing y, and this option is not needed then I think removing it as you did is the correct action and you don't need to care about the rest I said.