Skip to content

Commit

Permalink
add option to hide progress bar in fit_em
Browse files Browse the repository at this point in the history
  • Loading branch information
slinderman committed Aug 25, 2022
1 parent 3eb8074 commit 4d549c1
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions ssm_jax/hmm/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _single_expected_log_joint(emissions, posterior, **covariates):
num_epochs=num_sgd_epochs_per_mstep)
self.unconstrained_params = params

def fit_em(self, batch_emissions, num_iters=50, mstep_kwargs=dict(), **batch_covariates):
def fit_em(self, batch_emissions, num_iters=50, mstep_kwargs=dict(), verbose=True, **batch_covariates):
"""Fit this HMM with Expectation-Maximization (EM).
Args:
batch_emissions (_type_): _description_
Expand All @@ -173,7 +173,8 @@ def em_step(params):

log_probs = []
params = self.unconstrained_params
for _ in trange(num_iters):
pbar = trange(num_iters) if verbose else range(num_iters)
for _ in pbar:
params, lp = em_step(params)
log_probs.append(lp)

Expand Down

0 comments on commit 4d549c1

Please sign in to comment.