Skip to content

Commit

Permalink
finish impl of dp8 + DRY
Browse files Browse the repository at this point in the history
  • Loading branch information
weinbe58 committed Dec 22, 2023
1 parent 3b1ec97 commit 6624457
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 114 deletions.
34 changes: 1 addition & 33 deletions src/dp5/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,6 @@ function error_estimation(solver)
return err
end

function estimate_second_derivative(solver, h)

der2 = mapreduce(+, solver.consts.atol_iter, solver.consts.rtol_iter, solver.k2, solver.k1, solver.y) do atoli, rtoli, f1i, f0i, yi
sk = atoli + rtoli*abs(yi)
((f1i-f0i)/sk)^2
end

der2 = sqrt(der2)/h

return der2

end

function stiffness_detection!(solver, naccpt, h)
if (mod(naccpt, solver.options.stiffness_test_activation_step) == 0) || (solver.vars.iasti > 0)
#stnum = 0.0
Expand All @@ -63,7 +50,7 @@ function stiffness_detection!(solver, naccpt, h)
end

if stden > 0.0
solver.vars.hlamb = h*sqrt(stnum/stden)
solver.vars.hlamb = abs(h)*sqrt(stnum/stden)
else
solver.vars.hlamb = Inf
end
Expand All @@ -82,22 +69,3 @@ function stiffness_detection!(solver, naccpt, h)
end
end
end

function euler_first_guess(solver, hmax, posneg)

dnf, dny = mapreduce(.+, solver.consts.atol_iter, solver.consts.rtol_iter, solver.k1, solver.y) do atoli, rtoli, f0i, yi
sk = atoli + rtoli*abs(yi)
abs(f0i/sk)^2, abs(yi/sk)^2 # dnf, dny
end


if (dnf <= 1.0e-10) || (dny <= 1.0e-10)
h = 1.0e-6
else
h = 0.01*sqrt(dny/dnf)
end
h = min(h, hmax)
h = h * Base.sign(posneg)

return h, dnf
end
2 changes: 1 addition & 1 deletion src/dp5/mod.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module DP5

using ..DormandPrince: DormandPrince, DP5Solver, Vars, Consts, Options, Report
using ..DormandPrince: DormandPrince, DP5Solver, Vars, Consts, Options, Report, hinit

include("params.jl")
include("helpers.jl")
Expand Down
40 changes: 0 additions & 40 deletions src/dp5/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,43 +194,3 @@ function dopcor(
)

end

function hinit(
solver,
posneg,
iord,
hmax
# f0 arg is k1 from dopcor
# f1 arg is k2 from dopcor
# y1 arg is k3 from dopcor
)
#=
Compute a first guess for explicit euler as
h = 0.01 * norm (y0) / norm (f0)
the increment for explicit euler is small
compared to the solution
=#
h, dnf = euler_first_guess(solver, hmax, posneg)

###### Perform an explicit step
#y1 = y + h*f0
#fcn(n, x+h, y1, f1)
# copyto!(solver.y1, solver.y + h*solver.k1)
solver.y1 .= solver.y .+ h .*solver.k1
solver.f(solver.vars.x + h, solver.k3, solver.k2)

###### Estimate the second derivative of the solution
der2 = estimate_second_derivative(solver, h)

##### Step size is computed such that
##### H**IORD * MAX ( NORM(F0), NORM(F1), DER2 ) = 0.01
der12 = max(abs(der2), sqrt(dnf))
if der12 <= 1e-15
h1 = max(1.0e-6, abs(h)*1.0e-3)
else
h1 = (0.01/der12)^(1.0/iord)
end
h = min(100*abs(h), h1, hmax)
return h * Base.sign(posneg)

end
44 changes: 5 additions & 39 deletions src/dp8/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,39 +145,24 @@ function error_estimation(solver)
return err

Check warning on line 145 in src/dp8/helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/dp8/helpers.jl#L145

Added line #L145 was not covered by tests
end

function estimate_second_derivative(solver, h)

der2 = mapreduce(+, solver.consts.atol_iter, solver.consts.rtol_iter, solver.k2, solver.k1, solver.y) do atoli, rtoli, f1i, f0i, yi
sk = atoli + rtoli*abs(yi)
((f1i-f0i)/sk)^2
end

der2 = sqrt(der2)/h

return der2

end

function stiffness_detection!(solver, naccpt, h)

Check warning on line 148 in src/dp8/helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/dp8/helpers.jl#L148

Added line #L148 was not covered by tests

if (mod(naccpt, solver.options.stiffness_test_activation_step) == 0) || (solver.vars.iasti > 0)

Check warning on line 150 in src/dp8/helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/dp8/helpers.jl#L150

Added line #L150 was not covered by tests
#stnum = 0.0
#stden = 0.0

stnum, stden = mapreduce(.+, solver.k2, solver.k6, solver.y1, solver.ysti) do k2i, k6i, y1i, ystii
#stnum = abs(k2i-k6i)^2
#stden = abs(y1i-ystii)^2
abs(k2i-k6i)^2, abs(y1i-ystii)^2
# stnum, stden
stnum, stden = mapreduce(.+, solver.k3, solver.k4, solver.k5, solver.y1) do k3i, k4i, k5i, y1i
abs(k4i-k3i)^2, abs(k5i-y1i)^2

Check warning on line 155 in src/dp8/helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/dp8/helpers.jl#L154-L155

Added lines #L154 - L155 were not covered by tests
end

if stden > 0.0
solver.vars.hlamb = h*sqrt(stnum/stden)
solver.vars.hlamb = abs(h)*sqrt(stnum/stden)

Check warning on line 159 in src/dp8/helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/dp8/helpers.jl#L158-L159

Added lines #L158 - L159 were not covered by tests
else
solver.vars.hlamb = Inf

Check warning on line 161 in src/dp8/helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/dp8/helpers.jl#L161

Added line #L161 was not covered by tests
end


if solver.vars.hlamb > 3.25
if solver.vars.hlamb > 6.1
solver.vars.iasti += 1
if solver.vars.iasti == 15
@debug "The problem seems to become stiff at $x"

Check warning on line 168 in src/dp8/helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/dp8/helpers.jl#L165-L168

Added lines #L165 - L168 were not covered by tests
Expand All @@ -190,22 +175,3 @@ function stiffness_detection!(solver, naccpt, h)
end
end
end

function euler_first_guess(solver, hmax, posneg)

dnf, dny = mapreduce(.+, solver.consts.atol_iter, solver.consts.rtol_iter, solver.k1, solver.y) do atoli, rtoli, f0i, yi
sk = atoli + rtoli*abs(yi)
abs(f0i/sk)^2, abs(yi/sk)^2 # dnf, dny
end


if (dnf <= 1.0e-10) || (dny <= 1.0e-10)
h = 1.0e-6
else
h = 0.01*sqrt(dny/dnf)
end
h = min(h, hmax)
h = h * Base.sign(posneg)

return h, dnf
end
2 changes: 1 addition & 1 deletion src/dp8/mod.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module DP8

using ..DormandPrince: DormandPrince, DP8Solver, Vars, Consts, Options, Report
using ..DormandPrince: DormandPrince, DP8Solver, Vars, Consts, Options, Report, hinit

include("params.jl")
include("helpers.jl")
Expand Down
74 changes: 74 additions & 0 deletions src/hinit.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@


function euler_first_guess(solver::AbstractDPSolver{T}, hmax::T, posneg::T) where T

dnf, dny = mapreduce(.+, solver.consts.atol_iter, solver.consts.rtol_iter, solver.k1, solver.y) do atoli, rtoli, f0i, yi
sk = atoli + rtoli*abs(yi)
abs(f0i/sk)^2, abs(yi/sk)^2 # dnf, dny
end


if (dnf <= 1.0e-10) || (dny <= 1.0e-10)
h = 1.0e-6
else
h = 0.01*sqrt(dny/dnf)
end
h = min(h, hmax)
h = h * Base.sign(posneg)

return h, dnf
end


function estimate_second_derivative(solver::AbstractDPSolver{T}, h::T) where {T}

der2 = mapreduce(+, solver.consts.atol_iter, solver.consts.rtol_iter, solver.k2, solver.k1, solver.y) do atoli, rtoli, f1i, f0i, yi
sk = atoli + rtoli*abs(yi)
((f1i-f0i)/sk)^2
end

der2 = sqrt(der2)/h

return der2

end

function hinit(
solver::AbstractDPSolver{T},
posneg::T,
iord::Int,
hmax::T
# f0 arg is k1 from dopcor
# f1 arg is k2 from dopcor
# y1 arg is k3 from dopcor
) where T
#=
Compute a first guess for explicit euler as
h = 0.01 * norm (y0) / norm (f0)
the increment for explicit euler is small
compared to the solution
=#
h, dnf = euler_first_guess(solver, hmax, posneg)

###### Perform an explicit step
#y1 = y + h*f0
#fcn(n, x+h, y1, f1)
# copyto!(solver.y1, solver.y + h*solver.k1)
solver.y1 .= solver.y .+ h .*solver.k1
solver.f(solver.vars.x + h, solver.k3, solver.k2)

###### Estimate the second derivative of the solution
der2 = estimate_second_derivative(solver, h)

##### Step size is computed such that
##### H**IORD * MAX ( NORM(F0), NORM(F1), DER2 ) = 0.01
der12 = max(abs(der2), sqrt(dnf))
if der12 <= 1e-15
h1 = max(1.0e-6, abs(h)*1.0e-3)

Check warning on line 67 in src/hinit.jl

View check run for this annotation

Codecov / codecov/patch

src/hinit.jl#L67

Added line #L67 was not covered by tests
else
h1 = (0.01/der12)^(1.0/iord)
end
h = min(100*abs(h), h1, hmax)
return h * Base.sign(posneg)

end
1 change: 1 addition & 0 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ end

@kwdef mutable struct Vars{T <: Real}
x::T = zero(T)
xph::T = zero(T)
h::T = zero(T)
facold::T = 1e-4
iasti::Int = 0
Expand Down

0 comments on commit 6624457

Please sign in to comment.