diff --git a/.gitignore b/.gitignore index 0c6bc443..36d46902 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ report.json benchmark_data debug_data temp.md +temp # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/marker/v2/converters/pdf.py b/marker/v2/converters/pdf.py index c74854b1..4a7d4777 100644 --- a/marker/v2/converters/pdf.py +++ b/marker/v2/converters/pdf.py @@ -1,8 +1,13 @@ +<<<<<<< HEAD +import os +======= from marker.v2.providers.pdf import PdfProvider +>>>>>>> origin/v2 import tempfile from typing import List, Optional +import click import datasets from pydantic import BaseModel @@ -46,9 +51,14 @@ def __call__(self, filepath: str, page_range: List[int] | None = None): return renderer(document) -if __name__ == "__main__": +@click.command() +@click.option("--output", type=click.Path(exists=False), required=False, default="temp") +@click.option("--fname", type=str, default="adversarial.pdf") +def main(output: str, fname: str): dataset = datasets.load_dataset("datalab-to/pdfs", split="train") - idx = dataset['filename'].index('adversarial.pdf') + idx = dataset['filename'].index(fname) + out_filename = fname.rsplit(".", 1)[0] + ".md" + os.makedirs(output, exist_ok=True) with tempfile.NamedTemporaryFile(suffix=".pdf") as temp_pdf: temp_pdf.write(dataset['pdf'][idx]) @@ -57,5 +67,12 @@ def __call__(self, filepath: str, page_range: List[int] | None = None): converter = PdfConverter() rendered = converter(temp_pdf.name) - with open("temp.md", "w+") as f: - f.write(rendered) + with open(os.path.join(output, out_filename), "w+") as f: + f.write(rendered.markdown) + + for img_name, img in rendered.images.items(): + img.save(os.path.join(output, img_name), "PNG") + + +if __name__ == "__main__": + main() diff --git a/marker/v2/processors/__init__.py b/marker/v2/processors/__init__.py index 32d436e4..27947721 100644 --- a/marker/v2/processors/__init__.py +++ b/marker/v2/processors/__init__.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple from pydantic import BaseModel @@ -8,7 +8,7 @@ class BaseProcessor: - block_type: BlockTypes | None = None # What block type this processor is responsible for + block_types: Tuple[str] | None = None # What block types this processor is responsible for def __init__(self, config: Optional[BaseModel | dict] = None): assign_config(self, config) diff --git a/marker/v2/processors/equation.py b/marker/v2/processors/equation.py index d7555796..14ec0097 100644 --- a/marker/v2/processors/equation.py +++ b/marker/v2/processors/equation.py @@ -11,7 +11,7 @@ class EquationProcessor(BaseProcessor): - block_type = BlockTypes.Equation + block_types = (BlockTypes.Equation, ) model_max_length = 384 batch_size = None token_buffer = 256 @@ -26,7 +26,7 @@ def __call__(self, document: Document): for page in document.pages: for block in page.children: - if block.block_type != self.block_type: + if block.block_type not in self.block_types: continue image_poly = block.polygon.rescale((page.polygon.width, page.polygon.height), page.lowres_image.size) image = page.lowres_image.crop(image_poly.bbox).convert("RGB") diff --git a/marker/v2/processors/table.py b/marker/v2/processors/table.py index 6c330aac..67084e77 100644 --- a/marker/v2/processors/table.py +++ b/marker/v2/processors/table.py @@ -12,7 +12,7 @@ class TableProcessor(BaseProcessor): - block_type = BlockTypes.Table + block_types = (BlockTypes.Table, BlockTypes.TableOfContents, BlockTypes.Form) detect_boxes = False detector_batch_size = None table_rec_batch_size = None @@ -31,7 +31,7 @@ def __call__(self, document: Document): table_data = [] for page in document.pages: for block in page.children: - if block.block_type != self.block_type: + if block.block_type not in self.block_types: continue image_poly = block.polygon.rescale((page.polygon.width, page.polygon.height), page.highres_image.size) diff --git a/marker/v2/renderers/__init__.py b/marker/v2/renderers/__init__.py index 7f6297c5..385df6d5 100644 --- a/marker/v2/renderers/__init__.py +++ b/marker/v2/renderers/__init__.py @@ -5,6 +5,7 @@ from marker.v2.schema import BlockTypes + class BaseRenderer: block_type: BlockTypes | None = None diff --git a/marker/v2/renderers/html.py b/marker/v2/renderers/html.py index c2f37431..06b82b6b 100644 --- a/marker/v2/renderers/html.py +++ b/marker/v2/renderers/html.py @@ -1,33 +1,85 @@ +import re + from bs4 import BeautifulSoup +from pydantic import BaseModel + from marker.v2.renderers import BaseRenderer from marker.v2.schema import BlockTypes +from marker.v2.schema.blocks import BlockId + + +class HTMLOutput(BaseModel): + html: str + images: dict + + +def merge_consecutive_tags(html, tag): + if not html: + return html + + def replace_whitespace(match): + return match.group(1) + + pattern = fr'(\s*)<{tag}>' + + while True: + new_merged = re.sub(pattern, replace_whitespace, html) + if new_merged == html: + break + html = new_merged + + return html class HTMLRenderer(BaseRenderer): remove_blocks: list = [BlockTypes.PageHeader, BlockTypes.PageFooter] image_blocks: list = [BlockTypes.Picture, BlockTypes.Figure] - def extract_html(self, document, document_output): + def extract_image(self, document, image_id): + image_block = document.get_block(image_id) + page = document.get_page(image_block.page_id) + page_img = page.highres_image + image_box = image_block.polygon.rescale(page.polygon.size, page_img.size) + cropped = page_img.crop(image_box.bbox) + return cropped + + def extract_html(self, document, document_output, level=0): soup = BeautifulSoup(document_output.html, 'html.parser') content_refs = soup.find_all('content-ref') - ref_block_type = None + ref_block_id = None + images = {} for ref in content_refs: src = ref.get('src') + sub_images = {} for item in document_output.children: if item.id == src: - content = self.extract_html(document, item) - ref_block_type = item.id.block_type + content, sub_images = self.extract_html(document, item, level + 1) + ref_block_id: BlockId = item.id break - if ref_block_type in self.remove_blocks: + if ref_block_id.block_type in self.remove_blocks: ref.replace_with('') + elif ref_block_id.block_type in self.image_blocks: + image = self.extract_image(document, ref_block_id) + image_name = f"{ref_block_id.to_path()}.png" + images[image_name] = image + ref.replace_with(BeautifulSoup(f"

", 'html.parser')) else: - ref.replace_with(BeautifulSoup(f"
{content}
", 'html.parser')) + images.update(sub_images) + ref.replace_with(BeautifulSoup(f"{content}", 'html.parser')) + + output = str(soup) + if level == 0: + output = merge_consecutive_tags(output, 'b') + output = merge_consecutive_tags(output, 'i') - return str(soup) + return output, images - def __call__(self, document): + def __call__(self, document) -> HTMLOutput: document_output = document.render() - full_html = self.extract_html(document, document_output) - return full_html + full_html, images = self.extract_html(document, document_output) + return HTMLOutput( + html=full_html, + images=images, + ) diff --git a/marker/v2/renderers/markdown.py b/marker/v2/renderers/markdown.py index 878a20c4..025e3e95 100644 --- a/marker/v2/renderers/markdown.py +++ b/marker/v2/renderers/markdown.py @@ -1,16 +1,31 @@ -from markdownify import markdownify +from markdownify import markdownify, MarkdownConverter +from pydantic import BaseModel + from marker.v2.renderers.html import HTMLRenderer from marker.v2.schema.document import Document +class Markdownify(MarkdownConverter): + pass + + +class MarkdownOutput(BaseModel): + markdown: str + images: dict + + class MarkdownRenderer(HTMLRenderer): - def __call__(self, document: Document): + def __call__(self, document: Document) -> MarkdownOutput: document_output = document.render() - full_html = self.extract_html(document, document_output) - return markdownify( - full_html, + full_html, images = self.extract_html(document, document_output) + md_cls = Markdownify( heading_style="ATX", bullets="-", escape_misc=False, escape_underscores=False ) + markdown = md_cls.convert(full_html) + return MarkdownOutput( + markdown=markdown, + images=images + ) diff --git a/marker/v2/schema/blocks/base.py b/marker/v2/schema/blocks/base.py index acbba983..b8c5dad6 100644 --- a/marker/v2/schema/blocks/base.py +++ b/marker/v2/schema/blocks/base.py @@ -48,6 +48,9 @@ def validate_block_type(cls, v): raise ValueError(f"Invalid block type: {v}") return v + def to_path(self): + return str(self).replace('/', '_') + class Block(BaseModel): polygon: PolygonBox diff --git a/marker/v2/schema/blocks/equation.py b/marker/v2/schema/blocks/equation.py index f3c577e0..22441d70 100644 --- a/marker/v2/schema/blocks/equation.py +++ b/marker/v2/schema/blocks/equation.py @@ -7,4 +7,4 @@ class Equation(Block): latex: str | None = None def assemble_html(self, child_blocks, parent_structure=None): - return f"
{self.latex}
" + return f"

{self.latex}

" diff --git a/marker/v2/schema/blocks/figure.py b/marker/v2/schema/blocks/figure.py index acd4f7bd..feda0353 100644 --- a/marker/v2/schema/blocks/figure.py +++ b/marker/v2/schema/blocks/figure.py @@ -6,4 +6,4 @@ class Figure(Block): block_type: BlockTypes = BlockTypes.Figure def assemble_html(self, child_blocks, parent_structure): - return f"Image {self.block_id}" + return f"

Image {self.block_id}

" diff --git a/marker/v2/schema/blocks/form.py b/marker/v2/schema/blocks/form.py index a45b7924..7ececa71 100644 --- a/marker/v2/schema/blocks/form.py +++ b/marker/v2/schema/blocks/form.py @@ -1,6 +1,16 @@ +from typing import List + +from tabled.formats import html_format +from tabled.schema import SpanTableCell + from marker.v2.schema import BlockTypes from marker.v2.schema.blocks import Block class Form(Block): - block_type: BlockTypes = BlockTypes.Form + block_type: str = BlockTypes.Form + cells: List[SpanTableCell] | None = None + + def assemble_html(self, child_blocks, parent_structure=None): + return html_format(self.cells) + diff --git a/marker/v2/schema/blocks/pagefooter.py b/marker/v2/schema/blocks/pagefooter.py index 96b7519f..fd42b603 100644 --- a/marker/v2/schema/blocks/pagefooter.py +++ b/marker/v2/schema/blocks/pagefooter.py @@ -3,4 +3,9 @@ class PageFooter(Block): - block_type: BlockTypes = BlockTypes.PageFooter + block_type: str = BlockTypes.PageFooter + + def assemble_html(self, child_blocks, parent_structure): + template = super().assemble_html(child_blocks, parent_structure) + template = template.replace("\n", " ") + return f"

{template}

" diff --git a/marker/v2/schema/blocks/pageheader.py b/marker/v2/schema/blocks/pageheader.py index ef46da46..efc88464 100644 --- a/marker/v2/schema/blocks/pageheader.py +++ b/marker/v2/schema/blocks/pageheader.py @@ -3,4 +3,9 @@ class PageHeader(Block): - block_type: BlockTypes = BlockTypes.PageHeader + block_type: str = BlockTypes.PageHeader + + def assemble_html(self, child_blocks, parent_structure): + template = super().assemble_html(child_blocks, parent_structure) + template = template.replace("\n", " ") + return f"

{template}

" diff --git a/marker/v2/schema/blocks/picture.py b/marker/v2/schema/blocks/picture.py index b4e2e177..595bac90 100644 --- a/marker/v2/schema/blocks/picture.py +++ b/marker/v2/schema/blocks/picture.py @@ -6,4 +6,4 @@ class Picture(Block): block_type: BlockTypes = BlockTypes.Picture def assemble_html(self, child_blocks, parent_structure): - return f"Image {self.block_id}" + return f"

Image {self.block_id}

" diff --git a/marker/v2/schema/blocks/text.py b/marker/v2/schema/blocks/text.py index aaa9a3ee..53055be6 100644 --- a/marker/v2/schema/blocks/text.py +++ b/marker/v2/schema/blocks/text.py @@ -1,7 +1,6 @@ from marker.v2.schema import BlockTypes from marker.v2.schema.blocks import Block - class Text(Block): block_type: BlockTypes = BlockTypes.Text diff --git a/marker/v2/schema/blocks/toc.py b/marker/v2/schema/blocks/toc.py index 4a80a01a..8057e3f7 100644 --- a/marker/v2/schema/blocks/toc.py +++ b/marker/v2/schema/blocks/toc.py @@ -1,6 +1,15 @@ +from typing import List + +from tabled.formats import html_format +from tabled.schema import SpanTableCell + from marker.v2.schema import BlockTypes from marker.v2.schema.blocks import Block class TableOfContents(Block): - block_type: BlockTypes = BlockTypes.TableOfContents + block_type: str = BlockTypes.TableOfContents + cells: List[SpanTableCell] | None = None + + def assemble_html(self, child_blocks, parent_structure=None): + return html_format(self.cells) diff --git a/marker/v2/schema/document.py b/marker/v2/schema/document.py index 161f89e5..83f49fd3 100644 --- a/marker/v2/schema/document.py +++ b/marker/v2/schema/document.py @@ -21,11 +21,17 @@ class Document(BaseModel): block_type: BlockTypes = BlockTypes.Document def get_block(self, block_id: BlockId): - block = self.pages[block_id.page_id].get_block(block_id) + page = self.get_page(block_id.page_id) + block = page.get_block(block_id) if block: return block return None + def get_page(self, page_id): + page = self.pages[page_id] + assert page.page_id == page_id, "Mismatch between page_id and page index" + return page + def assemble_html(self, child_blocks): template = "" for c in child_blocks: diff --git a/marker/v2/schema/groups/list.py b/marker/v2/schema/groups/list.py index 0baa2931..f5880bc4 100644 --- a/marker/v2/schema/groups/list.py +++ b/marker/v2/schema/groups/list.py @@ -7,4 +7,4 @@ class ListGroup(Block): def assemble_html(self, child_blocks, parent_structure): template = super().assemble_html(child_blocks, parent_structure) - return f"" + return f"

" diff --git a/marker/v2/schema/text/line.py b/marker/v2/schema/text/line.py index 7eda5123..5f8eb27d 100644 --- a/marker/v2/schema/text/line.py +++ b/marker/v2/schema/text/line.py @@ -27,7 +27,8 @@ def strip_trailing_hyphens(line_text, next_line_text, line_html) -> str: next_line_starts_lowercase = regex.match(rf"^\s?[{lowercase_letters}]", next_line_text) if hyphen_regex.match(line_text) and next_line_starts_lowercase: - return replace_last(line_html, rf'[{HYPHENS}]', "") + line_html = replace_last(line_html, rf'[{HYPHENS}]', "") + return line_html