Skip to content

Commit

Permalink
improve ienterface for changing paraemters
Browse files Browse the repository at this point in the history
  • Loading branch information
TorkelE committed Jul 11, 2024
1 parent baa0f81 commit c626fcb
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 101 deletions.
2 changes: 1 addition & 1 deletion src/Catalyst.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ include("spatial_reaction_systems/utility.jl")

# Methods for interfacing with from LatticeReactionSystem derived problems, integrators, and solutions.
include("spatial_reaction_systems/lattice_sim_struct_interfacing.jl")
export lat_getu, lat_setu!, rebuild_lat_internals!
export lat_getp, lat_setp!, lat_getu, lat_setu!, rebuild_lat_internals!

### ReactionSystem Serialisation ###
# Has to be at the end (because it uses records of all metadata declared by Catalyst).
Expand Down
78 changes: 49 additions & 29 deletions src/spatial_reaction_systems/lattice_sim_struct_interfacing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,39 +39,46 @@ lat_setp!(oprob, :k2, lrs, [1.0 0.0 0.0; 0.0 0.0 0.0]) # Sets `k2` to `1.0` in o
```
"""
function lat_setp!(sim_struct, p, lrs::LatticeReactionSystem, p_vals)
# Checks that if u is non-uniform, it has the correct format for the system's lattice.
(u isa Number) || check_lattice_format(extract_lattice(lrs), p_vals)
# Error checks.
(p_vals isa Number) || check_lattice_format(extract_lattice(lrs), p_vals)
edge_param_check(p, lrs)

# Converts symbol parameter to symbolic and find correct species index and numbers.
(p isa Symbol) && (p = _symbol_to_var(lrs, p))
(p isa Num) && (p = Symbolics.unwrap(p))
p_idx, p_tot = get_p_idxs(sp, lrs)
p_idx, p_tot = get_p_idxs(p, lrs)

# Reshapes the values to a vector of the correct form, and calls lat_setu! on the input structure.
# Reshapes the values to a vector of the correct form, and calls lat_setp! on the input structure.
p_vals_reshaped = vertex_value_form(p_vals, lrs, p)
lat_setp!(sim_struct, p_idx, p_tot, p_vals_reshaped, num_verts(lrs))
lat_setp!(sim_struct, p_idx, p_vals_reshaped, num_verts(lrs))
end

function lat_setp!(oprob::ODEProblem, p_idx::Int64, p_tot::Int64, p_vals, num_verts)
# Note: currently, `lat_setp!(oprob::ODEProblem, ...`) and `lat_setp!(SciMLBase.AbstractODEIntegrator, ...`)
# are identical and could be merged to a singe function.
function lat_setp!(oprob::ODEProblem, p_idx::Int64, p_vals, num_verts)
if length(p_vals) == 1
foreach(idx -> (oprob.ps[p_idx + (idx - 1) * p_tot] = p_vals[1]), 1:num_verts)
else
foreach(idx -> (oprob.ps[p_idx + (idx - 1) * p_tot] = p_vals[idx]), 1:num_verts)
foreach(idx -> (oprob.p[p_idx][idx] = p_vals[1]), 1:num_verts)
elseif length(p_vals) == length(oprob.p[p_idx])
foreach(idx -> (oprob.p[p_idx][idx] = p_vals[idx]), 1:num_verts)
elseif length(oprob.p[p_idx]) == 1
oprob.p[p_idx][1] = p_vals[1]
foreach(idx -> (push!(oprob.p[p_idx], p_vals[idx])), 2:num_verts)
end
end
function lat_setp!(jprob::JumpProblem, p_idx::Int64, p_tot::Int64, p_vals, num_verts)
function lat_setp!(jprob::JumpProblem, p_idx::Int64, p_vals, num_verts)
error("The `lat_setp!` function is currently not supported for `JumpProblem`s.")
end
function lat_setp!(oint::SciMLBase.AbstractODEIntegrator, p_idx::Int64, p_tot::Int64,
p_vals, num_verts)
function lat_setp!(oint::SciMLBase.AbstractODEIntegrator, p_idx::Int64, p_vals, num_verts)
if length(p_vals) == 1
foreach(idx -> (oint.ps[p_idx + (idx - 1) * p_tot] = p_vals[1]), 1:num_verts)
else
foreach(idx -> (oint.ps[p_idx + (idx - 1) * p_tot] = p_vals[idx]), 1:num_verts)
foreach(idx -> (oint.p[p_idx][idx] = p_vals[1]), 1:num_verts)
elseif length(p_vals) == length(oint.p[p_idx])
foreach(idx -> (oint.p[p_idx][idx] = p_vals[idx]), 1:num_verts)
elseif length(oint.p[p_idx]) == 1
oint.p[p_idx][1] = p_vals[1]
foreach(idx -> (push!(oint.p[p_idx], p_vals[idx])), 2:num_verts)
end
end
function lat_setp!(jint::JumpProcesses.SSAIntegrator, p_idx::Int64, p_tot::Int64,
p_vals, num_verts)
function lat_setp!(jint::JumpProcesses.SSAIntegrator, p_idx::Int64, p_vals, num_verts)
error("The `lat_setp!` function is currently not supported for jump simulation integrators.")
end

Expand Down Expand Up @@ -117,22 +124,27 @@ lat_getp(oprob, :k1, lrs) # Retrieves the value of `k1`.
```
"""
function lat_getp(sim_struct, p, lrs::LatticeReactionSystem)
p_idx, p_tot = get_p_idxs(p, lrs)
lat_getp(sim_struct, p_idx, p_tot, extract_lattice(lrs))
edge_param_check(p, lrs)
p_idx, _ = get_p_idxs(p, lrs)
lat_getp(sim_struct, p_idx, extract_lattice(lrs), num_verts(lrs))
end

# Retrieves the lattice values for problem or integrator structures.
function lat_getp(oprob::ODEProblem, p_idx, p_tot, lattice)
return reshape_vals(oprob.u0[p_idx:p_tot:end], lattice)
function lat_getp(oprob::ODEProblem, p_idx::Int64, lattice, num_verts)
vals = oprob.p[p_idx]
(length(vals) == 1) && (vals = fill(vals[1], num_verts))
return reshape_vals(vals, lattice)
end
function lat_getp(jprob::JumpProblem, p_idx, p_tot, lattice)
return reshape_vals(jprob.prob.u0[p_idx, :], lattice)
function lat_getp(jprob::JumpProblem, p_idx::Int64, lattice, num_verts)
error("The `lat_getp` function is currently not supported for `JumpProblem`s.")
end
function lat_getp(oint::SciMLBase.AbstractODEIntegrator, p_idx, p_tot, lattice)
return reshape_vals(oint.u[p_idx:p_tot:end], lattice)
function lat_getp(oint::SciMLBase.AbstractODEIntegrator, p_idx::Int64, lattice, num_verts)
vals = oint.p[p_idx]
(length(vals) == 1) && (vals = fill(vals[1], num_verts))
return reshape_vals(vals, lattice)
end
function lat_getp(jint::JumpProcesses.SSAIntegrator, p_idx, p_tot, lattice)
return reshape_vals(jint.u[p_idx, :], lattice)
function lat_getp(jint::JumpProcesses.SSAIntegrator, p_idx::Int64, lattice, num_verts)
error("The `lat_getp` function is currently not supported for jump simulation integrators.")
end

"""
Expand Down Expand Up @@ -533,8 +545,8 @@ end
# Get a parameter index and the total number of parameters. Also handles different symbolic forms.
function get_p_idxs(p, lrs::LatticeReactionSystem)
(p isa Symbol) && (p = _symbol_to_var(lrs, p))
p_idx = findfirst(isequal(sp), parameters(lrs))
p_tot = length(speciparameterses(lrs))
p_idx = findfirst(isequal(p), parameters(lrs))
p_tot = length(parameters(lrs))
return p_idx, p_tot
end

Expand All @@ -556,4 +568,12 @@ function check_lattice_format(lattice::DiGraph, u)
error("The input u should be an AbstractVector. It is a $(typeof(u)).")
(length(u) == nv(lattice)) ||
error("The input u should have length $(nv(lattice)), but has length $(length(u)).")
end

# Throws an error when interfacing with an edge parameter.
function edge_param_check(p, lrs)
(p isa Symbol) && (p = _symbol_to_var(lrs, p))
if isedgeparameter(p)
throw(ArgumentError("The `lat_getp` and `lat_setp!` functions currently does not support edge parameter updating. If you require this functionality, please raise an issue on the Catalyst GitHub page and we can add this feature."))
end
end
152 changes: 81 additions & 71 deletions test/spatial_modelling/lattice_simulation_struct_interfacing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,68 +6,7 @@ using Catalyst, Graphs, JumpProcesses, OrdinaryDiffEq, SparseArrays, Test
# Fetch test networks.
include("../spatial_test_networks.jl")

### Basic Interfacing ###

# Checks that basic interfacing with ODEProblem parameters (getting and setting) works.
# Checks that basic interfacing with ODE integrators parameters (getting and setting) works.
let
# Creates an initial `ODEProblem` and integrator.
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, small_1d_cartesian_grid)
u0 = [:X => 1.0, :Y => 2.0]
ps = [:A => 1.0, :B => [1.0, 2.0, 3.0, 4.0, 5.0], :dX => 0.1]
oprob = ODEProblem(lrs, u0, (0.0, 10.0), ps)
oint = init(oprob, Tsit5())

# Checks that retrieved parameters are correct.
@test oprob.ps[:A] == [1.0]
@test oprob.ps[:B] == [1.0, 2.0, 3.0, 4.0, 5.0]
@test oprob.ps[:dX] == sparse([1], [1], [0.1])

# Updates content.
oprob.ps[:A] = [10.0, 20.0, 30.0, 40.0, 50.0]
oprob.ps[:B] = [10.0]
oprob.ps[:dX] = [0.01]

# Checks that content is correct.
@test oprob.ps[:A] == [10.0, 20.0, 30.0, 40.0, 50.0]
@test oprob.ps[:B] == [10.0]
@test oprob.ps[:dX] == [0.01]

# Checks that the integrator have the updated `ODEProblem` parameter (not sure if this is really desired though).
@test oint.ps[:A] == [10.0, 20.0, 30.0, 40.0, 50.0]
@test oint.ps[:B] == [10.0]
@test oint.ps[:dX] == [0.01]

# Updates content.
oint.ps[:A] = [5.0]
oint.ps[:B] = [0.5]
oint.ps[:dX] = fill(0.2, 5, 5)

# Checks that content is correct.
@test oint.ps[:A] == [5.0]
@test oint.ps[:B] == [0.5]
@test oint.ps[:dX] == fill(0.2, 5, 5)
end

# Checks normal interfacing for jump simulation structures. Currently does not work, and implementing
# this might be a major endeavour. This test primarily keeps track of it not working.
let
# Creates an initial `ODEProblem` and integrator.
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, small_1d_cartesian_grid)
u0 = [:X => 1, :Y => 2]
ps = [:A => 1.0, :B => [1.0, 2.0, 3.0, 4.0, 5.0], :dX => 0.1]
dprob = DiscreteProblem(lrs, u0, (0.0, 10.0), ps)
jprob = JumpProblem(lrs, dprob, NSM())
jint = init(jprob, SSAStepper())

# Make basic checks (this features does not currently work).
@test_broken jprob.ps[:A]
@test_broken jprob.ps[:A] = [1.0]
@test_broken jint.ps[:A]
@test_broken jint.ps[:A] = [1.0]
end

### Problem & Integrator `lat_getu` & `lat_setu!` Tests ###
### Problem & Integrator Interfacing Function Tests ###

# Checks `lat_getu` for ODE and Jump problem and integrators.
# Checks `lat_setu!` for ODE and Jump problem and integrators.
Expand All @@ -85,8 +24,8 @@ let
# Unpacks the `X` and `Y` symbolic variable (so that indexing using it can be tested).
@unpack X, Y = brusselator_system

# Loops through all alternative lattices and `X0`. Checks that `lat_getu` works in all cases.
for (lattice, val0) in zip([lattice_cartesian, lattice_masked, lattice_graph],[val0_cartesian, val0_masked, val0_graph])
# Loops through all alternative lattices and values. Checks that `lat_getu` works in all cases.
for (lattice, val0) in zip([lattice_cartesian, lattice_masked, lattice_graph], [val0_cartesian, val0_masked, val0_graph])
# Prepares various problems and integrators. Uses `deepcopy` to ensure there is no cross-talk
# between the different u vectors as they get updated.
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, lattice)
Expand All @@ -108,11 +47,11 @@ let
lat_setu!(oprob, :Y, lrs, val0)
@test lat_getu(oprob, :Y, lrs) == lat_getu(oprob, Y, lrs) == lat_getu(oprob, brusselator_system.Y, lrs) == val0
lat_setu!(oint, :Y, lrs, val0)
@test lat_getu(oint, :X, lrs) == lat_getu(oint, X, lrs) == lat_getu(oint, brusselator_system.X, lrs) == val0
@test lat_getu(oint, :Y, lrs) == lat_getu(oint, Y, lrs) == lat_getu(oint, brusselator_system.Y, lrs) == val0
lat_setu!(jprob, :Y, lrs, val0)
@test lat_getu(jprob, :X, lrs) == lat_getu(jprob, X, lrs) == lat_getu(jprob, brusselator_system.X, lrs) == val0
@test lat_getu(jprob, :Y, lrs) == lat_getu(jprob, Y, lrs) == lat_getu(jprob, brusselator_system.Y, lrs) == val0
lat_setu!(jint, :Y, lrs, val0)
@test lat_getu(jint, :X, lrs) == lat_getu(jint, X, lrs) == lat_getu(jint, brusselator_system.X, lrs) == val0
@test lat_getu(jint, :Y, lrs) == lat_getu(jint, Y, lrs) == lat_getu(jint, brusselator_system.Y, lrs) == val0

# Tries where we change a spatially non-uniform variable to spatially uniform.
lat_setu!(oprob, X, lrs, 0.0)
Expand All @@ -126,6 +65,77 @@ let
end
end

# Checks `lat_getp` for ODEproblem and integrators.
# Checks `lat_setp!` for ODE problem and integrators.
# Checks for all types of lattices.
# Checks for symbol and symbolic variables input.
let
# Declares various types of lattices and corresponding initial values of `A`.
lattice_cartesian = CartesianGrid((2,2,2))
lattice_masked = [true true; false true]
lattice_graph = cycle_graph(5)
val0_cartesian = fill(1.0, 2, 2, 2)
val0_masked = sparse([1.0 2.0; 0.0 3.0])
val0_graph = [1.0, 2.0, 3.0, 4.0, 5.0]

# Unpacks the `A` and `B` symbolic variable (so that indexing using it can be tested).
@unpack A, B = brusselator_system

# Loops through all alternative lattices and values. Checks that `lat_getp` works in all cases.
for (lattice, val0) in zip([lattice_cartesian, lattice_masked, lattice_graph], [val0_cartesian, val0_masked, val0_graph])
# Prepares various problems and integrators. Uses `deepcopy` to ensure there is no cross-talk
# between the different p vectors as they get updated.
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, lattice)
u0 = [:X => 1.0, :Y => 0.5]
ps = [:A => val0, :B => 2.0, :dX => 0.1]
oprob = ODEProblem(lrs, u0, (0.0, 1.0), deepcopy(ps))
oint = init(deepcopy(oprob), Tsit5())

# Check that `lat_getp` retrieves the correct values.
@test lat_getp(oprob, :A, lrs) == lat_getp(oprob, A, lrs) == lat_getp(oprob, brusselator_system.A, lrs) == val0
@test lat_getp(oint, :A, lrs) == lat_getp(oint, A, lrs) == lat_getp(oint, brusselator_system.A, lrs) == val0

# Updates Y and checks its content.
lat_setp!(oprob, :B, lrs, val0)
@test lat_getp(oprob, :B, lrs) == lat_getp(oprob, B, lrs) == lat_getp(oprob, brusselator_system.B, lrs) == val0
lat_setp!(oint, :B, lrs, val0)
@test lat_getp(oint, :B, lrs) == lat_getp(oint, B, lrs) == lat_getp(oint, brusselator_system.B, lrs) == val0

# Tries where we change a spatially non-uniform variable to spatially uniform.
lat_setp!(oprob, A, lrs, 0.0)
@test all(isequal(0.0), lat_getp(oprob, A, lrs))
lat_setp!(oint, A, lrs, 0.0)
@test all(isequal(0.0), lat_getp(oint, A, lrs))
end
end

# Checks that `lat_getp` and `lat_setp!` generates errors when applied to `JumpProblem`s and their integrators.
let
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, very_small_1d_cartesian_grid)
u0 = [:X => 1, :Y => 0]
ps = [:A => 3.0, :B => 2.0, :dX => 0.1]
dprob = DiscreteProblem(lrs, u0, (0.0, 1.0), ps)
jprob = JumpProblem(lrs, dprob, NSM())
jint = init(jprob, SSAStepper())

@test_throws Exception lat_getp(jprob, :A, lrs)
@test_throws Exception lat_getp(jprob, :A, lrs, 0.0)
@test_throws Exception lat_setp!(jint, :A, lrs)
@test_throws Exception lat_setp!(jint, :A, lrs, 0.0)
end

# Checks that `lat_getp` and `lat_setp!` generates errors when applied to edge parameters.
let
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, very_small_1d_cartesian_grid)
u0 = [:X => 1.0, :Y => 0.0]
ps = [:A => 3.0, :B => 2.0, :dX => 0.1]
oprob = ODEProblem(lrs, u0, (0.0, 1.0), ps)
oint = init(deepcopy(oprob), Tsit5())

@test_throws ArgumentError lat_getp(oprob, :dX, lrs)
@test_throws ArgumentError lat_setp!(oprob, :dX, lrs, 0.0)
end

### Simulation `lat_getu` Tests ###

# Basic test. For simulations without change in system, check that the solution corresponds to known
Expand Down Expand Up @@ -343,8 +353,8 @@ let
oprob_2 = ODEProblem(lrs, u0, (0.0, 10.0), ps; jac, sparse)

# Modifies the initial ODEProblem to be identical to the new one.
oprob_1.ps[:A] = [1.1 1.2; 1.3 1.4]
oprob_1.ps[:B] = [5.0]
lat_setp!(oprob_1, :A, lrs, [1.1 1.2; 1.3 1.4])
lat_setp!(oprob_1, :B, lrs, 5.0)
oprob_1.ps[:dX] = dX_vals
oprob_1.ps[:dY] = [0.01]
rebuild_lat_internals!(oprob_1)
Expand Down Expand Up @@ -393,8 +403,8 @@ let
oprob_2 = ODEProblem(lrs, u0, (0.0, 10.0), ps_1; jac = true, sparse = true)
condition(u, t, integrator) = (t == 5.0)
function affect!(integrator)
integrator.ps[:A] = A2
integrator.ps[:B] = [B2]
lat_setp!(integrator, :A, lrs, A2)
lat_setp!(integrator, :B, lrs, B2)
integrator.ps[:dX] = dX2
integrator.ps[:dY] = [dY2]
rebuild_lat_internals!(integrator)
Expand Down

0 comments on commit c626fcb

Please sign in to comment.