Skip to content

Commit

Permalink
Fix BAOAB/BADODAB regression #69 (#70)
Browse files Browse the repository at this point in the history
Typo in commit 1a43e5e broke BAOAB and BADODAB. This fixes the typo.
  • Loading branch information
zfurman56 authored Oct 23, 2023
1 parent fc33234 commit cb90097
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions sgmcmcjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,10 @@ def build_baoab_kernel(
estimate_gradient, init_gradient = build_gradient_estimation_fn(
grad_log_post, data, batch_size
)
init_diff, (update_diff1, update_diff1), get_p_diff = baoab(dt, gamma, tau)
init_diff, (update_diff1, update_diff2), get_p_diff = baoab(dt, gamma, tau)
init_fn, baoab_kernel, get_params = _build_langevin_kernel(
init_diff,
(update_diff1, update_diff1),
(update_diff1, update_diff2),
get_p_diff,
estimate_gradient,
init_gradient,
Expand Down Expand Up @@ -416,10 +416,10 @@ def build_badodab_kernel(
estimate_gradient, init_gradient = build_gradient_estimation_fn(
grad_log_post, data, batch_size
)
init_diff, (update_diff1, update_diff1), get_p_diff = badodab(dt, a)
init_diff, (update_diff1, update_diff2), get_p_diff = badodab(dt, a)
init_fn, baoab_kernel, get_params = _build_langevin_kernel(
init_diff,
(update_diff1, update_diff1),
(update_diff1, update_diff2),
get_p_diff,
estimate_gradient,
init_gradient,
Expand Down

0 comments on commit cb90097

Please sign in to comment.