Skip to content

Commit

Permalink
enable rebuildig oproblems/integrators
Browse files Browse the repository at this point in the history
  • Loading branch information
TorkelE committed Jul 8, 2024
1 parent 4731155 commit 5b6d6ec
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 41 deletions.
1 change: 1 addition & 0 deletions src/Catalyst.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ export make_edge_p_values, make_directed_edge_values

# Specific spatial problem types.
include("spatial_reaction_systems/spatial_ODE_systems.jl")
export rebuild_lat_internals!
include("spatial_reaction_systems/lattice_jump_systems.jl")

# General spatial modelling utility functions.
Expand Down
172 changes: 132 additions & 40 deletions src/spatial_reaction_systems/spatial_ODE_systems.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
### Spatial ODE Functor Structure ###

# Functor with information about a spatial Lattice Reaction ODE;s forcing and Jacobian functions.
# Functor with information about a spatial Lattice Reaction ODEs forcing and Jacobian functions.
# Also used as ODE Function input to corresponding `ODEProblem`.
struct LatticeTransportODEFunction{P,Q,R,S,T}
"""
Expand Down Expand Up @@ -59,40 +59,60 @@ struct LatticeTransportODEFunction{P,Q,R,S,T}
used).
"""
jac_transport::T
""" Whether sparse jacobian representation is used. """
sparse::Bool
"""Remove when we add this as problem metadata"""
lrs::LatticeReactionSystem

function LatticeTransportODEFunction(ofunc::P, ps::Vector{<:Pair},
lrs::LatticeReactionSystem, transport_rates::Vector{Pair{Int64, SparseMatrixCSC{S, Int64}}},
jac_transport::Union{Nothing, Matrix{S}, SparseMatrixCSC{S, Int64}}) where {P,S}

# Creates a vector with the heterogeneous vertex parameters' indexes in the full parameter vector.
p_dict = Dict(ps)
heterogeneous_vert_p_idxs = findall((p_dict[p] isa Vector) && (length(p_dict[p]) > 1)
for p in parameters(lrs))

# Creates the MTKParameters structure and `p_setters` vector (which are used to manage
# the vertex parameter values during the simulations).
nonspatial_osys = complete(convert(ODESystem, reactionsystem(lrs)))
p_init = [p => p_dict[p][1] for p in parameters(nonspatial_osys)]
mtk_ps = MT.MTKParameters(nonspatial_osys, p_init)
p_setters = [MT.setp(nonspatial_osys, p) for p in parameters(lrs)[heterogeneous_vert_p_idxs]]

# Computes the transport rate type vector and leaving rate matrix.
t_rate_idx_types = [size(tr[2]) == (1,1) for tr in transport_rates]
leaving_rates = zeros(length(transport_rates), num_verts(lrs))
for (s_idx, tr_pair) in enumerate(transport_rates)
for e in Catalyst.edge_iterator(lrs)
# Updates the exit rate for species s_idx from vertex e.src.
leaving_rates[s_idx, e[1]] += get_transport_rate(tr_pair[2], e, t_rate_idx_types[s_idx])
end
end
jac_transport::Union{Nothing, Matrix{S}, SparseMatrixCSC{S, Int64}}, sparse) where {P,S}
# Computes `LatticeTransportODEFunction` functor fields.
heterogeneous_vert_p_idxs = make_heterogeneous_vert_p_idxs(ps, lrs)
mtk_ps, p_setters = make_mtk_ps_structs(ps, lrs, heterogeneous_vert_p_idxs)
t_rate_idx_types, leaving_rates = make_t_types_and_leaving_rates(transport_rates, lrs)

# Creates and returns the `LatticeTransportODEFunction` functor.
new{P,typeof(mtk_ps),typeof(p_setters),S,typeof(jac_transport)}(ofunc, num_verts(lrs),
num_species(lrs), heterogeneous_vert_p_idxs, mtk_ps, p_setters, transport_rates,
t_rate_idx_types, leaving_rates, Catalyst.edge_iterator(lrs), jac_transport)
t_rate_idx_types, leaving_rates, Catalyst.edge_iterator(lrs), jac_transport, sparse, lrs)
end
end

# `LatticeTransportODEFunction` helper functions (re used by rebuild function later on).

# Creates a vector with the heterogeneous vertex parameters' indexes in the full parameter vector.
function make_heterogeneous_vert_p_idxs(ps, lrs)
p_dict = Dict(ps)
return findall((p_dict[p] isa Vector) && (length(p_dict[p]) > 1) for p in parameters(lrs))
end

# Creates the MTKParameters structure and `p_setters` vector (which are used to manage
# the vertex parameter values during the simulations).
function make_mtk_ps_structs(ps, lrs, heterogeneous_vert_p_idxs)
p_dict = Dict(ps)
nonspatial_osys = complete(convert(ODESystem, reactionsystem(lrs)))
p_init = [p => p_dict[p][1] for p in parameters(nonspatial_osys)]
mtk_ps = MT.MTKParameters(nonspatial_osys, p_init)
p_setters = [MT.setp(nonspatial_osys, p) for p in parameters(lrs)[heterogeneous_vert_p_idxs]]
return mtk_ps, p_setters
end

# Computes the transport rate type vector and leaving rate matrix.
function make_t_types_and_leaving_rates(transport_rates, lrs)
t_rate_idx_types = [size(tr[2]) == (1,1) for tr in transport_rates]
leaving_rates = zeros(length(transport_rates), num_verts(lrs))
for (s_idx, tr_pair) in enumerate(transport_rates)
for e in Catalyst.edge_iterator(lrs)
# Updates the exit rate for species s_idx from vertex e.src.
leaving_rates[s_idx, e[1]] += get_transport_rate(tr_pair[2], e, t_rate_idx_types[s_idx])
end
end
return t_rate_idx_types, leaving_rates
end

### Spatial ODE Functor Functions ###

# Defines the functor's effect when applied as a forcing function.
function (lt_ofun::LatticeTransportODEFunction)(du::AbstractVector, u, p, t)
# Updates for non-spatial reactions.
Expand Down Expand Up @@ -198,7 +218,7 @@ function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Pair{R,Ve
transport_rates = make_sidxs_to_transrate_map(vert_ps, edge_ps, lrs)

# Depending on Jacobian and sparsity options, computes the Jacobian transport matrix and prototype.
if sparse && !jac
if !sparse && !jac
jac_transport = nothing
jac_prototype = nothing
else
Expand All @@ -209,7 +229,7 @@ function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Pair{R,Ve
end

# Creates the `LatticeTransportODEFunction` functor (if `jac`, sets it as the Jacobian as well).
f = LatticeTransportODEFunction(ofunc_dense, [vert_ps; edge_ps], lrs, transport_rates, jac_transport)
f = LatticeTransportODEFunction(ofunc_dense, [vert_ps; edge_ps], lrs, transport_rates, jac_transport, sparse)
J = (jac ? f : nothing)

# Extracts the `Symbol` form for species and parameters. Creates and returns the `ODEFunction`.
Expand Down Expand Up @@ -267,23 +287,95 @@ function build_jac_prototype(ns_jac_prototype::SparseMatrixCSC{Float64, Int64},
end
end

# Create a sparse Jacobian prototype with 0-valued entries.
# Create a sparse Jacobian prototype with 0-valued entries. If requested,
# updates values with non-zero entries.
jac_prototype = sparse(i_idxs, j_idxs, zeros(T, num_entries))
set_nonzero && set_jac_transport_values!(jac_prototype, transport_rates, lrs)

# Set element values.
if set_nonzero
for (s, rates) in transport_rates, e in edge_iterator(lrs)
idx_src = get_index(e[1], s, num_species(lrs))
idx_dst = get_index(e[2], s, num_species(lrs))
val = get_transport_rate(rates, e, size(rates)==(1,1))
return jac_prototype
end

# Term due to species leaving source vertex.
jac_prototype[idx_src, idx_src] -= val
# For a Jacobian prototype with zero-valued entries. Set entry values according to a set of
# transport reaction values.
function set_jac_transport_values!(jac_prototype, transport_rates, lrs)
for (s, rates) in transport_rates, e in edge_iterator(lrs)
idx_src = get_index(e[1], s, num_species(lrs))
idx_dst = get_index(e[2], s, num_species(lrs))
val = get_transport_rate(rates, e, size(rates)==(1,1))

# Term due to species arriving to destination vertex.
jac_prototype[idx_src, idx_dst] += val
end
# Term due to species leaving source vertex.
jac_prototype[idx_src, idx_src] -= val

# Term due to species arriving to destination vertex.
jac_prototype[idx_src, idx_dst] += val
end
end

return jac_prototype
### Functor Updating Functionality ###

# Function for rebuilding a `LatticeReactionSystem` `ODEProblem` after it has been updated.
function rebuild_lat_internals!(oprob::ODEProblem)
rebuild_lat_internals!(oprob.f.f, oprob.p, oprob.f.f.lrs)
end

# Function for rebuilding a `LatticeReactionSystem` integrator after it has been updated.
# We could specify `integrator`'s type, but that required adding OrdinaryDiffEq as a direct
# dependency of Catalyst.
function rebuild_lat_internals!(integrator)
rebuild_lat_internals!(integrator.f.f, integrator.p, integrator.f.f.lrs)
end

# Function which rebuilds a `LatticeTransportODEFunction` functor for a new parameter set.
function rebuild_lat_internals!(lt_ofun::LatticeTransportODEFunction, ps_new, lrs::LatticeReactionSystem)
# Computes Jacobian properties.
jac = !isnothing(lt_ofun.jac_transport)
sparse = lt_ofun.sparse

# Recreates the new parameters on the requisite form.
ps_new = [(length(p) == 1) ? p[1] : p for p in deepcopy(ps_new)]
ps_new = [p => p_val for (p, p_val) in zip(parameters(lrs), deepcopy(ps_new))]
vert_ps, edge_ps = lattice_process_p(ps_new, vertex_parameters(lrs), edge_parameters(lrs), lrs)
ps_new = [vert_ps; edge_ps]

# Creates the new transport rates and transport Jacobian part.
transport_rates = make_sidxs_to_transrate_map(vert_ps, edge_ps, lrs)
if !isnothing(lt_ofun.jac_transport)
lt_ofun.jac_transport .= 0.0
set_jac_transport_values!(lt_ofun.jac_transport, transport_rates, lrs)
end

# Computes new field values.
heterogeneous_vert_p_idxs = make_heterogeneous_vert_p_idxs(ps_new, lrs)
mtk_ps, p_setters = make_mtk_ps_structs(ps_new, lrs, heterogeneous_vert_p_idxs)
t_rate_idx_types, leaving_rates = make_t_types_and_leaving_rates(transport_rates, lrs)

# Updates functor fields.
replace_vec!(lt_ofun.heterogeneous_vert_p_idxs, heterogeneous_vert_p_idxs)
replace_vec!(lt_ofun.p_setters, p_setters)
replace_vec!(lt_ofun.transport_rates, transport_rates)
replace_vec!(lt_ofun.t_rate_idx_types, t_rate_idx_types)
lt_ofun.leaving_rates .= leaving_rates

# Updating the `MTKParameters` structure is a bit more complicated.
p_dict = Dict(ps_new)
osys = complete(convert(ODESystem, reactionsystem(lrs)))
for p in parameters(osys)
MT.setp(osys, p)(lt_ofun.mtk_ps, (p_dict[p] isa Number) ? p_dict[p] : p_dict[p][1])
end

return nothing
end

# Specialised function which replaced one vector in another in a mutating way.
# Required to update the vectors in the `LatticeTransportODEFunction` functor.
function replace_vec!(vec1, vec2)
l1 = length(vec1)
l2 = length(vec2)

# Updates the fields, then deletes superfluous fields, or additional ones.
for (i, v) in enumerate(vec2[1:min(l1, l2)])
vec1[i] = v
end
foreach(idx -> deleteat!(vec1, idx), l1:-1:(l2 + 1))
foreach(val -> push!(vec1, val), vec2[l1+1:l2])
end
5 changes: 5 additions & 0 deletions src/spatial_reaction_systems/spatial_reactions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ function make_transport_reaction(rateex, species)
iv = :(@variables $(DEFAULT_IV_SYM))
trxexpr = :(TransportReaction($rateex, $species))

# Appends `edgeparameter` metadata to all declared parameters.
for idx = 4:2:(2 + 2*length(parameters))
insert!(pexprs.args, idx, :([edgeparameter=true]))
end

quote
$pexprs
$iv
Expand Down
118 changes: 118 additions & 0 deletions test/spatial_modelling/lattice_reaction_systems_ODEs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,124 @@ let
@test all(isequal.(ss_1, ss_2))
end

### ODEProblem & Integrator Interfacing ###

# Checks that basic interfacing with ODEProblem parameters (getting and setting) works.
let
# Creates an initial `ODEProblem`.
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_1, small_1d_cartesian_grid)
u0 = [:X => 1.0, :Y => 2.0]
ps = [:A => 1.0, :B => [1.0, 2.0, 3.0, 4.0, 5.0], :dX => 0.1]
oprob = ODEProblem(lrs, u0, (0.0, 10.0), ps)

# Checks that retrieved parameters are correct.
@test oprob.ps[:A] == [1.0]
@test oprob.ps[:B] == [1.0, 2.0, 3.0, 4.0, 5.0]
@test oprob.ps[:dX] == sparse([1], [1], [0.1])

# Updates content.
oprob.ps[:A] = [10.0, 20.0, 30.0, 40.0, 50.0]
oprob.ps[:B] = [10.0]
oprob.ps[:dX] = [0.01]

# Checks that content is correct.
@test oprob.ps[:A] == [10.0, 20.0, 30.0, 40.0, 50.0]
@test oprob.ps[:B] == [10.0]
@test oprob.ps[:dX] == [0.01]
end

# Checks that the `rebuild_lat_internals!` function is correctly applied to an ODEProblem.
let
# Creates a brusselator `LatticeReactionSystem`.
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_2, very_small_2d_cartesian_grid)

# Checks for all combinations of Jacobian and sparsity.
for jac in [false, true], sparse in [false, true]
# Creates an initial ODEProblem.
u0 = [:X => 1.0, :Y => [1.0 2.0; 3.0 4.0]]
dY_vals = spzeros(4,4)
dY_vals[1,2] = 0.1; dY_vals[2,1] = 0.1;
dY_vals[1,3] = 0.2; dY_vals[3,1] = 0.2;
dY_vals[2,4] = 0.3; dY_vals[4,2] = 0.3;
dY_vals[3,4] = 0.4; dY_vals[4,3] = 0.4;
ps = [:A => 1.0, :B => [4.0 5.0; 6.0 7.0], :dX => 0.1, :dY => dY_vals]
oprob_1 = ODEProblem(lrs, u0, (0.0, 10.0), ps; jac, sparse)

# Creates an alternative version of the ODEProblem.
dX_vals = spzeros(4,4)
dX_vals[1,2] = 0.01; dX_vals[2,1] = 0.01;
dX_vals[1,3] = 0.02; dX_vals[3,1] = 0.02;
dX_vals[2,4] = 0.03; dX_vals[4,2] = 0.03;
dX_vals[3,4] = 0.04; dX_vals[4,3] = 0.04;
ps = [:A => [1.1 1.2; 1.3 1.4], :B => 5.0, :dX => dX_vals, :dY => 0.01]
oprob_2 = ODEProblem(lrs, u0, (0.0, 10.0), ps; jac, sparse)

# Modifies the initial ODEProblem to be identical to the new one.
oprob_1.ps[:A] = [1.1 1.2; 1.3 1.4]
oprob_1.ps[:B] = [5.0]
oprob_1.ps[:dX] = dX_vals
oprob_1.ps[:dY] = [0.01]
rebuild_lat_internals!(oprob_1)

# Checks that simulations of the two `ODEProblem`s are identical.
@test solve(oprob_1, Rodas5P()) solve(oprob_2, Rodas5P())
end
end

# Checks that the `rebuild_lat_internals!` function is correctly applied to an integrator.
# Does through by applying it within a callback, and compare to simulations without callback.
let
# Prepares problem inputs.
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_2, very_small_2d_cartesian_grid)
u0 = [:X => 1.0, :Y => [1.0 2.0; 3.0 4.0]]
A1 = 1.0
B1 = [4.0 5.0; 6.0 7.0]
A2 = [1.1 1.2; 1.3 1.4]
B2 = 5.0
dY_vals = spzeros(4,4)
dY_vals[1,2] = 0.1; dY_vals[2,1] = 0.1;
dY_vals[1,3] = 0.2; dY_vals[3,1] = 0.2;
dY_vals[2,4] = 0.3; dY_vals[4,2] = 0.3;
dY_vals[3,4] = 0.4; dY_vals[4,3] = 0.4;
dX_vals = spzeros(4,4)
dX_vals[1,2] = 0.01; dX_vals[2,1] = 0.01;
dX_vals[1,3] = 0.02; dX_vals[3,1] = 0.02;
dX_vals[2,4] = 0.03; dX_vals[4,2] = 0.03;
dX_vals[3,4] = 0.04; dX_vals[4,3] = 0.04;
dX1 = 0.1
dY1 = dY_vals
dX2 = dX_vals
dY2 = 0.01
ps_1 = [:A => A1, :B => B1, :dX => dX1, :dY => dY1]
ps_2 = [:A => A2, :B => B2, :dX => dX2, :dY => dY2]

# Checks for all combinations of Jacobian and sparsity.
for jac in [false, true], sparse in [false, true]
# Creates simulation through two different separate simulations.
oprob_1_1 = ODEProblem(lrs, u0, (0.0, 5.0), ps_1; jac, sparse)
sol_1_1 = solve(oprob_1_1, Rosenbrock23(); saveat = 1.0, abstol = 1e-8, reltol = 1e-8)
u0_1_2 = [:X => sol_1_1.u[end][1:2:end], :Y => sol_1_1.u[end][2:2:end]]
oprob_1_2 = ODEProblem(lrs, u0_1_2, (0.0, 5.0), ps_2; jac, sparse)
sol_1_2 = solve(oprob_1_2, Rosenbrock23(); saveat = 1.0, abstol = 1e-8, reltol = 1e-8)

# Creates simulation through a single simulation with a callback
oprob_2 = ODEProblem(lrs, u0, (0.0, 10.0), ps_1; jac, sparse)
condition(u, t, integrator) = (t == 5.0)
function affect!(integrator)
integrator.ps[:A] = A2
integrator.ps[:B] = [B2]
integrator.ps[:dX] = dX2
integrator.ps[:dY] = [dY2]
rebuild_lat_internals!(integrator)
end
callback = DiscreteCallback(condition, affect!)
sol_2 = solve(oprob_2, Rosenbrock23(); saveat = 1.0, tstops = [5.0], callback, abstol = 1e-8, reltol = 1e-8)

# Check that trajectories are equivalent.
@test [sol_1_1.u; sol_1_2.u] sol_2.u
end
end

### Tests Special Cases ###

# Create network using either graphs or di-graphs.
Expand Down
5 changes: 4 additions & 1 deletion test/spatial_modelling/lattice_reaction_systems_jumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,7 @@ let
@test abs(d) < reltol * non_spatial_mean[i]
end
end
end
end


### JumpProblem & Integrator Interfacing ###

0 comments on commit 5b6d6ec

Please sign in to comment.