Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nested R-hat convergence diagnostic #303

Merged
merged 13 commits into from
Oct 29, 2023

Conversation

n-kall
Copy link
Collaborator

@n-kall n-kall commented Oct 11, 2023

Summary

This adds the nested R-hat convergence diagnostic which is useful when running many short chains. Chains need to be grouped into superchains (given as an additional argument). This addresses issue #256

Nested R-hat is described in:
Charles C. Margossian, Matthew D. Hoffman, Pavel Sountsov, Lionel Riou-Durand, Aki Vehtari, Andrew Gelman (2022). Nested Rˆ: Assessing the convergence of Markov chain Monte Carlo when running many short chains. https://arxiv.org/abs/2110.13017

Status

Working.
Currently work in progress. Functionality seems to work, but untested. Opening this draft PR early as place for discussion and further development.

TODOs:

  • add rhat_nested.rvar method Done
  • add tests

Example usage

Example usage:

x <- example_draws()
example_superchain_ids <- c(1,1,2,2) # first two chains are part of superchain 1, second two chains are part of superchain 2
summarise_draws(x, rhat_nested, .args = list(superchain_ids = example_superchain_ids))

Copyright and Licensing

By submitting this pull request, the copyright holder is agreeing to
license the submitted work under the following licenses:

@codecov-commenter
Copy link

codecov-commenter commented Oct 11, 2023

Codecov Report

Merging #303 (afd6cff) into master (e23467b) will decrease coverage by 0.07%.
Report is 4 commits behind head on master.
The diff coverage is 89.18%.

❗ Current head afd6cff differs from pull request most recent head 32f97c4. Consider uploading reports for the commit 32f97c4 to get more accurate results

@@            Coverage Diff             @@
##           master     #303      +/-   ##
==========================================
- Coverage   95.66%   95.60%   -0.07%     
==========================================
  Files          46       47       +1     
  Lines        3645     3682      +37     
==========================================
+ Hits         3487     3520      +33     
- Misses        158      162       +4     
Files Coverage Δ
R/convergence.R 91.16% <ø> (ø)
R/nested_rhat.R 89.18% <89.18%> (ø)

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@github-actions
Copy link

This is how benchmark results would change (along with a 95% confidence interval in relative change) if da6fed7 is merged into master:

  •   :ballot_box_with_check:as_draws_array: 146ms -> 146ms [-0.78%, +0.96%]
  •   :ballot_box_with_check:as_draws_df: 53ms -> 52.5ms [-2.24%, +0.47%]
  •   :ballot_box_with_check:as_draws_list: 275ms -> 271ms [-2.61%, +0.16%]
  •   :ballot_box_with_check:as_draws_matrix: 46.6ms -> 46.6ms [-1.76%, +2.08%]
  •   :ballot_box_with_check:as_draws_rvars: 252ms -> 249ms [-3.38%, +1.56%]
  •   :ballot_box_with_check:summarise_draws_100_variables: 1s -> 990ms [-4.65%, +2.04%]
  •   :ballot_box_with_check:summarise_draws_10_variables: 119ms -> 119ms [-5.88%, +6.46%]
    Further explanation regarding interpretation and methodology can be found in the documentation.

@github-actions
Copy link

This is how benchmark results would change (along with a 95% confidence interval in relative change) if e1d22ff is merged into master:

  •   :ballot_box_with_check:as_draws_array: 144ms -> 142ms [-6.15%, +3.25%]
  •   :ballot_box_with_check:as_draws_df: 45.6ms -> 45.6ms [-3.51%, +3.74%]
  •   :ballot_box_with_check:as_draws_list: 241ms -> 249ms [-5.65%, +12.82%]
  •   :ballot_box_with_check:as_draws_matrix: 44.3ms -> 45ms [-2.96%, +5.94%]
  •   :ballot_box_with_check:as_draws_rvars: 223ms -> 219ms [-5.41%, +1.94%]
  •   :ballot_box_with_check:summarise_draws_100_variables: 974ms -> 988ms [-2.19%, +5.19%]
  •   :ballot_box_with_check:summarise_draws_10_variables: 105ms -> 105ms [-1.1%, +1.77%]
    Further explanation regarding interpretation and methodology can be found in the documentation.

@github-actions
Copy link

This is how benchmark results would change (along with a 95% confidence interval in relative change) if e4d63a9 is merged into master:

  •   :ballot_box_with_check:as_draws_array: 137ms -> 137ms [-0.97%, +0.08%]
  •   :ballot_box_with_check:as_draws_df: 47.8ms -> 47.8ms [-0.6%, +0.75%]
  •   :ballot_box_with_check:as_draws_list: 247ms -> 248ms [-0.56%, +1.07%]
  •   :ballot_box_with_check:as_draws_matrix: 42.6ms -> 42.3ms [-1.67%, +0.43%]
  •   :ballot_box_with_check:as_draws_rvars: 227ms -> 227ms [-1.51%, +1.41%]
  •   :ballot_box_with_check:summarise_draws_100_variables: 946ms -> 950ms [-0.02%, +0.84%]
  •   :ballot_box_with_check:summarise_draws_10_variables: 105ms -> 105ms [-0.66%, +0.8%]
    Further explanation regarding interpretation and methodology can be found in the documentation.

@n-kall n-kall marked this pull request as ready for review October 17, 2023 11:29
@n-kall
Copy link
Collaborator Author

n-kall commented Oct 17, 2023

One thing to check: Should the rhat_nested with 1 chain per superchain be exactly equal to rhat_basic? Do we have some reference data to compare to?

Currently rhat_basic gives slightly different values, so perhaps there is an issue in the current implementation

summarise_draws(example_draws(), rhat_basic, rhat_nested, .args = list(superchain_ids = c(1,2,3,4))

# A tibble: 10 × 3
   variable   rhat_basic rhat_nested
   <chr>           <dbl>       <dbl>
 1 mu       0.9979105738 1.003389884
 2 tau      1.009976393  1.003445833
 3 theta[1] 1.014966741  1.007488973
 4 theta[2] 0.9981447065 1.002107795
 5 theta[3] 1.000405648  1.008267202
 6 theta[4] 0.9957624905 1.000607206
 7 theta[5] 0.9987923422 1.007260715
 8 theta[6] 0.9982158544 1.002237514
 9 theta[7] 1.002538583  1.003347448
10 theta[8] 0.9933503132 1.003124265

@paul-buerkner
Copy link
Collaborator

@avehtari do you want this feature to already be in the new posterior released to be release shortly? Just asking so I can set priority accordingly.

@n-kall
Copy link
Collaborator Author

n-kall commented Oct 25, 2023

I've now added input checks for the superchain_ids and NA/Inf values in the chains, and corresponding tests.

@paul-buerkner
Copy link
Collaborator

Great! Is there anything else you want to add from your side? Or is this ready for me to check and then merge?

@paul-buerkner
Copy link
Collaborator

I have check and things look good to me. @n-kall do I have your OK to merge?

@github-actions
Copy link

This is how benchmark results would change (along with a 95% confidence interval in relative change) if 10bdfc5 is merged into master:

  •   :ballot_box_with_check:as_draws_array: 135ms -> 135ms [-1.17%, +0.74%]
  •   :ballot_box_with_check:as_draws_df: 41.3ms -> 41.7ms [-0.12%, +2.04%]
  •   :ballot_box_with_check:as_draws_list: 219ms -> 218ms [-2.87%, +1.81%]
  •   :ballot_box_with_check:as_draws_matrix: 39.5ms -> 39.7ms [-1.6%, +2.58%]
  •   :ballot_box_with_check:as_draws_rvars: 192ms -> 191ms [-1.42%, +1.04%]
  • ❗🐌summarise_draws_100_variables: 929ms -> 943ms [+0.57%, +2.24%]
  •   :ballot_box_with_check:summarise_draws_10_variables: 103ms -> 103ms [-0.68%, +0.64%]
    Further explanation regarding interpretation and methodology can be found in the documentation.

@n-kall
Copy link
Collaborator Author

n-kall commented Oct 25, 2023

@paul-buerkner hold off on merging, as Charles might take a look beforehand. We'll let you know when it's ready for merge

@n-kall
Copy link
Collaborator Author

n-kall commented Oct 25, 2023

There's still a discrepancy between rhat_basic (without splitting) and rhat_nested with 1 chain per superchain.

summarise_draws(
  example_draws(),
  rhat_basic_nosplit = ~rhat_basic(.x, split = FALSE),
  rhat_nested = ~rhat_nested(.x, superchain_ids = c(1, 2, 3, 4))
)
# # A tibble: 10 × 3
#    variable rhat_basic_nosplit rhat_nested
#    <chr>                 <dbl>       <dbl>
#  1 mu                  0.99839      1.0034
#  2 tau                 0.99845      1.0034
#  3 theta[1]            1.0025       1.0075
#  4 theta[2]            0.99711      1.0021
#  5 theta[3]            1.0033       1.0083
#  6 theta[4]            0.99560      1.0006
#  7 theta[5]            1.0023       1.0073
#  8 theta[6]            0.99724      1.0022
#  9 theta[7]            0.99835      1.0033
# 10 theta[8]            0.99813      1.0031

@n-kall
Copy link
Collaborator Author

n-kall commented Oct 25, 2023

There's still a discrepancy between rhat_basic (without splitting) and rhat_nested with 1 chain per superchain.

summarise_draws(
  example_draws(),
  rhat_basic_nosplit = ~rhat_basic(.x, split = FALSE),
  rhat_nested = ~rhat_nested(.x, superchain_ids = c(1, 2, 3, 4))
)
# # A tibble: 10 × 3
#    variable rhat_basic_nosplit rhat_nested
#    <chr>                 <dbl>       <dbl>
#  1 mu                  0.99839      1.0034
#  2 tau                 0.99845      1.0034
#  3 theta[1]            1.0025       1.0075
#  4 theta[2]            0.99711      1.0021
#  5 theta[3]            1.0033       1.0083
#  6 theta[4]            0.99560      1.0006
#  7 theta[5]            1.0023       1.0073
#  8 theta[6]            0.99724      1.0022
#  9 theta[7]            0.99835      1.0033
# 10 theta[8]            0.99813      1.0031

Ok, it pays to check the footnotes. Page 7 of Margossian et al. states: "The original R-hat uses a slightly different estimate for the within-chain variance when computing the numerator in R-hat. There W is scaled by 1/N , rather than 1/(N − 1). This explains why occasionally R-hat < 1. This is of little concern when N is large, but we care about the case where N is small, and we therefore adjust the R-hat statistic slightly." So I think this discrepancy is fine

@github-actions
Copy link

This is how benchmark results would change (along with a 95% confidence interval in relative change) if 9500ca2 is merged into master:

  •   :ballot_box_with_check:as_draws_array: 163ms -> 163ms [-0.75%, +0.33%]
  •   :ballot_box_with_check:as_draws_df: 63ms -> 55.6ms [-36.38%, +12.82%]
  •   :ballot_box_with_check:as_draws_list: 294ms -> 294ms [-1.06%, +1.32%]
  •   :ballot_box_with_check:as_draws_matrix: 51ms -> 51ms [-1.27%, +1.1%]
  •   :ballot_box_with_check:as_draws_rvars: 269ms -> 270ms [-0.46%, +1.37%]
  • ❗🐌summarise_draws_100_variables: 1.13s -> 1.14s [+0.39%, +1.56%]
  •   :ballot_box_with_check:summarise_draws_10_variables: 126ms -> 128ms [-3.08%, +5.92%]
    Further explanation regarding interpretation and methodology can be found in the documentation.

@github-actions
Copy link

This is how benchmark results would change (along with a 95% confidence interval in relative change) if 8fce13a is merged into master:

  •   :ballot_box_with_check:as_draws_array: 163ms -> 163ms [-1.42%, +1.74%]
  •   :ballot_box_with_check:as_draws_df: 58.2ms -> 56.4ms [-6.79%, +0.74%]
  •   :ballot_box_with_check:as_draws_list: 297ms -> 292ms [-4.93%, +1.99%]
  •   :ballot_box_with_check:as_draws_matrix: 51ms -> 50.7ms [-3.82%, +2.34%]
  •   :ballot_box_with_check:as_draws_rvars: 271ms -> 273ms [-2.07%, +3.1%]
  •   :ballot_box_with_check:summarise_draws_100_variables: 1.19s -> 1.22s [-1.58%, +6.26%]
  •   :ballot_box_with_check:summarise_draws_10_variables: 128ms -> 127ms [-4.97%, +4.05%]
    Further explanation regarding interpretation and methodology can be found in the documentation.

@n-kall n-kall changed the title [WIP] Add nested R-hat convergence diagnostic Add nested R-hat convergence diagnostic Oct 26, 2023
Copy link

@charlesm93 charlesm93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me! Only changes I would put are for the doc:

  • indicate that the referenced preprint is version 4
  • in the documentation, explain the slight discrepancy between nRhat and Rhat because the former is lower-bounded by 1 (footnote 1 in the preprint).

if (nchains_per_superchain != min(superchain_id_table)) {
warning_no_call("Number of chains per superchain is not the same for ",
"each superchain, returning NA.")
return(NA_real_)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically we could define nRhat to work on superchains with different sizes, but we didn't do this in the paper. I don't see a strong motivation for addressing this, but we can think about it in the future.

#' Charles C. Margossian, Matthew D. Hoffman, Pavel Sountsov, Lionel
#' Riou-Durand, Aki Vehtari and Andrew Gelman (2023). Nested R-hat:
#' Assessing the convergence of Markov chain Monte Carlo when running
#' many short chains. arxiv:arXiv:2110.13017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you're listing equation numbers and the paper is still under review, make sure to list which version of the preprint you're referencing (the latest version is v4)

add details section with explanation of superchains and discrepancy
between Rhat and nested Rhat calculation, specify version of preprint
@n-kall
Copy link
Collaborator Author

n-kall commented Oct 28, 2023

I've now updated the docs as @charlesm93 suggested and added some further details. I think it's ready to merge

@github-actions
Copy link

This is how benchmark results would change (along with a 95% confidence interval in relative change) if afd6cff is merged into master:

  •   :ballot_box_with_check:as_draws_array: 189ms -> 188ms [-3.53%, +2%]
  •   :ballot_box_with_check:as_draws_df: 63.3ms -> 63.5ms [-2.63%, +3.14%]
  •   :ballot_box_with_check:as_draws_list: 339ms -> 340ms [-2.42%, +2.94%]
  •   :ballot_box_with_check:as_draws_matrix: 64ms -> 63.6ms [-9.74%, +8.46%]
  •   :ballot_box_with_check:as_draws_rvars: 310ms -> 311ms [-1.96%, +2.23%]
  • ❗🐌summarise_draws_100_variables: 1.41s -> 1.5s [+5.86%, +7.71%]
  •   :ballot_box_with_check:summarise_draws_10_variables: 156ms -> 157ms [-3.34%, +3.65%]
    Further explanation regarding interpretation and methodology can be found in the documentation.

@paul-buerkner
Copy link
Collaborator

Thank you! The failing tests seem to be unrelated to this PR so I will merge it now.

@paul-buerkner paul-buerkner merged commit 6add3e6 into stan-dev:master Oct 29, 2023
7 of 10 checks passed
@n-kall n-kall deleted the nested_rhat branch October 30, 2023 07:40
@n-kall n-kall mentioned this pull request Oct 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants