Skip to content

Commit

Permalink
inversion filter finiteness check and inversion feasability check
Browse files Browse the repository at this point in the history
  • Loading branch information
Thore Kockerols authored and Thore Kockerols committed Oct 27, 2024
1 parent cddc808 commit f834332
Showing 1 changed file with 78 additions and 24 deletions.
102 changes: 78 additions & 24 deletions src/filter/inversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ function calculate_inversion_filter_loglikelihood(::Val{:first_order},
return -Inf
end

logabsdets =.logabsdet(jac ./ precision_factor)[1]
logabsdets =.logabsdet(jac)[1]
invjac = inv(jacdecomp)
else
jacdecomp = try.svd(jac)
Expand All @@ -107,7 +107,7 @@ function calculate_inversion_filter_loglikelihood(::Val{:first_order},
return -Inf
end

logabsdets = sum(x -> log(abs(x)), ℒ.svdvals(jac ./ precision_factor))
logabsdets = sum(x -> log(abs(x)), ℒ.svdvals(jac))
invjac = try inv(jacdecomp)
catch
if verbose println("Inversion filter failed") end
Expand All @@ -117,6 +117,8 @@ function calculate_inversion_filter_loglikelihood(::Val{:first_order},

logabsdets *= size(data_in_deviations,2) - presample_periods

if !isfinite(logabsdets) return -Inf end

𝐒obs = 𝐒[cond_var_idx,1:end-T.nExo]

@timeit_debug timer "Loop" begin
Expand All @@ -129,6 +131,7 @@ function calculate_inversion_filter_loglikelihood(::Val{:first_order},

if i > presample_periods
shocks² += sum(abs2,x)
if !isfinite(shocks²) return -Inf end
end

.mul!(state, 𝐒, vcat(state[T.past_not_future_and_mixed_idx], x))
Expand Down Expand Up @@ -187,7 +190,7 @@ function rrule(::typeof(calculate_inversion_filter_loglikelihood),
jac = 𝐒[obs_idx,end-T.nExo+1:end]

if T.nExo == length(observables)
logabsdets =.logabsdet(-jac' ./ precision_factor)[1]
logabsdets =.logabsdet(jac)[1] # ./ precision_factor

jacdecomp =.lu(jac, check = false)

Expand All @@ -198,13 +201,17 @@ function rrule(::typeof(calculate_inversion_filter_loglikelihood),

invjac = inv(jacdecomp)
else
logabsdets = sum(x -> log(abs(x)), ℒ.svdvals(-jac' ./ precision_factor))
logabsdets = sum(x -> log(abs(x)), ℒ.svdvals(jac)) #' ./ precision_factor
jacdecomp =.svd(jac)
invjac = inv(jacdecomp)
end

logabsdets *= size(data_in_deviations,2) - presample_periods

if !isfinite(logabsdets)
return -Inf, x -> NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()
end

@views 𝐒obs = 𝐒[obs_idx,1:end-T.nExo]

for i in axes(data_in_deviations,2)
Expand All @@ -215,6 +222,9 @@ function rrule(::typeof(calculate_inversion_filter_loglikelihood),

if i > presample_periods
shocks² += sum(abs2,x[i])
if !isfinite(shocks²)
return -Inf, x -> NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()
end
end

.mul!(state[i+1], 𝐒, vcat(state[i][t⁻], x[i]))
Expand Down Expand Up @@ -516,6 +526,10 @@ function calculate_inversion_filter_loglikelihood(::Val{:pruned_second_order},
end

shocks² += sum(abs2,x)

if !isfinite(logabsdets) || !isfinite(shocks²)
return -Inf
end
end

# aug_state₁ = [state₁; 1; x]
Expand Down Expand Up @@ -745,6 +759,10 @@ function rrule(::typeof(calculate_inversion_filter_loglikelihood),
end

shocks² += sum(abs2,x[i])

if !isfinite(logabsdets) || !isfinite(shocks²)
return -Inf, x -> NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()
end
end

# aug_state₁[i] = [state₁; 1; x[i]]
Expand Down Expand Up @@ -855,11 +873,14 @@ function rrule(::typeof(calculate_inversion_filter_loglikelihood),
end

# logabsdets += ℒ.logabsdet(jacc ./ precision_factor)[1]
if size(jacc[i], 1) == size(jacc[i], 2)
∂jacc = inv(jacc[i])'
else
∂jacc = inv(ℒ.svd(jacc[i]))'
end
∂jacc = try if size(jacc[i], 1) == size(jacc[i], 2)
inv(jacc[i])'
else
inv(ℒ.svd(jacc[i]))'
end
catch
return NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()
end

# jacc = 𝐒ⁱ + 2 * 𝐒ⁱ²ᵉ * ℒ.kron(ℒ.I(T.nExo), x[1])
# ∂kronIx = 𝐒ⁱ²ᵉ' * ∂jacc
Expand Down Expand Up @@ -1190,6 +1211,10 @@ function calculate_inversion_filter_loglikelihood(::Val{:second_order},
end

shocks² += sum(abs2,x)

if !isfinite(logabsdets) || !isfinite(shocks²)
return -Inf
end
end

# aug_state = [state; 1; x]
Expand Down Expand Up @@ -1417,6 +1442,10 @@ function rrule(::typeof(calculate_inversion_filter_loglikelihood),
end

shocks² += sum(abs2, x[i])

if !isfinite(logabsdets) || !isfinite(shocks²)
return -Inf, x -> NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()
end
end

# aug_state[i] = [state¹⁻; 1; x[i]]
Expand Down Expand Up @@ -1513,11 +1542,14 @@ function rrule(::typeof(calculate_inversion_filter_loglikelihood),
end

# logabsdets += ℒ.logabsdet(jacc ./ precision_factor)[1]
if size(jacc[i], 1) == size(jacc[i], 2)
∂jacc = inv(jacc[i])'
else
∂jacc = inv(ℒ.svd(jacc[i]))'
end
∂jacc = try if size(jacc[i], 1) == size(jacc[i], 2)
inv(jacc[i])'
else
inv(ℒ.svd(jacc[i]))'
end
catch
return NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()
end

# jacc = 𝐒ⁱ + 2 * 𝐒ⁱ²ᵉ * ℒ.kron(ℒ.I(T.nExo), x[1])
.mul!(∂kronIx, 𝐒ⁱ²ᵉ', ∂jacc)
Expand Down Expand Up @@ -1963,6 +1995,10 @@ function calculate_inversion_filter_loglikelihood(::Val{:pruned_third_order},
end

shocks² += sum(abs2,x)

if !isfinite(logabsdets) || !isfinite(shocks²)
return -Inf
end
end

aug_state₁ = [state[1]; 1; x]
Expand Down Expand Up @@ -2236,6 +2272,10 @@ function rrule(::typeof(calculate_inversion_filter_loglikelihood),
end

shocks² += sum(abs2,x[i])

if !isfinite(logabsdets) || !isfinite(shocks²)
return -Inf, x -> NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()
end
end

aug_state₁[i] = [state₁; 1; x[i]]
Expand Down Expand Up @@ -2357,11 +2397,14 @@ function rrule(::typeof(calculate_inversion_filter_loglikelihood),
end

# logabsdets += ℒ.logabsdet(jacc ./ precision_factor)[1]
if size(jacc[i], 1) == size(jacc[i], 2)
∂jacc = inv(jacc[i])'
else
∂jacc = inv(ℒ.svd(jacc[i]))'
end
∂jacc = try if size(jacc[i], 1) == size(jacc[i], 2)
inv(jacc[i])'
else
inv(ℒ.svd(jacc[i]))'
end
catch
return NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()
end

# jacc = 𝐒ⁱ + 2 * 𝐒ⁱ²ᵉ * ℒ.kron(ℒ.I(T.nExo), x) + 3 * 𝐒ⁱ³ᵉ * ℒ.kron(ℒ.I(T.nExo), ℒ.kron(x, x))
# ∂𝐒ⁱ = -∂jacc / 2 # fine
Expand Down Expand Up @@ -2830,6 +2873,10 @@ function calculate_inversion_filter_loglikelihood(::Val{:third_order},
end

shocks² += sum(abs2,x)

if !isfinite(logabsdets) || !isfinite(shocks²)
return -Inf
end
end

aug_state = [state; 1; x]
Expand Down Expand Up @@ -3069,6 +3116,10 @@ function rrule(::typeof(calculate_inversion_filter_loglikelihood),
end

shocks² += sum(abs2,x[i])

if !isfinite(logabsdets) || !isfinite(shocks²)
return -Inf, x -> NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()
end
end

aug_state[i] = [stt; 1; x[i]]
Expand Down Expand Up @@ -3152,11 +3203,14 @@ function rrule(::typeof(calculate_inversion_filter_loglikelihood),
end

# logabsdets += ℒ.logabsdet(jacc ./ precision_factor)[1]
if size(jacc[i], 1) == size(jacc[i], 2)
∂jacc = inv(jacc[i])'
else
∂jacc = inv(ℒ.svd(jacc[i]))'
end
∂jacc = try if size(jacc[i], 1) == size(jacc[i], 2)
inv(jacc[i])'
else
inv(ℒ.svd(jacc[i]))'
end
catch
return NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()
end

# jacc = 𝐒ⁱ + 2 * 𝐒ⁱ²ᵉ * ℒ.kron(ℒ.I(T.nExo), x) + 3 * 𝐒ⁱ³ᵉ * ℒ.kron(ℒ.I(T.nExo), ℒ.kron(x, x))
# ∂𝐒ⁱ = -∂jacc / 2 # fine
Expand Down

0 comments on commit f834332

Please sign in to comment.