diff --git a/scatrex/models/cna/node.py b/scatrex/models/cna/node.py index fda885d..14c947f 100644 --- a/scatrex/models/cna/node.py +++ b/scatrex/models/cna/node.py @@ -12,8 +12,8 @@ from ...utils.math_utils import * from ...ntssb.node import * -MIN_ALPHA = jnp.log(0.01) -MAX_BETA = jnp.log(1./0.001) +MIN_ALPHA = jnp.log(1e-2) +MAX_BETA = jnp.log(1e6) def update_params(params, params_gradient, step_size): new_params = [] diff --git a/scatrex/scatrex.py b/scatrex/scatrex.py index 80ac1e5..42c7133 100644 --- a/scatrex/scatrex.py +++ b/scatrex/scatrex.py @@ -168,15 +168,15 @@ def learn_scales(self, n_epochs=100, mc_samples=10, step_size=0.01): n_cells = self.ntssb.data.shape[0] n_genes = self.ntssb.data.shape[1] gs = np.sqrt(np.mean(self.ntssb.data)) - root = self.ntssb.root['node'].root['node'] + gene_scales_alpha_init = 10. * jnp.ones((n_genes,)) #* jnp.exp(np.random.normal(size=self.n_genes)) - gene_scales_beta_init = 10. * jnp.ones((n_genes,)) * 1./gs * jnp.exp( 0. * np.random.normal(size=n_genes)) + gene_scales_beta_init = 10. * jnp.ones((n_genes,)) * 1./np.mean(self.ntssb.data, axis=0) #* jnp.exp(10. + 0. * np.random.normal(size=n_genes)) root.variational_parameters['global']['gene_scales']['log_alpha'] = jnp.log(gene_scales_alpha_init) root.variational_parameters['global']['gene_scales']['log_beta'] = jnp.log(gene_scales_beta_init) cell_scales_alpha_init = 10. * jnp.ones((n_cells,1)) #* jnp.exp(np.random.normal(size=[500,1])) - cell_scales_beta_init = 10. * jnp.ones((n_cells,1)) * 1./gs * jnp.exp( 0. * np.random.normal(size=[n_cells,1])) + cell_scales_beta_init = 10. * jnp.ones((n_cells,1)) * 1. #* jnp.exp(0. + 0. * np.random.normal(size=[n_cells,1])) root.variational_parameters['local']['cell_scales']['log_alpha'] = jnp.log(cell_scales_alpha_init) root.variational_parameters['local']['cell_scales']['log_beta'] = jnp.log(cell_scales_beta_init)