From 3a79e2a1e6dfd88f0e54aa628a9480517d045e7d Mon Sep 17 00:00:00 2001 From: thorek1 Date: Thu, 12 Oct 2023 11:00:19 +0200 Subject: [PATCH] parse obc --- src/MacroModelling.jl | 59 +++++++++++++++++++++++++++++++++++++++++-- src/macros.jl | 2 ++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/src/MacroModelling.jl b/src/MacroModelling.jl index 257f5786..a17d8b6f 100644 --- a/src/MacroModelling.jl +++ b/src/MacroModelling.jl @@ -601,6 +601,63 @@ function match_pattern(strings::Union{Set,Vector}, pattern::Regex) return filter(r -> match(pattern, string(r)) !== nothing, strings) end +function parse_occasionally_binding_constraints(equations_block; max_obc_shift::Int = 20) + eqs = [] + condition_list = [] + + for arg in equations_block.args + if isa(arg,Expr) + condition = [] + eq = postwalk(x -> + x isa Expr ? + x.head == :call ? + x.args[1] ∈ [:>, :<, :≤, :≥] ? + x.args[2].args[1] == :| ? + begin + condition = Expr(x.head, x.args[1], x.args[2].args[end], x.args[end]) + x.args[2].args[2] + end : + x : + x : + x : + x, + arg) + push!(condition_list, condition) + push!(eqs,eq) + end + end + + obc_shocks = Symbol[] + + for a in condition_list + if a isa Expr + s = get_symbols(a) + for ss in s + push!(obc_shocks,ss) + end + end + end + + eqs_with_obc_shocks = [] + for eq in eqs + eqq = postwalk(x -> + x isa Expr ? + x.head == :ref ? + x.args[1] ∈ obc_shocks ? + begin + obc_shock = intersect([x.args[1]], obc_shocks)[1] + obc_shifts = [Expr(:ref,Meta.parse(string(obc_shock) * "ᵒᵇᶜ⁽⁻"*super(string(i))*"⁾"),:(x-$i)) for i in 0:max_obc_shift] + Expr(:call,:+, x, obc_shifts...) + end : + x : + x : + x, + eq) + push!(eqs_with_obc_shocks, eqq) + end + + return Expr(:block,eqs_with_obc_shocks...), condition_list +end # compatibility with SymPy Max = max @@ -2105,8 +2162,6 @@ function solve_steady_state!(𝓂::ℳ; verbose::Bool = false) end push!(SS_solve_func,:($(dyn_exos...))) - - push!(SS_solve_func, min_max_errors...) # push!(SS_solve_func,:(push!(NSSS_solver_cache_tmp, params_scaled_flt))) push!(SS_solve_func,:(if length(NSSS_solver_cache_tmp) == 0 NSSS_solver_cache_tmp = [params_scaled_flt] else NSSS_solver_cache_tmp = [NSSS_solver_cache_tmp...,params_scaled_flt] end)) diff --git a/src/macros.jl b/src/macros.jl index af234c33..ca34d3b1 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -107,6 +107,8 @@ macro model(𝓂,ex...) model_ex = parse_for_loops(ex[end]) + model_ex, condition_list = parse_occasionally_binding_constraints(model_ex) + # write down dynamic equations and add auxilliary variables for leads and lags > 1 for (i,arg) in enumerate(model_ex.args) if isa(arg,Expr)