Skip to content

Commit

Permalink
transform obc s into correct form
Browse files Browse the repository at this point in the history
  • Loading branch information
thorek1 committed Nov 10, 2023
1 parent 55f87f7 commit c7ceda1
Showing 1 changed file with 112 additions and 1 deletion.
113 changes: 112 additions & 1 deletion src/MacroModelling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,111 @@ function check_for_dynamic_variables(ex::Expr)
end


function transform_expression(expr::Expr)
# Dictionary to store the transformations for reversing
reverse_transformations = Dict{Symbol, Expr}()

# Counter for generating unique placeholders
unique_counter = Ref(0)

# Step 1: Replace min/max calls and record their original form
function replace_min_max(expr)
if expr isa Expr && expr.head == :call && (expr.args[1] == :min || expr.args[1] == :max)
# Replace min/max functions with a placeholder
# placeholder = Symbol("minimal__P", unique_counter[])
placeholder = :minmax__P
unique_counter[] += 1

# Store the original min/max call for reversal
reverse_transformations[placeholder] = expr

return placeholder
else
return expr
end
end

# Step 2: Transform :ref fields in the rest of the expression
function transform_ref_fields(expr)
if expr isa Expr && expr.head == :ref && isa(expr.args[1], Symbol)
# Handle :ref expressions
if isa(expr.args[2], Number) || isa(expr.args[2], Symbol)
new_symbol = Symbol(expr.args[1], "_", expr.args[2])
else
# Generate a unique placeholder for complex :ref
unique_counter[] += 1
placeholder = Symbol("__placeholder", unique_counter[])
new_symbol = placeholder
end

# Record the reverse transformation
reverse_transformations[new_symbol] = expr

return new_symbol
else
return expr
end
end


# Replace equality sign with minus
function replace_equality_with_minus(expr)
if expr isa Expr && expr.head == :(=)
return Expr(:call, :-, expr.args...)
else
return expr
end
end

# Apply transformations
expr = postwalk(replace_min_max, expr)
expr = postwalk(transform_ref_fields, expr)
transformed_expr = postwalk(replace_equality_with_minus, expr)

return transformed_expr, reverse_transformations
end


function reverse_transformation(transformed_expr::Expr, reverse_dict::Dict{Symbol, Expr})
# Function to replace the transformed symbols with their original form
function revert_symbol(expr)
if expr isa Symbol && haskey(reverse_dict, expr)
return reverse_dict[expr]
else
return expr
end
end

# Revert the expression using postwalk
reverted_expr = postwalk(revert_symbol, transformed_expr)

return reverted_expr
end


function transform_obc(ex::Expr)
transformed_expr, reverse_dict = transform_expression(ex)

for symbs in get_symbols(transformed_expr)
eval(:($symbs = SPyPyC.symbols($(string(symbs)), real = true, finite = true)))
end

eq = eval(transformed_expr)

soll = try SPyPyC.solve(eq, minmax__P)
catch
end

if length(soll) == 1
sorted_minmax = Expr(:call, reverse_dict[:minmax__P].args[1], :($(reverse_dict[:minmax__P].args[2]) - $(Meta.parse(string(soll[1])))), :($(reverse_dict[:minmax__P].args[3]) - $(Meta.parse(string(soll[1])))))
return reverse_transformation(sorted_minmax, reverse_dict)
else
@error "Occasionally binding constraint not well-defined. See documentation for examples."
end
end



function set_up_obc_violation_function!(𝓂)
future_varss = collect(reduce(union,match_pattern.(get_symbols.(𝓂.dyn_equations),r"β‚β‚β‚Ž$")))
present_varss = collect(reduce(union,match_pattern.(get_symbols.(𝓂.dyn_equations),r"β‚β‚€β‚Ž$")))
Expand Down Expand Up @@ -863,6 +968,12 @@ function parse_occasionally_binding_constraints(equations_block; max_obc_shift::

for arg in equations_block.args
if isa(arg,Expr)
if check_for_minmax(arg)
arg_trans = transform_obc(arg)
else
arg_trans = arg
end

eq = postwalk(x ->
x isa Expr ?
x.head == :call ?
Expand Down Expand Up @@ -924,7 +1035,7 @@ function parse_occasionally_binding_constraints(equations_block; max_obc_shift::
x :
x :
x,
arg)
arg_trans)

push!(eqs, eq)
end
Expand Down

0 comments on commit c7ceda1

Please sign in to comment.