Skip to content

Commit

Permalink
Bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrofale committed Jun 17, 2024
1 parent 4d3f2d1 commit ef9f89b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions scatrex/models/cna/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def compute_local_priors(self, batch_indices):
obs_weights_contrib = jnp.sum(jnp.mean(mc_obs_weights_logp_val_and_grad(self.obs_weights_sample[:,batch_indices], 0., log_std)[0], axis=0))
log_alpha = jnp.log(self.node_hyperparams['cell_scale_shape'])
log_beta = jnp.log(self.node_hyperparams['cell_scale_shape'] * self.lib_ratio)
cell_scales_contrib = jnp.sum(jnp.mean(mc_cell_scales_logp_val_and_grad(self.cell_scales_sample[batch_indices], log_alpha, log_beta)[0], axis=0))
cell_scales_contrib = jnp.sum(jnp.mean(mc_cell_scales_logp_val_and_grad(self.cell_scales_sample[:,batch_indices], log_alpha, log_beta)[0], axis=0))
return obs_weights_contrib + cell_scales_contrib

def compute_global_entropies(self):
Expand Down Expand Up @@ -1421,4 +1421,4 @@ def update_direction_adaptive(self, direction_params_grad, direction_sample_grad
self.variational_parameters['kernel']['direction']['log_beta'] += step_size * mhat / (jnp.sqrt(vhat) + eps)

states = (state1, state2)
self.direction_states = states
self.direction_states = states

0 comments on commit ef9f89b

Please sign in to comment.