-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #70 from ArneBinder/span_utils
add span utils
- Loading branch information
Showing
2 changed files
with
215 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from typing import Tuple | ||
|
||
|
||
def are_nested(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> bool: | ||
"""Check if two spans are nested. The spans are defined by their start and end indices. | ||
Note that spans are considered to be nested if one is completely contained in the other, | ||
including the case where they are identical. | ||
""" | ||
return (start_end[0] <= other_start_end[0] and start_end[1] >= other_start_end[1]) or ( | ||
other_start_end[0] <= start_end[0] and other_start_end[1] >= start_end[1] | ||
) | ||
|
||
|
||
def have_overlap(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> bool: | ||
"""Check if two spans have an overlap. The spans are defined by their start and end indices. | ||
Note that two spans that are touching each other are not considered to have an overlap. But two | ||
spans that are nested, including the case where they are identical, are considered to have an | ||
overlap. | ||
""" | ||
|
||
other_start_overlaps = start_end[0] <= other_start_end[0] < start_end[1] | ||
other_end_overlaps = start_end[0] < other_start_end[1] <= start_end[1] | ||
start_overlaps_other = other_start_end[0] <= start_end[0] < other_start_end[1] | ||
end_overlaps_other = other_start_end[0] < start_end[1] <= other_start_end[1] | ||
return other_start_overlaps or other_end_overlaps or start_overlaps_other or end_overlaps_other | ||
|
||
|
||
def distance_center(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> float: | ||
"""Calculate the distance between the centers of two spans. | ||
The spans are defined by their start and end indices. | ||
""" | ||
center = (start_end[0] + start_end[1]) / 2 | ||
center_other = (other_start_end[0] + other_start_end[1]) / 2 | ||
return abs(center - center_other) | ||
|
||
|
||
def distance_outer(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> float: | ||
"""Calculate the distance between the outer edges of two spans. The spans are defined by their | ||
start and end indices. | ||
In case of an overlap, the covered area is considered to be the distance. | ||
""" | ||
_max = max(start_end[0], start_end[1], other_start_end[0], other_start_end[1]) | ||
_min = min(start_end[0], start_end[1], other_start_end[0], other_start_end[1]) | ||
return float(_max - _min) | ||
|
||
|
||
def distance_inner(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> float: | ||
"""Calculate the distance between the inner edges of two spans. The spans are defined by their | ||
start and end indices. | ||
In case of an overlap, the negative of the overlapping area is considered to be the distance. | ||
""" | ||
dist_start_other_end = abs(start_end[0] - other_start_end[1]) | ||
dist_end_other_start = abs(start_end[1] - other_start_end[0]) | ||
dist = float(min(dist_start_other_end, dist_end_other_start)) | ||
if not have_overlap(start_end, other_start_end): | ||
return dist | ||
else: | ||
return -dist | ||
|
||
|
||
def distance( | ||
start_end: Tuple[int, int], other_start_end: Tuple[int, int], distance_type: str | ||
) -> float: | ||
"""Calculate the distance between two spans based on the given distance type. | ||
Args: | ||
start_end: a tuple of two integers representing the start and end index of the first span | ||
other_start_end: a tuple of two integers representing the start and end index of the second span | ||
distance_type: the type of distance to calculate. One of: center, inner, outer | ||
""" | ||
if distance_type == "center": | ||
return distance_center(start_end, other_start_end) | ||
elif distance_type == "inner": | ||
return distance_inner(start_end, other_start_end) | ||
elif distance_type == "outer": | ||
return distance_outer(start_end, other_start_end) | ||
else: | ||
raise ValueError( | ||
f"unknown distance_type={distance_type}. use one of: center, inner, outer" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import pytest | ||
|
||
from pie_modules.utils.span import ( | ||
are_nested, | ||
distance, | ||
distance_center, | ||
distance_inner, | ||
distance_outer, | ||
have_overlap, | ||
) | ||
|
||
|
||
def test_have_overlap(): | ||
# no overlap, not touching | ||
assert not have_overlap((0, 1), (2, 3)) | ||
assert not have_overlap((2, 3), (0, 1)) | ||
# no overlap, touching | ||
assert not have_overlap((0, 1), (1, 2)) | ||
assert not have_overlap((1, 2), (0, 1)) | ||
# overlap, not touching | ||
assert have_overlap((0, 2), (1, 3)) | ||
assert have_overlap((1, 3), (0, 2)) | ||
# overlap, same start | ||
assert have_overlap((0, 2), (0, 3)) | ||
assert have_overlap((0, 3), (0, 2)) | ||
# overlap, same end | ||
assert have_overlap((0, 2), (1, 2)) | ||
assert have_overlap((1, 2), (0, 2)) | ||
# overlap, identical | ||
assert have_overlap((0, 1), (0, 1)) | ||
|
||
|
||
def test_are_nested(): | ||
# no overlap, not touching | ||
assert not are_nested((0, 1), (2, 3)) | ||
assert not are_nested((2, 3), (0, 1)) | ||
# no overlap, touching | ||
assert not are_nested((0, 1), (1, 2)) | ||
assert not are_nested((1, 2), (0, 1)) | ||
# overlap, not touching | ||
assert not are_nested((0, 2), (1, 3)) | ||
assert not are_nested((1, 3), (0, 2)) | ||
# overlap, same start | ||
assert are_nested((0, 2), (0, 3)) | ||
assert are_nested((0, 3), (0, 2)) | ||
# overlap, same end | ||
assert are_nested((0, 2), (1, 2)) | ||
assert are_nested((1, 2), (0, 2)) | ||
# overlap, identical | ||
assert are_nested((0, 1), (0, 1)) | ||
# nested, not touching | ||
assert are_nested((0, 3), (1, 2)) | ||
assert are_nested((1, 2), (0, 3)) | ||
|
||
|
||
def test_distance_center(): | ||
# no overlap, not touching | ||
assert distance_center((0, 1), (2, 3)) == 2.0 | ||
assert distance_center((2, 3), (0, 1)) == 2.0 | ||
# no overlap, touching | ||
assert distance_center((0, 1), (1, 2)) == 1.0 | ||
assert distance_center((1, 2), (0, 1)) == 1.0 | ||
# overlap, not touching | ||
assert distance_center((0, 2), (1, 3)) == 1.0 | ||
assert distance_center((1, 3), (0, 2)) == 1.0 | ||
# overlap, same start | ||
assert distance_center((0, 2), (0, 3)) == 0.5 | ||
assert distance_center((0, 3), (0, 2)) == 0.5 | ||
# overlap, same end | ||
assert distance_center((0, 2), (1, 2)) == 0.5 | ||
assert distance_center((1, 2), (0, 2)) == 0.5 | ||
# overlap, identical | ||
assert distance_center((0, 1), (0, 1)) == 0.0 | ||
|
||
|
||
def test_distance_inner(): | ||
# no overlap, not touching | ||
assert distance_inner((0, 1), (2, 3)) == 1.0 | ||
assert distance_inner((2, 3), (0, 1)) == 1.0 | ||
# no overlap, touching | ||
assert distance_inner((0, 1), (1, 2)) == 0.0 | ||
assert distance_inner((1, 2), (0, 1)) == 0.0 | ||
# overlap, not touching | ||
assert distance_inner((0, 2), (1, 3)) == -1.0 | ||
assert distance_inner((1, 3), (0, 2)) == -1.0 | ||
# overlap, same start | ||
assert distance_inner((0, 2), (0, 3)) == -2.0 | ||
assert distance_inner((0, 3), (0, 2)) == -2.0 | ||
# overlap, same end | ||
assert distance_inner((0, 2), (1, 2)) == -1.0 | ||
assert distance_inner((1, 2), (0, 2)) == -1.0 | ||
# overlap, identical | ||
assert distance_inner((0, 1), (0, 1)) == -1.0 | ||
|
||
|
||
def test_distance_outer(): | ||
# identical | ||
assert distance_outer((0, 1), (0, 1)) == 1.0 | ||
# no overlap, not touching | ||
assert distance_outer((0, 1), (2, 3)) == 3.0 | ||
assert distance_outer((2, 3), (0, 1)) == 3.0 | ||
# no overlap, touching | ||
assert distance_outer((0, 1), (1, 2)) == 2.0 | ||
assert distance_outer((1, 2), (0, 1)) == 2.0 | ||
# overlap, not touching | ||
assert distance_outer((0, 2), (1, 3)) == 3.0 | ||
assert distance_outer((1, 3), (0, 2)) == 3.0 | ||
# overlap, same start | ||
assert distance_outer((0, 2), (0, 3)) == 3.0 | ||
assert distance_outer((0, 3), (0, 2)) == 3.0 | ||
# overlap, same end | ||
assert distance_outer((0, 2), (1, 2)) == 2.0 | ||
assert distance_outer((1, 2), (0, 2)) == 2.0 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"distance_type", | ||
["outer", "center", "inner", "unknown"], | ||
) | ||
def test_distance(distance_type): | ||
start_end = (0, 1) | ||
other_start_end = (2, 3) | ||
if distance_type != "unknown": | ||
distance(start_end, other_start_end, distance_type) | ||
else: | ||
with pytest.raises(ValueError) as excinfo: | ||
distance(start_end, other_start_end, distance_type) | ||
assert ( | ||
str(excinfo.value) == "unknown distance_type=unknown. use one of: center, inner, outer" | ||
) |