Skip to content

Commit

Permalink
Initial chunk JSON output
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Nov 19, 2024
1 parent f89089c commit 59b6224
Show file tree
Hide file tree
Showing 14 changed files with 199 additions and 44 deletions.
25 changes: 19 additions & 6 deletions marker/v2/converters/pdf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from marker.v2.providers.pdf import PdfProvider
import os

from marker.v2.renderers.json import JSONRenderer

os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning

import tempfile
Expand Down Expand Up @@ -30,7 +32,7 @@
class PdfConverter(BaseConverter):
override_map: Dict[BlockTypes, Type[Block]] = defaultdict()

def __init__(self, config=None):
def __init__(self, config=None, output_format="markdown"):
super().__init__(config)

for block_type, override_block_type in self.override_map.items():
Expand All @@ -42,6 +44,11 @@ def __init__(self, config=None):
self.table_rec_model = setup_table_rec_model()
self.detection_model = setup_detection_model()

if output_format == "markdown":
self.renderer = MarkdownRenderer(self.config)
elif output_format == "json":
self.renderer = JSONRenderer(self.config)

def __call__(self, filepath: str):
pdf_provider = PdfProvider(filepath, self.config)

Expand All @@ -62,18 +69,18 @@ def __call__(self, filepath: str):
debug_processor = DebugProcessor(self.config)
debug_processor(document)

renderer = MarkdownRenderer(self.config)
return renderer(document)
return self.renderer(document)


@click.command()
@click.option("--output", type=click.Path(exists=False), required=False, default="temp")
@click.option("--fname", type=str, default="adversarial.pdf")
@click.option("--debug", is_flag=True)
def main(output: str, fname: str, debug: bool):
@click.option("--output_format", type=click.Choice(["markdown", "json"]), default="markdown")
def main(output: str, fname: str, debug: bool, output_format: str):
dataset = datasets.load_dataset("datalab-to/pdfs", split="train")
idx = dataset['filename'].index(fname)
out_filename = fname.rsplit(".", 1)[0] + ".md"
fname_base = fname.rsplit(".", 1)[0]
os.makedirs(output, exist_ok=True)

config = {}
Expand All @@ -86,14 +93,20 @@ def main(output: str, fname: str, debug: bool):
temp_pdf.write(dataset['pdf'][idx])
temp_pdf.flush()

converter = PdfConverter()
converter = PdfConverter(config=config, output_format=output_format)
rendered = converter(temp_pdf.name)

if output_format == "markdown":
out_filename = f"{fname_base}.md"
with open(os.path.join(output, out_filename), "w+") as f:
f.write(rendered.markdown)

for img_name, img in rendered.images.items():
img.save(os.path.join(output, img_name), "PNG")
elif output_format == "json":
out_filename = f"{fname_base}.json"
with open(os.path.join(output, out_filename), "w+") as f:
f.write(rendered.model_dump_json(indent=2))


if __name__ == "__main__":
Expand Down
28 changes: 28 additions & 0 deletions marker/v2/renderers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Optional

from pydantic import BaseModel
Expand All @@ -15,3 +16,30 @@ def __init__(self, config: Optional[BaseModel | dict] = None):
def __call__(self, document):
# Children are in reading order
raise NotImplementedError

@staticmethod
def extract_image(document, image_id):
image_block = document.get_block(image_id)
page = document.get_page(image_block.page_id)
page_img = page.highres_image
image_box = image_block.polygon.rescale(page.polygon.size, page_img.size)
cropped = page_img.crop(image_box.bbox)
return cropped

@staticmethod
def merge_consecutive_tags(html, tag):
if not html:
return html

def replace_whitespace(match):
return match.group(1)

pattern = fr'</{tag}>(\s*)<{tag}>'

while True:
new_merged = re.sub(pattern, replace_whitespace, html)
if new_merged == html:
break
html = new_merged

return html
22 changes: 2 additions & 20 deletions marker/v2/renderers/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,6 @@ class HTMLOutput(BaseModel):
images: dict


def merge_consecutive_tags(html, tag):
if not html:
return html

def replace_whitespace(match):
return match.group(1)

pattern = fr'</{tag}>(\s*)<{tag}>'

while True:
new_merged = re.sub(pattern, replace_whitespace, html)
if new_merged == html:
break
html = new_merged

return html


class HTMLRenderer(BaseRenderer):
remove_blocks: list = [BlockTypes.PageHeader, BlockTypes.PageFooter]
image_blocks: list = [BlockTypes.Picture, BlockTypes.Figure]
Expand Down Expand Up @@ -82,8 +64,8 @@ def extract_html(self, document, document_output, level=0):

output = str(soup)
if level == 0:
output = merge_consecutive_tags(output, 'b')
output = merge_consecutive_tags(output, 'i')
output = self.merge_consecutive_tags(output, 'b')
output = self.merge_consecutive_tags(output, 'i')

return output, images

Expand Down
103 changes: 103 additions & 0 deletions marker/v2/renderers/json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations

import base64
import io
from typing import List, Dict

from bs4 import BeautifulSoup
from pydantic import BaseModel

from marker.v2.schema.blocks import Block
from marker.v2.renderers import BaseRenderer
from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import BlockId
from marker.v2.schema.registry import get_block_class


class JSONBlockOutput(BaseModel):
id: str
block_type: str
html: str
polygon: List[List[float]]
children: List[JSONBlockOutput] | None = None
section_hierarchy: Dict[int, str] | None = None
images: dict | None = None


class JSONOutput(BaseModel):
children: List[JSONBlockOutput]
block_type: BlockTypes = BlockTypes.Document


def reformat_section_hierarchy(section_hierarchy):
new_section_hierarchy = {}
for key, value in section_hierarchy.items():
new_section_hierarchy[key] = str(value)
return new_section_hierarchy


class JSONRenderer(BaseRenderer):
image_blocks: list = [BlockTypes.Picture, BlockTypes.Figure]
page_blocks: list = [BlockTypes.Page]

def extract_json(self, document, block_output):
cls = get_block_class(block_output.id.block_type)
if cls.__base__ == Block:
html, images = self.extract_html(document, block_output)
return JSONBlockOutput(
html=html,
polygon=block_output.polygon.polygon,
id=str(block_output.id),
block_type=str(block_output.id.block_type),
images=images,
section_hierarchy=reformat_section_hierarchy(block_output.section_hierarchy)
)
else:
children = []
for child in block_output.children:
child_output = self.extract_json(document, child)
children.append(child_output)

return JSONBlockOutput(
html=block_output.html,
polygon=block_output.polygon.polygon,
id=str(block_output.id),
block_type=str(block_output.id.block_type),
children=children,
section_hierarchy=reformat_section_hierarchy(block_output.section_hierarchy)
)

def extract_html(self, document, block_output):
soup = BeautifulSoup(block_output.html, 'html.parser')

content_refs = soup.find_all('content-ref')
ref_block_id = None
images = {}
for ref in content_refs:
src = ref.get('src')
sub_images = {}
for item in block_output.children:
if item.id == src:
content, sub_images = self.extract_html(document, item)
ref_block_id: BlockId = item.id
break

if ref_block_id.block_type in self.image_blocks:
image = self.extract_image(document, ref_block_id)
image_buffer = io.BytesIO()
image.save(image_buffer, format='PNG')
images[ref_block_id] = base64.b64encode(image_buffer.getvalue()).decode('utf-8')
else:
images.update(sub_images)
ref.replace_with(BeautifulSoup(content, 'html.parser'))

return str(soup), images

def __call__(self, document) -> JSONOutput:
document_output = document.render()
json_output = []
for page_output in document_output.children:
json_output.append(self.extract_json(document, page_output))
return JSONOutput(
children=json_output,
)
26 changes: 22 additions & 4 deletions marker/v2/schema/blocks/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, List, Literal, Optional
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Dict

from pydantic import BaseModel, ConfigDict, field_validator

Expand All @@ -16,6 +16,7 @@ class BlockOutput(BaseModel):
polygon: PolygonBox
id: BlockId
children: List[BlockOutput] | None = None
section_hierarchy: Dict[int, BlockId] | None = None


class BlockId(BaseModel):
Expand Down Expand Up @@ -115,16 +116,33 @@ def assemble_html(self, child_blocks, parent_structure=None):
template += f"<content-ref src='{c.id}'></content-ref>"
return template

def render(self, document, parent_structure):
def assign_section_hierarchy(self, section_hierarchy):
if self.block_type == BlockTypes.SectionHeader and self.heading_level:
levels = list(section_hierarchy.keys())
for level in levels:
if level >= self.heading_level:
del section_hierarchy[level]
section_hierarchy[self.heading_level] = self.id

return section_hierarchy

def render(self, document, parent_structure, section_hierarchy=None):
child_content = []
if section_hierarchy is None:
section_hierarchy = {}
section_hierarchy = self.assign_section_hierarchy(section_hierarchy)

if self.structure is not None and len(self.structure) > 0:
for block_id in self.structure:
block = document.get_block(block_id)
child_content.append(block.render(document, self.structure))
rendered = block.render(document, self.structure, section_hierarchy)
section_hierarchy = rendered.section_hierarchy # Update the section hierarchy from the peer blocks
child_content.append(rendered)

return BlockOutput(
html=self.assemble_html(child_content, parent_structure),
polygon=self.polygon,
id=self.id,
children=child_content
children=child_content,
section_hierarchy=section_hierarchy
)
5 changes: 4 additions & 1 deletion marker/v2/schema/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ def assemble_html(self, child_blocks: List[Block]):

def render(self):
child_content = []
section_hierarchy = None
for page in self.pages:
child_content.append(page.render(self, None))
rendered = page.render(self, None, section_hierarchy)
section_hierarchy = rendered.section_hierarchy
child_content.append(rendered)

return DocumentOutput(
children=child_content,
Expand Down
5 changes: 5 additions & 0 deletions marker/v2/schema/groups/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from marker.v2.schema.blocks import Block


class Group(Block):
pass
4 changes: 2 additions & 2 deletions marker/v2/schema/groups/figure.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import Block
from marker.v2.schema.groups.base import Group


class FigureGroup(Block):
class FigureGroup(Group):
block_type: BlockTypes = BlockTypes.FigureGroup
4 changes: 2 additions & 2 deletions marker/v2/schema/groups/list.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import Block
from marker.v2.schema.groups.base import Group


class ListGroup(Block):
class ListGroup(Group):
block_type: BlockTypes = BlockTypes.ListGroup

def assemble_html(self, child_blocks, parent_structure):
Expand Down
3 changes: 2 additions & 1 deletion marker/v2/schema/groups/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import Block, BlockId
from marker.v2.schema.groups.base import Group
from marker.v2.schema.polygon import PolygonBox
from marker.v2.schema.text.line import Line
from marker.v2.schema.text.span import Span


class PageGroup(Block):
class PageGroup(Group):
block_type: BlockTypes = BlockTypes.Page
lowres_image: Image.Image | None = None
highres_image: Image.Image | None = None
Expand Down
4 changes: 2 additions & 2 deletions marker/v2/schema/groups/picture.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import Block
from marker.v2.schema.groups.base import Group


class PictureGroup(Block):
class PictureGroup(Group):
block_type: BlockTypes = BlockTypes.PictureGroup
5 changes: 3 additions & 2 deletions marker/v2/schema/groups/table.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import Block
from marker.v2.schema.groups.base import Group

class TableGroup(Block):

class TableGroup(Group):
block_type: BlockTypes = BlockTypes.TableGroup
7 changes: 4 additions & 3 deletions marker/v2/schema/text/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,17 @@ def assemble_html(self, document, child_blocks, parent_structure):
template = strip_trailing_hyphens(raw_text, next_line_raw_text, template)
return template

def render(self, document, parent_structure):
def render(self, document, parent_structure, section_hierarchy=None):
child_content = []
if self.structure is not None and len(self.structure) > 0:
for block_id in self.structure:
block = document.get_block(block_id)
child_content.append(block.render(document, parent_structure))
child_content.append(block.render(document, parent_structure, section_hierarchy))

return BlockOutput(
html=self.assemble_html(document, child_content, parent_structure),
polygon=self.polygon,
id=self.id,
children=[]
children=[],
section_hierarchy=section_hierarchy
)
Loading

0 comments on commit 59b6224

Please sign in to comment.