From cb900970b848558d61ed399e68c8099f7fa9faed Mon Sep 17 00:00:00 2001 From: Zach Furman Date: Mon, 23 Oct 2023 04:31:49 -0400 Subject: [PATCH] Fix BAOAB/BADODAB regression #69 (#70) Typo in commit 1a43e5ea391139a52b634f68432e76f88cb2b1c5 broke BAOAB and BADODAB. This fixes the typo. --- sgmcmcjax/kernels.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sgmcmcjax/kernels.py b/sgmcmcjax/kernels.py index f37b0df..3156158 100644 --- a/sgmcmcjax/kernels.py +++ b/sgmcmcjax/kernels.py @@ -380,10 +380,10 @@ def build_baoab_kernel( estimate_gradient, init_gradient = build_gradient_estimation_fn( grad_log_post, data, batch_size ) - init_diff, (update_diff1, update_diff1), get_p_diff = baoab(dt, gamma, tau) + init_diff, (update_diff1, update_diff2), get_p_diff = baoab(dt, gamma, tau) init_fn, baoab_kernel, get_params = _build_langevin_kernel( init_diff, - (update_diff1, update_diff1), + (update_diff1, update_diff2), get_p_diff, estimate_gradient, init_gradient, @@ -416,10 +416,10 @@ def build_badodab_kernel( estimate_gradient, init_gradient = build_gradient_estimation_fn( grad_log_post, data, batch_size ) - init_diff, (update_diff1, update_diff1), get_p_diff = badodab(dt, a) + init_diff, (update_diff1, update_diff2), get_p_diff = badodab(dt, a) init_fn, baoab_kernel, get_params = _build_langevin_kernel( init_diff, - (update_diff1, update_diff1), + (update_diff1, update_diff2), get_p_diff, estimate_gradient, init_gradient,