From 99ea2561f7a75bf92f246f3baf48b559d7d4ff30 Mon Sep 17 00:00:00 2001 From: thorek1 Date: Wed, 4 Oct 2023 14:12:18 +0200 Subject: [PATCH] added cheaper kron multiplication --- src/MacroModelling.jl | 44 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/src/MacroModelling.jl b/src/MacroModelling.jl index f928e392..79c261cb 100644 --- a/src/MacroModelling.jl +++ b/src/MacroModelling.jl @@ -122,6 +122,47 @@ Base.show(io::IO, 𝓂::β„³) = println(io, ) + + +function A_mult_kron_power_3_B(A::AbstractArray{T},B::AbstractArray{T}; tol::AbstractFloat = eps()) where T <: Real + n_row = size(B,1) + n_col = size(B,2) + + BΜ„ = collect(B) + + vals = T[] + rows = Int[] + cols = Int[] + + for row in 1:size(A,1) + idx_mat, vals_mat = A[row,:] |> findnz + + if length(vals_mat) == 0 continue end + + for col in 1:size(B,2)^3 + col_1, col_3 = divrem((col - 1) % (n_col^2), n_col) .+ 1 + col_2 = ((col - 1) Γ· (n_col^2)) + 1 + + mult_val = 0.0 + + for (i,idx) in enumerate(idx_mat) + i_1, i_3 = divrem((idx - 1) % (n_row^2), n_row) .+ 1 + i_2 = ((idx - 1) Γ· (n_row^2)) + 1 + mult_val += vals_mat[i] * BΜ„[i_1,col_1] * BΜ„[i_2,col_2] * BΜ„[i_3,col_3] + end + + if abs(mult_val) > tol + push!(vals,mult_val) + push!(rows,row) + push!(cols,col) + end + end + end + + sparse(rows,cols,vals,size(A,1),size(B,2)^3) +end + + function translate_symbol_to_ascii(x::Symbol) ss = Unicode.normalize(replace(string(x), "β—–" => "__", "β——" => "__"), :NFD) @@ -3827,7 +3868,8 @@ function calculate_third_order_solution(βˆ‡β‚::AbstractMatrix{<: Real}, #first aux = M₃.𝐒𝐏 * βŽΈπ’β‚π’β‚β‚‹β•±πŸβ‚‘βŽΉβ•±π’β‚β•±πŸβ‚‘β‚‹ - 𝐗₃ = -βˆ‡β‚ƒ * β„’.kron(β„’.kron(aux, aux), aux) + # 𝐗₃ = -βˆ‡β‚ƒ * β„’.kron(β„’.kron(aux, aux), aux) + 𝐗₃ = -A_mult_kron_power_3_B(βˆ‡β‚ƒ,aux) tmpkron = β„’.kron(βŽΈπ’β‚π’β‚β‚‹β•±πŸβ‚‘βŽΉβ•±π’β‚β•±πŸβ‚‘β‚‹, β„’.kron(π’β‚β‚Šβ•±πŸŽ, π’β‚β‚Šβ•±πŸŽ) * Mβ‚‚.𝛔) out = - βˆ‡β‚ƒ * tmpkron - βˆ‡β‚ƒ * M₃.𝐏₁ₗ̂ * tmpkron * M₃.𝐏₁ᡣ̃ - βˆ‡β‚ƒ * M₃.𝐏₂ₗ̂ * tmpkron * M₃.𝐏₂ᡣ̃