Skip to content

Commit

Permalink
format, improve writing
Browse files Browse the repository at this point in the history
  • Loading branch information
TorkelE committed Jul 9, 2024
1 parent 37a14d0 commit e19d67d
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 51 deletions.
7 changes: 3 additions & 4 deletions src/spatial_reaction_systems/lattice_jump_systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function JumpProcesses.JumpProblem(lrs::LatticeReactionSystem, dprob, aggregator

# Computes hopping constants and mass action jumps (requires some internal juggling).
# Currently, the resulting JumpProblem does not depend on parameters (no way to incorporate these).
# Hence the parameters of this one does not actually matter. If at some point JumpProcess can
# Hence the parameters of this one do not actually matter. If at some point JumpProcess can
# handle parameters this can be updated and improved.
# The non-spatial DiscreteProblem have a u0 matrix with entries for all combinations of species and vertexes.
hopping_constants = make_hopping_constants(dprob, lrs)
Expand All @@ -54,7 +54,7 @@ end
# Creates the hopping constants from a discrete problem and a lattice reaction system.
function make_hopping_constants(dprob::DiscreteProblem, lrs::LatticeReactionSystem)
# Creates the all_diff_rates vector, containing for each species, its transport rate across all edges.
# If transport rate is uniform for one species, the vector have a single element, else one for each edge.
# If the transport rate is uniform for one species, the vector has a single element, else one for each edge.
spatial_rates_dict = Dict(compute_all_transport_rates(Dict(dprob.p), lrs))
all_diff_rates = [haskey(spatial_rates_dict, s) ? spatial_rates_dict[s] : [0.0]
for s in species(lrs)]
Expand All @@ -77,7 +77,7 @@ function make_hopping_constants(dprob::DiscreteProblem, lrs::LatticeReactionSyst
end

# Creates a SpatialMassActionJump struct from a (spatial) DiscreteProblem and a LatticeReactionSystem.
# Could implementation a version which, if all reaction's rates are uniform, returns a MassActionJump.
# Could implement a version which, if all reactions' rates are uniform, returns a MassActionJump.
# Not sure if there is any form of performance improvement from that though. Likely not the case.
function make_spatial_majumps(dprob, lrs::LatticeReactionSystem)
# Creates a vector, storing which reactions have spatial components.
Expand Down Expand Up @@ -143,7 +143,6 @@ end
# return ModelingToolkit.assemble_maj(eqs.x[1], statetoid, majpmapper)
# end


### Problem & Integrator Rebuilding ###

# Currently not implemented.
Expand Down
28 changes: 14 additions & 14 deletions src/spatial_reaction_systems/lattice_solution_interfacing.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
### Rudimentary Interfacing Function ###
# A single function, `get_lrs_vals`, which contain all interfacing functionality. However,
# long-term it should be replaced with a sleeker interface. Ideally as MTK-wider support for
# lattice problems and solutions are introduced.
# A single function, `get_lrs_vals`, which contains all interfacing functionality. However,
# long-term it should be replaced with a sleeker interface. Ideally as MTK-wide support for
# lattice problems and solutions is introduced.

"""
get_lrs_vals(sol, sp, lrs::LatticeReactionSystem; t = nothing)
Expand All @@ -11,7 +11,7 @@ desired forms. Generally, for `LatticeReactionSystem`s, the values in `sol` is o
way which is not directly interpretable by the user. Furthermore, the normal Catalyst interface
for solutions (e.g. `sol[:X]`) does not work for these solutions. Hence this function is used instead.
The output is a vector, which in each position contain sp's value (either at a time step of time,
The output is a vector, which in each position contains sp's value (either at a time step of time,
depending on the input `t`). Its shape depends on the lattice (using a similar form as heterogeneous
initial conditions). I.e. for a NxM cartesian grid, the values are NxM matrices. For a masked grid,
the values are sparse matrices. For a graph lattice, the values are vectors (where the value in
Expand All @@ -22,8 +22,8 @@ Arguments:
- `sp`: The species which values we wish to retrieve. Can be either a symbol (e.g. `:X`) or a symbolic
variable (e.g. `X`).
- `lrs`: The `LatticeReactionSystem` which was simulated to generate the solution.
- `t = nothing`: If `nothing`, we simply returns the solution across all saved timesteps. If `t`
instead is a vector (or range of values), returns the solutions interpolated at these timepoints.
- `t = nothing`: If `nothing`, we simply return the solution across all saved time steps. If `t`
instead is a vector (or range of values), returns the solutions interpolated at these time points.
Notes:
- The `get_lrs_vals` is not optimised for performance. However, it should still be quite performant,
Expand All @@ -48,7 +48,7 @@ ps = [:k1 => 1, :k2 => 2.0, :D => 0.1]
oprob = ODEProblem(lrs1, u0, tspan, ps)
osol = solve(oprob1, Tsit5())
get_lrs_vals(osol, :X1, lrs) # Returns the value of X1 at each timestep.
get_lrs_vals(osol, :X1, lrs) # Returns the value of X1 at each time step.
get_lrs_vals(osol, :X1, lrs; t = 0.0:10.0) # Returns the value of X1 at times 0.0, 1.0, ..., 10.0
```
"""
Expand All @@ -61,7 +61,7 @@ function get_lrs_vals(sol, sp, lrs::LatticeReactionSystem; t = nothing)
# Extracts the lattice and calls the next function. Masked grids (Array of Bools) are converted
# to sparse array using the same template size as we wish to shape the data to.
lattice = Catalyst.lattice(lrs)
if has_masked_lattice(lrs)
if has_masked_lattice(lrs)
if grid_dims(lrs) == 3
error("The `get_lrs_vals` function is not defined for systems based on 3d sparse arrays. Please raise an issue at the Catalyst GitHub site if this is something which would be useful to you.")
end
Expand All @@ -79,15 +79,16 @@ function get_lrs_vals(sol, lattice, t::Nothing, sp_idx, sp_tot)
if sol.prob isa ODEProblem
return [reshape_vals(vals[sp_idx:sp_tot:end], lattice) for vals in sol.u]
elseif sol.prob isa DiscreteProblem
return [reshape_vals(vals[sp_idx,:], lattice) for vals in sol.u]
return [reshape_vals(vals[sp_idx, :], lattice) for vals in sol.u]
else
error("Unknown type of solution provided to `get_lrs_vals`. Only ODE or Jump solutions are supported.")
end
end

# Function which handles the input in the case where `t` is a range of values (i.e. return `sp`s
# value at all designated time points.
function get_lrs_vals(sol, lattice, t::AbstractVector{T}, sp_idx, sp_tot) where {T <: Number}
function get_lrs_vals(
sol, lattice, t::AbstractVector{T}, sp_idx, sp_tot) where {T <: Number}
if (minimum(t) < sol.t[1]) || (maximum(t) > sol.t[end])
error("The range of the t values provided for sampling, ($(minimum(t)),$(maximum(t))) is not fully within the range of the simulation time span ($(sol.t[1]),$(sol.t[end])).")
end
Expand All @@ -98,15 +99,15 @@ function get_lrs_vals(sol, lattice, t::AbstractVector{T}, sp_idx, sp_tot) where
if sol.prob isa ODEProblem
return [reshape_vals(sol(ti)[sp_idx:sp_tot:end], lattice) for ti in t]
elseif sol.prob isa DiscreteProblem
return [reshape_vals(sol(ti)[sp_idx,:], lattice) for ti in t]
return [reshape_vals(sol(ti)[sp_idx, :], lattice) for ti in t]
else
error("Unknown type of solution provided to `get_lrs_vals`. Only ODE or Jump solutions are supported.")
end
end

# Functions which in each sample point reshapes the vector of values to the correct form (depending
# Functions which in each sample point reshape the vector of values to the correct form (depending
# on the type of lattice used).
function reshape_vals(vals, lattice::CartesianGridRej{N, T}) where {N,T}
function reshape_vals(vals, lattice::CartesianGridRej{N, T}) where {N, T}
return reshape(vals, lattice.dims...)
end
function reshape_vals(vals, lattice::AbstractSparseArray{Bool, Int64, 1})
Expand All @@ -118,4 +119,3 @@ end
function reshape_vals(vals, lattice::DiGraph)
return vals
end

18 changes: 9 additions & 9 deletions src/spatial_reaction_systems/spatial_ODE_systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct LatticeTransportODEFunction{P, Q, R, S, T}
mtk_ps::Q
"""
Stores a SymbolicIndexingInterface `setp` function for each heterogeneous vertex parameter (i.e.
vertex parameter which value is not identical across the lattice). The `setp` function at index
vertex parameter whose value is not identical across the lattice). The `setp` function at index
i of `p_setters` corresponds to the parameter in index i of `heterogeneous_vert_p_idxs`.
"""
p_setters::R
Expand Down Expand Up @@ -82,7 +82,7 @@ struct LatticeTransportODEFunction{P, Q, R, S, T}
end
end

# `LatticeTransportODEFunction` helper functions (re used by rebuild function later on).
# `LatticeTransportODEFunction` helper functions (re-used by rebuild function later on).

# Creates a vector with the heterogeneous vertex parameters' indexes in the full parameter vector.
function make_heterogeneous_vert_p_idxs(ps, lrs)
Expand Down Expand Up @@ -226,7 +226,7 @@ function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Pair{R, V
ofunc_sparse = ODEFunction(osys; jac = true, sparse = true)
transport_rates = make_sidxs_to_transrate_map(vert_ps, edge_ps, lrs)

# Depending on Jacobian and sparsity options, computes the Jacobian transport matrix and prototype.
# Depending on Jacobian and sparsity options, compute the Jacobian transport matrix and prototype.
if !sparse && !jac
jac_transport = nothing
jac_prototype = nothing
Expand All @@ -249,7 +249,7 @@ function build_odefunction(lrs::LatticeReactionSystem, vert_ps::Vector{Pair{R, V
return ODEFunction(f; jac = J, jac_prototype, sys)
end

# Builds a jacobian prototype.
# Builds a Jacobian prototype.
# If requested, populate it with the constant values of the Jacobian's transportation part.
function build_jac_prototype(ns_jac_prototype::SparseMatrixCSC{Float64, Int64},
transport_rates::Vector{Pair{Int64, SparseMatrixCSC{T, Int64}}},
Expand Down Expand Up @@ -328,19 +328,19 @@ end
"""
rebuild_lat_internals!(sciml_struct)
Rebuilds the internal functions for simulating a LatticeReactionSystem. WHenever a problem or
integrator have had its parameter values updated, thus function should be called for the update to
Rebuilds the internal functions for simulating a LatticeReactionSystem. Wenever a problem or
integrator has had its parameter values updated, this function should be called for the update to
be taken into account. For ODE simulations, `rebuild_lat_internals!` needs only to be called when
- An edge parameter have been updated.
- When a parameter with spatially homogeneous values have been given spatially heterogeneous values
- An edge parameter has been updated.
- When a parameter with spatially homogeneous values has been given spatially heterogeneous values
(or vice versa).
Arguments:
- `sciml_struct`: The problem (e.g. an `ODEProblem`) or an integrator which we wish to rebuild.
Notes:
- Currently does not work for `DiscreteProblem`s, `JumpProblem`s, or their integrators.
- The function is not build with performance in mind, so avoid calling it multiple times in
- The function is not built with performance in mind, so avoid calling it multiple times in
performance-critical applications.
Example:
Expand Down
6 changes: 3 additions & 3 deletions src/spatial_reaction_systems/spatial_reactions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,15 @@ function check_spatial_reaction_validity(rs::ReactionSystem, tr::TransportReacti

# Checks that the species does not exist in the system with different metadata.
if any(isequal(tr.species, s) && !isequivalent(tr.species, s) for s in species(rs))
error("A transport reaction used a species, $(tr.species), with metadata not matching its lattice reaction system. Please fetch this species from the reaction system and used in transport reaction creation.")
error("A transport reaction used a species, $(tr.species), with metadata not matching its lattice reaction system. Please fetch this species from the reaction system and use it during transport reaction creation.")
end
# No `for` loop, just weird formatting by the formatter.
if any(isequal(rs_p, tr_p) && !isequivalent(rs_p, tr_p)
for rs_p in parameters(rs), tr_p in Symbolics.get_variables(tr.rate))
error("A transport reaction used a parameter with metadata not matching its lattice reaction system. Please fetch this parameter from the reaction system and used in transport reaction creation.")
error("A transport reaction used a parameter with metadata not matching its lattice reaction system. Please fetch this parameter from the reaction system and use it during transport reaction creation.")
end

# Checks that no edge parameter occur among rates of non-spatial reactions.
# Checks that no edge parameter occurs among rates of non-spatial reactions.
# No `for` loop, just weird formatting by the formatter.
if any(!isempty(intersect(Symbolics.get_variables(r.rate), edge_parameters))
for r in reactions(rs))
Expand Down
8 changes: 4 additions & 4 deletions src/spatial_reaction_systems/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ function get_transport_rate(transport_rate::SparseMatrixCSC{T, Int64},
return t_rate_idx_types ? transport_rate[1, 1] : transport_rate[edge[1], edge[2]]
end

# For a `LatticeTransportODEFunction`, updates its stored parameters (in `mtk_ps`) so that they
# For a `LatticeTransportODEFunction`, update its stored parameters (in `mtk_ps`) so that they
# the heterogeneous parameters' values correspond to the values in the specified vertex.
function update_mtk_ps!(lt_ofun::LatticeTransportODEFunction, all_ps::Vector{T},
vert::Int64) where {T}
Expand All @@ -278,7 +278,7 @@ function update_mtk_ps!(lt_ofun::LatticeTransportODEFunction, all_ps::Vector{T},
end
end

# For an expression, computes its values using the provided state and parameter vectors.
# For an expression, compute its values using the provided state and parameter vectors.
# The expression is assumed to be valid in vertexes (and can have vertex parameter and state components).
# If at least one component is non-uniform, output is a vector of length equal to the number of vertexes.
# If all components are uniform, the output is a length one vector.
Expand All @@ -289,7 +289,7 @@ function compute_vertex_value(exp, lrs::LatticeReactionSystem; u = [], ps = [])
throw(ArgumentError("An edge parameter was encountered in expressions: $exp. Here, only vertex-based components are expected."))
end

# Creates a Function that computes the expressions value for a parameter set.
# Creates a Function that computes the expression value for a parameter set.
exp_func = drop_expr(@RuntimeGeneratedFunction(build_function(exp, relevant_syms...)))

# Creates a dictionary with the value(s) for all edge parameters.
Expand All @@ -305,7 +305,7 @@ end

### System Property Checks ###

# For a Symbolic expression, and a parameter set, checks if any relevant parameters have a
# For a Symbolic expression, and a parameter set, check if any relevant parameters have a
# spatial component. Filters out any parameters that are edge parameters.
function has_spatial_vertex_component(exp, ps)
relevant_syms = Symbolics.get_variables(exp)
Expand Down
4 changes: 2 additions & 2 deletions test/spatial_modelling/lattice_reaction_systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ end

### Tests Edge Value Computation Helper Functions ###

# Checks that computes the correct values across various types of grids.
# Checks that we compute the correct values across various types of grids.
let
# Prepares the model and the function that determines the edge values.
rn = @reaction_network begin
Expand All @@ -323,7 +323,7 @@ let
end
end

# Checks that all species ends up in the correct place in in a pure flow system (checking various dimensions).
# Checks that all species end up in the correct place in a pure flow system (checking various dimensions).
let
# Prepares a system with a single species which is transported only.
rn = @reaction_network begin
Expand Down
10 changes: 5 additions & 5 deletions test/spatial_modelling/lattice_reaction_systems_ODEs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ rng = StableRNG(12345)
# Sets defaults
t = default_t()

### Tests Simulations Don't Error ###
### Tests Simulations Do Not Error ###
let
for grid in [small_1d_cartesian_grid, small_1d_masked_grid, small_1d_graph_grid]
for srs in [Vector{TransportReaction}(), SIR_srs_1, SIR_srs_2]
Expand Down Expand Up @@ -186,7 +186,7 @@ let
@test all(isapprox.(ss, solve(oprob_sparse_jac, Rosenbrock23(); abstol = 1e-10, reltol = 1e-10).u[end]; rtol = 0.0001))
end

# Compares Catalyst-generated to hand written one for the brusselator for a line of cells.
# Compares Catalyst-generated to hand-written one for the Brusselator for a line of cells.
let
function spatial_brusselator_f(du, u, p, t)
# Non-spatial
Expand Down Expand Up @@ -534,7 +534,7 @@ end

# Checks that the `rebuild_lat_internals!` function is correctly applied to an ODEProblem.
let
# Creates a brusselator `LatticeReactionSystem`.
# Creates a Brusselator `LatticeReactionSystem`.
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_2, very_small_2d_cartesian_grid)

# Checks for all combinations of Jacobian and sparsity.
Expand Down Expand Up @@ -572,7 +572,7 @@ end

# Checks that the `rebuild_lat_internals!` function is correctly applied to an integrator.
# Does through by applying it within a callback, and compare to simulations without callback.
# To keep test faster, only checks for `jac = sparse = true`.
# To keep test faster, only check for `jac = sparse = true` only.
let
# Prepares problem inputs.
lrs = LatticeReactionSystem(brusselator_system, brusselator_srs_2, very_small_2d_cartesian_grid)
Expand Down Expand Up @@ -624,7 +624,7 @@ end

### Tests Special Cases ###

# Create network using either graphs or di-graphs.
# Create networks using either graphs or di-graphs.
let
lrs_digraph = LatticeReactionSystem(SIR_system, SIR_srs_2, complete_digraph(3))
lrs_graph = LatticeReactionSystem(SIR_system, SIR_srs_2, complete_graph(3))
Expand Down
6 changes: 3 additions & 3 deletions test/spatial_modelling/lattice_reaction_systems_jumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ end

### SpatialMassActionJump Testing ###

# Checks that the correct structure;s is produced.
# Checks that the correct structures are produced.
let
# Network for reference:
# A, ∅ → X
Expand All @@ -113,7 +113,7 @@ let
# @test isequal(to_int(getfield.(reactions(reactionsystem(lrs)), :netstoich)), jprob.massaction_jump.net_stoch)
# @test isequal(to_int(Pair.(getfield.(reactions(reactionsystem(lrs)), :substrates),getfield.(reactions(reactionsystem(lrs)), :substoich))), jprob.massaction_jump.net_stoch)

# Checks that problem can be simulated.
# Checks that problems can be simulated.
@test SciMLBase.successful_retcode(solve(jprob, SSAStepper()))
end

Expand Down Expand Up @@ -203,7 +203,7 @@ end

### JumpProblem & Integrator Interfacing ###

# Currently not supported, check that corresponding functions yields errors.
# Currently not supported, check that corresponding functions yield errors.
let
# Prepare `LatticeReactionSystem`.
rs = @reaction_network begin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ include("../spatial_test_networks.jl")

### Run Tests ###

# Test errors when attempting to create networks with dimension > 3.
# Test errors when attempting to create networks with dimensions > 3.
let
@test_throws Exception LatticeReactionSystem(brusselator_system, brusselator_srs_1, CartesianGrid((5, 5, 5, 5)))
@test_throws Exception LatticeReactionSystem(brusselator_system, brusselator_srs_1, fill(true, 5, 5, 5, 5))
Expand Down
Loading

0 comments on commit e19d67d

Please sign in to comment.