From 412732d33137e36ca5f277159fe99aa30459c3b7 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Wed, 3 Jan 2024 13:54:04 -0500 Subject: [PATCH] adding interface for getting current state of integrator. --- Project.toml | 2 +- src/interface.jl | 7 ++++--- src/types.jl | 2 ++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 95dbf0d..861e8fa 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DormandPrince" uuid = "5e45e72d-22b8-4dd0-9c8b-f96714864bcd" authors = ["John Long", "Phillip Weinberg"] -version = "0.2.0" +version = "0.3.0" [deps] diff --git a/src/interface.jl b/src/interface.jl index 7099c5b..69a7f28 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -12,7 +12,7 @@ function Base.iterate(solver_iter::SolverIterator) # integrate to first time integrate(solver_iter.solver, first(solver_iter.times)) # return value and index which is the state - return (solver_iter.times[1], solver_iter.solver.y), 2 + return (solver_iter.times[1], get_current_state(solver_iter.solver)), 2 end # gets the next (t,y), return index+! which is the updated state @@ -21,7 +21,7 @@ function Base.iterate(solver_iter::SolverIterator, index::Int) # integrate to next time integrate(solver_iter.solver, solver_iter.times[index]) # return time and state - return (solver_iter.times[index], solver_iter.solver.y), index+1 + return (solver_iter.times[index], get_current_state(solver_iter.solver)), index+1 end # 3 modes of operation for integrate @@ -29,6 +29,7 @@ end # 2. integrate(solver, times) -> iterator # 3. integrate(callback, solver, times) -> vector of states with callback applied +get_current_state(::AbstractDPSolver) = error("not implemented") integrate(solver::AbstractDPSolver{T}, times::AbstractVector{T}) where {T <: Real} = SolverIterator(solver, times) function integrate(callback, solver::AbstractDPSolver{T}, times::AbstractVector{T}; sort_times::Bool = true) where {T <: Real} @@ -37,7 +38,7 @@ function integrate(callback, solver::AbstractDPSolver{T}, times::AbstractVector{ result = [] for time in times integrate(solver, time) - push!(result, callback(time, solver.y)) + push!(result, callback(time, get_current_state(solver))) end return result diff --git a/src/types.jl b/src/types.jl index 1ad85e2..45d9983 100644 --- a/src/types.jl +++ b/src/types.jl @@ -213,3 +213,5 @@ function DP8Solver( DP8Solver(f, x, y, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, y1;kw...) end +get_current_state(solver::DP5Solver) = solver.y +get_current_state(solver::DP8Solver) = solver.y \ No newline at end of file