Skip to content

Commit

Permalink
test: add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
christinestraub committed Aug 4, 2024
1 parent 051eca8 commit 06b7f43
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions test_unstructured/partition/pdf_image/test_inference_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.inference.elements import TextRegion, ImageTextRegion
from unstructured_inference.inference.layoutelement import LayoutElement

from unstructured.documents.elements import ElementType
from unstructured.partition.pdf_image.inference_utils import (
build_layout_elements_from_ocr_regions,
merge_text_regions,
merge_text_regions, merge_embedded_overlapping_regions,
)


Expand Down Expand Up @@ -35,3 +35,45 @@ def test_build_layout_elements_from_ocr_regions(mock_embedded_text_regions):

elements = build_layout_elements_from_ocr_regions(mock_embedded_text_regions)
assert elements == expected


def test_merge_embedded_overlapping_regions():
# Create some test regions
region1 = TextRegion.from_coords(10, 10, 30, 30, "Hello")
region2 = TextRegion.from_coords(15, 15, 25, 25, " World")
region3 = TextRegion.from_coords(50, 50, 80, 80, "Not overlapping")
image_region = ImageTextRegion.from_coords(90, 90, 100, 100)
expected_merged_region = TextRegion.from_coords(10, 10, 30, 30, "Hello World")

# Test merging
regions = [region1, region2, region3, image_region]
merged_regions = merge_embedded_overlapping_regions(regions)

assert len(merged_regions) == 3
# Check the merged region
assert merged_regions[0] == expected_merged_region
# Check the non-overlapping region
assert merged_regions[1] == region3


def test_merge_embedded_overlapping_regions_no_overlap():
region1 = TextRegion.from_coords(10, 10, 20, 20, "Text 1")
region2 = TextRegion.from_coords(30, 30, 40, 40, "Text 2")

regions = [region1, region2]
merged_regions = merge_embedded_overlapping_regions(regions)

assert len(merged_regions) == 2
assert merged_regions[0].text == "Text 1"
assert merged_regions[1].text == "Text 2"


def test_merge_embedded_overlapping_regions_image_text_region_handling():
regions = [
ImageTextRegion.from_coords(0, 0, 10, 10, "Image A"),
ImageTextRegion.from_coords(20, 20, 30, 30, "Image B"),
]
merged = merge_embedded_overlapping_regions(regions)
assert len(merged) == 2
assert isinstance(merged[0], ImageTextRegion)
assert isinstance(merged[1], ImageTextRegion)

0 comments on commit 06b7f43

Please sign in to comment.