Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add get_image for all DocItem, specialize for FloatingItem #67

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docling_core/types/doc/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,23 @@ def get_location_tokens(

return location

def get_image(self, doc: "DoclingDocument") -> Optional[PILImage.Image]:
"""Returns the image of this DocItem if the document stores page images."""
if not len(self.prov):
return None

page = doc.pages.get(self.prov[0].page_no)
if page is None or page.size is None or page.image is None:
return None

page_image = page.image.pil_image
crop_bbox = (
self.prov[0]
.bbox.to_top_left_origin(page_height=page.size.height)
.scaled(scale=page_image.height / page.size.height)
)
return page_image.crop(crop_bbox.as_tuple())


class TextItem(DocItem):
"""TextItem."""
Expand Down Expand Up @@ -633,6 +650,12 @@ def caption_text(self, doc: "DoclingDocument") -> str:
text += cap.resolve(doc).text
return text

def get_image(self, doc: "DoclingDocument") -> Optional[PILImage.Image]:
"""Returns the image corresponding to FloatingItem."""
if self.image is not None:
return self.image.pil_image
return super().get_image(doc=doc)


class PictureItem(FloatingItem):
"""PictureItem."""
Expand Down
128 changes: 128 additions & 0 deletions test/test_docling_doc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import deque
from unittest.mock import Mock

import pytest
import yaml
Expand All @@ -7,6 +8,7 @@

from docling_core.types.doc.document import (
CURRENT_VERSION,
BoundingBox,
DocItem,
DoclingDocument,
DocumentOrigin,
Expand All @@ -15,7 +17,9 @@
KeyValueItem,
ListItem,
PictureItem,
ProvenanceItem,
SectionHeaderItem,
Size,
TableCell,
TableData,
TableItem,
Expand Down Expand Up @@ -407,3 +411,127 @@ def test_version_doc():
comp_version = f"{major_split[0]}.{minor_split[0]}.{int(patch_split[0]) + 1}"
doc = DoclingDocument(name="Untitled 1", version=comp_version)
assert doc.version == CURRENT_VERSION


def test_docitem_get_image():
# Prepare the document
doc = DoclingDocument(name="Dummy")

page1_image = PILImage.new(mode="RGB", size=(200, 400), color=(0, 0, 0))
doc_item_image = PILImage.new(mode="RGB", size=(20, 40), color=(255, 0, 0))
page1_image.paste(doc_item_image, box=(20, 40))

doc.add_page( # With image
page_no=1,
size=Size(width=20, height=40),
image=ImageRef.from_pil(page1_image, dpi=72),
)
doc.add_page(page_no=2, size=Size(width=20, height=40), image=None) # Without image

# DocItem with no provenance
doc_item = DocItem(self_ref="#", label=DocItemLabel.TEXT, prov=[])
assert doc_item.get_image(doc=doc) is None

# DocItem on an invalid page
doc_item = DocItem(
self_ref="#",
label=DocItemLabel.TEXT,
prov=[ProvenanceItem(page_no=3, bbox=Mock(spec=BoundingBox), charspan=(1, 2))],
)
assert doc_item.get_image(doc=doc) is None

# DocItem on a page without page image
doc_item = DocItem(
self_ref="#",
label=DocItemLabel.TEXT,
prov=[ProvenanceItem(page_no=2, bbox=Mock(spec=BoundingBox), charspan=(1, 2))],
)
assert doc_item.get_image(doc=doc) is None

# DocItem on a page with valid page image
doc_item = DocItem(
self_ref="#",
label=DocItemLabel.TEXT,
prov=[
ProvenanceItem(
page_no=1, bbox=BoundingBox(l=2, t=4, r=4, b=8), charspan=(1, 2)
)
],
)
returned_doc_item_image = doc_item.get_image(doc=doc)
assert (
returned_doc_item_image is not None
and returned_doc_item_image.tobytes() == doc_item_image.tobytes()
)


def test_floatingitem_get_image():
# Prepare the document
doc = DoclingDocument(name="Dummy")

page1_image = PILImage.new(mode="RGB", size=(200, 400), color=(0, 0, 0))
floating_item_image = PILImage.new(mode="RGB", size=(20, 40), color=(255, 0, 0))
page1_image.paste(floating_item_image, box=(20, 40))

doc.add_page( # With image
page_no=1,
size=Size(width=20, height=40),
image=ImageRef.from_pil(page1_image, dpi=72),
)
doc.add_page(page_no=2, size=Size(width=20, height=40), image=None) # Without image

# FloatingItem with explicit image different from image based on provenance
new_image = PILImage.new(mode="RGB", size=(40, 80), color=(0, 255, 0))
floating_item = FloatingItem(
self_ref="#",
label=DocItemLabel.PICTURE,
prov=[
ProvenanceItem(
page_no=1, bbox=BoundingBox(l=2, t=4, r=6, b=12), charspan=(1, 2)
)
],
image=ImageRef.from_pil(image=new_image, dpi=72),
)
retured_image = floating_item.get_image(doc=doc)
assert retured_image is not None and retured_image.tobytes() == new_image.tobytes()

# FloatingItem without explicit image and no provenance
floating_item = FloatingItem(
self_ref="#", label=DocItemLabel.PICTURE, prov=[], image=None
)
assert floating_item.get_image(doc=doc) is None

# FloatingItem without explicit image on invalid page
floating_item = FloatingItem(
self_ref="#",
label=DocItemLabel.PICTURE,
prov=[ProvenanceItem(page_no=3, bbox=Mock(spec=BoundingBox), charspan=(1, 2))],
image=None,
)
assert floating_item.get_image(doc=doc) is None

# FloatingItem without explicit image on a page without page image
floating_item = FloatingItem(
self_ref="#",
label=DocItemLabel.PICTURE,
prov=[ProvenanceItem(page_no=2, bbox=Mock(spec=BoundingBox), charspan=(1, 2))],
image=None,
)
assert floating_item.get_image(doc=doc) is None

# FloatingItem without explicit image on a page with page image
floating_item = FloatingItem(
self_ref="#",
label=DocItemLabel.PICTURE,
prov=[
ProvenanceItem(
page_no=1, bbox=BoundingBox(l=2, t=4, r=4, b=8), charspan=(1, 2)
)
],
image=None,
)
retured_image = floating_item.get_image(doc=doc)
assert (
retured_image is not None
and retured_image.tobytes() == floating_item_image.tobytes()
)
Loading