Skip to content

Commit

Permalink
Handle noise parameter default values
Browse files Browse the repository at this point in the history
  • Loading branch information
TorkelE committed Sep 18, 2023
1 parent 4e4c3b3 commit d132d25
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 14 deletions.
8 changes: 5 additions & 3 deletions src/networkapi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -581,22 +581,24 @@ function symmap_to_varmap(sys, symmap::Tuple)
if all(p -> p isa Pair{Symbol}, symmap)
return ((_symbol_to_var(sys, sym) => val for (sym, val) in symmap)...,)
else # if not all entries map a symbol to value pass through
return symmapAny
return symmap
end
end

function symmap_to_varmap(sys, symmap::AbstractArray{Pair{Any, T}}) where {T}
function symmap_to_varmap(sys, symmap::AbstractArray{Pair{T, V}}) where {T, V}
[_symbol_to_var(sys, sym) => val for (sym, val) in symmap]
end

function symmap_to_varmap(sys, symmap::Dict{Any, T}) where {T}
function symmap_to_varmap(sys, symmap::Dict{T, V}) where {T, V}
Dict(_symbol_to_var(sys, sym) => val for (sym, val) in symmap)
end

# don't permute any other types and let varmap_to_vars handle erroring.
# If all keys are `Num`, conversion not needed.
symmap_to_varmap(sys, symmap) = symmap
symmap_to_varmap(sys, symmap::AbstractArray{Pair{SymbolicUtils.BasicSymbolic{Real}, T}}) where {T} = symmap
symmap_to_varmap(sys, symmap::AbstractArray{Pair{Num, T}}) where {T} = symmap
symmap_to_varmap(sys, symmap::Dict{SymbolicUtils.BasicSymbolic{Real}, T}) where {T} = symmap
symmap_to_varmap(sys, symmap::Dict{Num, T}) where {T} = symmap

Check warning on line 602 in src/networkapi.jl

View check run for this annotation

Codecov / codecov/patch

src/networkapi.jl#L601-L602

Added lines #L601 - L602 were not covered by tests
#error("symmap_to_varmap requires a Dict, AbstractArray or Tuple to map Symbols to values.")

Expand Down
12 changes: 8 additions & 4 deletions src/reaction_network.jl
Original file line number Diff line number Diff line change
Expand Up @@ -545,10 +545,14 @@ function get_noise_scaling_pexpr(options)
haskey(options, :noise_scaling_parameters) || return []
ns_expr = options[:noise_scaling_parameters]
for idx = length(ns_expr.args):-1:3
if ns_expr.args[idx] isa Symbol
insert!(ns_expr.args, idx+1, :([noisescalingparameter=true]))
elseif (ns_expr.args[idx] isa Expr) && (ns_expr.args[idx].head == :ref)
insert!(ns_expr.args, idx+1, :([noisescalingparameter=true]))
if (ns_expr.args[idx] isa Symbol) || # Parameter on form η.
(ns_expr.args[idx] isa Expr) && (ns_expr.args[idx].head == :ref) || # Parameter on form η[1:3].
(ns_expr.args[idx] isa Expr) && (ns_expr.args[idx].head == :(=)) # Parameter on form η=0.1.
if idx < length(ns_expr.args) && (ns_expr.args[idx+1] isa Expr) && (ns_expr.args[idx+1].head == :vect)
push!(ns_expr.args[idx+1].args,:(noisescalingparameter=true))
else
insert!(ns_expr.args, idx+1, :([noisescalingparameter=true]))
end
end
end
return ns_expr.args[3:end]
Expand Down
17 changes: 10 additions & 7 deletions test/model_simulation/simulate_SDEs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,32 +167,35 @@ let
@test var(sol_2_1[1,:]) > var(sol_2_2[1,:]) > var(sol_2_3[1,:])
end

# Tests using default values for nosie scaling.
# Tests using default values for noise scaling.
let
noise_scaling_network = @reaction_network begin
@noise_scaling_parameters η=0.0
(k1, k2), X1 X2
end
u0 = [:X1 => 1100.0, :X2 => 2900.0]
p = [:k1 => 2.0, :k2 => 0.66]
u0 = [:X1 => 1100.0, :X2 => 3900.0]
p = [:k1 => 2.0, :k2 => 0.5, =>0.0]
ss = solve(SDEProblem(noise_scaling_network, u0, (0.0, 1000.0), p), ImplicitEM())[end]
@test ss [1000.0, 3000.0]
@test ss [1000.0, 4000.0]
end

# Complicated test with many combinations of options.
let
noise_scaling_network = @reaction_network begin
@parameters k1 par1 [description="Parameter par1"] par2 η1 [noisescalingparameter=true]
@noise_scaling_parameters η2=0.0 [description="Parameter η2"] η3=1.0 η4 [description="Parameter η4"]
@noise_scaling_parameters η2=0.0 [description="Parameter η2"] η3=1.0 η4
(p, d), 0 X1
(k1, k2), X1 X2
end
@unpack X1, η4 = noise_scaling_network
@unpack X1, η4, p = noise_scaling_network
u0 = [X1 => 500.0, :X2 => 500.0]
p = [p => 20.0, :d => 0.1, :η1 => 0.0, :η3 => 0.0, η4 => 0.0, :k1 => 2.0, :k2 => 2.0, :par1 => 1000.0, :par2 => 1000.0]

@test getdescription(parameters(noise_scaling_network)[2]) == "Parameter par1"
@test getdescription(parameters(noise_scaling_network)[8]) == "Parameter η2"

ss = solve(SDEProblem(noise_scaling_network, u0, (0.0, 1000.0), p), ImplicitEM())[end]
ss [200.0, 200.0]
@test ss [200.0, 200.0]
end

# Tests that nosie scaling wor
Expand Down
20 changes: 20 additions & 0 deletions test/model_simulation/u0_n_parameter_inputs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,23 @@ let
end
end
end

# Tests uding mix of symbols and symbolics in input.
let
test_network = @reaction_network begin
(p1, d1), 0 X1
(p2, d2), 0 X2
end
@unpack p1, d1, p2, d2, X1, X2 = test_network
u0_1 = [X1 => 0.7, X2 => 3.6]
u0_2 = [:X1 => 0.7, X2 => 3.6]
u0_3 = [:X1 => 0.7, :X2 => 3.6]
p_1 = [p1 => 1.2, d1 => 4.0, p2 => 2.5, d2 =>0.1]
p_2 = [:p1 => 1.2, d1 => 4.0, :p2 => 2.5, d2 =>0.1]
p_3 = [:p1 => 1.2, :d1 => 4.0, :p2 => 2.5, :d2 =>0.1]

ss_base = solve(ODEProblem(test_network, u0_1, (0.0, 10.0), p_1), Tsit5())[end]
for u0 in [u0_1, u0_2, u0_3], p in [p_1, p_2, p_3]
@test ss_base == solve(ODEProblem(test_network, u0, (0.0, 10.0), p), Tsit5())[end]
end
end

0 comments on commit d132d25

Please sign in to comment.