Skip to content

Commit

Permalink
Update clip
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrofale committed Apr 22, 2024
1 parent bb60567 commit db98972
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions scatrex/models/cna/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from ...utils.math_utils import *
from ...ntssb.node import *

MIN_ALPHA = jnp.log(0.01)
MAX_BETA = jnp.log(1./0.001)
MIN_ALPHA = jnp.log(1e-2)
MAX_BETA = jnp.log(1e6)

def update_params(params, params_gradient, step_size):
new_params = []
Expand Down
6 changes: 3 additions & 3 deletions scatrex/scatrex.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,15 @@ def learn_scales(self, n_epochs=100, mc_samples=10, step_size=0.01):
n_cells = self.ntssb.data.shape[0]
n_genes = self.ntssb.data.shape[1]
gs = np.sqrt(np.mean(self.ntssb.data))

root = self.ntssb.root['node'].root['node']

gene_scales_alpha_init = 10. * jnp.ones((n_genes,)) #* jnp.exp(np.random.normal(size=self.n_genes))
gene_scales_beta_init = 10. * jnp.ones((n_genes,)) * 1./gs * jnp.exp( 0. * np.random.normal(size=n_genes))
gene_scales_beta_init = 10. * jnp.ones((n_genes,)) * 1./np.mean(self.ntssb.data, axis=0) #* jnp.exp(10. + 0. * np.random.normal(size=n_genes))
root.variational_parameters['global']['gene_scales']['log_alpha'] = jnp.log(gene_scales_alpha_init)
root.variational_parameters['global']['gene_scales']['log_beta'] = jnp.log(gene_scales_beta_init)

cell_scales_alpha_init = 10. * jnp.ones((n_cells,1)) #* jnp.exp(np.random.normal(size=[500,1]))
cell_scales_beta_init = 10. * jnp.ones((n_cells,1)) * 1./gs * jnp.exp( 0. * np.random.normal(size=[n_cells,1]))
cell_scales_beta_init = 10. * jnp.ones((n_cells,1)) * 1. #* jnp.exp(0. + 0. * np.random.normal(size=[n_cells,1]))
root.variational_parameters['local']['cell_scales']['log_alpha'] = jnp.log(cell_scales_alpha_init)
root.variational_parameters['local']['cell_scales']['log_beta'] = jnp.log(cell_scales_beta_init)

Expand Down

0 comments on commit db98972

Please sign in to comment.