Skip to content

Commit

Permalink
Merge pull request #359 from SouthEndMusic/integrals_refactor
Browse files Browse the repository at this point in the history
Refactor integration and `QuadraticInterpolation`
  • Loading branch information
ChrisRackauckas authored Nov 17, 2024
2 parents baa252d + 3e10278 commit 96d017d
Show file tree
Hide file tree
Showing 11 changed files with 174 additions and 147 deletions.
24 changes: 16 additions & 8 deletions ext/DataInterpolationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ if isdefined(Base, :get_extension)
LinearInterpolation, QuadraticInterpolation,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox, get_idx, get_parameters,
_quad_interp_indices, munge_data
munge_data
using ChainRulesCore
else
using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation,
LinearInterpolation, QuadraticInterpolation,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox, get_parameters,
_quad_interp_indices, munge_data
munge_data
using ..ChainRulesCore
end

Expand Down Expand Up @@ -74,6 +74,11 @@ function u_tangent(A::LinearInterpolation, t, Δ)
out
end

function _quad_interp_indices(A::QuadraticInterpolation, t::Number, iguess)
idx = get_idx(A, t, iguess; idx_shift = A.mode == :Backward ? -1 : 0, ub_shift = -2)
idx, idx + 1, idx + 2
end

function u_tangent(A::QuadraticInterpolation, t, Δ)
out = zero.(A.u)
i₀, i₁, i₂ = _quad_interp_indices(A, t, A.iguesser)
Expand All @@ -83,14 +88,17 @@ function u_tangent(A::QuadraticInterpolation, t, Δ)
Δt₀ = t₁ - t₀
Δt₁ = t₂ - t₁
Δt₂ = t₂ - t₀
Δt_rel₀ = t - A.t[i₀]
Δt_rel₁ = t - A.t[i₁]
Δt_rel₂ = t - A.t[i₂]
if eltype(out) <: Number
out[i₀] = Δ * (t - A.t[i₁]) * (t - A.t[i₂]) / (Δt₀ * Δt₂)
out[i₁] = -Δ * (t - A.t[i₀]) * (t - A.t[i₂]) / (Δt₀ * Δt₁)
out[i₂] = Δ * (t - A.t[i₀]) * (t - A.t[i₁]) / (Δt₂ * Δt₁)
out[i₀] = Δ * Δt_rel₁ * Δt_rel₂ / (Δt₀ * Δt₂)
out[i₁] = -Δ * Δt_rel₀ * Δt_rel₂ / (Δt₀ * Δt₁)
out[i₂] = Δ * Δt_rel₀ * Δt_rel₁ / (Δt₂ * Δt₁)
else
@. out[i₀] = Δ * (t - A.t[i₁]) * (t - A.t[i₂]) / (Δt₀ * Δt₂)
@. out[i₁] = -Δ * (t - A.t[i₀]) * (t - A.t[i₂]) / (Δt₀ * Δt₁)
@. out[i₂] = Δ * (t - A.t[i₀]) * (t - A.t[i₁]) / (Δt₂ * Δt₁)
@. out[i₀] = Δ * Δt_rel₁ * Δt_rel₂ / (Δt₀ * Δt₂)
@. out[i₁] = -Δ * Δt_rel₀ * Δt_rel₂ / (Δt₀ * Δt₁)
@. out[i₂] = Δ * Δt_rel₀ * Δt_rel₁ / (Δt₂ * Δt₁)
end
out
end
Expand Down
10 changes: 4 additions & 6 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ function _derivative(A::LinearInterpolation, t::Number, iguess)
end

function _derivative(A::QuadraticInterpolation, t::Number, iguess)
i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess)
l₀, l₁, l₂ = get_parameters(A, i₀)
du₀ = l₀ * (2t - A.t[i₁] - A.t[i₂])
du₁ = l₁ * (2t - A.t[i₀] - A.t[i₂])
du₂ = l₂ * (2t - A.t[i₀] - A.t[i₁])
return @views @. du₀ + du₁ + du₂
idx = get_idx(A, t, iguess)
Δt = t - A.t[idx]
α, β = get_parameters(A, idx)
return 2α * Δt + β
end

function _derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
Expand Down
6 changes: 5 additions & 1 deletion src/integral_inverses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ function invertible_integral(A::LinearInterpolation{<:AbstractVector{<:Number}})
return all(A.u .> 0)
end

get_I(A::AbstractInterpolation) = isempty(A.I) ? cumulative_integral(A, true) : A.I
function get_I(A::AbstractInterpolation)
I = isempty(A.I) ? cumulative_integral(A, true) : copy(A.I)
pushfirst!(I, 0)
I
end

function invert_integral(A::LinearInterpolation{<:AbstractVector{<:Number}})
!invertible_integral(A) && throw(IntegralNotInvertibleError())
Expand Down
122 changes: 59 additions & 63 deletions src/integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,122 +7,118 @@ function integral(A::AbstractInterpolation, t1::Number, t2::Number)
((t1 < A.t[1] || t1 > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
((t2 < A.t[1] || t2 > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
!hasfield(typeof(A), :I) && throw(IntegralNotFoundError())
(t2 < t1) && return -integral(A, t2, t1)
# the index less than or equal to t1
idx1 = get_idx(A, t1, 0)
# the index less than t2
idx2 = get_idx(A, t2, 0; idx_shift = -1, side = :first)

if A.cache_parameters
total = A.I[idx2] - A.I[idx1]
total = A.I[max(1, idx2 - 1)] - A.I[idx1]
return if t1 == t2
zero(total)
else
total += _integral(A, idx1, A.t[idx1])
total -= _integral(A, idx1, t1)
total += _integral(A, idx2, t2)
total -= _integral(A, idx2, A.t[idx2])
if idx1 == idx2
total += _integral(A, idx1, t1, t2)
else
total += _integral(A, idx1, t1, A.t[idx1 + 1])
total += _integral(A, idx2, A.t[idx2], t2)
end
total
end
else
total = zero(eltype(A.u))
for idx in idx1:idx2
lt1 = idx == idx1 ? t1 : A.t[idx]
lt2 = idx == idx2 ? t2 : A.t[idx + 1]
total += _integral(A, idx, lt2) - _integral(A, idx, lt1)
total += _integral(A, idx, lt1, lt2)
end
total
end
end

function _integral(A::LinearInterpolation{<:AbstractVector{<:Number}},
idx::Number,
t::Number)
Δt = t - A.t[idx]
idx::Number, t1::Number, t2::Number)
slope = get_parameters(A, idx)
Δt * (A.u[idx] + slope * Δt / 2)
u_mean = A.u[idx] + slope * ((t1 + t2) / 2 - A.t[idx])
u_mean * (t2 - t1)
end

function _integral(
A::ConstantInterpolation{<:AbstractVector{<:Number}}, idx::Number, t::Number)
A::ConstantInterpolation{<:AbstractVector{<:Number}}, idx::Number, t1::Number, t2::Number)
Δt = t2 - t1
if A.dir === :left
# :left means that value to the left is used for interpolation
return A.u[idx] * t
return A.u[idx] * Δt
else
# :right means that value to the right is used for interpolation
return A.u[idx + 1] * t
return A.u[idx + 1] * Δt
end
end

function _integral(A::QuadraticInterpolation{<:AbstractVector{<:Number}},
idx::Number,
t::Number)
A.mode == :Backward && idx > 1 && (idx -= 1)
idx = min(length(A.t) - 2, idx)
t₀ = A.t[idx]
t₁ = A.t[idx + 1]
t₂ = A.t[idx + 2]

t_sq = (t^2) / 3
l₀, l₁, l₂ = get_parameters(A, idx)
Iu₀ = l₀ * t * (t_sq - t * (t₁ + t₂) / 2 + t₁ * t₂)
Iu₁ = l₁ * t * (t_sq - t * (t₀ + t₂) / 2 + t₀ * t₂)
Iu₂ = l₂ * t * (t_sq - t * (t₀ + t₁) / 2 + t₀ * t₁)
return Iu₀ + Iu₁ + Iu₂
idx::Number, t1::Number, t2::Number)
α, β = get_parameters(A, idx)
uᵢ = A.u[idx]
tᵢ = A.t[idx]
t1_rel = t1 - tᵢ
t2_rel = t2 - tᵢ
Δt = t2 - t1
Δt ** (t2_rel^2 + t1_rel * t2_rel + t1_rel^2) / 3 + β * (t2_rel + t1_rel) / 2 + uᵢ)
end

function _integral(A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
function _integral(
A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t1::Number, t2::Number)
α, β = get_parameters(A, idx)
uᵢ = A.u[idx]
Δt = t - A.t[idx]
Δt_full = A.t[idx + 1] - A.t[idx]
Δt ** Δt^2 / (3Δt_full^2) + β * Δt / (2Δt_full) + uᵢ)
tᵢ = A.t[idx]
t1_rel = t1 - tᵢ
t2_rel = t2 - tᵢ
Δt = t2 - t1
Δt ** (t2_rel^2 + t1_rel * t2_rel + t1_rel^2) / 3 + β * (t2_rel + t1_rel) / 2 + uᵢ)
end

function _integral(A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
Δt₁sq = (t - A.t[idx])^2 / 2
Δt₂sq = (A.t[idx + 1] - t)^2 / 2
II = (-A.z[idx] * Δt₂sq^2 + A.z[idx + 1] * Δt₁sq^2) / (6A.h[idx + 1])
function _integral(
A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t1::Number, t2::Number)
tᵢ = A.t[idx]
tᵢ₊₁ = A.t[idx + 1]
c₁, c₂ = get_parameters(A, idx)
IC = c₁ * Δt₁sq
ID = -c₂ * Δt₂sq
II + IC + ID
integrate_cubic_polynomial(t1, t2, tᵢ, 0, c₁, 0, A.z[idx + 1] / (6A.h[idx + 1])) +
integrate_cubic_polynomial(t1, t2, tᵢ₊₁, 0, -c₂, 0, -A.z[idx] / (6A.h[idx + 1]))
end

function _integral(A::AkimaInterpolation{<:AbstractVector{<:Number}},
idx::Number,
t::Number)
t1 = A.t[idx]
A.u[idx] * (t - t1) + A.b[idx] * ((t - t1)^2 / 2) + A.c[idx] * ((t - t1)^3 / 3) +
A.d[idx] * ((t - t1)^4 / 4)
idx::Number, t1::Number, t2::Number)
integrate_cubic_polynomial(t1, t2, A.t[idx], A.u[idx], A.b[idx], A.c[idx], A.d[idx])
end

_integral(A::LagrangeInterpolation, idx::Number, t::Number) = throw(IntegralNotFoundError())
_integral(A::BSplineInterpolation, idx::Number, t::Number) = throw(IntegralNotFoundError())
_integral(A::BSplineApprox, idx::Number, t::Number) = throw(IntegralNotFoundError())
function _integral(A::LagrangeInterpolation, idx::Number, t1::Number, t2::Number)
throw(IntegralNotFoundError())
end
function _integral(A::BSplineInterpolation, idx::Number, t1::Number, t2::Number)
throw(IntegralNotFoundError())
end
function _integral(A::BSplineApprox, idx::Number, t1::Number, t2::Number)
throw(IntegralNotFoundError())
end

# Cubic Hermite Spline
function _integral(
A::CubicHermiteSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
Δt₀ = t - A.t[idx]
Δt₁ = t - A.t[idx + 1]
out = Δt₀ * (A.u[idx] + Δt₀ * A.du[idx] / 2)
A::CubicHermiteSpline{<:AbstractVector{<:Number}}, idx::Number, t1::Number, t2::Number)
c₁, c₂ = get_parameters(A, idx)
p = c₁ + Δt₁ * c₂
dp = c₂
out += Δt₀^3 / 3 * (p - dp * Δt₀ / 4)
out
tᵢ = A.t[idx]
tᵢ₊₁ = A.t[idx + 1]
c = c₁ - c₂ * (tᵢ₊₁ - tᵢ)
integrate_cubic_polynomial(t1, t2, tᵢ, A.u[idx], A.du[idx], c, c₂)
end

# Quintic Hermite Spline
function _integral(
A::QuinticHermiteSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number)
Δt₀ = t - A.t[idx]
Δt= t - A.t[idx + 1]
out = Δt₀ * (A.u[idx] + A.du[idx] * Δt₀ / 2 + A.ddu[idx] * Δt₀^2 / 6)
A::QuinticHermiteSpline{<:AbstractVector{<:Number}}, idx::Number, t1::Number, t2::Number)
tᵢ = A.t[idx]
tᵢ₊= A.t[idx + 1]
Δt = tᵢ₊₁ - tᵢ
c₁, c₂, c₃ = get_parameters(A, idx)
p = c₁ + c₂ * Δt₁ + c₃ * Δt₁^2
dp = c₂ + 2c₃ * Δt₁
ddp = 2c₃
out += Δt₀^4 / 4 * (p - Δt₀ / 5 * dp + Δt₀^2 / 30 * ddp)
out
integrate_quintic_polynomial(t1, t2, tᵢ, A.u[idx], A.du[idx], A.ddu[idx] / 2,
c₁ + Δt * (-c₂ + c₃ * Δt), c₂ - 2c₃ * Δt, c₃)
end
4 changes: 2 additions & 2 deletions src/interpolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct QuadraticInterpolation{uType, tType, IType, pType, T, N} <:
error("mode should be :Forward or :Backward for QuadraticInterpolation")
linear_lookup = seems_linear(assume_linear_t, t)
N = get_output_dim(u)
new{typeof(u), typeof(t), typeof(I), typeof(p.l₀), eltype(u), N}(
new{typeof(u), typeof(t), typeof(I), typeof(p.α), eltype(u), N}(
u, t, I, p, mode, extrapolate, Guesser(t), cache_parameters, linear_lookup)
end
end
Expand All @@ -93,7 +93,7 @@ function QuadraticInterpolation(
u, t, mode; extrapolate = false, cache_parameters = false, assume_linear_t = 1e-2)
u, t = munge_data(u, t)
linear_lookup = seems_linear(assume_linear_t, t)
p = QuadraticParameterCache(u, t, cache_parameters)
p = QuadraticParameterCache(u, t, cache_parameters, mode)
A = QuadraticInterpolation(
u, t, nothing, p, mode, extrapolate, cache_parameters, linear_lookup)
I = cumulative_integral(A, cache_parameters)
Expand Down
18 changes: 6 additions & 12 deletions src/interpolation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,13 @@ function _interpolate(A::LinearInterpolation{<:AbstractArray}, t::Number, iguess
end

# Quadratic Interpolation
_quad_interp_indices(A, t) = _quad_interp_indices(A, t, firstindex(A.t) - 1)
function _quad_interp_indices(A::QuadraticInterpolation, t::Number, iguess)
idx = get_idx(A, t, iguess; idx_shift = A.mode == :Backward ? -1 : 0, ub_shift = -2)
idx, idx + 1, idx + 2
end

function _interpolate(A::QuadraticInterpolation, t::Number, iguess)
i₀, i₁, i₂ = _quad_interp_indices(A, t, iguess)
l₀, l₁, l₂ = get_parameters(A, i₀)
u₀ = l₀ * (t - A.t[i₁]) * (t - A.t[i₂])
u₁ = l₁ * (t - A.t[i₀]) * (t - A.t[i₂])
u₂ = l₂ * (t - A.t[i₀]) * (t - A.t[i₁])
return u₀ + u₁ + u₂
idx = get_idx(A, t, iguess)
Δt = t - A.t[idx]
α, β = get_parameters(A, idx)
out = A.u isa AbstractMatrix ? A.u[:, idx] : A.u[idx]
out += @. Δt * (α * Δt + β)
out
end

# Lagrange Interpolation
Expand Down
30 changes: 24 additions & 6 deletions src/interpolation_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,9 @@ function get_idx(A::AbstractInterpolation, t, iguess::Union{<:Integer, Guesser};
end

function cumulative_integral(A, cache_parameters)
if cache_parameters && hasmethod(_integral, Tuple{typeof(A), Number, Number})
integral_values = [_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx])
for idx in 1:(length(A.t) - 1)]
pushfirst!(integral_values, zero(first(integral_values)))
if cache_parameters && hasmethod(_integral, Tuple{typeof(A), Number, Number, Number})
integral_values = _integral.(
Ref(A), 1:(length(A.t) - 1), A.t[1:(end - 1)], A.t[2:end])
cumsum(integral_values)
else
promote_type(eltype(A.u), eltype(A.t))[]
Expand All @@ -210,9 +209,9 @@ end

function get_parameters(A::QuadraticInterpolation, idx)
if A.cache_parameters
A.p.l₀[idx], A.p.l₁[idx], A.p.l₂[idx]
A.p.α[idx], A.p.β[idx]
else
quadratic_interpolation_parameters(A.u, A.t, idx)
quadratic_interpolation_parameters(A.u, A.t, idx, A.mode)
end
end

Expand Down Expand Up @@ -282,3 +281,22 @@ function du_PCHIP(u, t)

return _du.(eachindex(t))
end

function integrate_cubic_polynomial(t1, t2, offset, a, b, c, d)
t1_rel = t1 - offset
t2_rel = t2 - offset
t_sum = t1_rel + t2_rel
t_sq_sum = t1_rel^2 + t2_rel^2
Δt = t2 - t1
Δt * (a + t_sum * (b / 2 + d * t_sq_sum / 4) + c * (t_sq_sum + t1_rel * t2_rel) / 3)
end

function integrate_quintic_polynomial(t1, t2, offset, a, b, c, d, e, f)
t1_rel = t1 - offset
t2_rel = t2 - offset
t_sum = t1_rel + t2_rel
t_sq_sum = t1_rel^2 + t2_rel^2
Δt = t2 - t1
Δt * (a + t_sum * (b / 2 + d * t_sq_sum / 4) + c * (t_sq_sum + t1_rel * t2_rel) / 3) +
e * (t2_rel^5 - t1_rel^5) / 5 + f * (t2_rel^6 - t1_rel^6) / 6
end
Loading

0 comments on commit 96d017d

Please sign in to comment.