Skip to content

Commit

Permalink
fast diff
Browse files Browse the repository at this point in the history
  • Loading branch information
thorek1 committed Sep 23, 2023
1 parent 21bbcd7 commit 022d9bf
Showing 1 changed file with 67 additions and 45 deletions.
112 changes: 67 additions & 45 deletions src/MacroModelling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ Base.show(io::IO, 𝓂::ℳ) = println(io,



function jacobian_wrt_values(A::AbstractSparseArray{T}, B::AbstractSparseArray{T}) where T
function jacobian_wrt_values(A, B)
# does this without creating dense arrays: reshape(permutedims(reshape(ℒ.I - ℒ.kron(A, B) ,size(B,1), size(A,1), size(A,1), size(B,1)), [2, 3, 4, 1]), size(A,1) * size(B,1), size(A,1) * size(B,1))

# Compute the Kronecker product and subtract from identity
Expand Down Expand Up @@ -156,7 +156,7 @@ end



function jacobian_wrt_A(A::AbstractSparseArray{T}, X::Matrix{T}) where T
function jacobian_wrt_A(A, X)
# does this without creating dense arrays: reshape(permutedims(reshape(ℒ.I - ℒ.kron(A, B) ,size(B,1), size(A,1), size(A,1), size(B,1)), [2, 3, 4, 1]), size(A,1) * size(B,1), size(A,1) * size(B,1))

# Compute the Kronecker product and subtract from identity
Expand Down Expand Up @@ -849,6 +849,7 @@ function levenberg_marquardt(f::Function,
transformation_level::S = 3,
backtracking_order::S = 2,
) where {T <: AbstractFloat, S <: Integer}
# issues with optimization: https://www.gurobi.com/documentation/8.1/refman/numerics_gurobi_guidelines.html

@assert size(lower_bounds) == size(upper_bounds) == size(initial_guess)
@assert lower_bounds < upper_bounds
Expand Down Expand Up @@ -4385,9 +4386,9 @@ function solve_sylvester_equation_forward(ABC::Vector{Float64};
vB = ABC[lengthA .+ (1:lengthB)]
vC = ABC[lengthA + lengthB + 1:end]

A = sparse(coords[1]...,vA,dims[1]...)# |> ThreadedSparseArrays.ThreadedSparseMatrixCSC
B = sparse(coords[2]...,vB,dims[2]...)# |> ThreadedSparseArrays.ThreadedSparseMatrixCSC
C = sparse(coords[3]...,vC,dims[3]...)# |> ThreadedSparseArrays.ThreadedSparseMatrixCSC
A = sparse(coords[1]...,vA,dims[1]...) |> ThreadedSparseArrays.ThreadedSparseMatrixCSC
B = sparse(coords[2]...,vB,dims[2]...) |> ThreadedSparseArrays.ThreadedSparseMatrixCSC
C = sparse(coords[3]...,vC,dims[3]...) |> ThreadedSparseArrays.ThreadedSparseMatrixCSC
else
lengthA = dims[1][1] * dims[1][2]
A = reshape(ABC[1:lengthA],dims[1]...)
Expand Down Expand Up @@ -4416,8 +4417,8 @@ function solve_sylvester_equation_forward(ABC::Vector{Float64};
elseif solver == :iterative
iter = 1
change = 1
𝐂 = copy(C)
𝐂¹ = copy(C)
𝐂 = C
𝐂¹ = C
while change > eps(Float32) && iter < 10000
𝐂¹ = A * 𝐂 * B - C
if !(A isa DenseMatrix)
Expand Down Expand Up @@ -4514,7 +4515,10 @@ function solve_sylvester_equation_forward(abc::Vector{ℱ.Dual{Z,S,N}};
ABC =.value.(abc)

# you can play with the dimension here, sometimes it makes sense to transpose
partial_values = mapreduce(ℱ.partials, hcat, abc)'
partial_values = zeros(length(abc), N)
for i in 1:N
partial_values[:,i] =.partials.(abc, i)
end

# get f(vs)
val, solved = solve_sylvester_equation_forward(ABC, coords = coords, dims = dims, sparse_output = sparse_output, solver = solver)
Expand All @@ -4527,16 +4531,21 @@ function solve_sylvester_equation_forward(abc::Vector{ℱ.Dual{Z,S,N}};
# C = reshape(ABC[lengthA+1:end],dims[2]...)
droptol!(A,eps())

B = sparse(A')

b = hcat(jacobian_wrt_A(A, -val), ℒ.I(length(val)))
droptol!(b,eps())

a = jacobian_wrt_values(A, B)
droptol!(a,eps())
B = sparse(A') |> ThreadedSparseArrays.ThreadedSparseMatrixCSC

partials = zeros(dims[1][1] * dims[1][2] + dims[2][1] * dims[2][2], size(partial_values,2))
partials[vcat(coords[1][1] + (coords[1][2] .- 1) * dims[1][1], dims[1][1] * dims[1][2] + 1:end),:] = partial_values

reshape_matmul_b = LinearOperators.LinearOperator(Float64, length(val) * size(partials,2), 2*size(A,1)^2 * size(partials,2), false, false,
(sol,𝐱) -> begin
𝐗 = reshape(𝐱, (2* size(A,1)^2,size(partials,2))) |> sparse

b = hcat(jacobian_wrt_A(A, val), -.I(length(val)))
droptol!(b,eps())

sol .= vec(b * 𝐗)
return sol
end)
elseif length(coords) == 3
lengthA = length(coords[1][1])
lengthB = length(coords[2][1])
Expand All @@ -4545,53 +4554,65 @@ function solve_sylvester_equation_forward(abc::Vector{ℱ.Dual{Z,S,N}};
vB = ABC[lengthA .+ (1:lengthB)]
# vC = ABC[lengthA + lengthB + 1:end]

A = sparse(coords[1]...,vA,dims[1]...)# |> ThreadedSparseArrays.ThreadedSparseMatrixCSC
B = sparse(coords[2]...,vB,dims[2]...)# |> ThreadedSparseArrays.ThreadedSparseMatrixCSC
A = sparse(coords[1]...,vA,dims[1]...) |> ThreadedSparseArrays.ThreadedSparseMatrixCSC
B = sparse(coords[2]...,vB,dims[2]...) |> ThreadedSparseArrays.ThreadedSparseMatrixCSC
# C = sparse(coords[3]...,vC,dims[3]...) |> ThreadedSparseArrays.ThreadedSparseMatrixCSC

jacobian_A =.kron(-val * B, ℒ.I(size(A,1)))
jacobian_B =.kron(ℒ.I(size(B,1)), -A * val)

b = hcat(jacobian_A', jacobian_B, ℒ.I(length(val)))
droptol!(b,eps())

a = jacobian_wrt_values(A, B)
droptol!(a,eps())

partials = spzeros(dims[1][1] * dims[1][2] + dims[2][1] * dims[2][2] + dims[3][1] * dims[3][2], size(partial_values,2))
partials[vcat(
coords[1][1] + (coords[1][2] .- 1) * dims[1][1],
coords[2][1] + (coords[2][2] .- 1) * dims[2][1] .+ dims[1][1] * dims[1][2],
coords[3][1] + (coords[3][2] .- 1) * dims[3][1] .+ dims[1][1] * dims[1][2] .+ dims[2][1] * dims[2][2]),:] = partial_values

reshape_matmul_b = LinearOperators.LinearOperator(Float64, length(val) * size(partials,2), (length(A) + length(B) + length(val)) * size(partials,2), false, false,
(sol,𝐱) -> begin
𝐗 = reshape(𝐱, (length(A) + length(B) + length(val), size(partials,2))) |> sparse

jacobian_A =.kron(val * B, ℒ.I(size(A,1)))
jacobian_B =.kron(ℒ.I(size(B,1)), A * val)

b = hcat(jacobian_A', jacobian_B, -.I(length(val)))
droptol!(b,eps())

sol .= vec(b * 𝐗)
return sol
end)
else
lengthA = dims[1][1] * dims[1][2]
A = reshape(ABC[1:lengthA],dims[1]...)
A = reshape(ABC[1:lengthA],dims[1]...) |> sparse
droptol!(A, eps())
# C = reshape(ABC[lengthA+1:end],dims[2]...)
B = A'
# jacobian_A = reshape(permutedims(reshape(ℒ.kron(ℒ.I(size(A,1)), -A * val) ,size(A,1), size(A,1), size(A,1), size(A,1)), [1, 2, 4, 3]), size(A,1) * size(A,1), size(A,1) * size(A,1))
B = sparse(A') |> ThreadedSparseArrays.ThreadedSparseMatrixCSC

spA = sparse(A)
droptol!(spA, eps())
partials = partial_values

b = hcat(jacobian_wrt_A(spA, -val), ℒ.I(length(val)))
reshape_matmul_b = LinearOperators.LinearOperator(Float64, length(val) * size(partials,2), 2*size(A,1)^2 * size(partials,2), false, false,
(sol,𝐱) -> begin
𝐗 = reshape(𝐱, (2* size(A,1)^2,size(partials,2))) |> sparse

a = reshape(permutedims(reshape(ℒ.I -.kron(A, B) ,size(B,1), size(A,1), size(A,1), size(B,1)), [2, 3, 4, 1]), size(A,1) * size(B,1), size(A,1) * size(B,1))
b = hcat(jacobian_wrt_A(A, val), -.I(length(val)))
droptol!(b,eps())

partials = partial_values
sol .= vec(b * 𝐗)
return sol
end)
end


# get J(f, vs) * ps (cheating). Write your custom rule here. This used to be the conditions but here they are analytically derived.
reshape_matmul = LinearOperators.LinearOperator(Float64, size(b,1) * size(partials,2), size(b,1) * size(partials,2), false, false,
reshape_matmul_a = LinearOperators.LinearOperator(Float64, length(val) * size(partials,2), length(val) * size(partials,2), false, false,
(sol,𝐱) -> begin
𝐗 = reshape(𝐱, (size(b,1),size(partials,2))) |> sparse
𝐗 = reshape(𝐱, (length(val),size(partials,2))) |> sparse

a = jacobian_wrt_values(A, B)
droptol!(a,eps())

sol .= vec(a * 𝐗)
return sol
end)

X, info = Krylov.gmres(reshape_matmul, -vec(b * partials))#, atol = tol)
X, info = Krylov.gmres(reshape_matmul_a, vec(reshape_matmul_b * vec(partials)))#, atol = tol)

jvp = reshape(X, (size(b,1),size(partials,2)))
jvp = reshape(X, (length(val), size(partials,2)))

out = reshape(map(val, eachrow(jvp)) do v, p
.Dual{Z}(v, p...) # Z is the tag
Expand Down Expand Up @@ -4875,8 +4896,8 @@ function calculate_second_order_moments(

values = vcat(v1, vec(collect(-C)))

Σᶻ₂, info = solve_sylvester_equation_forward(values, coords = coordinates, dims = dimensions, solver = :doubling)
# Σᶻ₂, info = solve_sylvester_equation_AD(values, coords = coordinates, dims = dimensions, solver = :doubling)
# Σᶻ₂, info = solve_sylvester_equation_forward(values, coords = coordinates, dims = dimensions, solver = :doubling)
Σᶻ₂, info = solve_sylvester_equation_AD(values, coords = coordinates, dims = dimensions, solver = :doubling)
# Σᶻ₂, info = solve_sylvester_equation_AD([vec(ŝ_to_ŝ₂); vec(-C)], dims = [size(ŝ_to_ŝ₂) ;size(C)])#, solver = :doubling)
# Σᶻ₂, info = solve_sylvester_equation_forward([vec(ŝ_to_ŝ₂); vec(-C)], dims = [size(ŝ_to_ŝ₂) ;size(C)])

Expand Down Expand Up @@ -5106,8 +5127,8 @@ function calculate_third_order_moments(parameters::Vector{T},

values = vcat(v1, vec(collect(-C)))

Σᶻ₃, info = solve_sylvester_equation_forward(values, coords = coordinates, dims = dimensions, solver = :doubling)
# Σᶻ₃, info = solve_sylvester_equation_AD(values, coords = coordinates, dims = dimensions, solver = :doubling)
# Σᶻ₃, info = solve_sylvester_equation_forward(values, coords = coordinates, dims = dimensions, solver = :doubling)
Σᶻ₃, info = solve_sylvester_equation_AD(values, coords = coordinates, dims = dimensions, solver = :doubling)

Σʸ₃tmp = ŝ_to_y₃ * Σᶻ₃ * ŝ_to_y₃' + ê_to_y₃ * Γ₃ * ê_to_y₃' + ê_to_y₃ * Eᴸᶻ * ŝ_to_y₃' + ŝ_to_y₃ * Eᴸᶻ' * ê_to_y₃'

Expand All @@ -5124,7 +5145,7 @@ function calculate_third_order_moments(parameters::Vector{T},
ŝ_to_ŝ₃ⁱ = zero(ŝ_to_ŝ₃)
ŝ_to_ŝ₃ⁱ +=.diagm(ones(size(Σᶻ₃,1)))

Σᶻ₃ⁱ = copy(Σᶻ₃)
Σᶻ₃ⁱ = Σᶻ₃

for i in autocorrelation_periods
Σᶻ₃ⁱ .= ŝ_to_ŝ₃ * Σᶻ₃ⁱ + ê_to_ŝ₃ * Eᴸᶻ
Expand Down Expand Up @@ -5218,7 +5239,8 @@ function calculate_kalman_filter_loglikelihood(𝓂::ℳ, data::AbstractArray{Fl

values = vcat(vec(A), vec(collect(-𝐁)))

P, _ = solve_sylvester_equation_forward(values, coords = coordinates, dims = dimensions, solver = :doubling)
P, _ = solve_sylvester_equation_AD(values, coords = coordinates, dims = dimensions, solver = :doubling)
# P, _ = solve_sylvester_equation_forward(values, coords = coordinates, dims = dimensions, solver = :doubling)
# P, _ = solve_sylvester_equation_AD_direct(values, coords = coordinates, dims = dimensions, solver = :doubling)
# P, _ = solve_sylvester_equation_AD_direct([vec(A); vec(-𝐁)], dims = [size(A), size(𝐁)], solver = :bicgstab)
# P, _ = solve_sylvester_equation_forward([vec(A); vec(-CC)], dims = [size(A), size(CC)])
Expand Down

0 comments on commit 022d9bf

Please sign in to comment.