diff --git a/src/_roots.jl b/src/_roots.jl index eeaf49b..aa784e2 100644 --- a/src/_roots.jl +++ b/src/_roots.jl @@ -39,10 +39,11 @@ end closeenough(a::T, b::T, ϵ) where {T<:AbstractFloat} = isapprox(a, b; rtol=0, atol=ϵ) closeenough(a::T, b::T, _) where {T<:Integer} = a == b -function refine_grid(grid, ::Val{α}) where {α} +function refine_grid(grid::Vector{T}, ::Val{α}) where {T, α} + isempty(grid) && return float(T)[] n = length(grid) newn = α * (n - 1) + 1 - newgrid = Vector{eltype(grid)}(undef, newn) + newgrid = Vector{float(T)}(undef, newn) @inbounds for i in 1:(n - 1) xb = grid[i] diff --git a/test/_roots.jl b/test/_roots.jl index 45ac513..453cc24 100644 --- a/test/_roots.jl +++ b/test/_roots.jl @@ -2,6 +2,83 @@ using Test using SparseIR @testset "_roots.jl" begin + @testset "refine_grid" begin + @testset "Basic refinement" begin + # Test with α = 2 + grid = [0.0, 1.0, 2.0] + refined = @inferred SparseIR.refine_grid(grid, Val(2)) + @test length(refined) == 5 # α * (n-1) + 1 = 2 * (3-1) + 1 = 5 + @test refined ≈ [0.0, 0.5, 1.0, 1.5, 2.0] + + # Test with α = 3 + refined3 = SparseIR.refine_grid(grid, Val(3)) + @test length(refined3) == 7 # α * (n-1) + 1 = 3 * (3-1) + 1 = 7 + @test refined3 ≈ [0.0, 1/3, 2/3, 1.0, 4/3, 5/3, 2.0] + + # Test with α = 4 + refined4 = SparseIR.refine_grid(grid, Val(4)) + @test length(refined4) == 9 # α * (n-1) + 1 = 4 * (3-1) + 1 = 9 + @test refined4 ≈ [0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0] + end + + #= + Currently the SparseIR.refine_grid function is used only for the + SparseIR.find_all function, which should accept only Float64 grids, but + just in case it is a good idea to test that it works for other types. + =# + @testset "Type stability" begin + # Integer grid + int_grid = [0, 1, 2] + refined = @inferred SparseIR.refine_grid(int_grid, Val(2)) + @test eltype(refined) === Float64 + @test refined == [0.0, 0.5, 1.0, 1.5, 2.0] + + # Float32 grid + f32_grid = Float32[0, 1, 2] + refined_f32 = @inferred SparseIR.refine_grid(f32_grid, Val(2)) + @test eltype(refined_f32) === Float32 + @test refined_f32 ≈ Float32[0, 0.5, 1, 1.5, 2] + end + + @testset "Edge cases" begin + # Single interval + single_interval = [0.0, 1.0] + refined_single = SparseIR.refine_grid(single_interval, Val(4)) + @test length(refined_single) == 5 # α * (2-1) + 1 = 4 * 1 + 1 = 5 + @test refined_single ≈ [0.0, 0.25, 0.5, 0.75, 1.0] + + # Empty grid + empty_grid = Float64[] + @test isempty(SparseIR.refine_grid(empty_grid, Val(2))) + # Empty grid + empty_grid = Int[] + out_grid = SparseIR.refine_grid(empty_grid, Val(2)) + @test isempty(out_grid) + @test eltype(out_grid) === Float64 + # Single point + single_point = [1.0] + @test SparseIR.refine_grid(single_point, Val(2)) == [1.0] + end + + @testset "Uneven spacing" begin + # Test with unevenly spaced grid + uneven = [0.0, 1.0, 10.0] + refined_uneven = SparseIR.refine_grid(uneven, Val(2)) + @test length(refined_uneven) == 5 + @test refined_uneven[1:3] ≈ [0.0, 0.5, 1.0] # First interval + @test refined_uneven[3:5] ≈ [1.0, 5.5, 10.0] # Second interval + end + + @testset "Preservation of endpoints" begin + grid = [-1.0, 0.0, 1.0] + for α in [2, 3, 4] + refined = SparseIR.refine_grid(grid, Val(α)) + @test first(refined) == first(grid) + @test last(refined) == last(grid) + end + end + end + @testset "discrete_extrema" begin nonnegative = collect(0:8) symmetric = collect(-8:8)