From c6af68d44c28a5dae88917974a6ffd97b8cb1b97 Mon Sep 17 00:00:00 2001 From: Venkateshprasad <32921645+ven-k@users.noreply.github.com> Date: Tue, 4 Jul 2023 00:43:56 +0530 Subject: [PATCH] Parse the args in extended components (#2202) --- src/systems/abstractsystem.jl | 2 +- src/systems/model_parsing.jl | 35 ++++++++++++++--------------------- test/jumpsystem.jl | 24 +++++++++++++----------- test/model_parsing.jl | 32 +++++++++++++++++++++++++++++++- 4 files changed, 59 insertions(+), 34 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 2215c25cd7..a8ebdd50db 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -941,7 +941,7 @@ function varname_fix!(expr::Expr) for arg in expr.args MLStyle.@match arg begin ::Symbol => continue - Expr(:kw, a) => varname_sanitization!(arg) + Expr(:kw, a...) || Expr(:kw, a) => varname_sanitization!(arg) Expr(:parameters, a...) => begin for _arg in arg.args varname_sanitization!(_arg) diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index c130e4474f..63c14a8d40 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -15,7 +15,7 @@ end @inline is_kwarg(::Symbol) = false @inline is_kwarg(e::Expr) = (e.head == :parameters) -function connector_macro(mod, name, body; arglist = Set([]), kwargs = Set([])) +function connector_macro(mod, name, body) if !Meta.isexpr(body, :block) err = """ connector body must be a block! It should be in the form of @@ -29,6 +29,7 @@ function connector_macro(mod, name, body; arglist = Set([]), kwargs = Set([])) error(err) end vs = [] + kwargs = [] icon = Ref{Union{String, URI}}() dict = Dict{Symbol, Any}() dict[:kwargs] = Dict{Symbol, Any}() @@ -48,7 +49,7 @@ function connector_macro(mod, name, body; arglist = Set([]), kwargs = Set([])) gui_metadata = isassigned(icon) ? GUIMetadata(GlobalRef(mod, name), icon[]) : nothing quote - $name = $Model(($(arglist...); name, $(kwargs...)) -> begin + $name = $Model((; name, $(kwargs...)) -> begin $expr var"#___sys___" = $ODESystem($(Equation[]), $iv, [$(vs...)], $([]); name, gui_metadata = $gui_metadata) @@ -173,7 +174,7 @@ function get_var(mod::Module, b) b isa Symbol ? getproperty(mod, b) : b end -function mtkmodel_macro(mod, name, expr; arglist = Set([]), kwargs = Set([])) +function mtkmodel_macro(mod, name, expr) exprs = Expr(:block) dict = Dict{Symbol, Any}() dict[:kwargs] = Dict{Symbol, Any}() @@ -183,6 +184,7 @@ function mtkmodel_macro(mod, name, expr; arglist = Set([]), kwargs = Set([])) icon = Ref{Union{String, URI}}() vs = [] ps = [] + kwargs = [] for arg in expr.args arg isa LineNumberNode && continue @@ -211,7 +213,7 @@ function mtkmodel_macro(mod, name, expr; arglist = Set([]), kwargs = Set([])) push!(exprs.args, :($extend($sys, $(ext[])))) end - :($name = $Model(($(arglist...); name, $(kwargs...)) -> $exprs, $dict)) + :($name = $Model((; name, $(kwargs...)) -> $exprs, $dict)) end function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, dict, @@ -221,7 +223,7 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, dict, if mname == Symbol("@components") parse_components!(exprs, comps, dict, body, kwargs) elseif mname == Symbol("@extend") - parse_extend!(exprs, ext, dict, body) + parse_extend!(exprs, ext, dict, body, kwargs) elseif mname == Symbol("@variables") parse_variables!(exprs, vs, dict, mod, body, :variables, kwargs) elseif mname == Symbol("@parameters") @@ -272,27 +274,17 @@ function component_args!(a, b, expr, kwargs) arg = b.args[i] arg isa LineNumberNode && continue MLStyle.@match arg begin - ::Symbol => begin - _v = _rename(a, arg) - push!(kwargs, _v) - b.args[i] = Expr(:kw, arg, _v) - end - Expr(:parameters, x...) => begin - component_args!(a, arg, expr, kwargs) - end - Expr(:kw, x) => begin + x::Symbol || Expr(:kw, x) => begin _v = _rename(a, x) b.args[i] = Expr(:kw, x, _v) - push!(kwargs, _v) + push!(kwargs, Expr(:kw, _v, nothing)) end - Expr(:kw, x, y::Number) => begin - _v = _rename(a, x) - b.args[i] = Expr(:kw, x, _v) - push!(kwargs, Expr(:kw, _v, y)) + Expr(:parameters, x...) => begin + component_args!(a, arg, expr, kwargs) end Expr(:kw, x, y) => begin _v = _rename(a, x) - push!(expr.args, :($y = $_v)) + b.args[i] = Expr(:kw, x, _v) push!(kwargs, Expr(:kw, _v, y)) end _ => error("Could not parse $arg of component $a") @@ -300,7 +292,7 @@ function component_args!(a, b, expr, kwargs) end end -function parse_extend!(exprs, ext, dict, body) +function parse_extend!(exprs, ext, dict, body, kwargs) expr = Expr(:block) push!(exprs, expr) body = deepcopy(body) @@ -313,6 +305,7 @@ function parse_extend!(exprs, ext, dict, body) error("`@extend` destructuring only takes an tuple as LHS. Got $body") end a, b = b.args + component_args!(a, b, expr, kwargs) vars, a, b end ext[] = a diff --git a/test/jumpsystem.jl b/test/jumpsystem.jl index 1e0676c40e..69ac24f9af 100644 --- a/test/jumpsystem.jl +++ b/test/jumpsystem.jl @@ -1,6 +1,8 @@ -using ModelingToolkit, DiffEqBase, JumpProcesses, Test, LinearAlgebra +using ModelingToolkit, DiffEqBase, JumpProcesses, Test, LinearAlgebra, StableRNGs MT = ModelingToolkit +rng = StableRNG(12345) + # basic MT SIR model with tweaks @parameters β γ t @constants h = 1 @@ -63,7 +65,7 @@ tspan = (0.0, 250.0); u₀map = [S => 999, I => 1, R => 0] parammap = [β => 0.1 / 1000, γ => 0.01] dprob = DiscreteProblem(js2, u₀map, tspan, parammap) -jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false)) +jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng) Nsims = 30000 function getmean(jprob, Nsims) m = 0.0 @@ -79,13 +81,13 @@ m = getmean(jprob, Nsims) obs = [S2 ~ 2 * S] @named js2b = JumpSystem([j₁, j₃], t, [S, I, R], [β, γ], observed = obs) dprob = DiscreteProblem(js2b, u₀map, tspan, parammap) -jprob = JumpProblem(js2b, dprob, Direct(), save_positions = (false, false)) +jprob = JumpProblem(js2b, dprob, Direct(), save_positions = (false, false), rng = rng) sol = solve(jprob, SSAStepper(), saveat = tspan[2] / 10) @test all(2 .* sol[S] .== sol[S2]) # test save_positions is working -jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false)) +jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng) sol = solve(jprob, SSAStepper(), saveat = 1.0) @test all((sol.t) .== collect(0.0:tspan[2])) @@ -120,7 +122,7 @@ function a2!(integrator) end j2 = ConstantRateJump(r2, a2!) jset = JumpSet((), (j1, j2), nothing, nothing) -jprob = JumpProblem(prob, Direct(), jset, save_positions = (false, false)) +jprob = JumpProblem(prob, Direct(), jset, save_positions = (false, false), rng = rng) m2 = getmean(jprob, Nsims) # test JumpSystem solution agrees with direct version @@ -131,16 +133,16 @@ maj1 = MassActionJump(2 * β / 2, [S => 1, I => 1], [S => -1, I => 1]) maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1]) @named js3 = JumpSystem([maj1, maj2], t, [S, I, R], [β, γ]) dprob = DiscreteProblem(js3, u₀map, tspan, parammap) -jprob = JumpProblem(js3, dprob, Direct()) +jprob = JumpProblem(js3, dprob, Direct(), rng = rng) m3 = getmean(jprob, Nsims) @test abs(m - m3) / m < 0.01 # maj jump test with various dep graphs @named js3b = JumpSystem([maj1, maj2], t, [S, I, R], [β, γ]) -jprobb = JumpProblem(js3b, dprob, NRM()) +jprobb = JumpProblem(js3b, dprob, NRM(), rng = rng) m4 = getmean(jprobb, Nsims) @test abs(m - m4) / m < 0.01 -jprobc = JumpProblem(js3b, dprob, RSSA()) +jprobc = JumpProblem(js3b, dprob, RSSA(), rng = rng) m4 = getmean(jprobc, Nsims) @test abs(m - m4) / m < 0.01 @@ -149,7 +151,7 @@ maj1 = MassActionJump(2.0, [0 => 1], [S => 1]) maj2 = MassActionJump(γ, [S => 1], [S => -1]) @named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ]) dprob = DiscreteProblem(js4, [S => 999], (0, 1000.0), [β => 100.0, γ => 0.01]) -jprob = JumpProblem(js4, dprob, Direct()) +jprob = JumpProblem(js4, dprob, Direct(), rng = rng) m4 = getmean(jprob, Nsims) @test abs(m4 - 2.0 / 0.01) * 0.01 / 2.0 < 0.01 @@ -158,7 +160,7 @@ maj1 = MassActionJump(2.0, [0 => 1], [S => 1]) maj2 = MassActionJump(γ, [S => 2], [S => -1]) @named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ]) dprob = DiscreteProblem(js4, [S => 999], (0, 1000.0), [β => 100.0, γ => 0.01]) -jprob = JumpProblem(js4, dprob, Direct()) +jprob = JumpProblem(js4, dprob, Direct(), rng = rng) sol = solve(jprob, SSAStepper()); # issue #819 @@ -179,7 +181,7 @@ p = [k1 => 2.0, k2 => 0.0, k3 => 0.5] u₀ = [A => 100, B => 0] tspan = (0.0, 2000.0) dprob = DiscreteProblem(js5, u₀, tspan, p) -jprob = JumpProblem(js5, dprob, Direct(), save_positions = (false, false)) +jprob = JumpProblem(js5, dprob, Direct(), save_positions = (false, false), rng = rng) @test all(jprob.massaction_jump.scaled_rates .== [1.0, 0.0]) pcondit(u, t, integrator) = t == 1000.0 diff --git a/test/model_parsing.jl b/test/model_parsing.jl index f6a78d0781..817aae94ac 100644 --- a/test/model_parsing.jl +++ b/test/model_parsing.jl @@ -91,10 +91,13 @@ l15 0" stroke="black" stroke-width="1" stroke-linejoin="bevel" fill="none">