Skip to content

Commit

Permalink
Merge branch 'main' into obc
Browse files Browse the repository at this point in the history
  • Loading branch information
thorek1 authored Oct 27, 2023
2 parents 7ca4314 + b2c4a08 commit 9241112
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ RuntimeGeneratedFunctions = "^0.5"
SpecialFunctions = "^2"
SpeedMapping = "^0.3"
Subscripts = "^0.1"
SymPyPythonCall = "^0.1.1"
SymPyPythonCall = "^0.2"
Symbolics = "^5"
ThreadedSparseArrays = "^0.2"
julia = "1.8"
Expand Down
1 change: 1 addition & 0 deletions docs/src/unfinished_docs/todo.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- [ ] nonlinear conditional forecasts
- [ ] add balanced growth path handling
- [ ] feedback: write out RBC equations, provide option for external SS guess, sell the sampler better (ESS vs dynare), more details on algorithm (SS solver)
- [ ] higher order solutions: some kron matrix mults are later compressed. write custom compressed kron mult; check if sometimes dense mult is faster? (e.g. GNSS2010 seems dense at higher order)
- [ ] recheck function examples and docs (include output description)
- [ ] riccati with analytical derivatives (much faster if sparse) instead of implicit diff
- [ ] add user facing option to choose sylvester solver
Expand Down
33 changes: 19 additions & 14 deletions src/MacroModelling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1527,7 +1527,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo

SS_solve_func = []

atoms_in_equations = Set()
atoms_in_equations = Set{Symbol}()
atoms_in_equations_list = []
relevant_pars_across = []
NSSS_solver_cache_init_tmp = []
Expand Down Expand Up @@ -1568,16 +1568,17 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo
soll = try SPyPyC.solve(eq_to_solve,var_to_solve_for)
catch
end

if isnothing(soll)
# println("Could not solve single variables case symbolically.")
println("Failed finding solution symbolically for: ",var_to_solve_for," in: ",eq_to_solve)
# solve numerically
continue
# elseif PythonCall.pyconvert(Bool,soll[1].is_number)
elseif soll[1].is_number == SPyPyC.TRUE
# ss_equations = ss_equations.subs(var_to_solve_for,soll[1])
ss_equations = [eq.subs(var_to_solve_for,soll[1]) for eq in ss_equations]

elseif soll[1].is_number == true
# ss_equations = ss_equations.subs(var_to_solve,soll[1])
ss_equations = [eq.subs(var_to_solve,soll[1]) for eq in ss_equations]

push!(𝓂.solved_vars,Symbol(var_to_solve_for))
push!(𝓂.solved_vals,Meta.parse(string(soll[1])))
Expand All @@ -1595,8 +1596,10 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo
push!(𝓂.solved_vals,Meta.parse(string(soll[1])))

# atoms = reduce(union,soll[1].atoms())
[push!(atoms_in_equations, a) for a in soll[1].atoms()]

[push!(atoms_in_equations, Symbol(a)) for a in soll[1].atoms()]
push!(atoms_in_equations_list, Set(union(setdiff(get_symbols(parsed_eq_to_solve_for), get_symbols(minmax_fixed_eqs)),Symbol.(soll[1].atoms()))))

# println(atoms_in_equations)
# push!(atoms_in_equations, soll[1].atoms())

Expand Down Expand Up @@ -1635,7 +1638,9 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo
end
numerical_sol = true
# continue
elseif length(intersect(vars_to_solve,reduce(union,map(x->x.atoms(),collect(soll[1]))))) > 0
# elseif length(intersect(vars_to_solve,reduce(union,map(x->x.atoms(),collect(soll[1]))))) > 0
elseif length(intersect((union(SPyPyC.free_symbols.(soll[1])...) .|> SPyPyC.:↓),(vars_to_solve .|> SPyPyC.:↓))) > 0
# elseif length(intersect(union(SPyPyC.free_symbols.(soll[1])...),vars_to_solve)) > 0
if verbose
println("Failed finding solution symbolically for: ",vars_to_solve," in: ",eqs_to_solve,". Solving numerically.")
end
Expand All @@ -1651,7 +1656,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo
# relevant_pars = reduce(union,map(x->x.atoms(),collect(soll[1])))
atoms = reduce(union,map(x->x.atoms(),collect(soll[1])))
# println(atoms)
[push!(atoms_in_equations, a) for a in atoms]
[push!(atoms_in_equations, Symbol(a)) for a in atoms]

for (k, vars) in enumerate(vars_to_solve)
push!(𝓂.solved_vars,Symbol(vars))
Expand Down Expand Up @@ -1679,7 +1684,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo

for i in eqs_to_solve
# push!(syms_in_eqs, Symbol.(PythonCall.pystr.(i.atoms()))...)
push!(syms_in_eqs, Symbol.(SPyPyC.unSym.(SPyPyC.free_symbols(i)))...)
push!(syms_in_eqs, Symbol.(SPyPyC.:↓(SPyPyC.free_symbols(i)))...)
end

# println(syms_in_eqs)
Expand Down Expand Up @@ -1943,7 +1948,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo
parameters_in_equations = []

for (i, parss) in enumerate(𝓂.parameters)
if parss ∈ union(Symbol.(atoms_in_equations),relevant_pars_across)
if parss ∈ union(atoms_in_equations, relevant_pars_across)
push!(parameters_in_equations,:($parss = params[$i]))
end
end
Expand Down Expand Up @@ -1995,7 +2000,7 @@ function solve_steady_state!(𝓂::β„³, symbolic_SS, Symbolics::symbolics; verbo
# fix parameter bounds
par_bounds = []

for varpar in intersect(𝓂.bounded_vars, intersect(𝓂.parameters,union(Symbol.(atoms_in_equations),relevant_pars_across)))
for varpar in intersect(𝓂.bounded_vars, intersect(𝓂.parameters,union(atoms_in_equations, relevant_pars_across)))
i = indexin([varpar],𝓂.bounded_vars)
push!(par_bounds, :($varpar = min(max($varpar,$(𝓂.lower_bounds[i...])),$(𝓂.upper_bounds[i...]))))
end
Expand Down Expand Up @@ -2096,7 +2101,7 @@ function solve_steady_state!(𝓂::β„³; verbose::Bool = false)

SS_solve_func = []

atoms_in_equations = Set()
atoms_in_equations = Set{Symbol}()
atoms_in_equations_list = []
relevant_pars_across = []
NSSS_solver_cache_init_tmp = []
Expand Down Expand Up @@ -2309,7 +2314,7 @@ function solve_steady_state!(𝓂::β„³; verbose::Bool = false)
parameters_in_equations = []

for (i, parss) in enumerate(𝓂.parameters)
if parss ∈ union(Symbol.(atoms_in_equations),relevant_pars_across)
if parss ∈ union(atoms_in_equations, relevant_pars_across)
push!(parameters_in_equations,:($parss = params[$i]))
end
end
Expand Down Expand Up @@ -2351,7 +2356,7 @@ function solve_steady_state!(𝓂::β„³; verbose::Bool = false)
# fix parameter bounds
par_bounds = []

for varpar in intersect(𝓂.bounded_vars, intersect(𝓂.parameters,union(Symbol.(atoms_in_equations),relevant_pars_across)))
for varpar in intersect(𝓂.bounded_vars, intersect(𝓂.parameters,union(atoms_in_equations, relevant_pars_across)))
i = indexin([varpar],𝓂.bounded_vars)
push!(par_bounds, :($varpar = min(max($varpar,$(𝓂.lower_bounds[i...])),$(𝓂.upper_bounds[i...]))))
end
Expand Down
3 changes: 2 additions & 1 deletion test/test_standalone_function.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using SparseArrays
using MacroModelling: timings
using MacroModelling
import MacroModelling: timings
using ForwardDiff, FiniteDifferences, Zygote
import Optim, LineSearches
import LinearAlgebra as β„’
Expand Down

0 comments on commit 9241112

Please sign in to comment.