Skip to content

Commit

Permalink
Don't be so permissive on tree learning
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrofale committed Apr 25, 2024
1 parent f2af21c commit a1737f7
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions scatrex/scatrex.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,13 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1
self.ntssb.learn_roots(n_epochs, memoized=False, mc_samples=mc_samples, step_size=step_size, return_trace=False)

# Update assignments
# self.ntssb.update_local_params(jax.random.PRNGKey(seed), update_ass=True, update_globals=False)
if update_outer_ass:
self.ntssb.update_local_params(jax.random.PRNGKey(seed), update_ass=True, update_globals=False)

# Learn a tree with root updates on noiseless data (over-cluster) and more permissive prior on tree
# Learn a tree with root updates on noiseless data (over-cluster)
searcher = StructureSearch(self.ntssb)
searcher.tree.set_tssb_params(dp_alpha=1., dp_gamma=1.,)
searcher.tree.set_node_hyperparams(direction_shape=1.)
searcher.tree.set_tssb_params(dp_alpha=.01, dp_gamma=.01,)
searcher.tree.set_node_hyperparams(direction_shape=.1)
searcher.tree.sample_variational_distributions(n_samples=mc_samples)
searcher.tree.reset_sufficient_statistics()
for batch_idx in range(len(searcher.tree.batch_indices)):
Expand Down

0 comments on commit a1737f7

Please sign in to comment.