Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the SparseIR.refine_grid function #58

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/_roots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ end
closeenough(a::T, b::T, ϵ) where {T<:AbstractFloat} = isapprox(a, b; rtol=0, atol=ϵ)
closeenough(a::T, b::T, _) where {T<:Integer} = a == b

function refine_grid(grid, ::Val{α}) where {α}
function refine_grid(grid::Vector{T}, ::Val{α}) where {T, α}
isempty(grid) && return float(T)[]
n = length(grid)
newn = α * (n - 1) + 1
newgrid = Vector{eltype(grid)}(undef, newn)
newgrid = Vector{float(T)}(undef, newn)

@inbounds for i in 1:(n - 1)
xb = grid[i]
Expand Down
77 changes: 77 additions & 0 deletions test/_roots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,83 @@ using Test
using SparseIR

@testset "_roots.jl" begin
@testset "refine_grid" begin
@testset "Basic refinement" begin
# Test with α = 2
grid = [0.0, 1.0, 2.0]
refined = @inferred SparseIR.refine_grid(grid, Val(2))
@test length(refined) == 5 # α * (n-1) + 1 = 2 * (3-1) + 1 = 5
@test refined ≈ [0.0, 0.5, 1.0, 1.5, 2.0]

# Test with α = 3
refined3 = SparseIR.refine_grid(grid, Val(3))
@test length(refined3) == 7 # α * (n-1) + 1 = 3 * (3-1) + 1 = 7
@test refined3 ≈ [0.0, 1/3, 2/3, 1.0, 4/3, 5/3, 2.0]

# Test with α = 4
refined4 = SparseIR.refine_grid(grid, Val(4))
@test length(refined4) == 9 # α * (n-1) + 1 = 4 * (3-1) + 1 = 9
@test refined4 ≈ [0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0]
end

#=
Currently the SparseIR.refine_grid function is used only for the
SparseIR.find_all function, which should accept only Float64 grids, but
just in case it is a good idea to test that it works for other types.
=#
@testset "Type stability" begin
# Integer grid
int_grid = [0, 1, 2]
refined = @inferred SparseIR.refine_grid(int_grid, Val(2))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've inserted @inferred macro:

refined = @inferred SparseIR.refine_grid(int_grid, Val(2))

@test eltype(refined) === Float64
@test refined == [0.0, 0.5, 1.0, 1.5, 2.0]

# Float32 grid
f32_grid = Float32[0, 1, 2]
refined_f32 = @inferred SparseIR.refine_grid(f32_grid, Val(2))
@test eltype(refined_f32) === Float32
@test refined_f32 ≈ Float32[0, 0.5, 1, 1.5, 2]
end

@testset "Edge cases" begin
# Single interval
single_interval = [0.0, 1.0]
refined_single = SparseIR.refine_grid(single_interval, Val(4))
@test length(refined_single) == 5 # α * (2-1) + 1 = 4 * 1 + 1 = 5
@test refined_single ≈ [0.0, 0.25, 0.5, 0.75, 1.0]

# Empty grid
empty_grid = Float64[]
@test isempty(SparseIR.refine_grid(empty_grid, Val(2)))
# Empty grid
empty_grid = Int[]
out_grid = SparseIR.refine_grid(empty_grid, Val(2))
@test isempty(out_grid)
@test eltype(out_grid) === Float64
# Single point
single_point = [1.0]
@test SparseIR.refine_grid(single_point, Val(2)) == [1.0]
end

@testset "Uneven spacing" begin
# Test with unevenly spaced grid
uneven = [0.0, 1.0, 10.0]
refined_uneven = SparseIR.refine_grid(uneven, Val(2))
@test length(refined_uneven) == 5
@test refined_uneven[1:3] ≈ [0.0, 0.5, 1.0] # First interval
@test refined_uneven[3:5] ≈ [1.0, 5.5, 10.0] # Second interval
end

@testset "Preservation of endpoints" begin
grid = [-1.0, 0.0, 1.0]
for α in [2, 3, 4]
refined = SparseIR.refine_grid(grid, Val(α))
@test first(refined) == first(grid)
@test last(refined) == last(grid)
end
end
end

@testset "discrete_extrema" begin
nonnegative = collect(0:8)
symmetric = collect(-8:8)
Expand Down
Loading