Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: use the preface API from Symbolics #2217

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,15 +231,12 @@
p = map(x -> time_varying_as_func(value(x), sys), ps)
t = get_iv(sys)

# pre = has_difference ? (ex -> ex) : get_postprocess_fbody(sys)

args = (u, inputs, p, t)
if implicit_dae
ddvs = map(Differential(get_iv(sys)), dvs)
args = (ddvs, args...)
end
process = get_postprocess_fbody(sys)
f = build_function(rhss, args...; postprocess_fbody = process,
f = build_function(rhss, args...; preface = get_preface_vec(sys),

Check warning on line 239 in src/inputoutput.jl

View check run for this annotation

Codecov / codecov/patch

src/inputoutput.jl#L239

Added line #L239 was not covered by tests
expression = Val{false}, kwargs...)
(; f, dvs, ps, io_sys = sys)
end
Expand Down
2 changes: 1 addition & 1 deletion src/structural_transformation/StructuralTransformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
has_tearing_state, defaults, InvalidSystemException,
ExtraEquationsSystemException,
ExtraVariablesSystemException,
get_postprocess_fbody, vars!,
get_preface_vec, vars!,
IncrementalCycleTracker, add_edge_checked!, topological_sort,
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
Expand Down
20 changes: 10 additions & 10 deletions src/structural_transformation/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@
f = Func([DestructuredArgs(vars, inbounds = !checkbounds)
DestructuredArgs(params, inbounds = !checkbounds)],
[],
pre(Let(needed_assignments[inner_idxs],
Let(vcat(pre, needed_assignments[inner_idxs]),
funex,
false))) |> SymbolicUtils.Code.toexpr
false)) |> SymbolicUtils.Code.toexpr

# solver call contains code to call the root-finding solver on the function f
solver_call = LiteralExpr(quote
Expand Down Expand Up @@ -301,9 +301,9 @@
@set! sys.unknown_states = states
syms = map(Symbol, states)

pre = get_postprocess_fbody(sys)
pre = get_preface_vec(sys)

Check warning on line 304 in src/structural_transformation/codegen.jl

View check run for this annotation

Codecov / codecov/patch

src/structural_transformation/codegen.jl#L304

Added line #L304 was not covered by tests
cpre = get_preprocess_constants(rhss)
pre2 = x -> pre(cpre(x))
pre2 = vcat(pre, cpre)

Check warning on line 306 in src/structural_transformation/codegen.jl

View check run for this annotation

Codecov / codecov/patch

src/structural_transformation/codegen.jl#L306

Added line #L306 was not covered by tests

expr = SymbolicUtils.Code.toexpr(Func([out
DestructuredArgs(states,
Expand All @@ -312,10 +312,10 @@
inbounds = !checkbounds)
independent_variables(sys)],
[],
pre2(Let([torn_expr;
Let([pre2; torn_expr;
assignments[is_not_prepended_assignment]],
funbody,
false))),
false)),
sol_states)
if expression
expr, states
Expand Down Expand Up @@ -490,17 +490,17 @@
pre = get_postprocess_fbody(sys)
cpre = get_preprocess_constants([obs[1:maxidx];
isscalar ? ts[1] : MakeArray(ts, output_type)])
pre2 = x -> pre(cpre(x))
pre2 = vcat(pre, cpre)

Check warning on line 493 in src/structural_transformation/codegen.jl

View check run for this annotation

Codecov / codecov/patch

src/structural_transformation/codegen.jl#L493

Added line #L493 was not covered by tests
ex = Code.toexpr(Func([DestructuredArgs(unknown_states, inbounds = !checkbounds)
DestructuredArgs(parameters(sys), inbounds = !checkbounds)
independent_variables(sys)],
[],
pre2(Let([collect(Iterators.flatten(solves))
Let(vcat(pre2, [collect(Iterators.flatten(solves))
assignments[is_not_prepended_assignment]
map(eq -> eq.lhs ← eq.rhs, obs[1:maxidx])
subs],
subs]),
isscalar ? ts[1] : MakeArray(ts, output_type),
false))), sol_states)
false)), sol_states)

expression ? ex : drop_expr(@RuntimeGeneratedFunction(ex))
end
Expand Down
4 changes: 2 additions & 2 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = getexpr,
wrap_code = add_integrator_header(integ, outvar),
outputidxs = update_inds,
postprocess_fbody = pre,
preface = pre,
kwargs...)
# applied user-provided function to the generated expression
if postprocess_affect_expr! !== nothing
Expand Down Expand Up @@ -389,7 +389,7 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = states
t = get_iv(sys)
pre = get_preprocess_constants(rhss)
rf_oop, rf_ip = build_function(rhss, u, p, t; expression = Val{false},
postprocess_fbody = pre, kwargs...)
preface = pre, kwargs...)

affect_functions = map(cbs) do cb # Keep affect function separate
eq_aff = affects(cb)
Expand Down
16 changes: 8 additions & 8 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@
simplify = false, kwargs...)
tgrad = calculate_tgrad(sys, simplify = simplify)
pre = get_preprocess_constants(tgrad)
return build_function(tgrad, dvs, ps, get_iv(sys); postprocess_fbody = pre, kwargs...)
return build_function(tgrad, dvs, ps, get_iv(sys); preface = pre, kwargs...)

Check warning on line 87 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L87

Added line #L87 was not covered by tests
end

function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
simplify = false, sparse = false, kwargs...)
jac = calculate_jacobian(sys; simplify = simplify, sparse = sparse)
pre = get_preprocess_constants(jac)
return build_function(jac, dvs, ps, get_iv(sys); postprocess_fbody = pre, kwargs...)
return build_function(jac, dvs, ps, get_iv(sys); preface = pre, kwargs...)

Check warning on line 94 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L94

Added line #L94 was not covered by tests
end

function generate_control_jacobian(sys::AbstractODESystem, dvs = states(sys),
Expand All @@ -113,7 +113,7 @@
jac = ˍ₋gamma * jac_du + jac_u
pre = get_preprocess_constants(jac)
return build_function(jac, derivatives, dvs, ps, ˍ₋gamma, get_iv(sys);
postprocess_fbody = pre, kwargs...)
preface = pre, kwargs...)
end

function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
Expand Down Expand Up @@ -148,11 +148,11 @@
no_postprocess = has_difference)

if implicit_dae
build_function(rhss, ddvs, u, p, t; postprocess_fbody = pre,
build_function(rhss, ddvs, u, p, t; preface = pre,

Check warning on line 151 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L151

Added line #L151 was not covered by tests
states = sol_states,
kwargs...)
else
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
build_function(rhss, u, p, t; preface = pre, states = sol_states,
kwargs...)
end
end
Expand Down Expand Up @@ -215,11 +215,11 @@
d.update ? eq.rhs : eq.rhs + v
end

pre = get_postprocess_fbody(sys)
pre = get_preface_vec(sys)

Check warning on line 218 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L218

Added line #L218 was not covered by tests
cpre = get_preprocess_constants(body)
pre2 = x -> pre(cpre(x))
pre2 = vcat(pre, cpre)

Check warning on line 220 in src/systems/diffeqs/abstractodesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/abstractodesystem.jl#L220

Added line #L220 was not covered by tests
f_oop, f_iip = build_function(body, u, p, t; expression = Val{false},
postprocess_fbody = pre2, kwargs...)
preface = pre2, kwargs...)

cb_affect! = let f_oop = f_oop, f_iip = f_iip
function cb_affect!(integ)
Expand Down
6 changes: 3 additions & 3 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -397,12 +397,12 @@
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
args = [dvs, ipts, ps, ivs...]
end
pre = get_postprocess_fbody(sys)
pre = get_preface_vec(sys)

Check warning on line 400 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L400

Added line #L400 was not covered by tests

ex = Func(args, [],
pre(Let(obsexprs,
Let(vcat(pre, obsexprs),
isscalar ? ts[1] : MakeArray(ts, output_type),
false))) |> toexpr
false)) |> toexpr
expression ? ex : drop_expr(@RuntimeGeneratedFunction(ex))
end

Expand Down
2 changes: 1 addition & 1 deletion src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@

build_function(rhss, u, p, t; kwargs...)
pre, sol_states = get_substitutions_and_solved_states(sys)
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states, kwargs...)
build_function(rhss, u, p, t; preface = pre, states = sol_states, kwargs...)

Check warning on line 300 in src/systems/discrete_system/discrete_system.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/discrete_system/discrete_system.jl#L300

Added line #L300 was not covered by tests
end

"""
Expand Down
6 changes: 3 additions & 3 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
sparse = false, simplify = false, kwargs...)
jac = calculate_jacobian(sys, sparse = sparse, simplify = simplify)
pre = get_preprocess_constants(jac)
return build_function(jac, vs, ps; postprocess_fbody = pre, kwargs...)
return build_function(jac, vs, ps; preface = pre, kwargs...)

Check warning on line 165 in src/systems/nonlinear/nonlinearsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/nonlinear/nonlinearsystem.jl#L165

Added line #L165 was not covered by tests
end

function calculate_hessian(sys::NonlinearSystem; sparse = false, simplify = false)
Expand All @@ -180,15 +180,15 @@
sparse = false, simplify = false, kwargs...)
hess = calculate_hessian(sys, sparse = sparse, simplify = simplify)
pre = get_preprocess_constants(hess)
return build_function(hess, vs, ps; postprocess_fbody = pre, kwargs...)
return build_function(hess, vs, ps; preface = pre, kwargs...)

Check warning on line 183 in src/systems/nonlinear/nonlinearsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/nonlinear/nonlinearsystem.jl#L183

Added line #L183 was not covered by tests
end

function generate_function(sys::NonlinearSystem, dvs = states(sys), ps = parameters(sys);
kwargs...)
rhss = [deq.rhs for deq in equations(sys)]
pre, sol_states = get_substitutions_and_solved_states(sys)

return build_function(rhss, value.(dvs), value.(ps); postprocess_fbody = pre,
return build_function(rhss, value.(dvs), value.(ps); preface = pre,

Check warning on line 191 in src/systems/nonlinear/nonlinearsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/nonlinear/nonlinearsystem.jl#L191

Added line #L191 was not covered by tests
states = sol_states, kwargs...)
end

Expand Down
2 changes: 1 addition & 1 deletion src/systems/optimization/constraints_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@
lhss = generate_canonical_form_lhss(sys)
pre, sol_states = get_substitutions_and_solved_states(sys)

func = build_function(lhss, value.(dvs), value.(ps); postprocess_fbody = pre,
func = build_function(lhss, value.(dvs), value.(ps); preface = pre,

Check warning on line 183 in src/systems/optimization/constraints_system.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/optimization/constraints_system.jl#L183

Added line #L183 was not covered by tests
states = sol_states, kwargs...)

cstr = constraints(sys)
Expand Down
4 changes: 2 additions & 2 deletions src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
kwargs...)
grad = calculate_gradient(sys)
pre = get_preprocess_constants(grad)
return build_function(grad, vs, ps; postprocess_fbody = pre,
return build_function(grad, vs, ps; preface = pre,

Check warning on line 126 in src/systems/optimization/optimizationsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/optimization/optimizationsystem.jl#L126

Added line #L126 was not covered by tests
conv = AbstractSysToExpr(sys), kwargs...)
end

Expand All @@ -139,7 +139,7 @@
hess = calculate_hessian(sys)
end
pre = get_preprocess_constants(hess)
return build_function(hess, vs, ps; postprocess_fbody = pre,
return build_function(hess, vs, ps; preface = pre,

Check warning on line 142 in src/systems/optimization/optimizationsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/optimization/optimizationsystem.jl#L142

Added line #L142 was not covered by tests
conv = AbstractSysToExpr(sys), kwargs...)
end

Expand Down
28 changes: 9 additions & 19 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -572,20 +572,11 @@
"""
function get_preprocess_constants(eqs)
cs = collect_constants(eqs)
pre = ex -> Let(Assignment[Assignment(x, getdefault(x)) for x in cs],
ex, false)
return pre
Assignment[Assignment(x, getdefault(x)) for x in cs]
end

function get_postprocess_fbody(sys)
if has_preface(sys) && (pre = preface(sys); pre !== nothing)
pre_ = let pre = pre
ex -> Let(pre, ex, false)
end
else
pre_ = ex -> ex
end
return pre_
function get_preface_vec(sys)

Check warning on line 578 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L578

Added line #L578 was not covered by tests
has_preface(sys) && (pre = preface(sys); pre !== nothing) ? pre : []
end

"""
Expand Down Expand Up @@ -629,7 +620,7 @@
cmap, cs = get_cmap(sys)
if empty_substitutions(sys) && isempty(cs)
sol_states = Code.LazyState()
pre = no_postprocess ? (ex -> ex) : get_postprocess_fbody(sys)
assignments = no_postprocess ? [] : get_preface_vec(sys)
else # Have to do some work
if !empty_substitutions(sys)
@unpack subs = get_substitutions(sys)
Expand All @@ -639,15 +630,14 @@
subs = [cmap; subs] # The constants need to go first
sol_states = Code.NameState(Dict(eq.lhs => Symbol(eq.lhs) for eq in subs))
if no_postprocess
pre = ex -> Let(Assignment[Assignment(eq.lhs, eq.rhs) for eq in subs], ex,
false)
assignments = Assignment[Assignment(eq.lhs, eq.rhs) for eq in subs]

Check warning on line 633 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L633

Added line #L633 was not covered by tests
else
process = get_postprocess_fbody(sys)
pre = ex -> Let(Assignment[Assignment(eq.lhs, eq.rhs) for eq in subs],
process(ex), false)
assignments = vcat(get_preface_vec(sys),
Assignment[Assignment(eq.lhs, eq.rhs)
for eq in subs])
end
end
return pre, sol_states
return assignments, sol_states
end

function mergedefaults(defaults, varmap, vars)
Expand Down
Loading