Skip to content

Commit

Permalink
Reset direction_shape to initial
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrofale committed Apr 22, 2024
1 parent 05278c5 commit 6bbb924
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions scatrex/scatrex.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def learn_scales(self, n_epochs=100, mc_samples=10, step_size=0.01):
n_genes = self.ntssb.data.shape[1]
gs = np.sqrt(np.mean(self.ntssb.data))
root = self.ntssb.root['node'].root['node']

gene_scales_alpha_init = 10. * jnp.ones((n_genes,)) #* jnp.exp(np.random.normal(size=self.n_genes))
gene_scales_beta_init = 10. * jnp.ones((n_genes,)) * 1./np.mean(self.ntssb.data, axis=0) #* jnp.exp(10. + 0. * np.random.normal(size=n_genes))
root.variational_parameters['global']['gene_scales']['log_alpha'] = jnp.log(gene_scales_alpha_init)
Expand Down Expand Up @@ -200,6 +200,8 @@ def learn_scales(self, n_epochs=100, mc_samples=10, step_size=0.01):
def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=10, memoized=True, mc_samples=10, step_size=0.01, seed=42):
logger.info("Learning roots and noise")
# Remove noise and learn roots
n_factors = self.ntssb.root['node'].root['node'].node_hyperparams['n_factors']
direction_shape = self.ntssb.root['node'].root['node'].node_hyperparams['direction_shape']
self.ntssb.set_node_hyperparams(n_factors=0)
self.ntssb.root['node'].root['node'].reset_variational_noise_factors()
self.ntssb.sample_variational_distributions(n_samples=mc_samples)
Expand All @@ -221,8 +223,8 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1
memoized=memoized, seed=seed, update_roots=True)

self.ntssb = deepcopy(searcher.tree)
self.ntssb.set_node_hyperparams(n_factors=self.model_args['n_factors'])
self.ntssb.set_node_hyperparams(direction_shape=self.model_args['direction_shape'])
self.ntssb.set_node_hyperparams(n_factors=n_factors)
self.ntssb.set_node_hyperparams(direction_shape=direction_shape)
self.ntssb.set_tssb_params(dp_alpha=.1, dp_gamma=.1,)
self.ntssb.root['node'].root['node'].reset_variational_noise_factors()
self.ntssb.sample_variational_distributions(n_samples=mc_samples)
Expand Down

0 comments on commit 6bbb924

Please sign in to comment.