diff --git a/src/Catalyst.jl b/src/Catalyst.jl index 37cf7fcdf4..95aa69ea03 100644 --- a/src/Catalyst.jl +++ b/src/Catalyst.jl @@ -189,7 +189,7 @@ include("spatial_reaction_systems/utility.jl") # Methods for interfacing with from LatticeReactionSystem derived problems, integrators, and solutions. include("spatial_reaction_systems/lattice_sim_struct_interfacing.jl") -export lat_getu, lat_setu!, rebuild_lat_internals! +export lat_getp, lat_setp!, lat_getu, lat_setu!, rebuild_lat_internals! ### ReactionSystem Serialisation ### # Has to be at the end (because it uses records of all metadata declared by Catalyst). diff --git a/src/spatial_reaction_systems/lattice_sim_struct_interfacing.jl b/src/spatial_reaction_systems/lattice_sim_struct_interfacing.jl index bfa9f6c599..7f3cac9b11 100644 --- a/src/spatial_reaction_systems/lattice_sim_struct_interfacing.jl +++ b/src/spatial_reaction_systems/lattice_sim_struct_interfacing.jl @@ -39,39 +39,46 @@ lat_setp!(oprob, :k2, lrs, [1.0 0.0 0.0; 0.0 0.0 0.0]) # Sets `k2` to `1.0` in o ``` """ function lat_setp!(sim_struct, p, lrs::LatticeReactionSystem, p_vals) - # Checks that if u is non-uniform, it has the correct format for the system's lattice. - (u isa Number) || check_lattice_format(extract_lattice(lrs), p_vals) + # Error checks. + (p_vals isa Number) || check_lattice_format(extract_lattice(lrs), p_vals) + edge_param_check(p, lrs) # Converts symbol parameter to symbolic and find correct species index and numbers. (p isa Symbol) && (p = _symbol_to_var(lrs, p)) (p isa Num) && (p = Symbolics.unwrap(p)) - p_idx, p_tot = get_p_idxs(sp, lrs) + p_idx, p_tot = get_p_idxs(p, lrs) - # Reshapes the values to a vector of the correct form, and calls lat_setu! on the input structure. + # Reshapes the values to a vector of the correct form, and calls lat_setp! on the input structure. p_vals_reshaped = vertex_value_form(p_vals, lrs, p) - lat_setp!(sim_struct, p_idx, p_tot, p_vals_reshaped, num_verts(lrs)) + lat_setp!(sim_struct, p_idx, p_vals_reshaped, num_verts(lrs)) end -function lat_setp!(oprob::ODEProblem, p_idx::Int64, p_tot::Int64, p_vals, num_verts) +# Note: currently, `lat_setp!(oprob::ODEProblem, ...`) and `lat_setp!(SciMLBase.AbstractODEIntegrator, ...`) +# are identical and could be merged to a singe function. +function lat_setp!(oprob::ODEProblem, p_idx::Int64, p_vals, num_verts) if length(p_vals) == 1 - foreach(idx -> (oprob.ps[p_idx + (idx - 1) * p_tot] = p_vals[1]), 1:num_verts) - else - foreach(idx -> (oprob.ps[p_idx + (idx - 1) * p_tot] = p_vals[idx]), 1:num_verts) + foreach(idx -> (oprob.p[p_idx][idx] = p_vals[1]), 1:num_verts) + elseif length(p_vals) == length(oprob.p[p_idx]) + foreach(idx -> (oprob.p[p_idx][idx] = p_vals[idx]), 1:num_verts) + elseif length(oprob.p[p_idx]) == 1 + oprob.p[p_idx][1] = p_vals[1] + foreach(idx -> (push!(oprob.p[p_idx], p_vals[idx])), 2:num_verts) end end -function lat_setp!(jprob::JumpProblem, p_idx::Int64, p_tot::Int64, p_vals, num_verts) +function lat_setp!(jprob::JumpProblem, p_idx::Int64, p_vals, num_verts) error("The `lat_setp!` function is currently not supported for `JumpProblem`s.") end -function lat_setp!(oint::SciMLBase.AbstractODEIntegrator, p_idx::Int64, p_tot::Int64, - p_vals, num_verts) +function lat_setp!(oint::SciMLBase.AbstractODEIntegrator, p_idx::Int64, p_vals, num_verts) if length(p_vals) == 1 - foreach(idx -> (oint.ps[p_idx + (idx - 1) * p_tot] = p_vals[1]), 1:num_verts) - else - foreach(idx -> (oint.ps[p_idx + (idx - 1) * p_tot] = p_vals[idx]), 1:num_verts) + foreach(idx -> (oint.p[p_idx][idx] = p_vals[1]), 1:num_verts) + elseif length(p_vals) == length(oint.p[p_idx]) + foreach(idx -> (oint.p[p_idx][idx] = p_vals[idx]), 1:num_verts) + elseif length(oint.p[p_idx]) == 1 + oint.p[p_idx][1] = p_vals[1] + foreach(idx -> (push!(oint.p[p_idx], p_vals[idx])), 2:num_verts) end end -function lat_setp!(jint::JumpProcesses.SSAIntegrator, p_idx::Int64, p_tot::Int64, - p_vals, num_verts) +function lat_setp!(jint::JumpProcesses.SSAIntegrator, p_idx::Int64, p_vals, num_verts) error("The `lat_setp!` function is currently not supported for jump simulation integrators.") end @@ -117,22 +124,27 @@ lat_getp(oprob, :k1, lrs) # Retrieves the value of `k1`. ``` """ function lat_getp(sim_struct, p, lrs::LatticeReactionSystem) - p_idx, p_tot = get_p_idxs(p, lrs) - lat_getp(sim_struct, p_idx, p_tot, extract_lattice(lrs)) + edge_param_check(p, lrs) + p_idx, _ = get_p_idxs(p, lrs) + lat_getp(sim_struct, p_idx, extract_lattice(lrs), num_verts(lrs)) end # Retrieves the lattice values for problem or integrator structures. -function lat_getp(oprob::ODEProblem, p_idx, p_tot, lattice) - return reshape_vals(oprob.u0[p_idx:p_tot:end], lattice) +function lat_getp(oprob::ODEProblem, p_idx::Int64, lattice, num_verts) + vals = oprob.p[p_idx] + (length(vals) == 1) && (vals = fill(vals[1], num_verts)) + return reshape_vals(vals, lattice) end -function lat_getp(jprob::JumpProblem, p_idx, p_tot, lattice) - return reshape_vals(jprob.prob.u0[p_idx, :], lattice) +function lat_getp(jprob::JumpProblem, p_idx::Int64, lattice, num_verts) + error("The `lat_getp` function is currently not supported for `JumpProblem`s.") end -function lat_getp(oint::SciMLBase.AbstractODEIntegrator, p_idx, p_tot, lattice) - return reshape_vals(oint.u[p_idx:p_tot:end], lattice) +function lat_getp(oint::SciMLBase.AbstractODEIntegrator, p_idx::Int64, lattice, num_verts) + vals = oint.p[p_idx] + (length(vals) == 1) && (vals = fill(vals[1], num_verts)) + return reshape_vals(vals, lattice) end -function lat_getp(jint::JumpProcesses.SSAIntegrator, p_idx, p_tot, lattice) - return reshape_vals(jint.u[p_idx, :], lattice) +function lat_getp(jint::JumpProcesses.SSAIntegrator, p_idx::Int64, lattice, num_verts) + error("The `lat_getp` function is currently not supported for jump simulation integrators.") end """ @@ -533,8 +545,8 @@ end # Get a parameter index and the total number of parameters. Also handles different symbolic forms. function get_p_idxs(p, lrs::LatticeReactionSystem) (p isa Symbol) && (p = _symbol_to_var(lrs, p)) - p_idx = findfirst(isequal(sp), parameters(lrs)) - p_tot = length(speciparameterses(lrs)) + p_idx = findfirst(isequal(p), parameters(lrs)) + p_tot = length(parameters(lrs)) return p_idx, p_tot end @@ -556,4 +568,12 @@ function check_lattice_format(lattice::DiGraph, u) error("The input u should be an AbstractVector. It is a $(typeof(u)).") (length(u) == nv(lattice)) || error("The input u should have length $(nv(lattice)), but has length $(length(u)).") +end + +# Throws an error when interfacing with an edge parameter. +function edge_param_check(p, lrs) + (p isa Symbol) && (p = _symbol_to_var(lrs, p)) + if isedgeparameter(p) + throw(ArgumentError("The `lat_getp` and `lat_setp!` functions currently does not support edge parameter updating. If you require this functionality, please raise an issue on the Catalyst GitHub page and we can add this feature.")) + end end \ No newline at end of file diff --git a/test/spatial_modelling/lattice_simulation_struct_interfacing.jl b/test/spatial_modelling/lattice_simulation_struct_interfacing.jl index 3169b1979b..ed4e97983d 100644 --- a/test/spatial_modelling/lattice_simulation_struct_interfacing.jl +++ b/test/spatial_modelling/lattice_simulation_struct_interfacing.jl @@ -6,68 +6,7 @@ using Catalyst, Graphs, JumpProcesses, OrdinaryDiffEq, SparseArrays, Test # Fetch test networks. include("../spatial_test_networks.jl") -### Basic Interfacing ### - -# Checks that basic interfacing with ODEProblem parameters (getting and setting) works. -# Checks that basic interfacing with ODE integrators parameters (getting and setting) works. -let - # Creates an initial `ODEProblem` and integrator. - lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, small_1d_cartesian_grid) - u0 = [:X => 1.0, :Y => 2.0] - ps = [:A => 1.0, :B => [1.0, 2.0, 3.0, 4.0, 5.0], :dX => 0.1] - oprob = ODEProblem(lrs, u0, (0.0, 10.0), ps) - oint = init(oprob, Tsit5()) - - # Checks that retrieved parameters are correct. - @test oprob.ps[:A] == [1.0] - @test oprob.ps[:B] == [1.0, 2.0, 3.0, 4.0, 5.0] - @test oprob.ps[:dX] == sparse([1], [1], [0.1]) - - # Updates content. - oprob.ps[:A] = [10.0, 20.0, 30.0, 40.0, 50.0] - oprob.ps[:B] = [10.0] - oprob.ps[:dX] = [0.01] - - # Checks that content is correct. - @test oprob.ps[:A] == [10.0, 20.0, 30.0, 40.0, 50.0] - @test oprob.ps[:B] == [10.0] - @test oprob.ps[:dX] == [0.01] - - # Checks that the integrator have the updated `ODEProblem` parameter (not sure if this is really desired though). - @test oint.ps[:A] == [10.0, 20.0, 30.0, 40.0, 50.0] - @test oint.ps[:B] == [10.0] - @test oint.ps[:dX] == [0.01] - - # Updates content. - oint.ps[:A] = [5.0] - oint.ps[:B] = [0.5] - oint.ps[:dX] = fill(0.2, 5, 5) - - # Checks that content is correct. - @test oint.ps[:A] == [5.0] - @test oint.ps[:B] == [0.5] - @test oint.ps[:dX] == fill(0.2, 5, 5) -end - -# Checks normal interfacing for jump simulation structures. Currently does not work, and implementing -# this might be a major endeavour. This test primarily keeps track of it not working. -let - # Creates an initial `ODEProblem` and integrator. - lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, small_1d_cartesian_grid) - u0 = [:X => 1, :Y => 2] - ps = [:A => 1.0, :B => [1.0, 2.0, 3.0, 4.0, 5.0], :dX => 0.1] - dprob = DiscreteProblem(lrs, u0, (0.0, 10.0), ps) - jprob = JumpProblem(lrs, dprob, NSM()) - jint = init(jprob, SSAStepper()) - - # Make basic checks (this features does not currently work). - @test_broken jprob.ps[:A] - @test_broken jprob.ps[:A] = [1.0] - @test_broken jint.ps[:A] - @test_broken jint.ps[:A] = [1.0] -end - -### Problem & Integrator `lat_getu` & `lat_setu!` Tests ### +### Problem & Integrator Interfacing Function Tests ### # Checks `lat_getu` for ODE and Jump problem and integrators. # Checks `lat_setu!` for ODE and Jump problem and integrators. @@ -85,8 +24,8 @@ let # Unpacks the `X` and `Y` symbolic variable (so that indexing using it can be tested). @unpack X, Y = brusselator_system - # Loops through all alternative lattices and `X0`. Checks that `lat_getu` works in all cases. - for (lattice, val0) in zip([lattice_cartesian, lattice_masked, lattice_graph],[val0_cartesian, val0_masked, val0_graph]) + # Loops through all alternative lattices and values. Checks that `lat_getu` works in all cases. + for (lattice, val0) in zip([lattice_cartesian, lattice_masked, lattice_graph], [val0_cartesian, val0_masked, val0_graph]) # Prepares various problems and integrators. Uses `deepcopy` to ensure there is no cross-talk # between the different u vectors as they get updated. lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, lattice) @@ -108,11 +47,11 @@ let lat_setu!(oprob, :Y, lrs, val0) @test lat_getu(oprob, :Y, lrs) == lat_getu(oprob, Y, lrs) == lat_getu(oprob, brusselator_system.Y, lrs) == val0 lat_setu!(oint, :Y, lrs, val0) - @test lat_getu(oint, :X, lrs) == lat_getu(oint, X, lrs) == lat_getu(oint, brusselator_system.X, lrs) == val0 + @test lat_getu(oint, :Y, lrs) == lat_getu(oint, Y, lrs) == lat_getu(oint, brusselator_system.Y, lrs) == val0 lat_setu!(jprob, :Y, lrs, val0) - @test lat_getu(jprob, :X, lrs) == lat_getu(jprob, X, lrs) == lat_getu(jprob, brusselator_system.X, lrs) == val0 + @test lat_getu(jprob, :Y, lrs) == lat_getu(jprob, Y, lrs) == lat_getu(jprob, brusselator_system.Y, lrs) == val0 lat_setu!(jint, :Y, lrs, val0) - @test lat_getu(jint, :X, lrs) == lat_getu(jint, X, lrs) == lat_getu(jint, brusselator_system.X, lrs) == val0 + @test lat_getu(jint, :Y, lrs) == lat_getu(jint, Y, lrs) == lat_getu(jint, brusselator_system.Y, lrs) == val0 # Tries where we change a spatially non-uniform variable to spatially uniform. lat_setu!(oprob, X, lrs, 0.0) @@ -126,6 +65,77 @@ let end end +# Checks `lat_getp` for ODEproblem and integrators. +# Checks `lat_setp!` for ODE problem and integrators. +# Checks for all types of lattices. +# Checks for symbol and symbolic variables input. +let + # Declares various types of lattices and corresponding initial values of `A`. + lattice_cartesian = CartesianGrid((2,2,2)) + lattice_masked = [true true; false true] + lattice_graph = cycle_graph(5) + val0_cartesian = fill(1.0, 2, 2, 2) + val0_masked = sparse([1.0 2.0; 0.0 3.0]) + val0_graph = [1.0, 2.0, 3.0, 4.0, 5.0] + + # Unpacks the `A` and `B` symbolic variable (so that indexing using it can be tested). + @unpack A, B = brusselator_system + + # Loops through all alternative lattices and values. Checks that `lat_getp` works in all cases. + for (lattice, val0) in zip([lattice_cartesian, lattice_masked, lattice_graph], [val0_cartesian, val0_masked, val0_graph]) + # Prepares various problems and integrators. Uses `deepcopy` to ensure there is no cross-talk + # between the different p vectors as they get updated. + lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, lattice) + u0 = [:X => 1.0, :Y => 0.5] + ps = [:A => val0, :B => 2.0, :dX => 0.1] + oprob = ODEProblem(lrs, u0, (0.0, 1.0), deepcopy(ps)) + oint = init(deepcopy(oprob), Tsit5()) + + # Check that `lat_getp` retrieves the correct values. + @test lat_getp(oprob, :A, lrs) == lat_getp(oprob, A, lrs) == lat_getp(oprob, brusselator_system.A, lrs) == val0 + @test lat_getp(oint, :A, lrs) == lat_getp(oint, A, lrs) == lat_getp(oint, brusselator_system.A, lrs) == val0 + + # Updates Y and checks its content. + lat_setp!(oprob, :B, lrs, val0) + @test lat_getp(oprob, :B, lrs) == lat_getp(oprob, B, lrs) == lat_getp(oprob, brusselator_system.B, lrs) == val0 + lat_setp!(oint, :B, lrs, val0) + @test lat_getp(oint, :B, lrs) == lat_getp(oint, B, lrs) == lat_getp(oint, brusselator_system.B, lrs) == val0 + + # Tries where we change a spatially non-uniform variable to spatially uniform. + lat_setp!(oprob, A, lrs, 0.0) + @test all(isequal(0.0), lat_getp(oprob, A, lrs)) + lat_setp!(oint, A, lrs, 0.0) + @test all(isequal(0.0), lat_getp(oint, A, lrs)) + end +end + +# Checks that `lat_getp` and `lat_setp!` generates errors when applied to `JumpProblem`s and their integrators. +let + lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, very_small_1d_cartesian_grid) + u0 = [:X => 1, :Y => 0] + ps = [:A => 3.0, :B => 2.0, :dX => 0.1] + dprob = DiscreteProblem(lrs, u0, (0.0, 1.0), ps) + jprob = JumpProblem(lrs, dprob, NSM()) + jint = init(jprob, SSAStepper()) + + @test_throws Exception lat_getp(jprob, :A, lrs) + @test_throws Exception lat_getp(jprob, :A, lrs, 0.0) + @test_throws Exception lat_setp!(jint, :A, lrs) + @test_throws Exception lat_setp!(jint, :A, lrs, 0.0) +end + +# Checks that `lat_getp` and `lat_setp!` generates errors when applied to edge parameters. +let + lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, very_small_1d_cartesian_grid) + u0 = [:X => 1.0, :Y => 0.0] + ps = [:A => 3.0, :B => 2.0, :dX => 0.1] + oprob = ODEProblem(lrs, u0, (0.0, 1.0), ps) + oint = init(deepcopy(oprob), Tsit5()) + + @test_throws ArgumentError lat_getp(oprob, :dX, lrs) + @test_throws ArgumentError lat_setp!(oprob, :dX, lrs, 0.0) +end + ### Simulation `lat_getu` Tests ### # Basic test. For simulations without change in system, check that the solution corresponds to known @@ -343,8 +353,8 @@ let oprob_2 = ODEProblem(lrs, u0, (0.0, 10.0), ps; jac, sparse) # Modifies the initial ODEProblem to be identical to the new one. - oprob_1.ps[:A] = [1.1 1.2; 1.3 1.4] - oprob_1.ps[:B] = [5.0] + lat_setp!(oprob_1, :A, lrs, [1.1 1.2; 1.3 1.4]) + lat_setp!(oprob_1, :B, lrs, 5.0) oprob_1.ps[:dX] = dX_vals oprob_1.ps[:dY] = [0.01] rebuild_lat_internals!(oprob_1) @@ -393,8 +403,8 @@ let oprob_2 = ODEProblem(lrs, u0, (0.0, 10.0), ps_1; jac = true, sparse = true) condition(u, t, integrator) = (t == 5.0) function affect!(integrator) - integrator.ps[:A] = A2 - integrator.ps[:B] = [B2] + lat_setp!(integrator, :A, lrs, A2) + lat_setp!(integrator, :B, lrs, B2) integrator.ps[:dX] = dX2 integrator.ps[:dY] = [dY2] rebuild_lat_internals!(integrator)