From a1737f79fd97a8d53ab4ca0381be4e7fabec7391 Mon Sep 17 00:00:00 2001 From: pedrofale Date: Thu, 25 Apr 2024 11:19:36 +0200 Subject: [PATCH] Don't be so permissive on tree learning --- scatrex/scatrex.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scatrex/scatrex.py b/scatrex/scatrex.py index b08e1ba..4346533 100644 --- a/scatrex/scatrex.py +++ b/scatrex/scatrex.py @@ -211,12 +211,13 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1 self.ntssb.learn_roots(n_epochs, memoized=False, mc_samples=mc_samples, step_size=step_size, return_trace=False) # Update assignments - # self.ntssb.update_local_params(jax.random.PRNGKey(seed), update_ass=True, update_globals=False) + if update_outer_ass: + self.ntssb.update_local_params(jax.random.PRNGKey(seed), update_ass=True, update_globals=False) - # Learn a tree with root updates on noiseless data (over-cluster) and more permissive prior on tree + # Learn a tree with root updates on noiseless data (over-cluster) searcher = StructureSearch(self.ntssb) - searcher.tree.set_tssb_params(dp_alpha=1., dp_gamma=1.,) - searcher.tree.set_node_hyperparams(direction_shape=1.) + searcher.tree.set_tssb_params(dp_alpha=.01, dp_gamma=.01,) + searcher.tree.set_node_hyperparams(direction_shape=.1) searcher.tree.sample_variational_distributions(n_samples=mc_samples) searcher.tree.reset_sufficient_statistics() for batch_idx in range(len(searcher.tree.batch_indices)):