Skip to content

Commit

Permalink
Merge pull request #70 from ArneBinder/span_utils
Browse files Browse the repository at this point in the history
add span utils
  • Loading branch information
ArneBinder authored Feb 27, 2024
2 parents a357989 + 2ace3a1 commit cd1cb7e
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 0 deletions.
85 changes: 85 additions & 0 deletions src/pie_modules/utils/span.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Tuple


def are_nested(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> bool:
"""Check if two spans are nested. The spans are defined by their start and end indices.
Note that spans are considered to be nested if one is completely contained in the other,
including the case where they are identical.
"""
return (start_end[0] <= other_start_end[0] and start_end[1] >= other_start_end[1]) or (
other_start_end[0] <= start_end[0] and other_start_end[1] >= start_end[1]
)


def have_overlap(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> bool:
"""Check if two spans have an overlap. The spans are defined by their start and end indices.
Note that two spans that are touching each other are not considered to have an overlap. But two
spans that are nested, including the case where they are identical, are considered to have an
overlap.
"""

other_start_overlaps = start_end[0] <= other_start_end[0] < start_end[1]
other_end_overlaps = start_end[0] < other_start_end[1] <= start_end[1]
start_overlaps_other = other_start_end[0] <= start_end[0] < other_start_end[1]
end_overlaps_other = other_start_end[0] < start_end[1] <= other_start_end[1]
return other_start_overlaps or other_end_overlaps or start_overlaps_other or end_overlaps_other


def distance_center(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> float:
"""Calculate the distance between the centers of two spans.
The spans are defined by their start and end indices.
"""
center = (start_end[0] + start_end[1]) / 2
center_other = (other_start_end[0] + other_start_end[1]) / 2
return abs(center - center_other)


def distance_outer(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> float:
"""Calculate the distance between the outer edges of two spans. The spans are defined by their
start and end indices.
In case of an overlap, the covered area is considered to be the distance.
"""
_max = max(start_end[0], start_end[1], other_start_end[0], other_start_end[1])
_min = min(start_end[0], start_end[1], other_start_end[0], other_start_end[1])
return float(_max - _min)


def distance_inner(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> float:
"""Calculate the distance between the inner edges of two spans. The spans are defined by their
start and end indices.
In case of an overlap, the negative of the overlapping area is considered to be the distance.
"""
dist_start_other_end = abs(start_end[0] - other_start_end[1])
dist_end_other_start = abs(start_end[1] - other_start_end[0])
dist = float(min(dist_start_other_end, dist_end_other_start))
if not have_overlap(start_end, other_start_end):
return dist
else:
return -dist


def distance(
start_end: Tuple[int, int], other_start_end: Tuple[int, int], distance_type: str
) -> float:
"""Calculate the distance between two spans based on the given distance type.
Args:
start_end: a tuple of two integers representing the start and end index of the first span
other_start_end: a tuple of two integers representing the start and end index of the second span
distance_type: the type of distance to calculate. One of: center, inner, outer
"""
if distance_type == "center":
return distance_center(start_end, other_start_end)
elif distance_type == "inner":
return distance_inner(start_end, other_start_end)
elif distance_type == "outer":
return distance_outer(start_end, other_start_end)
else:
raise ValueError(
f"unknown distance_type={distance_type}. use one of: center, inner, outer"
)
130 changes: 130 additions & 0 deletions tests/utils/test_span.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import pytest

from pie_modules.utils.span import (
are_nested,
distance,
distance_center,
distance_inner,
distance_outer,
have_overlap,
)


def test_have_overlap():
# no overlap, not touching
assert not have_overlap((0, 1), (2, 3))
assert not have_overlap((2, 3), (0, 1))
# no overlap, touching
assert not have_overlap((0, 1), (1, 2))
assert not have_overlap((1, 2), (0, 1))
# overlap, not touching
assert have_overlap((0, 2), (1, 3))
assert have_overlap((1, 3), (0, 2))
# overlap, same start
assert have_overlap((0, 2), (0, 3))
assert have_overlap((0, 3), (0, 2))
# overlap, same end
assert have_overlap((0, 2), (1, 2))
assert have_overlap((1, 2), (0, 2))
# overlap, identical
assert have_overlap((0, 1), (0, 1))


def test_are_nested():
# no overlap, not touching
assert not are_nested((0, 1), (2, 3))
assert not are_nested((2, 3), (0, 1))
# no overlap, touching
assert not are_nested((0, 1), (1, 2))
assert not are_nested((1, 2), (0, 1))
# overlap, not touching
assert not are_nested((0, 2), (1, 3))
assert not are_nested((1, 3), (0, 2))
# overlap, same start
assert are_nested((0, 2), (0, 3))
assert are_nested((0, 3), (0, 2))
# overlap, same end
assert are_nested((0, 2), (1, 2))
assert are_nested((1, 2), (0, 2))
# overlap, identical
assert are_nested((0, 1), (0, 1))
# nested, not touching
assert are_nested((0, 3), (1, 2))
assert are_nested((1, 2), (0, 3))


def test_distance_center():
# no overlap, not touching
assert distance_center((0, 1), (2, 3)) == 2.0
assert distance_center((2, 3), (0, 1)) == 2.0
# no overlap, touching
assert distance_center((0, 1), (1, 2)) == 1.0
assert distance_center((1, 2), (0, 1)) == 1.0
# overlap, not touching
assert distance_center((0, 2), (1, 3)) == 1.0
assert distance_center((1, 3), (0, 2)) == 1.0
# overlap, same start
assert distance_center((0, 2), (0, 3)) == 0.5
assert distance_center((0, 3), (0, 2)) == 0.5
# overlap, same end
assert distance_center((0, 2), (1, 2)) == 0.5
assert distance_center((1, 2), (0, 2)) == 0.5
# overlap, identical
assert distance_center((0, 1), (0, 1)) == 0.0


def test_distance_inner():
# no overlap, not touching
assert distance_inner((0, 1), (2, 3)) == 1.0
assert distance_inner((2, 3), (0, 1)) == 1.0
# no overlap, touching
assert distance_inner((0, 1), (1, 2)) == 0.0
assert distance_inner((1, 2), (0, 1)) == 0.0
# overlap, not touching
assert distance_inner((0, 2), (1, 3)) == -1.0
assert distance_inner((1, 3), (0, 2)) == -1.0
# overlap, same start
assert distance_inner((0, 2), (0, 3)) == -2.0
assert distance_inner((0, 3), (0, 2)) == -2.0
# overlap, same end
assert distance_inner((0, 2), (1, 2)) == -1.0
assert distance_inner((1, 2), (0, 2)) == -1.0
# overlap, identical
assert distance_inner((0, 1), (0, 1)) == -1.0


def test_distance_outer():
# identical
assert distance_outer((0, 1), (0, 1)) == 1.0
# no overlap, not touching
assert distance_outer((0, 1), (2, 3)) == 3.0
assert distance_outer((2, 3), (0, 1)) == 3.0
# no overlap, touching
assert distance_outer((0, 1), (1, 2)) == 2.0
assert distance_outer((1, 2), (0, 1)) == 2.0
# overlap, not touching
assert distance_outer((0, 2), (1, 3)) == 3.0
assert distance_outer((1, 3), (0, 2)) == 3.0
# overlap, same start
assert distance_outer((0, 2), (0, 3)) == 3.0
assert distance_outer((0, 3), (0, 2)) == 3.0
# overlap, same end
assert distance_outer((0, 2), (1, 2)) == 2.0
assert distance_outer((1, 2), (0, 2)) == 2.0


@pytest.mark.parametrize(
"distance_type",
["outer", "center", "inner", "unknown"],
)
def test_distance(distance_type):
start_end = (0, 1)
other_start_end = (2, 3)
if distance_type != "unknown":
distance(start_end, other_start_end, distance_type)
else:
with pytest.raises(ValueError) as excinfo:
distance(start_end, other_start_end, distance_type)
assert (
str(excinfo.value) == "unknown distance_type=unknown. use one of: center, inner, outer"
)

0 comments on commit cd1cb7e

Please sign in to comment.