Skip to content

Commit

Permalink
Merge pull request #2231 from SciML/myb/ps
Browse files Browse the repository at this point in the history
Handle inhomogeneous parameters using a Tuple of Vectors
  • Loading branch information
YingboMa authored Sep 14, 2023
2 parents 5ef23af + 6acee61 commit 736520a
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 36 deletions.
46 changes: 46 additions & 0 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,49 @@ macro parameters(xs...)
xs,
toparam) |> esc
end

function find_types(array)
by = let set = Dict{Any, Int}(), counter = Ref(0)
x -> begin
# t = typeof(x)

get!(set, typeof(x)) do
# if t == Float64
# 1
# else
counter[] += 1
# end
end
end
end
return by.(array)
end

function split_parameters_by_type(ps)
if ps === SciMLBase.NullParameters()
return Float64[], [] #use Float64 to avoid Any type warning
else
by = let set = Dict{Any, Int}(), counter = Ref(0)
x -> begin
get!(set, typeof(x)) do
counter[] += 1
end
end
end
idxs = by.(ps)
split_idxs = [Int[]]
for (i, idx) in enumerate(idxs)
if idx > length(split_idxs)
push!(split_idxs, Int[])
end
push!(split_idxs[idx], i)
end
tighten_types = x -> identity.(x)
split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs))
if length(split_ps) == 1 #Tuple not needed, only 1 type
return split_ps[1], split_idxs
else
return (split_ps...,), split_idxs
end
end
end
6 changes: 3 additions & 3 deletions src/structural_transformation/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@ function ODAEProblem{iip}(sys,
tspan,
parammap = DiffEqBase.NullParameters();
callback = nothing,
use_union = false,
use_union = true,
tofloat = true,
check = true,
kwargs...) where {iip}
eqs = equations(sys)
Expand All @@ -540,8 +541,7 @@ function ODAEProblem{iip}(sys,
defs = ModelingToolkit.mergedefaults(defs, parammap, ps)
defs = ModelingToolkit.mergedefaults(defs, u0map, dvs)
u0 = ModelingToolkit.varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults = defs, tofloat = !use_union,
use_union)
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)

has_difference = any(isdifferenceeq, eqs)
cbs = process_events(sys; callback, has_difference, kwargs...)
Expand Down
93 changes: 75 additions & 18 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,14 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
states = sol_states,
kwargs...)
else
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
kwargs...)
if p isa Tuple
build_function(rhss, u, p..., t; postprocess_fbody = pre,
states = sol_states,
kwargs...)
else
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
kwargs...)
end
end
end
end
Expand Down Expand Up @@ -332,8 +338,15 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
f_oop, f_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
f_gen
f(u, p, t) = f_oop(u, p, t)
f(du, u, p, t) = f_iip(du, u, p, t)
if p isa Tuple
g(u, p, t) = f_oop(u, p..., t)
g(du, u, p, t) = f_iip(du, u, p..., t)
f = g
else
k(u, p, t) = f_oop(u, p, t)
k(du, u, p, t) = f_iip(du, u, p, t)
f = k
end

if specialize === SciMLBase.FunctionWrapperSpecialize && iip
if u0 === nothing || p === nothing || t === nothing
Expand Down Expand Up @@ -384,32 +397,64 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s

obs = observed(sys)
observedfun = if steady_state
let sys = sys, dict = Dict()
let sys = sys, dict = Dict(), ps = ps
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar)
end
if args === ()
let obs = obs
(u, p, t = Inf) -> obs(u, p, t)
(u, p, t = Inf) -> if ps isa Tuple
obs(u, p..., t)
else
obs(u, p, t)
end
end
else
length(args) == 2 ? obs(args..., Inf) : obs(args...)
if ps isa Tuple
if length(args) == 2
u, p = args
obs(u, p..., Inf)
else
u, p, t = args
obs(u, p..., t)
end
else
if length(args) == 2
u, p = args
obs(u, p, Inf)
else
u, p, t = args
obs(u, p, t)
end
end
end
end
end
else
let sys = sys, dict = Dict()
let sys = sys, dict = Dict(), ps = ps
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
build_explicit_observed_function(sys,
obsvar;
checkbounds = checkbounds,
ps)
end
if args === ()
let obs = obs
(u, p, t) -> obs(u, p, t)
(u, p, t) -> if ps isa Tuple
obs(u, p..., t)
else
obs(u, p, t)
end
end
else
obs(args...)
if ps isa Tuple # split parameters
u, p, t = args
obs(u, p..., t)
else
obs(args...)
end
end
end
end
Expand Down Expand Up @@ -677,15 +722,15 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
end

"""
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=false, tofloat=!use_union)
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=true, tofloat=true)
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
"""
function get_u0_p(sys,
u0map,
parammap;
use_union = false,
tofloat = !use_union,
use_union = true,
tofloat = true,
symbolic_u0 = false)
dvs = states(sys)
ps = parameters(sys)
Expand All @@ -712,16 +757,27 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
simplify = false,
linenumbers = true, parallel = SerialForm(),
eval_expression = true,
use_union = false,
tofloat = !use_union,
use_union = true,
tofloat = true,
symbolic_u0 = false,
kwargs...)
eqs = equations(sys)
dvs = states(sys)
ps = parameters(sys)
iv = get_iv(sys)

u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union, symbolic_u0)
u0, p, defs = get_u0_p(sys,
u0map,
parammap;
tofloat,
use_union,
symbolic_u0)

p, split_idxs = split_parameters_by_type(p)
if p isa Tuple
ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
ps = (ps...,) #if p is Tuple, ps should be Tuple
end

if implicit_dae && du0map !== nothing
ddvs = map(Differential(iv), dvs)
Expand All @@ -738,7 +794,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
checkbounds = checkbounds, p = p,
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
sparse = sparse, eval_expression = eval_expression, kwargs...)
sparse = sparse, eval_expression = eval_expression,
kwargs...)
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
end

Expand Down
14 changes: 9 additions & 5 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ function build_explicit_observed_function(sys, ts;
output_type = Array,
checkbounds = true,
drop_expr = drop_expr,
ps = parameters(sys),
throw = true)
if (isscalar = !(ts isa AbstractVector))
ts = [ts]
Expand Down Expand Up @@ -385,17 +386,20 @@ function build_explicit_observed_function(sys, ts;
push!(obsexprs, lhs rhs)
end

pars = parameters(sys)
if inputs !== nothing
pars = setdiff(pars, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
end
if ps isa Tuple
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
else
ps = (DestructuredArgs(ps, inbounds = !checkbounds),)
end
ps = DestructuredArgs(pars, inbounds = !checkbounds)
dvs = DestructuredArgs(states(sys), inbounds = !checkbounds)
if inputs === nothing
args = [dvs, ps, ivs...]
args = [dvs, ps..., ivs...]
else
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
args = [dvs, ipts, ps, ivs...]
args = [dvs, ipts, ps..., ivs...]
end
pre = get_postprocess_fbody(sys)

Expand Down
18 changes: 17 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ end

hasdefault(v) = hasmetadata(v, Symbolics.VariableDefaultValue)
getdefault(v) = value(getmetadata(v, Symbolics.VariableDefaultValue))
function getdefaulttype(v)
def = value(getmetadata(unwrap(v), Symbolics.VariableDefaultValue, nothing))
def === nothing ? Float64 : typeof(def)
end
function setdefault(v, val)
val === nothing ? v : setmetadata(v, Symbolics.VariableDefaultValue, value(val))
end
Expand Down Expand Up @@ -642,10 +646,15 @@ end
throw(ArgumentError("$vars are either missing from the variable map or missing from the system's states/parameters list."))
end

function promote_to_concrete(vs; tofloat = true, use_union = false)
function promote_to_concrete(vs; tofloat = true, use_union = true)
if isempty(vs)
return vs
end
if vs isa Tuple #special rule, if vs is a Tuple, preserve types, container converted to Array
tofloat = false
use_union = true
vs = Any[vs...]
end
T = eltype(vs)
if Base.isconcretetype(T) && (!tofloat || T === float(T)) # nothing to do
vs
Expand All @@ -656,6 +665,7 @@ function promote_to_concrete(vs; tofloat = true, use_union = false)
I = Int8
has_int = false
has_array = false
has_bool = false
array_T = nothing
for v in vs
if v isa AbstractArray
Expand All @@ -668,6 +678,9 @@ function promote_to_concrete(vs; tofloat = true, use_union = false)
has_int = true
I = promote_type(I, E)
end
if E <: Bool
has_bool = true
end
end
if tofloat && !has_array
C = float(C)
Expand All @@ -678,6 +691,9 @@ function promote_to_concrete(vs; tofloat = true, use_union = false)
if has_int
C = Union{C, I}
end
if has_bool
C = Union{C, Bool}
end
return copyto!(similar(vs, C), vs)
end
convert.(C, vs)
Expand Down
9 changes: 5 additions & 4 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ applicable.
"""
function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
toterm = default_toterm, promotetoconcrete = nothing,
tofloat = true, use_union = false)
tofloat = true, use_union = true)
varlist = collect(map(unwrap, varlist))

# Edge cases where one of the arguments is effectively empty.
Expand All @@ -75,9 +75,10 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
end
end

T = typeof(varmap)
# We respect the input type
container_type = T <: Dict ? Array : T
# T = typeof(varmap)
# We respect the input type (feature removed, not needed with Tuple support)
# container_type = T <: Union{Dict,Tuple} ? Array : T
container_type = Array

vals = if eltype(varmap) <: Pair # `varmap` is a dict or an array of pairs
varmap = todict(varmap)
Expand Down
20 changes: 15 additions & 5 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -734,18 +734,28 @@ let
u0map = [A => 1.0]
pmap = (k1 => 1.0, k2 => 1)
tspan = (0.0, 1.0)
prob = ODEProblem(sys, u0map, tspan, pmap; tofloat = false)
@test prob.p == ([1], [1.0]) #Tuple([(Dict(pmap))[k] for k in values(parameters(sys))])

prob = ODEProblem(sys, u0map, tspan, pmap)
@test prob.p === Tuple([(Dict(pmap))[k] for k in values(parameters(sys))])
@test prob.p isa Vector{Float64}

pmap = [k1 => 1, k2 => 1]
tspan = (0.0, 1.0)
prob = ODEProblem(sys, u0map, tspan, pmap)
@test eltype(prob.p) === Float64

pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0]
tspan = (0.0, 1.0)
prob = ODEProblem(sys, u0map, tspan, pmap, use_union = true)
@test eltype(prob.p) === Union{Float64, Int}
prob = ODEProblem(sys, u0map, tspan, pmap; tofloat = false)
@test eltype(prob.p) === Int

prob = ODEProblem(sys, u0map, tspan, pmap)
@test prob.p isa Vector{Float64}

# No longer supported, Tuple used instead
# pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0]
# tspan = (0.0, 1.0)
# prob = ODEProblem(sys, u0map, tspan, pmap, use_union = true)
# @test eltype(prob.p) === Union{Float64, Int}
end

let
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ using SafeTestsets, Test
@safetestset "JumpSystem Test" include("jumpsystem.jl")
@safetestset "Constraints Test" include("constraints.jl")
@safetestset "Reduction Test" include("reduction.jl")
@safetestset "Split Parameters Test" include("split_parameters.jl")
@safetestset "ODAEProblem Test" include("odaeproblem.jl")
@safetestset "Components Test" include("components.jl")
@safetestset "Model Parsing Test" include("model_parsing.jl")
Expand Down
Loading

0 comments on commit 736520a

Please sign in to comment.