Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize find_intersections for closed intervals #203

Merged
merged 15 commits into from
Jun 14, 2023
52 changes: 52 additions & 0 deletions src/interval_sets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,55 @@ function find_intersections_helper!(result, x, y, lt)

return unique!.(result)
end

function find_intersections(
x::AbstractVector{<:AbstractInterval{T1,Closed,Closed}},
y::AbstractVector{<:AbstractInterval{T2,Closed,Closed}},
) where {T1,T2}
omus marked this conversation as resolved.
Show resolved Hide resolved
# Strategy:
# two binary searches per interval `I` in `x`
# * identify the set of intervals in `y` that start during-or-after `I`
# * identify the set of intervals in `y` that stop before-or-during `I`
# * intersect them
starts = first.(y)
starts_perm = sortperm(starts)
starts_sorted = starts[starts_perm]

# Sneaky performance optimization (makes a huge difference!)
# Rather than sorting `stops` relative to `y`, we sort it relative to `starts`.
# This allows us to work in the `starts` frame of reference until the very end.
# In particular, when we intersect the sets of intervals obtained from starts and from stops,
# the `starts` set can be kept as a `UnitRange`, making the intersection *much* faster.
stops = last.(y[starts_perm])
stops_perm = sortperm(stops)
stops_sorted = stops[stops_perm]
len = length(stops_sorted)

results = Vector{Vector{Int}}(undef, length(x))
for (i, I) in enumerate(x)
# find all the starts which occur before or at the end of `I`
idx_first = searchsortedlast(starts_sorted, last(I); lt=(<))
if idx_first < 1
omus marked this conversation as resolved.
Show resolved Hide resolved
results[i] = Int[]
continue
end

# find all the stops which occur at or after the start of `I`
idx_last = searchsortedfirst(stops_sorted, first(I); lt=(<))
if idx_last > len
omus marked this conversation as resolved.
Show resolved Hide resolved
results[i] = Int[]
continue
end

# Working in "starts" frame of reference
starts_before_or_during = 1:idx_first
stops_during_or_after = @view stops_perm[idx_last:end]

# Intersect them
r = intersect(starts_before_or_during, stops_during_or_after)

# *Now* go back to y's sorting order, post-intersection.
results[i] = starts_perm[r]
end
return results
end