Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PNG Bug #83

Merged
merged 3 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added src/cookbooks/images.ipynb
Empty file.
116 changes: 113 additions & 3 deletions src/openparse/processing/basic_transforms.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
import base64
import io
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Dict, List, Literal
from typing import Dict, List, Literal, Type, TypeVar

from openparse.schemas import Bbox, Node, TextElement
from PIL import Image

from openparse.schemas import Bbox, ImageElement, Node, TableElement, TextElement

E = TypeVar("E", TextElement, ImageElement, TableElement)


def get_elements_of_type(nodes: List[Node], element_type: Type[E]) -> List[E]:
elements: List[E] = []
for node in nodes:
for element in node.elements:
if isinstance(element, element_type):
elements.append(element)
return elements


class ProcessingStep(ABC):
Expand All @@ -14,6 +29,96 @@ def process(self, nodes: List[Node]) -> List[Node]:
raise NotImplementedError("Subclasses must implement this method.")


class CombineSlicedImages(ProcessingStep):
"""
PDF will slice images into multiple pieces if they are too large. This combines them back together.
"""

def _combine_images_in_group(
self, image_elements: List[ImageElement]
) -> ImageElement:
"""Combine a list of ImageElements into a single ImageElement."""
if not image_elements:
raise ValueError("No images to combine.")

images = []
for node in image_elements:
image_data = base64.b64decode(node.image)
image = Image.open(io.BytesIO(image_data))
# image = image.rotate(180)
images.append(image)

# Determine the width and total height of the final image
width = max(img.width for img in images)
total_height = sum(img.height for img in images)

# Create a new blank image
new_image = Image.new("RGB", (width, total_height))

# Paste images one below the other
y_offset = 0
for img in images:
new_image.paste(img, (0, y_offset))
y_offset += img.height

# Save or encode the final image
buffered = io.BytesIO()
new_image.save(buffered, format="PNG")
final_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

return ImageElement(
bbox=image_elements[0].bbox,
image=final_base64,
image_mimetype="image/png",
text="",
)

def _group_overlapping_images(
self, image_elements: List[ImageElement], buffer: float = 1.0
) -> List[List[ImageElement]]:
"""Group images that overlap or are adjacent."""
groups = []
used = set()

for i, elem1 in enumerate(image_elements):
if i in used:
continue
group = [elem1]
used.add(i)
queue = [elem1]
while queue:
current = queue.pop()
for j, elem2 in enumerate(image_elements):
if j in used:
continue
if current.overlaps(elem2, buffer=buffer):
group.append(elem2)
used.add(j)
queue.append(elem2)
groups.append(group)
return groups

def process(self, nodes: List[Node]) -> List[Node]:
nodes_by_page: Dict[int, List[Node]] = defaultdict(list)
for node in nodes:
pages = {element.bbox.page for element in node.elements}
for page in pages:
nodes_by_page[page].append(node)

new_nodes = []
for page, page_nodes in nodes_by_page.items():
image_nodes = [e for e in page_nodes if e.variant == {"image"}]
if image_nodes:
image_elements = get_elements_of_type(image_nodes, ImageElement)
text_elements = get_elements_of_type(page_nodes, TextElement)

combined_image = self._combine_images_in_group(image_elements)
new_nodes.append(Node(elements=(combined_image, *text_elements)))
else:
new_nodes.extend(page_nodes)
return new_nodes


class RemoveTextInsideTables(ProcessingStep):
"""
If we're using the table extraction pipeline, we need to remove text that is inside tables to avoid duplication.
Expand Down Expand Up @@ -162,7 +267,12 @@ def __init__(self, min_tokens: int):
self.min_tokens = min_tokens

def process(self, nodes: List[Node]) -> List[Node]:
return [node for node in nodes if node.tokens >= self.min_tokens]
res = []
for node in nodes:
if node.tokens <= self.min_tokens and "image" not in node.variant:
continue
res.append(node)
return res


class CombineNodesSpatially(ProcessingStep):
Expand Down
3 changes: 3 additions & 0 deletions src/openparse/processing/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
CombineBullets,
CombineHeadingsWithClosestText,
CombineNodesSpatially,
CombineSlicedImages,
ProcessingStep,
RemoveFullPageStubs,
RemoveMetadataElements,
Expand Down Expand Up @@ -69,6 +70,7 @@ class BasicIngestionPipeline(IngestionPipeline):
def __init__(self):
self.transformations = [
RemoveTextInsideTables(),
CombineSlicedImages(),
RemoveFullPageStubs(max_area_pct=0.35),
# mostly aimed at combining bullets and weird formatting
CombineNodesSpatially(
Expand Down Expand Up @@ -106,6 +108,7 @@ def __init__(

self.transformations = [
RemoveTextInsideTables(),
CombineSlicedImages(),
RemoveFullPageStubs(max_area_pct=0.35),
# mostly aimed at combining bullets and weird formatting
CombineNodesSpatially(
Expand Down
29 changes: 26 additions & 3 deletions src/openparse/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ class ImageElement(BaseModel):
def embed_text(self) -> str:
if self._embed_text:
return self._embed_text

return self.text

@cached_property
Expand All @@ -381,9 +380,20 @@ def is_at_similar_height(
error_margin: float = 1,
) -> bool:
y_distance = abs(self.bbox.y1 - other.bbox.y1)

return y_distance <= error_margin

def overlaps(self, other: "ImageElement", buffer: float = 1.0) -> bool:
"""Check if this image overlaps or is adjacent to another image, considering a buffer."""
if self.bbox.page != other.bbox.page:
return False

return not (
self.bbox.x1 + buffer < other.bbox.x0 - buffer
or self.bbox.x0 - buffer > other.bbox.x1 + buffer
or self.bbox.y1 + buffer < other.bbox.y0 - buffer
or self.bbox.y0 - buffer > other.bbox.y1 + buffer
)


#############
### NODES ###
Expand Down Expand Up @@ -641,7 +651,20 @@ def _repr_markdown_(self):
"""
When called in a Jupyter environment, this will display the node as Markdown, which Jupyter will then render as HTML.
"""
return self.text
markdown_parts = []
for element in self.elements:
if element.variant == NodeVariant.TEXT:
markdown_parts.append(element.text)
elif element.variant == NodeVariant.IMAGE:
image_data = element.image
mime_type = element.image_mimetype
if mime_type == "unknown":
mime_type = "image/png"
markdown_image = f"![Image](data:{mime_type};base64,{image_data})"
markdown_parts.append(markdown_image)
elif element.variant == NodeVariant.TABLE:
markdown_parts.append(element.text)
return "\n\n".join(markdown_parts)

def __add__(self, other: "Node") -> "Node":
"""
Expand Down
Loading