Skip to content

Commit

Permalink
refactor: change SCCNonlinearProblem fields
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 3, 2024
1 parent 86aa145 commit 1dcbd1f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 30 deletions.
30 changes: 16 additions & 14 deletions src/problems/nonlinear_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,28 +462,30 @@ Note that this example aliases the parameters together for a memory-reduced repr
* `probs`: the collection of problems to solve
* `explictfuns!`: the explicit functions for mutating the parameter set
"""
mutable struct SCCNonlinearProblem{uType, iip, P, E, I, Par} <:
mutable struct SCCNonlinearProblem{uType, iip, P, E, F <: NonlinearFunction{iip}, Par} <:
AbstractNonlinearProblem{uType, iip}
probs::P
explicitfuns!::E
full_index_provider::I
parameter_object::Par
# NonlinearFunction with `f = Returns(nothing)`
f::F
p::Par
parameters_alias::Bool

function SCCNonlinearProblem{P, E, I, Par}(
probs::P, funs::E, indp::I, pobj::Par, alias::Bool) where {P, E, I, Par}
function SCCNonlinearProblem{P, E, F, Par}(probs::P, funs::E, f::F, pobj::Par,
alias::Bool) where {P, E, F <: NonlinearFunction, Par}
u0 = mapreduce(
state_values, vcat, probs; init = similar(state_values(first(probs)), 0))
uType = typeof(u0)
new{uType, false, P, E, I, Par}(probs, funs, indp, pobj, alias)
new{uType, false, P, E, F, Par}(probs, funs, f, pobj, alias)
end
end

function SCCNonlinearProblem(probs, explicitfuns!, full_index_provider = nothing,
parameter_object = nothing, parameters_alias = false)
function SCCNonlinearProblem(probs, explicitfuns!, parameter_object = nothing,
parameters_alias = false; kwargs...)
f = NonlinearFunction{false}(Returns(nothing); kwargs...)
return SCCNonlinearProblem{typeof(probs), typeof(explicitfuns!),
typeof(full_index_provider), typeof(parameter_object)}(
probs, explicitfuns!, full_index_provider, parameter_object, parameters_alias)
typeof(f), typeof(parameter_object)}(
probs, explicitfuns!, f, parameter_object, parameters_alias)
end

function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol)
Expand All @@ -496,10 +498,10 @@ function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol)
end

function SymbolicIndexingInterface.symbolic_container(prob::SCCNonlinearProblem)
prob.full_index_provider
prob.f
end
function SymbolicIndexingInterface.parameter_values(prob::SCCNonlinearProblem)
prob.parameter_object
prob.p
end
function SymbolicIndexingInterface.state_values(prob::SCCNonlinearProblem)
mapreduce(
Expand All @@ -516,8 +518,8 @@ function SymbolicIndexingInterface.set_state!(prob::SCCNonlinearProblem, val, id
end

function SymbolicIndexingInterface.set_parameter!(prob::SCCNonlinearProblem, val, idx)
if prob.parameter_object !== nothing
set_parameter!(prob.parameter_object, val, idx)
if prob.p !== nothing
set_parameter!(prob.p, val, idx)
prob.parameters_alias && return
end
for scc in prob.probs
Expand Down
14 changes: 6 additions & 8 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,7 @@ function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = mi
if p !== missing && !parameters_alias && probs === missing
throw(ArgumentError("`parameters_alias` is `false` for the given `SCCNonlinearProblem`. Please provide the subproblems using the keyword `probs` with the parameters updated appropriately in each."))
end
newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults,
indp = sys === missing ? prob.full_index_provider : sys)
newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
if probs === missing
probs = prob.probs
end
Expand All @@ -547,11 +546,10 @@ function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = mi
end
end
if sys === missing
sys = prob.full_index_provider
sys = prob.f.sys
end
return SCCNonlinearProblem{
typeof(probs), typeof(explicitfuns!), typeof(sys), typeof(newp)}(
probs, explicitfuns!, sys, newp, parameters_alias)
return SCCNonlinearProblem(
probs, explicitfuns!, newp, parameters_alias; sys)
end

function varmap_has_var(varmap, var)
Expand Down Expand Up @@ -784,11 +782,11 @@ end

function updated_u0_p(
prob, u0, p, t0 = nothing; interpret_symbolicmap = true,
use_defaults = false, indp = has_sys(prob.f) ? prob.f.sys : nothing)
use_defaults = false)
if u0 === missing && p === missing
return state_values(prob), parameter_values(prob)
end
if indp === nothing
if prob.f.sys === nothing
if interpret_symbolicmap && eltype(p) !== Union{} && eltype(p) <: Pair
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
Expand Down
12 changes: 6 additions & 6 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ fullsys = complete(fullsys)
prob1 = NonlinearProblem(sys1, u0, p)
prob2 = NonlinearProblem(sys2, u0, prob1.p)
sccprob = SCCNonlinearProblem(
[prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true)
[prob1, prob2], [Returns(nothing), Returns(nothing)], prob1.p, true; sys = fullsys)
push!(syss, fullsys)
push!(probs, sccprob)

Expand Down Expand Up @@ -315,16 +315,16 @@ end
prob1 = NonlinearProblem(sys1, u0, p)
prob2 = NonlinearProblem(sys2, u0, prob1.p)
sccprob = SCCNonlinearProblem(
[prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true)
[prob1, prob2], [Returns(nothing), Returns(nothing)], prob1.p, true; sys = fullsys)

sccprob2 = remake(sccprob; u0 = 2ones(3))
@test state_values(sccprob2) 2ones(3)
@test sccprob2.probs[1].u0 2ones(2)
@test sccprob2.probs[2].u0 2ones(1)

sccprob3 = remake(sccprob; p ==> 2.0])
@test sccprob3.parameter_object === sccprob3.probs[1].p
@test sccprob3.parameter_object === sccprob3.probs[2].p
@test sccprob3.p === sccprob3.probs[1].p
@test sccprob3.p === sccprob3.probs[2].p

@test_throws ["parameters_alias", "SCCNonlinearProblem"] remake(
sccprob; parameters_alias = false, p ==> 2.0])
Expand All @@ -333,6 +333,6 @@ end
sccprob4 = remake(sccprob; parameters_alias = false, p = newp,
probs = [remake(prob1; p ==> 3.0]), prob2])
@test !sccprob4.parameters_alias
@test sccprob4.parameter_object !== sccprob4.probs[1].p
@test sccprob4.parameter_object !== sccprob4.probs[2].p
@test sccprob4.p !== sccprob4.probs[1].p
@test sccprob4.p !== sccprob4.probs[2].p
end
4 changes: 2 additions & 2 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ prob = SteadyStateProblem(osys, u0, ps)
prob = NonlinearProblem(model, [])
sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3],
SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]),
model, copy(cache))
copy(cache); sys = model)

for sym in [u, u..., u[2] + u[3], p * u[1] + u[2]]
@test prob[sym] sccprob[sym]
Expand All @@ -384,7 +384,7 @@ prob = SteadyStateProblem(osys, u0, ps)
end
sccprob.ps[p] = 2.5
@test sccprob.ps[p] 2.5
@test sccprob.parameter_object[1] 2.5
@test sccprob.p[1] 2.5
for scc in sccprob.probs
@test parameter_values(scc)[1] 2.5
end
Expand Down

0 comments on commit 1dcbd1f

Please sign in to comment.