Skip to content

Commit

Permalink
Parse the args in extended components (#2202)
Browse files Browse the repository at this point in the history
  • Loading branch information
ven-k authored Jul 3, 2023
1 parent bf3551e commit c6af68d
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 34 deletions.
2 changes: 1 addition & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ function varname_fix!(expr::Expr)
for arg in expr.args
MLStyle.@match arg begin
::Symbol => continue
Expr(:kw, a) => varname_sanitization!(arg)
Expr(:kw, a...) || Expr(:kw, a) => varname_sanitization!(arg)
Expr(:parameters, a...) => begin
for _arg in arg.args
varname_sanitization!(_arg)
Expand Down
35 changes: 14 additions & 21 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end
@inline is_kwarg(::Symbol) = false
@inline is_kwarg(e::Expr) = (e.head == :parameters)

function connector_macro(mod, name, body; arglist = Set([]), kwargs = Set([]))
function connector_macro(mod, name, body)
if !Meta.isexpr(body, :block)
err = """
connector body must be a block! It should be in the form of
Expand All @@ -29,6 +29,7 @@ function connector_macro(mod, name, body; arglist = Set([]), kwargs = Set([]))
error(err)
end
vs = []
kwargs = []
icon = Ref{Union{String, URI}}()
dict = Dict{Symbol, Any}()
dict[:kwargs] = Dict{Symbol, Any}()
Expand All @@ -48,7 +49,7 @@ function connector_macro(mod, name, body; arglist = Set([]), kwargs = Set([]))
gui_metadata = isassigned(icon) ? GUIMetadata(GlobalRef(mod, name), icon[]) :
nothing
quote
$name = $Model(($(arglist...); name, $(kwargs...)) -> begin
$name = $Model((; name, $(kwargs...)) -> begin
$expr
var"#___sys___" = $ODESystem($(Equation[]), $iv, [$(vs...)], $([]);
name, gui_metadata = $gui_metadata)
Expand Down Expand Up @@ -173,7 +174,7 @@ function get_var(mod::Module, b)
b isa Symbol ? getproperty(mod, b) : b
end

function mtkmodel_macro(mod, name, expr; arglist = Set([]), kwargs = Set([]))
function mtkmodel_macro(mod, name, expr)
exprs = Expr(:block)
dict = Dict{Symbol, Any}()
dict[:kwargs] = Dict{Symbol, Any}()
Expand All @@ -183,6 +184,7 @@ function mtkmodel_macro(mod, name, expr; arglist = Set([]), kwargs = Set([]))
icon = Ref{Union{String, URI}}()
vs = []
ps = []
kwargs = []

for arg in expr.args
arg isa LineNumberNode && continue
Expand Down Expand Up @@ -211,7 +213,7 @@ function mtkmodel_macro(mod, name, expr; arglist = Set([]), kwargs = Set([]))
push!(exprs.args, :($extend($sys, $(ext[]))))
end

:($name = $Model(($(arglist...); name, $(kwargs...)) -> $exprs, $dict))
:($name = $Model((; name, $(kwargs...)) -> $exprs, $dict))
end

function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, dict,
Expand All @@ -221,7 +223,7 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, dict,
if mname == Symbol("@components")
parse_components!(exprs, comps, dict, body, kwargs)
elseif mname == Symbol("@extend")
parse_extend!(exprs, ext, dict, body)
parse_extend!(exprs, ext, dict, body, kwargs)
elseif mname == Symbol("@variables")
parse_variables!(exprs, vs, dict, mod, body, :variables, kwargs)
elseif mname == Symbol("@parameters")
Expand Down Expand Up @@ -272,35 +274,25 @@ function component_args!(a, b, expr, kwargs)
arg = b.args[i]
arg isa LineNumberNode && continue
MLStyle.@match arg begin
::Symbol => begin
_v = _rename(a, arg)
push!(kwargs, _v)
b.args[i] = Expr(:kw, arg, _v)
end
Expr(:parameters, x...) => begin
component_args!(a, arg, expr, kwargs)
end
Expr(:kw, x) => begin
x::Symbol || Expr(:kw, x) => begin
_v = _rename(a, x)
b.args[i] = Expr(:kw, x, _v)
push!(kwargs, _v)
push!(kwargs, Expr(:kw, _v, nothing))
end
Expr(:kw, x, y::Number) => begin
_v = _rename(a, x)
b.args[i] = Expr(:kw, x, _v)
push!(kwargs, Expr(:kw, _v, y))
Expr(:parameters, x...) => begin
component_args!(a, arg, expr, kwargs)
end
Expr(:kw, x, y) => begin
_v = _rename(a, x)
push!(expr.args, :($y = $_v))
b.args[i] = Expr(:kw, x, _v)
push!(kwargs, Expr(:kw, _v, y))
end
_ => error("Could not parse $arg of component $a")
end
end
end

function parse_extend!(exprs, ext, dict, body)
function parse_extend!(exprs, ext, dict, body, kwargs)
expr = Expr(:block)
push!(exprs, expr)
body = deepcopy(body)
Expand All @@ -313,6 +305,7 @@ function parse_extend!(exprs, ext, dict, body)
error("`@extend` destructuring only takes an tuple as LHS. Got $body")
end
a, b = b.args
component_args!(a, b, expr, kwargs)
vars, a, b
end
ext[] = a
Expand Down
24 changes: 13 additions & 11 deletions test/jumpsystem.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using ModelingToolkit, DiffEqBase, JumpProcesses, Test, LinearAlgebra
using ModelingToolkit, DiffEqBase, JumpProcesses, Test, LinearAlgebra, StableRNGs
MT = ModelingToolkit

rng = StableRNG(12345)

# basic MT SIR model with tweaks
@parameters β γ t
@constants h = 1
Expand Down Expand Up @@ -63,7 +65,7 @@ tspan = (0.0, 250.0);
u₀map = [S => 999, I => 1, R => 0]
parammap ==> 0.1 / 1000, γ => 0.01]
dprob = DiscreteProblem(js2, u₀map, tspan, parammap)
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false))
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng)
Nsims = 30000
function getmean(jprob, Nsims)
m = 0.0
Expand All @@ -79,13 +81,13 @@ m = getmean(jprob, Nsims)
obs = [S2 ~ 2 * S]
@named js2b = JumpSystem([j₁, j₃], t, [S, I, R], [β, γ], observed = obs)
dprob = DiscreteProblem(js2b, u₀map, tspan, parammap)
jprob = JumpProblem(js2b, dprob, Direct(), save_positions = (false, false))
jprob = JumpProblem(js2b, dprob, Direct(), save_positions = (false, false), rng = rng)
sol = solve(jprob, SSAStepper(), saveat = tspan[2] / 10)
@test all(2 .* sol[S] .== sol[S2])

# test save_positions is working

jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false))
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng)
sol = solve(jprob, SSAStepper(), saveat = 1.0)
@test all((sol.t) .== collect(0.0:tspan[2]))

Expand Down Expand Up @@ -120,7 +122,7 @@ function a2!(integrator)
end
j2 = ConstantRateJump(r2, a2!)
jset = JumpSet((), (j1, j2), nothing, nothing)
jprob = JumpProblem(prob, Direct(), jset, save_positions = (false, false))
jprob = JumpProblem(prob, Direct(), jset, save_positions = (false, false), rng = rng)
m2 = getmean(jprob, Nsims)

# test JumpSystem solution agrees with direct version
Expand All @@ -131,16 +133,16 @@ maj1 = MassActionJump(2 * β / 2, [S => 1, I => 1], [S => -1, I => 1])
maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1])
@named js3 = JumpSystem([maj1, maj2], t, [S, I, R], [β, γ])
dprob = DiscreteProblem(js3, u₀map, tspan, parammap)
jprob = JumpProblem(js3, dprob, Direct())
jprob = JumpProblem(js3, dprob, Direct(), rng = rng)
m3 = getmean(jprob, Nsims)
@test abs(m - m3) / m < 0.01

# maj jump test with various dep graphs
@named js3b = JumpSystem([maj1, maj2], t, [S, I, R], [β, γ])
jprobb = JumpProblem(js3b, dprob, NRM())
jprobb = JumpProblem(js3b, dprob, NRM(), rng = rng)
m4 = getmean(jprobb, Nsims)
@test abs(m - m4) / m < 0.01
jprobc = JumpProblem(js3b, dprob, RSSA())
jprobc = JumpProblem(js3b, dprob, RSSA(), rng = rng)
m4 = getmean(jprobc, Nsims)
@test abs(m - m4) / m < 0.01

Expand All @@ -149,7 +151,7 @@ maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
maj2 = MassActionJump(γ, [S => 1], [S => -1])
@named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ])
dprob = DiscreteProblem(js4, [S => 999], (0, 1000.0), [β => 100.0, γ => 0.01])
jprob = JumpProblem(js4, dprob, Direct())
jprob = JumpProblem(js4, dprob, Direct(), rng = rng)
m4 = getmean(jprob, Nsims)
@test abs(m4 - 2.0 / 0.01) * 0.01 / 2.0 < 0.01

Expand All @@ -158,7 +160,7 @@ maj1 = MassActionJump(2.0, [0 => 1], [S => 1])
maj2 = MassActionJump(γ, [S => 2], [S => -1])
@named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ])
dprob = DiscreteProblem(js4, [S => 999], (0, 1000.0), [β => 100.0, γ => 0.01])
jprob = JumpProblem(js4, dprob, Direct())
jprob = JumpProblem(js4, dprob, Direct(), rng = rng)
sol = solve(jprob, SSAStepper());

# issue #819
Expand All @@ -179,7 +181,7 @@ p = [k1 => 2.0, k2 => 0.0, k3 => 0.5]
u₀ = [A => 100, B => 0]
tspan = (0.0, 2000.0)
dprob = DiscreteProblem(js5, u₀, tspan, p)
jprob = JumpProblem(js5, dprob, Direct(), save_positions = (false, false))
jprob = JumpProblem(js5, dprob, Direct(), save_positions = (false, false), rng = rng)
@test all(jprob.massaction_jump.scaled_rates .== [1.0, 0.0])

pcondit(u, t, integrator) = t == 1000.0
Expand Down
32 changes: 31 additions & 1 deletion test/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,13 @@ l15 0" stroke="black" stroke-width="1" stroke-linejoin="bevel" fill="none"></pat
end

@mtkmodel Capacitor begin
@extend v, i = oneport = OnePort()
@parameters begin
C
end
@variables begin
v = 0.0
end
@extend v, i = oneport = OnePort(; v = v)
@icon "https://upload.wikimedia.org/wikipedia/commons/7/78/Capacitor_symbol.svg"
@equations begin
D(v) ~ i / C
Expand Down Expand Up @@ -182,3 +185,30 @@ model = complete(model)
@test getdefault(model.i) == 4
@test isequal(getdefault(model.j), model.jval)
@test isequal(getdefault(model.k), model.kval)

@mtkmodel A begin
@parameters begin
p
end
@components begin
b = B(i = p, j = 1 / p, k = 1)
end
end

@mtkmodel B begin
@parameters begin
i
j
k
end
end

@named a = A(p = 10)
getdefault(a.b.i) == 10
getdefault(a.b.j) == 0.1
getdefault(a.b.k) == 1

@named a = A(p = 10, b.i = 20, b.j = 30, b.k = 40)
getdefault(a.b.i) == 20
getdefault(a.b.j) == 30
getdefault(a.b.k) == 40

0 comments on commit c6af68d

Please sign in to comment.