Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Sep 30, 2024
1 parent 26c099f commit 3bb007d
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/HiddenMarkovModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, rrule_via_ad
using DensityInterface: DensityInterface, DensityKind, HasDensity, NoDensity, logdensityof
using DocStringExtensions
using FillArrays: Fill
using LinearAlgebra: Transpose, dot, ldiv!, lmul!, mul!, parent
using LinearAlgebra: Transpose, axpy!, dot, ldiv!, lmul!, mul!, parent
using Random: Random, AbstractRNG, default_rng
using SparseArrays: AbstractSparseArray, SparseMatrixCSC, nonzeros, nnz, nzrange, rowvals
using StatsAPI: StatsAPI, fit, fit!
Expand Down
2 changes: 1 addition & 1 deletion src/types/abstract_hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ function obs_logdensities!(
logb::AbstractVector{T}, hmm::AbstractHMM, obs, control
) where {T}
dists = obs_distributions(hmm, control)
@inbounds @simd for i in eachindex(logb, dists)
@simd for i in eachindex(logb, dists)
logb[i] = logdensityof(dists[i], obs)
end
@argcheck maximum(logb) < typemax(T)
Expand Down
2 changes: 1 addition & 1 deletion src/utils/lightcategorical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function StatsAPI.fit!(
@argcheck 1 <= minimum(x) <= maximum(x) <= length(dist.p)
w_tot = sum(w)
fill!(dist.p, zero(T1))
@inbounds @simd for i in eachindex(x, w)
@simd for i in eachindex(x, w)
dist.p[x[i]] += w[i]
end
dist.p ./= w_tot
Expand Down
8 changes: 4 additions & 4 deletions src/utils/lightdiagnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function DensityInterface.logdensityof(
) where {T1,T2,T3}
l = zero(promote_type(T1, T2, T3, eltype(x)))
l -= sum(dist.logσ) + log2π * length(x) / 2
@inbounds @simd for i in eachindex(x, dist.μ, dist.σ)
@simd for i in eachindex(x, dist.μ, dist.σ)
l -= abs2(x[i] - dist.μ[i]) / (2 * abs2(dist.σ[i]))
end
return l
Expand All @@ -58,11 +58,11 @@ function StatsAPI.fit!(
w_tot = sum(w)
fill!(dist.μ, zero(T1))
fill!(dist.σ, zero(T2))
@inbounds @simd for i in eachindex(x, w)
dist.μ .+= x[i] .* w[i]
@simd for i in eachindex(x, w)
axpy!(w[i], x[i], dist.μ)
end
dist.μ ./= w_tot
@inbounds @simd for i in eachindex(x, w)
@simd for i in eachindex(x, w)
dist.σ .+= abs2.(x[i] .- dist.μ) .* w[i]
end
dist.σ .= sqrt.(dist.σ ./ w_tot)
Expand Down
10 changes: 8 additions & 2 deletions test/correctness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ using Test

TEST_SUITE = get(ENV, "JULIA_HMM_TEST_SUITE", "Standard")

rng = StableRNG(63)

## Settings

T, K = 100, 200
Expand All @@ -31,6 +29,7 @@ p_guess = [[0.7, 0.3], [0.3, 0.7]]

σ = ones(2)

rng = StableRNG(63)
control_seqs = [fill(nothing, rand(rng, T:(2T))) for k in 1:K];
control_seq = reduce(vcat, control_seqs);
seq_ends = cumsum(length.(control_seqs));
Expand All @@ -44,6 +43,7 @@ seq_ends = cumsum(length.(control_seqs));
hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

rng = StableRNG(63)
if TEST_SUITE == "HMMBase"
test_identical_hmmbase(rng, hmm, T; hmm_guess)
else
Expand All @@ -62,6 +62,7 @@ end
hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

rng = StableRNG(63)
if TEST_SUITE == "HMMBase"
test_identical_hmmbase(rng, hmm, T; hmm_guess)
else
Expand All @@ -77,6 +78,7 @@ end
hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

rng = StableRNG(63)
if TEST_SUITE != "HMMBase"
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
Expand All @@ -91,6 +93,7 @@ end
hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

rng = StableRNG(63)
if TEST_SUITE != "HMMBase"
test_coherent_algorithms(rng, hmm, control_seq; seq_ends, hmm_guess, init=false)
test_type_stability(rng, hmm, control_seq; seq_ends, hmm_guess)
Expand All @@ -105,6 +108,7 @@ end
hmm = HMM(init, sparse(trans), dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

rng = StableRNG(63)
if TEST_SUITE == "HMMBase"
test_identical_hmmbase(rng, hmm, T; hmm_guess)
else
Expand All @@ -121,6 +125,7 @@ end
hmm = transpose_hmm(HMM(init, trans, dists))
hmm_guess = transpose_hmm(HMM(init_guess, trans_guess, dists_guess))

rng = StableRNG(63)
if TEST_SUITE == "HMMBase"
test_identical_hmmbase(rng, hmm, T; hmm_guess)
else
Expand All @@ -137,6 +142,7 @@ end
hmm = HMM(init, trans, dists)
hmm_guess = HMM(init_guess, trans_guess, dists_guess)

rng = StableRNG(63)
if TEST_SUITE == "HMMBase"
test_identical_hmmbase(rng, hmm, T; hmm_guess)
else
Expand Down
4 changes: 2 additions & 2 deletions test/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ end
end
@test val_count ./ length(x) p atol = 2e-2
# Fitting
dist_est = deepcopy(dist)
dist_est = LightCategorical(rand_prob_vec(rng, 10))
w = ones(length(x))
fit!(dist_est, x, w)
@test dist_est.p p atol = 2e-2
Expand All @@ -43,7 +43,7 @@ end
@test mean(x) μ atol = 2e-2
@test std(x) σ atol = 2e-2
# Fitting
dist_est = deepcopy(dist)
dist_est = LightDiagNormal(randn(rng, 10), rand(rng, 10))
w = ones(length(x))
fit!(dist_est, x, w)
@test dist_est.μ μ atol = 2e-2
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Pkg.develop(; path=joinpath(dirname(@__DIR__), "libs", "HMMTest"))
end
end

@testset "Correctness - $TEST_SUITE" begin
@testset verbose = true "Correctness - $TEST_SUITE" begin
include("correctness.jl")
end
end

0 comments on commit 3bb007d

Please sign in to comment.