From 6bbb92499802f983d4f3fa59a3dc0d0ce0498278 Mon Sep 17 00:00:00 2001 From: pedrofale Date: Tue, 23 Apr 2024 01:00:53 +0200 Subject: [PATCH] Reset direction_shape to initial --- scatrex/scatrex.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/scatrex/scatrex.py b/scatrex/scatrex.py index 42c7133..1c48e55 100644 --- a/scatrex/scatrex.py +++ b/scatrex/scatrex.py @@ -169,7 +169,7 @@ def learn_scales(self, n_epochs=100, mc_samples=10, step_size=0.01): n_genes = self.ntssb.data.shape[1] gs = np.sqrt(np.mean(self.ntssb.data)) root = self.ntssb.root['node'].root['node'] - + gene_scales_alpha_init = 10. * jnp.ones((n_genes,)) #* jnp.exp(np.random.normal(size=self.n_genes)) gene_scales_beta_init = 10. * jnp.ones((n_genes,)) * 1./np.mean(self.ntssb.data, axis=0) #* jnp.exp(10. + 0. * np.random.normal(size=n_genes)) root.variational_parameters['global']['gene_scales']['log_alpha'] = jnp.log(gene_scales_alpha_init) @@ -200,6 +200,8 @@ def learn_scales(self, n_epochs=100, mc_samples=10, step_size=0.01): def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=10, memoized=True, mc_samples=10, step_size=0.01, seed=42): logger.info("Learning roots and noise") # Remove noise and learn roots + n_factors = self.ntssb.root['node'].root['node'].node_hyperparams['n_factors'] + direction_shape = self.ntssb.root['node'].root['node'].node_hyperparams['direction_shape'] self.ntssb.set_node_hyperparams(n_factors=0) self.ntssb.root['node'].root['node'].reset_variational_noise_factors() self.ntssb.sample_variational_distributions(n_samples=mc_samples) @@ -221,8 +223,8 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1 memoized=memoized, seed=seed, update_roots=True) self.ntssb = deepcopy(searcher.tree) - self.ntssb.set_node_hyperparams(n_factors=self.model_args['n_factors']) - self.ntssb.set_node_hyperparams(direction_shape=self.model_args['direction_shape']) + self.ntssb.set_node_hyperparams(n_factors=n_factors) + self.ntssb.set_node_hyperparams(direction_shape=direction_shape) self.ntssb.set_tssb_params(dp_alpha=.1, dp_gamma=.1,) self.ntssb.root['node'].root['node'].reset_variational_noise_factors() self.ntssb.sample_variational_distributions(n_samples=mc_samples)