Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LOO Difference Plot #178

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions R/loo_difference_plot.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#' Compare models across domains
#'
#' The LOO difference plot shows how the ELPD of two different models
#' changes when a predictor is varied. This can is useful for identifying
#' opportunities for model stacking or expansion.
#'
#' @param y A vector of observations. See Details.
#' @param psis_object_1,psis_object_2 If using loo version 2.0.0 or greater,
#' an object returned by the `[loo::psis()]` function (or by the
#' `[loo::loo()]` function with argument `save_psis` set to `TRUE`).
#' @param ... Currently unused.
#' @param group A grouping variable (a vector or factor) the same length
#' as `y`. Each value in group is interpreted as the group level pertaining
#' to the corresponding value of `y`. If `FALSE`, ignored.
#' @param outlier_thresh Flag values when the difference in the ELPD exceeds
#' this threshold. Defaults to `NULL`, in which case no values are flagged.
#' @param size,alpha,jitter Passed to `[ggplot2::geom_point()]` to control
#' aesthetics. `size` and `alpha` are passed to to the `size` and `alpha`
#' arguments of `[ggplot2::geom_jitter()]` to control the appearance of
#' points. `jitter` can be either a number or a vector of numbers.
#' Passing a single number will jitter variables along the x axis only, while
#' passing a vector will jitter along both axes.
#' @param quantiles Boolean that determines whether to plot the quantiles of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me quantiles sounds like it would accept vector of quantiles (real valued) as, e.g., dnorm and pnorm. I suggest that quantiles parameter would actually allow to define which quantiles are plotted and the type of the plot would be selected with argument called, e.g., type. With options type='y' (default) and type='quantiles', which would allow then also extending to other possible types- Can you also include example of how these plots look like in the discussion thread of this PR?

Copy link
Author

@ParadaCarleton ParadaCarleton May 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me quantiles sounds like it would accept vector of quantiles (real valued) as, e.g., dnorm and pnorm.

This does seem like it might be confusing; I think it makes sense to remove this, since users can provide transformations of the values themselves by replacing y.

I suggest that quantiles parameter would actually allow to define which quantiles are plotted and the type of the plot would be selected with argument called, e.g., type. With options type='y' (default) and type='quantiles', which would allow then also extending to other possible types

I'm not sure what you mean here; could you elaborate?

Can you also include example of how these plots look like in the discussion thread of this PR?

Will do.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If users can provide transformations of the values themselves by replacing y, and this option is not needed then I think removing it as you did is the correct action and you don't need to care about the rest I said.

#' `y` rather than `y` itself. Useful when `y` has a very irregular
#' distribution.
#' @param sort_by_group Sort observations by `group`, then plot against an
#' arbitrary index. Plotting by index can be useful when categories have
#' very different sample sizes.
#'
#'
#' @template return-ggplot
#'
#' @template reference-vis-paper
#'
#' @examples
#'
#' library(loo)
#'
#' cbPalette <- c("#636363", "#E69F00", "#56B4E9", "#009E73",
#' "#F0E442", "#0072B2","#CC79A7")
#'
#' # Plot using groups from WHO
#'
#' plot_loo_dif(factor(GM@data$super_region_name), loo3, loo2,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like here in the doc it's plot_loo_dif but in the code it's plot_loo_variation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now none of these examples run because the data GM and loo objects (loo2, loo3) don't exist. For the Examples section in the doc we'll need to change this to a self contained example or add all this data to the loo package itself so it can be used in the example. One possibility is to use a toy example here in the doc and then potentially add a more real example (e.g. this one from the paper) in one of the package vignettes. I'll think a bit more about this too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is the only problem that's left unresolved; as I mentioned in the email, I'd like to see if it's possible to add this data to the LOO package, since I think the example here is really great.

#' group = GM@data$super_region_name, alpha = .5,
#' jitter = c(.45, .2)
#' ) +
#' xlab("Region") + scale_colour_manual(values=cbPalette)
#'
#' # Plot using groups identified with clustering
#'
#' plot_loo_dif(factor(GM@data$cluster_region), loo3, loo2,
#' group = GM@data$super_region_name, alpha = .5,
#' jitter = c(.45, .2)
#' ) +
#' xlab("Cluster Group") + scale_colour_manual(values=cbPalette)
#'
#' # Plot using an index variable to reduce crowding
#'
#' plot_loo_dif(1:2980, loo3, loo2, group = GM@data$super_region_name,
#' alpha = .5, sort_by_group = TRUE,
#' ) +
#' xlab("Index") + scale_colour_manual(values=cbPalette)
#'
#'
plot_loo_dif <-
function(y,
psis_object_1,
psis_object_2,
...,
group = NULL,
outlier_thresh = NULL,
size = 1,
alpha = 1,
jitter = 0,
quantiles = FALSE,
sort_by_group = FALSE
){

# Adding a 0 at the end lets users provide a single number as input.
# In this case, only horizontal jitter is applied.
jitter <- c(jitter, 0)

elpdDif <- psis_object_1$pointwise[, "elpd_loo"] -
psis_object_2$pointwise[, "elpd_loo"]


if (quantiles){
# If quantiles is set to true, replace all y values with their quantile
y <- ecdf(y)(y)
}


if (sort_by_group){
if (identical(group, NULL) || !identical(y, 1:length(y))){
stop("ERROR: sort_by_group should only be used for grouping categorical
variables, then plotting them with an arbitrary index. You can
create such an index using `1:length(data)`.
")
}

ordering <- order(group)
elpdDif <- elpdDif[ordering]
group <- group[ordering]

}


plot <- ggplot2::ggplot(mapping=aes(y, elpdDif)) +
ggplot2::geom_hline(yintercept=0) +
ggplot2::xlab(ifelse(sort_by_group, "y", "Index")) +
ggplot2::ylab(expression(ELPD[i][1] - ELPD[i][2])) +
ggplot2::labs(color = "Groups")



if (identical(group, FALSE)){
# Don't color by group if no groups are passed
plot <- plot +
ggplot2::geom_jitter(width = jitter[1], height = jitter[2],
alpha = alpha, size = size
)
}
else{
# If group is passed, use color
plot <- plot +
ggplot2::geom_jitter(aes(color = factor(group)),
width = jitter[1], height = jitter[2],
alpha = alpha, size = size
)
}

if (!identical(outlier_thresh, NULL)){
# Flag outliers
is_outlier <- elpdDif > outlier_thresh
index <- 1:length(y)
outlier_labs <- index[is_outlier]

plot <- plot + ggplot2::annotate("text",
x = y[is_outlier],
y = elpdDif[outlier_labs],
label = outlier_labs,
size = 4
)


}

return(plot)
}