From ef9f89b8005ea3c3f922a39e1305e0c37b6b9060 Mon Sep 17 00:00:00 2001 From: pedrofale Date: Mon, 17 Jun 2024 23:27:59 +0100 Subject: [PATCH] Bugfix --- scatrex/models/cna/node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scatrex/models/cna/node.py b/scatrex/models/cna/node.py index 720b49e..dc25ca1 100644 --- a/scatrex/models/cna/node.py +++ b/scatrex/models/cna/node.py @@ -680,7 +680,7 @@ def compute_local_priors(self, batch_indices): obs_weights_contrib = jnp.sum(jnp.mean(mc_obs_weights_logp_val_and_grad(self.obs_weights_sample[:,batch_indices], 0., log_std)[0], axis=0)) log_alpha = jnp.log(self.node_hyperparams['cell_scale_shape']) log_beta = jnp.log(self.node_hyperparams['cell_scale_shape'] * self.lib_ratio) - cell_scales_contrib = jnp.sum(jnp.mean(mc_cell_scales_logp_val_and_grad(self.cell_scales_sample[batch_indices], log_alpha, log_beta)[0], axis=0)) + cell_scales_contrib = jnp.sum(jnp.mean(mc_cell_scales_logp_val_and_grad(self.cell_scales_sample[:,batch_indices], log_alpha, log_beta)[0], axis=0)) return obs_weights_contrib + cell_scales_contrib def compute_global_entropies(self): @@ -1421,4 +1421,4 @@ def update_direction_adaptive(self, direction_params_grad, direction_sample_grad self.variational_parameters['kernel']['direction']['log_beta'] += step_size * mhat / (jnp.sqrt(vhat) + eps) states = (state1, state2) - self.direction_states = states \ No newline at end of file + self.direction_states = states