Skip to content

Commit

Permalink
Merge pull request #857 from AayushSabharwal/as/get-nlsolve
Browse files Browse the repository at this point in the history
feat: add fields to `OverrideInit`, better `nlsolve_alg` handling
  • Loading branch information
ChrisRackauckas authored Nov 18, 2024
2 parents 169d419 + 7f93fb2 commit 258063b
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 62 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ jobs:
with:
file: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true
fail_ci_if_error: false
1 change: 0 additions & 1 deletion src/ODE_nlsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,3 @@ struct ODE_NLProbData{NLProb, UNLProb, NLProbMap, NLProbPmap}
"""
nlprobpmap::NLProbPmap
end

13 changes: 11 additions & 2 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import CommonSolve: solve, init, step!, solve!
import FunctionWrappersWrappers
import RuntimeGeneratedFunctions
import EnumX
import ADTypes: AbstractADType
import ADTypes: ADTypes, AbstractADType
import Accessors: @set, @reset
using Expronicon.ADT: @match

Expand Down Expand Up @@ -351,7 +351,16 @@ struct CheckInit <: DAEInitializationAlgorithm end
"""
$(TYPEDEF)
"""
struct OverrideInit <: DAEInitializationAlgorithm end
struct OverrideInit{T1, T2, F} <: DAEInitializationAlgorithm
abstol::T1
reltol::T2
nlsolve::F
end

function OverrideInit(; abstol = nothing, reltol = nothing, nlsolve = nothing)
OverrideInit(abstol, reltol, nlsolve)
end
OverrideInit(abstol) = OverrideInit(; abstol = abstol, nlsolve = nothing)

# PDE Discretizations

Expand Down
101 changes: 64 additions & 37 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,26 @@ function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm)
"OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.")
end

struct OverrideInitNoTolerance <: Exception
tolerance::Symbol
end

function Base.showerror(io::IO, e::OverrideInitNoTolerance)
print(io,
"Tolerances were not provided to `OverrideInit`. `$(e.tolerance)` must be provided as a keyword argument to `get_initial_values` or as a keyword argument to the `OverrideInit` constructor.")
end

"""
Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if
Utility function to evaluate the RHS, using the integrator's `tmp_cache` if
it is in-place or simply calling the function if not.
"""
function _evaluate_f_ode(integrator, f, isinplace::Val{true}, args...)
function _evaluate_f(integrator, f, isinplace::Val{true}, args...)
tmp = first(get_tmp_cache(integrator))
f(tmp, args...)
return tmp
end

function _evaluate_f_ode(integrator, f, isinplace::Val{false}, args...)
function _evaluate_f(integrator, f, isinplace::Val{false}, args...)
return f(args...)
end

Expand All @@ -98,53 +107,49 @@ _vec(v::AbstractVector) = v
Check if the algebraic constraints are satisfied, and error if they aren't. Returns
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
`AbstractDEProblem` and `AbstractDAEProblem`. Requires a `DEIntegrator` as its second argument.
Keyword arguments:
- `abstol`: The absolute value below which the norm of the residual of algebraic equations
should lie. The norm function used is `integrator.opts.internalnorm` if present, and
`LinearAlgebra.norm` if not.
"""
function get_initial_values(prob::AbstractODEProblem, integrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
function get_initial_values(
prob::AbstractDEProblem, integrator::DEIntegrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; abstol, kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
t = current_time(integrator)
M = f.mass_matrix

algebraic_vars = [all(iszero, x) for x in eachcol(M)]
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true
update_coefficients!(M, u0, p, t)
tmp = _evaluate_f_ode(integrator, f, isinplace, u0, p, t)
tmp = _evaluate_f(integrator, f, isinplace, u0, p, t)
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))

normresid = integrator.opts.internalnorm(tmp, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
normresid = isdefined(integrator.opts, :internalnorm) ?
integrator.opts.internalnorm(tmp, t) : norm(tmp)
if normresid > abstol
throw(CheckInitFailureError(normresid, abstol))
end
return u0, p, true
end

"""
Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if
it is in-place or simply calling the function if not.
"""
function _evaluate_f_dae(integrator, f, isinplace::Val{true}, args...)
tmp = get_tmp_cache(integrator)[2]
f(tmp, args...)
return tmp
end

function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...)
return f(args...)
end

function get_initial_values(prob::AbstractDAEProblem, integrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
function get_initial_values(
prob::AbstractDAEProblem, integrator::DEIntegrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; abstol, kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
t = current_time(integrator)

resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t)
normresid = integrator.opts.internalnorm(resid, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
resid = _evaluate_f(integrator, f, isinplace, integrator.du, u0, p, t)
normresid = isdefined(integrator.opts, :internalnorm) ?
integrator.opts.internalnorm(resid, t) : norm(resid)

if normresid > abstol
throw(CheckInitFailureError(normresid, abstol))
end
return u0, p, true
end
Expand All @@ -155,12 +160,19 @@ end
Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and
`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`.
If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is.
The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword
argument, failing which this function will throw an error. The success value returned
depends on the success of the nonlinear solve.
The success value returned depends on the success of the nonlinear solve.
Keyword arguments:
- `nlsolve_alg`: The NonlinearSolve.jl algorithm to use. If not provided, this function will
throw an error.
- `abstol`, `reltol`: The `abstol` (`reltol`) to use for the nonlinear solve. The value
provided to the `OverrideInit` constructor takes priority over this keyword argument.
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
an error will be thrown.
"""
function get_initial_values(prob, valp, f, alg::OverrideInit,
isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
u0 = state_values(valp)
p = parameter_values(valp)

Expand All @@ -171,15 +183,30 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
initdata::OverrideInitData = f.initialization_data
initprob = initdata.initializeprob

if nlsolve_alg === nothing
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
if nlsolve_alg === nothing && state_values(initprob) !== nothing
throw(OverrideInitMissingAlgorithm())
end

if initdata.update_initializeprob! !== nothing
initdata.update_initializeprob!(initprob, valp)
end

nlsol = solve(initprob, nlsolve_alg)
if alg.abstol !== nothing
_abstol = alg.abstol
elseif abstol !== nothing
_abstol = abstol
else
throw(OverrideInitNoTolerance(:abstol))
end
if alg.reltol !== nothing
_reltol = alg.reltol
elseif reltol !== nothing
_reltol = reltol
else
throw(OverrideInitNoTolerance(:reltol))
end
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)

u0 = initdata.initializeprobmap(nlsol)
if initdata.initializeprobpmap !== nothing
Expand Down
14 changes: 9 additions & 5 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,8 @@ numerically-defined functions.
"""
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
O, TCV,
SYS, ID<:Union{Nothing, OverrideInitData}, NLP<:Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip}
SYS, ID <: Union{Nothing, OverrideInitData}, NLP <: Union{Nothing, ODE_NLProbData}} <:
AbstractODEFunction{iip}
f::F
mass_matrix::TMM
analytic::Ta
Expand Down Expand Up @@ -522,7 +523,8 @@ information on generating the SplitFunction from this symbolic engine.
"""
struct SplitFunction{
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
TPJ, O, TCV, SYS, ID<:Union{Nothing, OverrideInitData}, NLP<:Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip}
TPJ, O, TCV, SYS, ID <: Union{Nothing, OverrideInitData},
NLP <: Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip}
f1::F1
f2::F2
mass_matrix::TMM
Expand Down Expand Up @@ -2442,7 +2444,7 @@ function ODEFunction{iip, specialize}(f;
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
initialization_data = __has_initialization_data(f) ? f.initialization_data :
nothing,
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing,
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing
) where {iip,
specialize
}
Expand Down Expand Up @@ -2500,7 +2502,8 @@ function ODEFunction{iip, specialize}(f;
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
Any,
typeof(_colorvec),
typeof(sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(_f, mass_matrix, analytic, tgrad, jac,
typeof(sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(
_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initdata, nlprob_data)
Expand Down Expand Up @@ -2770,7 +2773,8 @@ function SplitFunction{iip, specialize}(f1, f2;
if specialize === NoSpecialize
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any,
Any, Any, Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(f1, f2, mass_matrix, _func_cache,
Any, Any, Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(
f1, f2, mass_matrix, _func_cache,
analytic,
tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
sparsity, Wfact, Wfact_t, paramjac,
Expand Down
61 changes: 61 additions & 0 deletions test/downstream/initialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using OrdinaryDiffEq, Sundials, SciMLBase, Test

@testset "CheckInit" begin
abstol = 1e-10
@testset "Sundials + ODEProblem" begin
function rhs(u, p, t)
return [u[1] * t, u[1]^2 - u[2]^2]
end
function rhs!(du, u, p, t)
du[1] = u[1] * t
du[2] = u[1]^2 - u[2]^2
end

oopfn = ODEFunction{false}(rhs, mass_matrix = [1 0; 0 0])
iipfn = ODEFunction{true}(rhs!, mass_matrix = [1 0; 0 0])

@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0))
integ = init(prob, Sundials.ARKODE())
u0, _, success = SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
@test success
@test u0 == prob.u0

integ.u[2] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
end
end

@testset "Sundials + DAEProblem" begin
function daerhs(du, u, p, t)
return [du[1] - u[1] * t - p, u[1]^2 - u[2]^2]
end
function daerhs!(resid, du, u, p, t)
resid[1] = du[1] - u[1] * t - p
resid[2] = u[1]^2 - u[2]^2
end

oopfn = DAEFunction{false}(daerhs)
iipfn = DAEFunction{true}(daerhs!)

@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0)
integ = init(prob, Sundials.IDA())
u0, _, success = SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
@test success
@test u0 == prob.u0

integ.u[2] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)

integ.u[2] = 1.0
integ.du[1] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
end
end
end
Loading

0 comments on commit 258063b

Please sign in to comment.