Skip to content

Commit

Permalink
Merge pull request #2196 from ven-k/vkb/kw
Browse files Browse the repository at this point in the history
Pass arguments (including hierarchal ones) to `Model`s
  • Loading branch information
ChrisRackauckas authored Jun 26, 2023
2 parents 50cf8f9 + fb7d334 commit a864541
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 51 deletions.
2 changes: 2 additions & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ import SymbolicUtils.Code: toexpr
import SymbolicUtils.Rewriters: Chain, Postwalk, Prewalk, Fixpoint
import JuliaFormatter

using MLStyle

using Reexport
using Symbolics: degree
@reexport using Symbolics
Expand Down
28 changes: 28 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,32 @@ function split_assign(expr)
name, call = expr.args
end

varname_fix!(s) = return

function varname_fix!(expr::Expr)
for arg in expr.args
MLStyle.@match arg begin
::Symbol => continue
Expr(:kw, a) => varname_sanitization!(arg)
Expr(:parameters, a...) => begin
for _arg in arg.args
varname_sanitization!(_arg)
end
end
_ => @debug "skipping variable sanitization of $arg"
end
end
end

varname_sanitization!(a) = return

function varname_sanitization!(expr::Expr)
var_splits = split(string(expr.args[1]), ".")
if length(var_splits) > 1
expr.args[1] = Symbol(join(var_splits, "__"))
end
end

function _named(name, call, runtime = false)
has_kw = false
call isa Expr || throw(Meta.ParseError("The rhs must be an Expr. Got $call."))
Expand All @@ -948,6 +974,8 @@ function _named(name, call, runtime = false)
end
end

varname_fix!(call)

if !has_kw
param = Expr(:parameters)
if length(call.args) == 1
Expand Down
187 changes: 144 additions & 43 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
macro connector(name::Symbol, body)
esc(connector_macro(__module__, name, body))
end

struct Model{F, S}
f::F
structure::S
end
(m::Model)(args...; kw...) = m.f(args...; kw...)

using MLStyle
for f in (:connector, :model)
@eval begin
macro $f(name::Symbol, body)
esc($(Symbol(f, :_macro))(__module__, name, body))
end
end
end

@inline is_kwarg(::Symbol) = false
@inline is_kwarg(e::Expr) = (e.head == :parameters)

function connector_macro(mod, name, body)
function connector_macro(mod, name, body; arglist = Set([]), kwargs = Set([]))
if !Meta.isexpr(body, :block)
err = """
connector body must be a block! It should be in the form of
Expand All @@ -23,16 +28,18 @@ function connector_macro(mod, name, body)
"""
error(err)
end
vs = Num[]
vs = []
icon = Ref{Union{String, URI}}()
dict = Dict{Symbol, Any}()
dict[:kwargs] = Dict{Symbol, Any}()
expr = Expr(:block)
for arg in body.args
arg isa LineNumberNode && continue
if arg.head == :macrocall && arg.args[1] == Symbol("@icon")
parse_icon!(icon, dict, dict, arg.args[end])
continue
end
push!(vs, Num(parse_variable_def!(dict, mod, arg, :variables)))
parse_variable_arg!(expr, vs, dict, mod, arg, :variables, kwargs)
end
iv = get(dict, :independent_variable, nothing)
if iv === nothing
Expand All @@ -41,31 +48,50 @@ function connector_macro(mod, name, body)
gui_metadata = isassigned(icon) ? GUIMetadata(GlobalRef(mod, name), icon[]) :
nothing
quote
$name = $Model((; name) -> begin
var"#___sys___" = $ODESystem($(Equation[]), $iv, $vs, $([]);
$name = $Model(($(arglist...); name, $(kwargs...)) -> begin
$expr
var"#___sys___" = $ODESystem($(Equation[]), $iv, [$(vs...)], $([]);
name, gui_metadata = $gui_metadata)
$Setfield.@set!(var"#___sys___".connector_type=$connector_type(var"#___sys___"))
end, $dict)
end
end

function parse_variable_def!(dict, mod, arg, varclass)
function parse_variable_def!(dict, mod, arg, varclass, kwargs, def = nothing)
arg isa LineNumberNode && return
MLStyle.@match arg begin
::Symbol => generate_var!(dict, arg, varclass)
Expr(:call, a, b) => generate_var!(dict, a, b, varclass)
a::Symbol => begin
push!(kwargs, Expr(:kw, a, nothing))
var = generate_var!(dict, a, varclass)
dict[:kwargs][getname(var)] = def
(var, nothing)
end
Expr(:call, a, b) => begin
push!(kwargs, Expr(:kw, a, nothing))
var = generate_var!(dict, a, b, varclass)
dict[:kwargs][getname(var)] = def
(var, nothing)
end
Expr(:(=), a, b) => begin
var = parse_variable_def!(dict, mod, a, varclass)
def = parse_default(mod, b)
Base.remove_linenums!(b)
def, meta = parse_default(mod, b)
var, _ = parse_variable_def!(dict, mod, a, varclass, kwargs, def)
dict[varclass][getname(var)][:default] = def
setdefault(var, def)
if !isnothing(meta)
if (ct = get(meta, VariableConnectType, nothing)) !== nothing
dict[varclass][getname(var)][:connection_type] = nameof(ct)
end
var = set_var_metadata(var, meta)
end
(var, def)
end
Expr(:tuple, a, b) => begin
var = parse_variable_def!(dict, mod, a, varclass)
var, def = parse_variable_def!(dict, mod, a, varclass, kwargs)
meta = parse_metadata(mod, b)
if (ct = get(meta, VariableConnectType, nothing)) !== nothing
dict[varclass][getname(var)][:connection_type] = nameof(ct)
end
set_var_metadata(var, meta)
(set_var_metadata(var, meta), def)
end
_ => error("$arg cannot be parsed")
end
Expand All @@ -78,14 +104,17 @@ function generate_var(a, varclass)
end
var
end

function generate_var!(dict, a, varclass)
#var = generate_var(Symbol("#", a), varclass)
var = generate_var(a, varclass)
vd = get!(dict, varclass) do
Dict{Symbol, Dict{Symbol, Any}}()
end
vd[a] = Dict{Symbol, Any}()
var
end

function generate_var!(dict, a, b, varclass)
iv = generate_var(b, :variables)
prev_iv = get!(dict, :independent_variable) do
Expand All @@ -102,77 +131,101 @@ function generate_var!(dict, a, b, varclass)
end
var
end

function parse_default(mod, a)
a = Base.remove_linenums!(deepcopy(a))
MLStyle.@match a begin
Expr(:block, a) => get_var(mod, a)
::Symbol => get_var(mod, a)
::Number => a
Expr(:block, x) => parse_default(mod, x)
Expr(:tuple, x, y) => begin
def, _ = parse_default(mod, x)
meta = parse_metadata(mod, y)
(def, meta)
end
::Symbol || ::Number => (a, nothing)
Expr(:call, a...) => begin
def = parse_default.(Ref(mod), a)
expr = Expr(:call)
for (d, _) in def
push!(expr.args, d)
end
(expr, nothing)
end
_ => error("Cannot parse default $a")
end
end

function parse_metadata(mod, a)
MLStyle.@match a begin
Expr(:vect, eles...) => Dict(parse_metadata(mod, e) for e in eles)
Expr(:(=), a, b) => Symbolics.option_to_metadata_type(Val(a)) => get_var(mod, b)
_ => error("Cannot parse metadata $a")
end
end

function set_var_metadata(a, ms)
for (m, v) in ms
a = setmetadata(a, m, v)
end
a
end

function get_var(mod::Module, b)
b isa Symbol ? getproperty(mod, b) : b
end

macro model(name::Symbol, expr)
esc(model_macro(__module__, name, expr))
end

function model_macro(mod, name, expr)
function model_macro(mod, name, expr; arglist = Set([]), kwargs = Set([]))
exprs = Expr(:block)
dict = Dict{Symbol, Any}()
dict[:kwargs] = Dict{Symbol, Any}()
comps = Symbol[]
ext = Ref{Any}(nothing)
vs = Symbol[]
ps = Symbol[]
eqs = Expr[]
icon = Ref{Union{String, URI}}()
vs = []
ps = []

for arg in expr.args
arg isa LineNumberNode && continue
arg.head == :macrocall || error("$arg is not valid syntax. Expected a macro call.")
parse_model!(exprs.args, comps, ext, eqs, vs, ps, icon, dict, mod, arg)
if arg.head == :macrocall
parse_model!(exprs.args, comps, ext, eqs, icon, vs, ps,
dict, mod, arg, kwargs)
elseif arg.head == :block
push!(exprs.args, arg)
else
error("$arg is not valid syntax. Expected a macro call.")
end
end
iv = get(dict, :independent_variable, nothing)
if iv === nothing
iv = dict[:independent_variable] = variable(:t)
end

gui_metadata = isassigned(icon) > 0 ? GUIMetadata(GlobalRef(mod, name), icon[]) :
nothing

sys = :($ODESystem($Equation[$(eqs...)], $iv, [$(vs...)], [$(ps...)];
systems = [$(comps...)], name, gui_metadata = $gui_metadata))
systems = [$(comps...)], name, gui_metadata = $gui_metadata)) #, defaults = $defaults))
if ext[] === nothing
push!(exprs.args, sys)
else
push!(exprs.args, :($extend($sys, $(ext[]))))
end
:($name = $Model((; name) -> $exprs, $dict))

:($name = $Model(($(arglist...); name, $(kwargs...)) -> $exprs, $dict))
end

function parse_model!(exprs, comps, ext, eqs, vs, ps, icon, dict, mod, arg)
function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, dict,
mod, arg, kwargs)
mname = arg.args[1]
body = arg.args[end]
if mname == Symbol("@components")
parse_components!(exprs, comps, dict, body)
parse_components!(exprs, comps, dict, body, kwargs)
elseif mname == Symbol("@extend")
parse_extend!(exprs, ext, dict, body)
elseif mname == Symbol("@variables")
parse_variables!(exprs, vs, dict, mod, body, :variables)
parse_variables!(exprs, vs, dict, mod, body, :variables, kwargs)
elseif mname == Symbol("@parameters")
parse_variables!(exprs, ps, dict, mod, body, :parameters)
parse_variables!(exprs, ps, dict, mod, body, :parameters, kwargs)
elseif mname == Symbol("@equations")
parse_equations!(exprs, eqs, dict, body)
elseif mname == Symbol("@icon")
Expand All @@ -182,7 +235,7 @@ function parse_model!(exprs, comps, ext, eqs, vs, ps, icon, dict, mod, arg)
end
end

function parse_components!(exprs, cs, dict, body)
function parse_components!(exprs, cs, dict, body, kwargs)
expr = Expr(:block)
push!(exprs, expr)
comps = Vector{String}[]
Expand All @@ -194,6 +247,9 @@ function parse_components!(exprs, cs, dict, body)
push!(comps, [String(a), String(b.args[1])])
arg = deepcopy(arg)
b = deepcopy(arg.args[2])

component_args!(a, b, expr, kwargs)

push!(b.args, Expr(:kw, :name, Meta.quot(a)))
arg.args[2] = b
push!(expr.args, arg)
Expand All @@ -204,6 +260,46 @@ function parse_components!(exprs, cs, dict, body)
dict[:components] = comps
end

function _rename(compname, varname)
compname = Symbol(compname, :__, varname)
end

function component_args!(a, b, expr, kwargs)
# Whenever `b` is a function call, skip the first arg aka the function name.
# Whenver it is a kwargs list, include it.
start = b.head == :call ? 2 : 1
for i in start:lastindex(b.args)
arg = b.args[i]
arg isa LineNumberNode && continue
MLStyle.@match arg begin
::Symbol => begin
_v = _rename(a, arg)
push!(kwargs, _v)
b.args[i] = Expr(:kw, arg, _v)
end
Expr(:parameters, x...) => begin
component_args!(a, arg, expr, kwargs)
end
Expr(:kw, x) => begin
_v = _rename(a, x)
b.args[i] = Expr(:kw, x, _v)
push!(kwargs, _v)
end
Expr(:kw, x, y::Number) => begin
_v = _rename(a, x)
b.args[i] = Expr(:kw, x, _v)
push!(kwargs, Expr(:kw, _v, y))
end
Expr(:kw, x, y) => begin
_v = _rename(a, x)
push!(expr.args, :($y = $_v))
push!(kwargs, Expr(:kw, _v, y))
end
_ => error("Could not parse $arg of component $a")
end
end
end

function parse_extend!(exprs, ext, dict, body)
expr = Expr(:block)
push!(exprs, expr)
Expand Down Expand Up @@ -231,16 +327,21 @@ function parse_extend!(exprs, ext, dict, body)
end
end

function parse_variables!(exprs, vs, dict, mod, body, varclass)
function parse_variable_arg!(expr, vs, dict, mod, arg, varclass, kwargs)
vv, def = parse_variable_def!(dict, mod, arg, varclass, kwargs)
v = Num(vv)
name = getname(v)
push!(vs, name)
push!(expr.args,
:($name = $name === nothing ? $setdefault($vv, $def) : $setdefault($vv, $name)))
end

function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs)
expr = Expr(:block)
push!(exprs, expr)
for arg in body.args
arg isa LineNumberNode && continue
vv = parse_variable_def!(dict, mod, arg, varclass)
v = Num(vv)
name = getname(v)
push!(vs, name)
push!(expr.args, :($name = $v))
parse_variable_arg!(expr, vs, dict, mod, arg, varclass, kwargs)
end
end

Expand Down
Loading

0 comments on commit a864541

Please sign in to comment.