Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CubicSplines: Add AbstractMatrix support #343

Merged
merged 8 commits into from
Oct 14, 2024
34 changes: 34 additions & 0 deletions src/interpolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,40 @@ function CubicSpline(u::uType,
CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters, linear_lookup)
end

function CubicSpline(u::uType,
t;
extrapolate = false, cache_parameters = false,
assume_linear_t = 1e-2) where {uType <:
AbstractArray{T, N}} where {T, N}
u, t = munge_data(u, t)
n = length(t) - 1
h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0)
dl = vcat(h[2:n], zero(eltype(h)))
d_tmp = 2 .* (h[1:(n + 1)] .+ h[2:(n + 2)])
du = vcat(zero(eltype(h)), h[3:(n + 1)])
tA = Tridiagonal(dl, d_tmp, du)

# zero for element type of d, which we don't know yet
ax = axes(u)[1:(end - 1)]
typed_zero = zero(6(u[ax..., begin + 2] - u[ax..., begin + 1]) / h[begin + 2] -
6(u[ax..., begin + 1] - u[ax..., begin]) / h[begin + 1])

h_ = reshape(h, ones(Int64, N - 1)..., :)
ax_h = axes(h_)[1:(end - 1)]
d = 6 * ((u[ax..., 3:(n + 1)] - u[ax..., 2:n]) ./ h_[ax_h..., 3:(n + 1)]) -
6 * ((u[ax..., 2:n] - u[ax..., 1:(n - 1)]) ./ h_[ax_h..., 2:n])
d = cat(typed_zero, d, typed_zero; dims = ndims(d))
d_reshaped = reshape(d, prod(size(d)[1:(end - 1)]), :)
z = (tA \ d_reshaped')'
z = reshape(z, size(u)...)
linear_lookup = seems_linear(assume_linear_t, t)
p = CubicSplineParameterCache(u, h, z, cache_parameters)
A = CubicSpline(
u, t, nothing, p, h[1:(n + 1)], z, extrapolate, cache_parameters, linear_lookup)
I = cumulative_integral(A, cache_parameters)
CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, cache_parameters, linear_lookup)
end

function CubicSpline(
u::uType, t; extrapolate = false, cache_parameters = false,
assume_linear_t = 1e-2) where {uType <:
Expand Down
12 changes: 12 additions & 0 deletions src/interpolation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,18 @@ function _interpolate(A::CubicSpline{<:AbstractVector}, t::Number, iguess)
I + C + D
end

function _interpolate(A::CubicSpline{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
idx = get_idx(A, t, iguess)
Δt₁ = t - A.t[idx]
Δt₂ = A.t[idx + 1] - t
ax = axes(A.z)[1:(end - 1)]
I = (A.z[ax..., idx] * Δt₂^3 + A.z[ax..., idx + 1] * Δt₁^3) / (6A.h[idx + 1])
c₁, c₂ = get_parameters(A, idx)
C = c₁ * Δt₁
D = c₂ * Δt₂
I + C + D
end

# BSpline Curve Interpolation
function _interpolate(A::BSplineInterpolation{<:AbstractVector{<:Number}},
t::Number,
Expand Down
9 changes: 8 additions & 1 deletion src/parameter_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,19 @@ function CubicSplineParameterCache(u, h, z, cache_parameters)
end
end

function cubic_spline_parameters(u, h, z, idx)
function cubic_spline_parameters(u::AbstractVector, h, z, idx)
c₁ = (u[idx + 1] / h[idx + 1] - z[idx + 1] * h[idx + 1] / 6)
c₂ = (u[idx] / h[idx + 1] - z[idx] * h[idx + 1] / 6)
return c₁, c₂
end

function cubic_spline_parameters(u::AbstractArray, h, z, idx)
ax = axes(u)[1:(end - 1)]
c₁ = (u[ax..., idx + 1] / h[idx + 1] - z[ax..., idx + 1] * h[idx + 1] / 6)
c₂ = (u[ax..., idx] / h[idx + 1] - z[ax..., idx] * h[idx + 1] / 6)
return c₁, c₂
end

struct CubicHermiteParameterCache{pType}
c₁::pType
c₂::pType
Expand Down
21 changes: 21 additions & 0 deletions test/interpolation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,27 @@ end
A = CubicSpline(u, t)
@test_throws DataInterpolations.ExtrapolationError A(-2.0)
@test_throws DataInterpolations.ExtrapolationError A(2.0)

@testset "AbstractMatrix" begin
t = 0.1:0.1:1.0
u = [sin.(t) cos.(t)]' |> collect
c = CubicSpline(u, t)
t_test = 0.1:0.05:1.0
u_test = reduce(hcat, c.(t_test))
@test isapprox(u_test[1, :], sin.(t_test), atol = 1e-3)
@test isapprox(u_test[2, :], cos.(t_test), atol = 1e-3)
end
@testset "AbstractArray{T, 3}" begin
f3d(t) = [sin(t) cos(t);
0.0 cos(2t)]
t = 0.1:0.1:1.0
u3d = f3d.(t)
c = CubicSpline(u3d, t)
t_test = 0.1:0.05:1.0
u_test = reduce(hcat, c.(t_test))
f_test = reduce(hcat, f3d.(t_test))
@test isapprox(u_test, f_test, atol = 1e-2)
end
end

@testset "BSplines" begin
Expand Down
Loading