Skip to content

Commit

Permalink
feat: add `ODEProblem(::DifferentialEquation) (#300)
Browse files Browse the repository at this point in the history
  • Loading branch information
oameye authored Oct 29, 2024
1 parent af6450d commit 889753c
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 19 deletions.
51 changes: 45 additions & 6 deletions ext/ModelingToolkitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,17 @@ module ModelingToolkitExt
export ODESystem, ODEProblem, SteadyStateProblem, NonlinearProblem

using HarmonicBalance:
HarmonicEquation, is_rearranged, rearrange_standard, get_variables, ParameterList
using Symbolics: simplify, Equation, substitute, Num, @variables, expand, unwrap, arguments
HarmonicEquation,
is_rearranged,
rearrange_standard,
get_variables,
ParameterList,
DifferentialEquation,
get_independent_variables
using HarmonicBalance.KrylovBogoliubov:
rearrange_standard!, is_rearranged_standard, first_order_transform!
using Symbolics:
simplify, Equation, substitute, Num, @variables, expand, unwrap, arguments, wrap
using ModelingToolkit:
ModelingToolkit,
ODESystem,
Expand Down Expand Up @@ -38,29 +47,59 @@ function ModelingToolkit.ODESystem(eom::HarmonicEquation)
@assert isone(length(slow_time)) "The argument of the variables are not the same."
slow_time_ivp = @eval @independent_variables $(Symbol(first(slow_time)))

# we have the replace the param made with @variables with the ones made with @parameters
par_names = declare_parameter.(eom.parameters)

eqs = deepcopy(eom.equations)
eqs = swapsides.(eqs)
eqs = simplify.(expand.(eqs))
eqs = substitute(eqs, Dict(zip(eom.parameters, par_names)))

# compute jacobian for performance
# ∨ mtk v9 need @mtkbuild
@mtkbuild sys = ODESystem(eqs, first(slow_time_ivp), vars, par_names)
return sys
end

function ModelingToolkit.ODESystem(diff_eq::DifferentialEquation)
if !is_rearranged_standard(diff_eq)
rearrange_standard!(diff_eq)
end

times = get_independent_variables(diff_eq)
@assert isone(length(times)) "Only one independent variable allowed."
iv = first(@eval @independent_variables $(Symbol(first(times))))

first_order_transform!(diff_eq, iv)

eqs = collect(values(diff_eq.equations))
vars = get_variables(diff_eq)

diff_eq_sym = collect(Iterators.flatten(get_variables.(eqs)))
param_undeclared = setdiff(setdiff(wrap.(diff_eq_sym), vars), iv)
params = declare_parameter.(param_undeclared)

eqs = substitute(eqs, Dict(zip(param_undeclared, params)))

@mtkbuild sys = ODESystem(eqs, first(iv), vars, params)

return sys
end

function ModelingToolkit.ODEProblem(
eom::HarmonicEquation, u0, tspan::Tuple, p::ParameterList; in_place=true, kwargs...
eom::Union{HarmonicEquation,DifferentialEquation},
u0,
tspan::Tuple,
p::ParameterList;
in_place=true,
kwargs...,
)
sys = ODESystem(eom)
param = varmap_to_vars(p, parameters(sys))
if !in_place # out-of-place
prob = ODEProblem{false}(sys, u0, tspan, param; jac=true, kwargs...)
else # in-place
prob = ODEProblem{true}(sys, u0, tspan, param; jac=true, kwargs...)
end
end # compute jacobian for performance
return prob
end

Expand All @@ -80,7 +119,7 @@ function ModelingToolkit.SteadyStateProblem(
prob = SteadyStateProblem{false}(sys, u0, param; jac=true, kwargs...)
else # in-place
prob = SteadyStateProblem{true}(sys, u0, param; jac=true, kwargs...)
end
end # compute jacobian for performance
return prob
end

Expand Down
69 changes: 56 additions & 13 deletions test/ModelingToolkitExt.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,64 @@
using HarmonicBalance
using ModelingToolkit
using ModelingToolkit: varmap_to_vars

ModelingToolkitExt = Base.get_extension(HarmonicBalance, :ModelingToolkitExt)

using Test

@variables α ω ω0 F γ t x(t)
diff_eq = DifferentialEquation(
d(x, t, 2) + ω0^2 * x + α * x^3 + γ * d(x, t) ~ F * cos* t), x
)
add_harmonic!(diff_eq, x, ω) #
harmonic_eq = get_harmonic_equations(diff_eq)
@testset "Utilities" begin
@variables α
check = ModelingToolkitExt.declare_parameter(α)
@test ModelingToolkit.PARAMETER values(check.val.metadata)
end

@testset "DifferentialEquation" begin
@testset "ODESystem" begin
@variables α ω ω0 F γ t x(t)
diff_eq = DifferentialEquation(
d(x, t, 2) + ω0^2 * x + α * x^3 + γ * d(x, t) ~ F * cos* t), x
)

fixed ==> 1.0, ω0 => 1.1, F => 0.01, γ => 0.01)
param = ParameterList(merge(Dict(fixed), Dict=> 1.1)))
sys = ODESystem(diff_eq)

for p in string.([α, ω, ω0, F, γ])
@test p string.(parameters(sys))
end
end
@testset "ODEProblem" begin
@variables α ω ω0 F γ t x(t)
diff_eq = DifferentialEquation(
d(x, t, 2) + ω0^2 * x + α * x^3 + γ * d(x, t) ~ F * cos* t), x
)

sys = ODESystem(harmonic_eq)
fixed ==> 1.0, ω0 => 1.1, F => 0.01, γ => 0.01)
param = ParameterList(merge(Dict(fixed), Dict=> 1.1)))
fixed ==> 1.0, ω0 => 1.1, F => 0.01, γ => 0.01)
param = ParameterList(merge(Dict(fixed), Dict=> 1.1)))

for p in string.([α, ω, ω0, F, γ])
@test p string.(parameters(sys))
ODEProblem(diff_eq, [1.0, 0.0], (0, 100), param)
end
end

ODEProblem(harmonic_eq, [1.0, 0.0], (0, 100), param)
@testset "HarmonicEquation" begin
@variables α ω ω0 F γ t x(t)
diff_eq = DifferentialEquation(
d(x, t, 2) + ω0^2 * x + α * x^3 + γ * d(x, t) ~ F * cos* t), x
)

add_harmonic!(diff_eq, x, ω) #
harmonic_eq = get_harmonic_equations(diff_eq)

fixed ==> 1.0, ω0 => 1.1, F => 0.01, γ => 0.01)
param = ParameterList(merge(Dict(fixed), Dict=> 1.1)))
@testset "ODESystem" begin
sys = ODESystem(harmonic_eq)

for p in string.([α, ω, ω0, F, γ])
@test p string.(parameters(sys))
end
end

@testset "ODEProblem" begin
ODEProblem(harmonic_eq, [1.0, 0.0], (0, 100), param)
end
end

0 comments on commit 889753c

Please sign in to comment.