Skip to content

Commit

Permalink
fix JET
Browse files Browse the repository at this point in the history
  • Loading branch information
thorek1 committed Dec 1, 2024
1 parent 8cb8ab7 commit 38e44af
Showing 1 changed file with 136 additions and 9 deletions.
145 changes: 136 additions & 9 deletions src/moments.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,31 @@ function calculate_mean(parameters::Vector{T},
# return mean_of_variables, 𝐒₁, ∇₁, 𝐒₂, ∇₂, true
end



function calculate_second_order_moments(parameters::Vector{R},
𝓂::ℳ;
covariance::Bool = true,
verbose::Bool = false,
sylvester_algorithm::Symbol = :doubling,
lyapunov_algorithm::Symbol = :doubling,
tol::AbstractFloat = eps()) where R <: Real
calculate_second_order_moments(
parameters,
𝓂,
Val(covariance);
verbose = verbose,
sylvester_algorithm = sylvester_algorithm,
lyapunov_algorithm = lyapunov_algorithm,
tol = tol)
end

function calculate_second_order_moments(
parameters::Vector{R},
𝓂::;
covariance::Bool = true,
𝓂::,
::Val{false}; # covariance;
verbose::Bool = false,
sylvester_algorithm::Symbol = :doubling,
lyapunov_algorithm::Symbol = :doubling,
tol::AbstractFloat = eps())::Union{Tuple{Matrix{R}, Matrix{R}, Vector{R}, Vector{R}, Matrix{R}, Matrix{R}, Matrix{R}, Matrix{R}, Matrix{R}, Vector{R}, Matrix{R}, Matrix{R}, AbstractSparseMatrix{R}, AbstractSparseMatrix{R}, Bool}, Tuple{Vector{R}, Vector{R}, Matrix{R}, Matrix{R}, Vector{R}, Matrix{R}, Matrix{R}, AbstractSparseMatrix{R}, AbstractSparseMatrix{R}, Bool}} where R <: Real
tol::AbstractFloat = eps())::Tuple{Vector{R}, Vector{R}, Matrix{R}, Matrix{R}, Vector{R}, Matrix{R}, Matrix{R}, AbstractSparseMatrix{R}, AbstractSparseMatrix{R}, Bool} where R <: Real

Σʸ₁, 𝐒₁, ∇₁, SS_and_pars, solved = calculate_covariance(parameters, 𝓂, verbose = verbose, lyapunov_algorithm = lyapunov_algorithm)

Expand Down Expand Up @@ -244,10 +258,122 @@ function calculate_second_order_moments(
Δμˢ₂ = vec((ℒ.I - s_to_s₁) \ (s_s_to_s₂ * vec(Σᶻ₁) / 2 + (v_v_to_s₂ + e_e_to_s₂ * vec(ℒ.I(nᵉ))) / 2))
μʸ₂ = SS_and_pars[1:𝓂.timings.nVars] + ŝ_to_y₂ * μˢ⁺₂ + yv₂

if !covariance
return μʸ₂, Δμˢ₂, Σʸ₁, Σᶻ₁, SS_and_pars, 𝐒₁, ∇₁, 𝐒₂, ∇₂, solved && solved2
return μʸ₂, Δμˢ₂, Σʸ₁, Σᶻ₁, SS_and_pars, 𝐒₁, ∇₁, 𝐒₂, ∇₂, (solved && solved2)
end



function calculate_second_order_moments(
parameters::Vector{R},
𝓂::ℳ,
::Val{true}; # covariance
verbose::Bool = false,
sylvester_algorithm::Symbol = :doubling,
lyapunov_algorithm::Symbol = :doubling,
tol::AbstractFloat = eps())::Tuple{Matrix{R}, Matrix{R}, Vector{R}, Vector{R}, Matrix{R}, Matrix{R}, Matrix{R}, Matrix{R}, Matrix{R}, Vector{R}, Matrix{R}, Matrix{R}, AbstractSparseMatrix{R}, AbstractSparseMatrix{R}, Bool} where R <: Real

Σʸ₁, 𝐒₁, ∇₁, SS_and_pars, solved = calculate_covariance(parameters, 𝓂, verbose = verbose, lyapunov_algorithm = lyapunov_algorithm)

nᵉ = 𝓂.timings.nExo

= 𝓂.timings.nPast_not_future_and_mixed

= 𝓂.timings.past_not_future_and_mixed_idx

Σᶻ₁ = Σʸ₁[iˢ, iˢ]

# precalc second order
## mean
I_plus_s_s = sparse(reshape(ℒ.kron(vec(ℒ.I(nˢ)), ℒ.I(nˢ)), nˢ^2, nˢ^2) +.I)

## covariance
E_e⁴ = zeros(nᵉ * (nᵉ + 1)÷2 * (nᵉ + 2)÷3 * (nᵉ + 3)÷4)

quadrup = multiplicate(nᵉ, 4)

comb⁴ = reduce(vcat, generateSumVectors(nᵉ, 4))

comb⁴ = comb⁴ isa Int64 ? reshape([comb⁴],1,1) : comb⁴

for j = 1:size(comb⁴,1)
E_e⁴[j] = product_moments(ℒ.I(nᵉ), 1:nᵉ, comb⁴[j,:])
end

e⁴ = quadrup * E_e⁴

# second order
∇₂ = calculate_hessian(parameters, SS_and_pars, 𝓂)# * 𝓂.solution.perturbation.second_order_auxilliary_matrices.𝐔∇₂

𝐒₂, solved2 = calculate_second_order_solution(∇₁, ∇₂, 𝐒₁,
𝓂.solution.perturbation.second_order_auxilliary_matrices;
T = 𝓂.timings,
tol = tol,
initial_guess = 𝓂.solution.perturbation.second_order_solution,
sylvester_algorithm = sylvester_algorithm,
verbose = verbose)

if eltype(𝐒₂) == Float64 && solved2 𝓂.solution.perturbation.second_order_solution = 𝐒₂ end

𝐒₂ *= 𝓂.solution.perturbation.second_order_auxilliary_matrices.𝐔₂

𝐒₂ = sparse(𝐒₂)

s_in_s⁺ = BitVector(vcat(ones(Bool, nˢ), zeros(Bool, nᵉ + 1)))
e_in_s⁺ = BitVector(vcat(zeros(Bool, nˢ + 1), ones(Bool, nᵉ)))
v_in_s⁺ = BitVector(vcat(zeros(Bool, nˢ), 1, zeros(Bool, nᵉ)))

kron_s_s =.kron(s_in_s⁺, s_in_s⁺)
kron_e_e =.kron(e_in_s⁺, e_in_s⁺)
kron_v_v =.kron(v_in_s⁺, v_in_s⁺)
kron_s_e =.kron(s_in_s⁺, e_in_s⁺)

# first order
s_to_y₁ = 𝐒₁[:, 1:nˢ]
e_to_y₁ = 𝐒₁[:, (nˢ + 1):end]

s_to_s₁ = 𝐒₁[iˢ, 1:nˢ]
e_to_s₁ = 𝐒₁[iˢ, (nˢ + 1):end]


# second order
s_s_to_y₂ = 𝐒₂[:, kron_s_s]
e_e_to_y₂ = 𝐒₂[:, kron_e_e]
v_v_to_y₂ = 𝐒₂[:, kron_v_v]
s_e_to_y₂ = 𝐒₂[:, kron_s_e]

s_s_to_s₂ = 𝐒₂[iˢ, kron_s_s] |> collect
e_e_to_s₂ = 𝐒₂[iˢ, kron_e_e]
v_v_to_s₂ = 𝐒₂[iˢ, kron_v_v] |> collect
s_e_to_s₂ = 𝐒₂[iˢ, kron_s_e]

s_to_s₁_by_s_to_s₁ =.kron(s_to_s₁, s_to_s₁) |> collect
e_to_s₁_by_e_to_s₁ =.kron(e_to_s₁, e_to_s₁)
s_to_s₁_by_e_to_s₁ =.kron(s_to_s₁, e_to_s₁)

# # Set up in pruned state transition matrices
ŝ_to_ŝ₂ = [ s_to_s₁ zeros(nˢ, nˢ +^2)
zeros(nˢ, nˢ) s_to_s₁ s_s_to_s₂ / 2
zeros(nˢ^2, 2*nˢ) s_to_s₁_by_s_to_s₁ ]

ê_to_ŝ₂ = [ e_to_s₁ zeros(nˢ, nᵉ^2 + nᵉ * nˢ)
zeros(nˢ,nᵉ) e_e_to_s₂ / 2 s_e_to_s₂
zeros(nˢ^2,nᵉ) e_to_s₁_by_e_to_s₁ I_plus_s_s * s_to_s₁_by_e_to_s₁]

ŝ_to_y₂ = [s_to_y₁ s_to_y₁ s_s_to_y₂ / 2]

ê_to_y₂ = [e_to_y₁ e_e_to_y₂ / 2 s_e_to_y₂]

ŝv₂ = [ zeros(nˢ)
vec(v_v_to_s₂) / 2 + e_e_to_s₂ / 2 * vec(ℒ.I(nᵉ))
e_to_s₁_by_e_to_s₁ * vec(ℒ.I(nᵉ))]

yv₂ = (vec(v_v_to_y₂) + e_e_to_y₂ * vec(ℒ.I(nᵉ))) / 2

## Mean
μˢ⁺₂ = (ℒ.I - ŝ_to_ŝ₂) \ ŝv₂
Δμˢ₂ = vec((ℒ.I - s_to_s₁) \ (s_s_to_s₂ * vec(Σᶻ₁) / 2 + (v_v_to_s₂ + e_e_to_s₂ * vec(ℒ.I(nᵉ))) / 2))
μʸ₂ = SS_and_pars[1:𝓂.timings.nVars] + ŝ_to_y₂ * μˢ⁺₂ + yv₂

# Covariance
Γ₂ = [ ℒ.I(nᵉ) zeros(nᵉ, nᵉ^2 + nᵉ * nˢ)
zeros(nᵉ^2, nᵉ) reshape(e⁴, nᵉ^2, nᵉ^2) - vec(ℒ.I(nᵉ)) * vec(ℒ.I(nᵉ))' zeros(nᵉ^2, nᵉ * nˢ)
Expand All @@ -265,7 +391,7 @@ function calculate_second_order_moments(

autocorr_tmp = ŝ_to_ŝ₂ * Σᶻ₂ * ŝ_to_y₂' + ê_to_ŝ₂ * Γ₂ * ê_to_y₂'

return Σʸ₂, Σᶻ₂, μʸ₂, Δμˢ₂, autocorr_tmp, ŝ_to_ŝ₂, ŝ_to_y₂, Σʸ₁, Σᶻ₁, SS_and_pars, 𝐒₁, ∇₁, 𝐒₂, ∇₂, solved && solved2 && info
return Σʸ₂, Σᶻ₂, μʸ₂, Δμˢ₂, autocorr_tmp, ŝ_to_ŝ₂, ŝ_to_y₂, Σʸ₁, Σᶻ₁, SS_and_pars, 𝐒₁, ∇₁, 𝐒₂, ∇₂, (solved && solved2 && info)
end


Expand All @@ -286,7 +412,8 @@ function calculate_third_order_moments(parameters::Vector{T},
tol::AbstractFloat = eps()) where {U, T <: Real}

second_order_moments = calculate_second_order_moments(parameters,
𝓂,
𝓂,
Val(true);
verbose = verbose,
sylvester_algorithm = sylvester_algorithm,
lyapunov_algorithm = lyapunov_algorithm)
Expand Down

0 comments on commit 38e44af

Please sign in to comment.