diff --git a/src/interval_sets.jl b/src/interval_sets.jl index 4d915c3..68b106e 100644 --- a/src/interval_sets.jl +++ b/src/interval_sets.jl @@ -521,3 +521,55 @@ function find_intersections_helper!(result, x, y, lt) return unique!.(result) end + +function find_intersections( + x::AbstractVector{<:AbstractInterval{T1,Closed,Closed}}, + y::AbstractVector{<:AbstractInterval{T2,Closed,Closed}}, +) where {T1,T2} + # Strategy: + # two binary searches per interval `I` in `x` + # * identify the set of intervals in `y` that start during-or-after `I` + # * identify the set of intervals in `y` that stop before-or-during `I` + # * intersect them + starts = first.(y) + starts_perm = sortperm(starts) + starts_sorted = starts[starts_perm] + + # Sneaky performance optimization (makes a huge difference!) + # Rather than sorting `stops` relative to `y`, we sort it relative to `starts`. + # This allows us to work in the `starts` frame of reference until the very end. + # In particular, when we intersect the sets of intervals obtained from starts and from stops, + # the `starts` set can be kept as a `UnitRange`, making the intersection *much* faster. + stops = last.(y[starts_perm]) + stops_perm = sortperm(stops) + stops_sorted = stops[stops_perm] + len = length(stops_sorted) + + results = Vector{Vector{Int}}(undef, length(x)) + for (i, I) in enumerate(x) + # find all the starts which occur before or at the end of `I` + idx_first = searchsortedlast(starts_sorted, last(I); lt=(<)) + if idx_first < 1 + results[i] = Int[] + continue + end + + # find all the stops which occur at or after the start of `I` + idx_last = searchsortedfirst(stops_sorted, first(I); lt=(<)) + if idx_last > len + results[i] = Int[] + continue + end + + # Working in "starts" frame of reference + starts_before_or_during = 1:idx_first + stops_during_or_after = @view stops_perm[idx_last:end] + + # Intersect them + r = intersect(starts_before_or_during, stops_during_or_after) + + # *Now* go back to y's sorting order, post-intersection. + results[i] = starts_perm[r] + end + return results +end