diff --git a/ext/ModelingToolkitExt.jl b/ext/ModelingToolkitExt.jl index b5110701..de0cd3c7 100644 --- a/ext/ModelingToolkitExt.jl +++ b/ext/ModelingToolkitExt.jl @@ -3,8 +3,17 @@ module ModelingToolkitExt export ODESystem, ODEProblem, SteadyStateProblem, NonlinearProblem using HarmonicBalance: - HarmonicEquation, is_rearranged, rearrange_standard, get_variables, ParameterList -using Symbolics: simplify, Equation, substitute, Num, @variables, expand, unwrap, arguments + HarmonicEquation, + is_rearranged, + rearrange_standard, + get_variables, + ParameterList, + DifferentialEquation, + get_independent_variables +using HarmonicBalance.KrylovBogoliubov: + rearrange_standard!, is_rearranged_standard, first_order_transform! +using Symbolics: + simplify, Equation, substitute, Num, @variables, expand, unwrap, arguments, wrap using ModelingToolkit: ModelingToolkit, ODESystem, @@ -38,6 +47,7 @@ function ModelingToolkit.ODESystem(eom::HarmonicEquation) @assert isone(length(slow_time)) "The argument of the variables are not the same." slow_time_ivp = @eval @independent_variables $(Symbol(first(slow_time))) + # we have the replace the param made with @variables with the ones made with @parameters par_names = declare_parameter.(eom.parameters) eqs = deepcopy(eom.equations) @@ -45,14 +55,43 @@ function ModelingToolkit.ODESystem(eom::HarmonicEquation) eqs = simplify.(expand.(eqs)) eqs = substitute(eqs, Dict(zip(eom.parameters, par_names))) - # compute jacobian for performance # ∨ mtk v9 need @mtkbuild @mtkbuild sys = ODESystem(eqs, first(slow_time_ivp), vars, par_names) return sys end +function ModelingToolkit.ODESystem(diff_eq::DifferentialEquation) + if !is_rearranged_standard(diff_eq) + rearrange_standard!(diff_eq) + end + + times = get_independent_variables(diff_eq) + @assert isone(length(times)) "Only one independent variable allowed." + iv = first(@eval @independent_variables $(Symbol(first(times)))) + + first_order_transform!(diff_eq, iv) + + eqs = collect(values(diff_eq.equations)) + vars = get_variables(diff_eq) + + diff_eq_sym = collect(Iterators.flatten(get_variables.(eqs))) + param_undeclared = setdiff(setdiff(wrap.(diff_eq_sym), vars), iv) + params = declare_parameter.(param_undeclared) + + eqs = substitute(eqs, Dict(zip(param_undeclared, params))) + + @mtkbuild sys = ODESystem(eqs, first(iv), vars, params) + + return sys +end + function ModelingToolkit.ODEProblem( - eom::HarmonicEquation, u0, tspan::Tuple, p::ParameterList; in_place=true, kwargs... + eom::Union{HarmonicEquation,DifferentialEquation}, + u0, + tspan::Tuple, + p::ParameterList; + in_place=true, + kwargs..., ) sys = ODESystem(eom) param = varmap_to_vars(p, parameters(sys)) @@ -60,7 +99,7 @@ function ModelingToolkit.ODEProblem( prob = ODEProblem{false}(sys, u0, tspan, param; jac=true, kwargs...) else # in-place prob = ODEProblem{true}(sys, u0, tspan, param; jac=true, kwargs...) - end + end # compute jacobian for performance return prob end @@ -80,7 +119,7 @@ function ModelingToolkit.SteadyStateProblem( prob = SteadyStateProblem{false}(sys, u0, param; jac=true, kwargs...) else # in-place prob = SteadyStateProblem{true}(sys, u0, param; jac=true, kwargs...) - end + end # compute jacobian for performance return prob end diff --git a/test/ModelingToolkitExt.jl b/test/ModelingToolkitExt.jl index 22e05866..778253f1 100644 --- a/test/ModelingToolkitExt.jl +++ b/test/ModelingToolkitExt.jl @@ -1,21 +1,64 @@ using HarmonicBalance using ModelingToolkit -using ModelingToolkit: varmap_to_vars + +ModelingToolkitExt = Base.get_extension(HarmonicBalance, :ModelingToolkitExt) + using Test -@variables α ω ω0 F γ t x(t) -diff_eq = DifferentialEquation( - d(x, t, 2) + ω0^2 * x + α * x^3 + γ * d(x, t) ~ F * cos(ω * t), x -) -add_harmonic!(diff_eq, x, ω) # -harmonic_eq = get_harmonic_equations(diff_eq) +@testset "Utilities" begin + @variables α + check = ModelingToolkitExt.declare_parameter(α) + @test ModelingToolkit.PARAMETER ∈ values(check.val.metadata) +end + +@testset "DifferentialEquation" begin + @testset "ODESystem" begin + @variables α ω ω0 F γ t x(t) + diff_eq = DifferentialEquation( + d(x, t, 2) + ω0^2 * x + α * x^3 + γ * d(x, t) ~ F * cos(ω * t), x + ) + + fixed = (α => 1.0, ω0 => 1.1, F => 0.01, γ => 0.01) + param = ParameterList(merge(Dict(fixed), Dict(ω => 1.1))) + sys = ODESystem(diff_eq) + + for p in string.([α, ω, ω0, F, γ]) + @test p ∈ string.(parameters(sys)) + end + end + @testset "ODEProblem" begin + @variables α ω ω0 F γ t x(t) + diff_eq = DifferentialEquation( + d(x, t, 2) + ω0^2 * x + α * x^3 + γ * d(x, t) ~ F * cos(ω * t), x + ) -sys = ODESystem(harmonic_eq) -fixed = (α => 1.0, ω0 => 1.1, F => 0.01, γ => 0.01) -param = ParameterList(merge(Dict(fixed), Dict(ω => 1.1))) + fixed = (α => 1.0, ω0 => 1.1, F => 0.01, γ => 0.01) + param = ParameterList(merge(Dict(fixed), Dict(ω => 1.1))) -for p in string.([α, ω, ω0, F, γ]) - @test p ∈ string.(parameters(sys)) + ODEProblem(diff_eq, [1.0, 0.0], (0, 100), param) + end end -ODEProblem(harmonic_eq, [1.0, 0.0], (0, 100), param) +@testset "HarmonicEquation" begin + @variables α ω ω0 F γ t x(t) + diff_eq = DifferentialEquation( + d(x, t, 2) + ω0^2 * x + α * x^3 + γ * d(x, t) ~ F * cos(ω * t), x + ) + + add_harmonic!(diff_eq, x, ω) # + harmonic_eq = get_harmonic_equations(diff_eq) + + fixed = (α => 1.0, ω0 => 1.1, F => 0.01, γ => 0.01) + param = ParameterList(merge(Dict(fixed), Dict(ω => 1.1))) + @testset "ODESystem" begin + sys = ODESystem(harmonic_eq) + + for p in string.([α, ω, ω0, F, γ]) + @test p ∈ string.(parameters(sys)) + end + end + + @testset "ODEProblem" begin + ODEProblem(harmonic_eq, [1.0, 0.0], (0, 100), param) + end +end