Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MethodError Using baum_welch with MvLogNormal Emission Distributions #76

Closed
mcwaga opened this issue Feb 2, 2024 · 3 comments
Closed

Comments

@mcwaga
Copy link

mcwaga commented Feb 2, 2024

Description:
When attempting to estimate the parameters of a Hidden Markov Model using the baum_welch function with multivariate log-normal (MvLogNormal) emission distributions, I encounter a MethodError. The error suggests a mismatch in expected argument types for logpdf.

Code to reproduce:

using HiddenMarkovModels, Distributions, PDMats

# Define variables
r_path = randn(3000)
w_path = randn(3000)
r0 = [r_path w_path]

# Define the HMM
states = 4
p = ones(states) / states
transition_matrix = [0.525 0.35 0.03125 0.09375; 
                    0.038889 0.836111 0.002083 0.122917; 
                    0.09375 0.03125 0.291667 0.583333; 
                    0.009115 0.115885 0.024306 0.850694]

transition_matrix[:,4] = 1.0 .- sum(transition_matrix[:,1:3],dims=2)

# Create HMM and attempt parameter estimation
pre_hmm = HMM(p, transition_matrix, [MvLogNormal(randn(2), cov(r0)) for k=1:states])
hmm_est, logL_evolution = HiddenMarkovModels.baum_welch(pre_hmm, r0)

The function call results in the following error:

ERROR: MethodError: no method matching logpdf(::MvLogNormal{Float64, PDMats.PDMat{Float64, Matrix{Float64}}, Vector{Float64}}, ::Int64)
I am new to HMM, so maybe something is wrong with my approach, but I don't know. Any guidance on correcting the usage or addressing the error would be greatly appreciated.

@gdalle
Copy link
Owner

gdalle commented Feb 2, 2024

Hey @mcwaga,
Here are the successive steps I took to debug your code:

  1. In HiddenMarkovModels.jl, each observation sequence must be a vector, and matrices are not accepted. Granted, this should probably be caught by dispatch instead of causing a weird downstream bug. To fix it, I transformed your observations into a sequence of 2-vectors, one for each of the 3000 time steps:
julia> obs_seq = collect(eachrow(r0));

julia> hmm_est, logL_evolution = HiddenMarkovModels.baum_welch(pre_hmm, obs_seq)
ERROR: TaskFailedException

    nested task error: OverflowError: Some values are NaN
  1. Your observation distributions are log normal, but your observations are not all > 0: no wonder we get NaNs. To remove them, I took the liberty of exponentiating the inputs:
julia> obs_seq_exp = collect(eachrow([exp.(randn(3000)) exp.(randn(3000))]));

julia> hmm_est, logL_evolution = HiddenMarkovModels.baum_welch(pre_hmm, obs_seq_exp)
ERROR: suffstats is not implemented for (MvLogNormal{Float64, PDMat{Float64, Matrix{Float64}}, Vector{Float64}}, Matrix{Float64}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}).
  1. This is a Distributions.jl problem: MLE fitting is only implemented for a few distributions, mostly univariate. Not sure how to circumvent this, but if you know that your observations are log-normal, perhaps the easiest way is to take their log (i.e. keep your original, non-exponentiated inputs) and fit a normal distribution instead?
julia> pre_hmm_normal = HMM(p, transition_matrix, [MvNormal(randn(2), cov(r0)) for k=1:states]);

julia> hmm_est, logL_evolution = HiddenMarkovModels.baum_welch(pre_hmm_normal, obs_seq);

julia> logL_evolution
100-element Vector{Float64}:
 -9889.346931214455
 -8624.38202625008
 -8619.623391367273
...

@gdalle
Copy link
Owner

gdalle commented Feb 3, 2024

Did this solve your troubles? Or can I help out further?

@mcwaga
Copy link
Author

mcwaga commented Feb 3, 2024 via email

@mcwaga mcwaga closed this as completed Feb 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants