Skip to content

Commit

Permalink
Try to fix FlexiJoins with other predicates
Browse files Browse the repository at this point in the history
  • Loading branch information
asinghvi17 committed Jun 9, 2024
1 parent ec6090e commit dd2ea98
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
24 changes: 19 additions & 5 deletions ext/GeometryOpsFlexiJoinsExt/GeometryOpsFlexiJoinsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ using SortTileRecursiveTree, Tables
# This module defines the FlexiJoins APIs for GeometryOps' boolean comparison functions, taken from DE-9IM.

# First, we define the joining modes (Tree, NestedLoopFast) that the GO DE-9IM functions support.
const GO_DE9IM_FUNCS = Union{typeof(GO.contains), typeof(GO.within), typeof(GO.intersects), typeof(GO.disjoint), typeof(GO.touches), typeof(GO.crosses), typeof(GO.overlaps), typeof(GO.covers), typeof(GO.coveredby), typeof(GO.equals)}
const GO_DE9IM_DIRECT_FUNCS = ((GO.contains), (GO.within), (GO.intersects), (GO.disjoint), (!(GO.disjoint)), (GO.touches), (GO.crosses), (GO.overlaps), (GO.covers), (GO.coveredby), (GO.equals))
const GO_DE9IM_FUNC_TYPES = Union{typeof.(GO_DE9IM_DIRECT_FUNCS)..., typeof.((!).(GO_DE9IM_DIRECT_FUNCS))...}
# NestedLoopFast is the naive fallback method
FlexiJoins.supports_mode(::FlexiJoins.Mode.NestedLoopFast, ::FlexiJoins.ByPred{F}, datas) where F <: GO_DE9IM_FUNCS = true
FlexiJoins.supports_mode(::FlexiJoins.Mode.NestedLoopFast, ::FlexiJoins.ByPred{F}, datas) where F <: GO_DE9IM_FUNC_TYPES = true
# This method allows you to cache a tree, which we do by using an STRtree.
# TODO: wrap GO predicate functions in a `TreeJoiner` struct or something, to indicate that we want to use trees,
# since they can be slower in some situations.
FlexiJoins.supports_mode(::FlexiJoins.Mode.Tree, ::FlexiJoins.ByPred{F}, datas) where F <: GO_DE9IM_FUNCS = true
FlexiJoins.supports_mode(::FlexiJoins.Mode.Tree, ::FlexiJoins.ByPred{F}, datas) where F <: GO_DE9IM_FUNC_TYPES = true

# Nested loop support is simple, and needs no further support.
# However, for trees, we need to define how the tree is prepared and how it is used.
Expand All @@ -26,8 +27,8 @@ FlexiJoins.supports_mode(::FlexiJoins.Mode.Tree, ::FlexiJoins.ByPred{F}, datas)

# In theory, one could extract the tree from e.g a GeoPackage or some future GeoDataFrame.

FlexiJoins.prepare_for_join(::FlexiJoins.Mode.Tree, X, cond::FlexiJoins.ByPred{<: GO_DE9IM_FUNCS}) = (X, SortTileRecursiveTree.STRtree(map(cond.Rf, X)))
function FlexiJoins.findmatchix(::FlexiJoins.Mode.Tree, cond::FlexiJoins.ByPred{F}, ix_a, a, (B, tree)::Tuple, multi::typeof(identity)) where F <: GO_DE9IM_FUNCS
FlexiJoins.prepare_for_join(::FlexiJoins.Mode.Tree, X, cond::FlexiJoins.ByPred{<: GO_DE9IM_FUNC_TYPES}) = (X, SortTileRecursiveTree.STRtree(map(cond.Rf, X)))
function FlexiJoins.findmatchix(::FlexiJoins.Mode.Tree, cond::FlexiJoins.ByPred{F}, ix_a, a, (B, tree)::Tuple, multi::typeof(identity)) where F <: GO_DE9IM_FUNC_TYPES
idxs = SortTileRecursiveTree.query(tree, cond.Lf(a))
intersecting_idxs = filter!(idxs) do idx
cond.pred(a, cond.Rf(B[idx]))
Expand All @@ -42,6 +43,19 @@ FlexiJoins.swap_sides(::typeof(GO.within)) = GO.contains
FlexiJoins.swap_sides(::typeof(GO.coveredby)) = GO.covers
FlexiJoins.swap_sides(::typeof(GO.covers)) = GO.coveredby

FlexiJoins.swap_sides(::typeof(GO.intersects)) = !GO.disjoint
FlexiJoins.swap_sides(::typeof(!(GO.disjoint))) = GO.intersects
FlexiJoins.swap_sides(::typeof(GO.disjoint)) = !GO.intersects
FlexiJoins.swap_sides(::typeof(!(GO.intersects))) = GO.disjoint

FlexiJoins.swap_sides(::typeof(GO.touches)) = !GO.touches
FlexiJoins.swap_sides(::typeof(!(GO.touches))) = GO.touches

FlexiJoins.swap_sides(::typeof(GO.crosses)) = !GO.crosses
FlexiJoins.swap_sides(::typeof(!(GO.crosses))) = GO.crosses

FlexiJoins.swap_sides(::typeof(GO.equals)) = GO.equals

# That's a wrap, folks!

end
Expand Down
5 changes: 5 additions & 0 deletions test/extensions/flexijoins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,8 @@ points_df = DataFrame(geometry = points)

end

@testset "All Predicates" begin
for func in [GO.contains, GO.within, GO.intersects, GO.disjoint, GO.touches, GO.overlaps, GO.covers, GO.coveredby, GO.equals]
@test_nowarn FlexiJoins.innerjoin((poly_df, points_df), by_pred(:geometry, func, :geometry))
end
end

0 comments on commit dd2ea98

Please sign in to comment.