diff --git a/ext/HiddenMarkovModelsChainRulesCoreExt.jl b/ext/HiddenMarkovModelsChainRulesCoreExt.jl index 1e716ed1..2a6221bb 100644 --- a/ext/HiddenMarkovModelsChainRulesCoreExt.jl +++ b/ext/HiddenMarkovModelsChainRulesCoreExt.jl @@ -22,7 +22,6 @@ end function ChainRulesCore.rrule( rc::RuleConfig, ::typeof(logdensityof), hmm::AbstractHMM, obs_seq::Vector ) - @info "Chain rule used" (p, A, logB), pullback = rrule_via_ad(rc, _params_and_loglikelihoods, hmm, obs_seq) storage = HMMs.initialize_forward_backward(hmm, obs_seq) HMMs.forward_backward!(storage, hmm, obs_seq) diff --git a/src/inference/baum_welch.jl b/src/inference/baum_welch.jl index ae4de322..e8c867ec 100644 --- a/src/inference/baum_welch.jl +++ b/src/inference/baum_welch.jl @@ -117,7 +117,7 @@ function baum_welch!( loglikelihood_increasing::Bool, ) for _ in 1:max_iterations - for k in eachindex(obs_seqs, fb_storages) + @threads for k in eachindex(obs_seqs, fb_storages) forward_backward!(fb_storages[k], hmm, obs_seqs[k]) end update_sufficient_statistics!(bw_storage, fb_storages)