diff --git a/sgmcmcjax/kernels.py b/sgmcmcjax/kernels.py index f37b0df..3156158 100644 --- a/sgmcmcjax/kernels.py +++ b/sgmcmcjax/kernels.py @@ -380,10 +380,10 @@ def build_baoab_kernel( estimate_gradient, init_gradient = build_gradient_estimation_fn( grad_log_post, data, batch_size ) - init_diff, (update_diff1, update_diff1), get_p_diff = baoab(dt, gamma, tau) + init_diff, (update_diff1, update_diff2), get_p_diff = baoab(dt, gamma, tau) init_fn, baoab_kernel, get_params = _build_langevin_kernel( init_diff, - (update_diff1, update_diff1), + (update_diff1, update_diff2), get_p_diff, estimate_gradient, init_gradient, @@ -416,10 +416,10 @@ def build_badodab_kernel( estimate_gradient, init_gradient = build_gradient_estimation_fn( grad_log_post, data, batch_size ) - init_diff, (update_diff1, update_diff1), get_p_diff = badodab(dt, a) + init_diff, (update_diff1, update_diff2), get_p_diff = badodab(dt, a) init_fn, baoab_kernel, get_params = _build_langevin_kernel( init_diff, - (update_diff1, update_diff1), + (update_diff1, update_diff2), get_p_diff, estimate_gradient, init_gradient,