Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Overriding Node Classes #368

Merged
merged 5 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions marker/v2/builders/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from marker.v2.builders.layout import LayoutBuilder
from marker.v2.builders.ocr import OcrBuilder
from marker.v2.providers.pdf import PdfProvider
from marker.v2.schema import BlockTypes
from marker.v2.schema.document import Document
from marker.v2.schema.groups.page import PageGroup
from marker.v2.schema.registry import get_block_class


class DocumentBuilder(BaseBuilder):
Expand All @@ -15,13 +17,14 @@ def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, ocr_bui
return document

def build_document(self, provider: PdfProvider):
PageGroupClass: PageGroup = get_block_class(BlockTypes.Page)
initial_pages = [
PageGroup(
PageGroupClass(
page_id=i,
lowres_image=provider.get_image(i, settings.IMAGE_DPI),
highres_image=provider.get_image(i, settings.HIGHRES_IMAGE_DPI),
polygon=provider.get_page_bbox(i)
) for i in provider.page_range
]

return Document(filepath=provider.filepath, pages=initial_pages)
DocumentClass: Document = get_block_class(BlockTypes.Document)
return DocumentClass(filepath=provider.filepath, pages=initial_pages)
4 changes: 2 additions & 2 deletions marker/v2/builders/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from marker.v2.builders import BaseBuilder
from marker.v2.providers.pdf import PageLines, PageSpans, PdfProvider
from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import LAYOUT_BLOCK_REGISTRY
from marker.v2.schema.document import Document
from marker.v2.schema.groups.page import PageGroup
from marker.v2.schema.polygon import PolygonBox
from marker.v2.schema.registry import get_block_class
from marker.v2.schema.text.line import Line


Expand Down Expand Up @@ -49,7 +49,7 @@ def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[Layou
layout_page_size = PolygonBox.from_bbox(layout_result.image_bbox).size
provider_page_size = page.polygon.size
for bbox in sorted(layout_result.bboxes, key=lambda x: x.position):
block_cls = LAYOUT_BLOCK_REGISTRY[BlockTypes[bbox.label]]
block_cls = get_block_class(BlockTypes[bbox.label])
layout_block = page.add_block(block_cls, PolygonBox(polygon=bbox.polygon))
layout_block.polygon = layout_block.polygon.rescale(layout_page_size, provider_page_size)
page.add_structure(layout_block)
Expand Down
9 changes: 6 additions & 3 deletions marker/v2/builders/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from marker.v2.builders import BaseBuilder
from marker.v2.providers.pdf import PdfProvider
from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import Block
from marker.v2.schema.document import Document
from marker.v2.schema.polygon import PolygonBox
from marker.v2.schema.registry import get_block_class
from marker.v2.schema.text.line import Line
from marker.v2.schema.text.span import Span

Expand Down Expand Up @@ -63,6 +63,9 @@ def ocr_extraction(self, document: Document, provider: PdfProvider) -> Tuple[Pag
page_lines = {}
page_spans = {}

SpanClass: Span = get_block_class(BlockTypes.Span)
LineClass: Line = get_block_class(BlockTypes.Line)

for page_id, recognition_result in zip((page.page_id for page in page_list), recognition_results):
page_spans.setdefault(page_id, {})
page_lines.setdefault(page_id, [])
Expand All @@ -74,13 +77,13 @@ def ocr_extraction(self, document: Document, provider: PdfProvider) -> Tuple[Pag
image_polygon = PolygonBox.from_bbox(recognition_result.image_bbox)
polygon = PolygonBox.from_bbox(ocr_line.bbox).rescale(image_polygon.size, page_size)

page_lines[page_id].append(Line(
page_lines[page_id].append(LineClass(
polygon=polygon,
page_id=page_id,
))

line_spans.setdefault(ocr_line_idx, [])
line_spans[ocr_line_idx].append(Span(
line_spans[ocr_line_idx].append(SpanClass(
text=ocr_line.text,
formats=['plain'],
page_id=page_id,
Expand Down
9 changes: 3 additions & 6 deletions marker/v2/builders/structure.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Optional

from pydantic import BaseModel

from marker.v2.builders import BaseBuilder
from marker.v2.schema import BlockTypes
from marker.v2.schema.document import Document
from marker.v2.schema.groups import GROUP_BLOCK_REGISTRY, ListGroup
from marker.v2.schema.groups import ListGroup
from marker.v2.schema.groups.page import PageGroup
from marker.v2.schema.registry import get_block_class


class StructureBuilder(BaseBuilder):
Expand Down Expand Up @@ -53,7 +50,7 @@ def group_caption_blocks(self, page: PageGroup):

if len(block_structure) > 1:
# Create a merged block
new_block_cls = GROUP_BLOCK_REGISTRY[BlockTypes[block.block_type.name + "Group"]]
new_block_cls = get_block_class(BlockTypes[block.block_type.name + "Group"])
new_polygon = block.polygon.merge(selected_polygons)
group_block = page.add_block(new_block_cls, new_polygon)
group_block.structure = block_structure
Expand Down
18 changes: 14 additions & 4 deletions marker/v2/converters/pdf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from marker.v2.providers.pdf import PdfProvider
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning

from marker.v2.processors.sectionheader import SectionHeaderProcessor
from marker.v2.providers.pdf import PdfProvider
import tempfile
from collections import defaultdict
from typing import Dict, Type

import click
import datasets
Expand All @@ -14,17 +15,26 @@
from marker.v2.builders.ocr import OcrBuilder
from marker.v2.builders.structure import StructureBuilder
from marker.v2.converters import BaseConverter
from marker.v2.models import setup_detection_model, setup_layout_model, \
setup_recognition_model, setup_table_rec_model, setup_texify_model
from marker.v2.processors.equation import EquationProcessor
from marker.v2.processors.sectionheader import SectionHeaderProcessor
from marker.v2.processors.table import TableProcessor
from marker.v2.models import setup_layout_model, setup_texify_model, setup_recognition_model, setup_table_rec_model, \
setup_detection_model
from marker.v2.renderers.markdown import MarkdownRenderer
from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import Block
from marker.v2.schema.registry import register_block_class
from marker.v2.processors.debug import DebugProcessor


class PdfConverter(BaseConverter):
override_map: Dict[BlockTypes, Type[Block]] = defaultdict()

def __init__(self, config=None):
super().__init__(config)

for block_type, override_block_type in self.override_map.items():
register_block_class(block_type, override_block_type)

self.layout_model = setup_layout_model()
self.texify_model = setup_texify_model()
Expand Down
11 changes: 7 additions & 4 deletions marker/v2/providers/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import pypdfium2 as pdfium
from pdftext.extraction import dictionary_output
from PIL import Image
from pydantic import BaseModel

from marker.ocr.heuristics import detect_bad_ocr
from marker.v2.providers import BaseProvider
from marker.v2.schema.polygon import PolygonBox
from marker.v2.schema import BlockTypes
from marker.v2.schema.registry import get_block_class
from marker.v2.schema.text.line import Line
from marker.v2.schema.text.span import Span

Expand All @@ -23,7 +24,7 @@ class PdfProvider(BaseProvider):
flatten_pdf: bool = True
force_ocr: bool = False

def __init__(self, filepath: str, config = None):
def __init__(self, filepath: str, config=None):
super().__init__(filepath, config)

self.doc: pdfium.PdfDocument = pdfium.PdfDocument(self.filepath)
Expand Down Expand Up @@ -105,6 +106,8 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]:
workers=self.pdftext_workers,
flatten_pdf=self.flatten_pdf
)
SpanClass: Span = get_block_class(BlockTypes.Span)
LineClass: Span = get_block_class(BlockTypes.Line)
for page in page_char_blocks:
page_id = page["page"]
lines: List[Line] = []
Expand All @@ -120,7 +123,7 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]:
font_weight = span["font"]["weight"] or 0
font_size = span["font"]["size"] or 0
spans.append(
Span(
SpanClass(
polygon=PolygonBox.from_bbox(span["bbox"]),
text=span["text"],
font=font_name,
Expand All @@ -133,7 +136,7 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]:
text_extraction_method="pdftext"
)
)
lines.append(Line(polygon=PolygonBox.from_bbox(line["bbox"]), page_id=page_id))
lines.append(LineClass(polygon=PolygonBox.from_bbox(line["bbox"]), page_id=page_id))
line_spans[len(lines) - 1] = spans
if self.check_line_spans(line_spans):
page_lines[page_id] = lines
Expand Down
8 changes: 0 additions & 8 deletions marker/v2/schema/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,3 @@
from marker.v2.schema.blocks.table import Table
from marker.v2.schema.blocks.text import Text
from marker.v2.schema.blocks.toc import TableOfContents

LAYOUT_BLOCK_REGISTRY = {
v.model_fields['block_type'].default: v for k, v in locals().items()
if isinstance(v, type)
and issubclass(v, Block)
and v != Block # Exclude the base Block class
and not v.model_fields['block_type'].default.name.endswith("Group")
}
8 changes: 0 additions & 8 deletions marker/v2/schema/groups/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,3 @@
from marker.v2.schema.groups.list import ListGroup
from marker.v2.schema.groups.picture import PictureGroup
from marker.v2.schema.groups.page import PageGroup

GROUP_BLOCK_REGISTRY = {
v.model_fields['block_type'].default: v for k, v in locals().items()
if isinstance(v, type)
and issubclass(v, Block)
and v != Block # Exclude the base Block class
and (v.model_fields['block_type'].default.name.endswith("Group") or v.model_fields['block_type'].default.name == "Page")
}
55 changes: 55 additions & 0 deletions marker/v2/schema/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Dict, Type
from importlib import import_module

from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import Block, Caption, Code, Equation, Figure, \
Footnote, Form, Handwriting, InlineMath, \
ListItem, PageFooter, PageHeader, Picture, \
SectionHeader, Table, TableOfContents, \
Text
from marker.v2.schema.document import Document
from marker.v2.schema.groups import FigureGroup, ListGroup, PageGroup, \
PictureGroup, TableGroup
from marker.v2.schema.text import Line, Span

BLOCK_REGISTRY: Dict[BlockTypes, str] = {}


def register_block_class(block_type: BlockTypes, block_cls: Type[Block]):
BLOCK_REGISTRY[block_type] = f"{block_cls.__module__}.{block_cls.__name__}"


def get_block_class(block_type: BlockTypes) -> Type[Block]:
class_path = BLOCK_REGISTRY[block_type]
module_name, class_name = class_path.rsplit('.', 1)
module = import_module(module_name)
return getattr(module, class_name)


register_block_class(BlockTypes.Line, Line)
register_block_class(BlockTypes.Span, Span)
register_block_class(BlockTypes.FigureGroup, FigureGroup)
register_block_class(BlockTypes.TableGroup, TableGroup)
register_block_class(BlockTypes.ListGroup, ListGroup)
register_block_class(BlockTypes.PictureGroup, PictureGroup)
register_block_class(BlockTypes.Page, PageGroup)
register_block_class(BlockTypes.Caption, Caption)
register_block_class(BlockTypes.Code, Code)
register_block_class(BlockTypes.Figure, Figure)
register_block_class(BlockTypes.Footnote, Footnote)
register_block_class(BlockTypes.Form, Form)
register_block_class(BlockTypes.Equation, Equation)
register_block_class(BlockTypes.Handwriting, Handwriting)
register_block_class(BlockTypes.TextInlineMath, InlineMath)
register_block_class(BlockTypes.ListItem, ListItem)
register_block_class(BlockTypes.PageFooter, PageFooter)
register_block_class(BlockTypes.PageHeader, PageHeader)
register_block_class(BlockTypes.Picture, Picture)
register_block_class(BlockTypes.SectionHeader, SectionHeader)
register_block_class(BlockTypes.Table, Table)
register_block_class(BlockTypes.Text, Text)
register_block_class(BlockTypes.TableOfContents, TableOfContents)
register_block_class(BlockTypes.Document, Document)

assert len(BLOCK_REGISTRY) == len(BlockTypes)
assert all([get_block_class(k).model_fields['block_type'].default == k for k, _ in BLOCK_REGISTRY.items()])
6 changes: 0 additions & 6 deletions marker/v2/schema/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,2 @@
from marker.v2.schema import BlockTypes
from marker.v2.schema.text.line import Line
from marker.v2.schema.text.span import Span

TEXT_BLOCK_REGISTRY = {
BlockTypes.Line: Line,
BlockTypes.Span: Span,
}
17 changes: 16 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ filetype = "^1.2.0"
regex = "^2024.4.28"
pdftext = "^0.3.18"
tabled-pdf = { git = "https://github.com/VikParuchuri/tabled.git", branch = "dev-mose/compilation-updates" }
markdownify = "^0.13.1"

[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
Expand Down
26 changes: 18 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@

import datasets
import pytest
from typing import Dict, Type

from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import Block
from marker.v2.models import setup_layout_model, setup_texify_model, setup_recognition_model, setup_table_rec_model, \
setup_detection_model
from marker.v2.builders.document import DocumentBuilder
from marker.v2.builders.layout import LayoutBuilder
from marker.v2.builders.ocr import OcrBuilder
from marker.v2.schema.document import Document
from marker.v2.schema.registry import register_block_class


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -48,13 +52,22 @@ def table_rec_model():


@pytest.fixture(scope="function")
def pdf_provider(request):
def config(request):
config_mark = request.node.get_closest_marker("config")
config = config_mark.args[0] if config_mark else {}

override_map: Dict[BlockTypes, Type[Block]] = config.get("override_map", {})
for block_type, override_block_type in override_map.items():
register_block_class(block_type, override_block_type)

return config


@pytest.fixture(scope="function")
def pdf_provider(request, config):
filename_mark = request.node.get_closest_marker("filename")
filename = filename_mark.args[0] if filename_mark else "adversarial.pdf"

config_mark = request.node.get_closest_marker("config")
config = config_mark.args[0] if config_mark else None

dataset = datasets.load_dataset("datalab-to/pdfs", split="train")
idx = dataset['filename'].index(filename)

Expand All @@ -65,10 +78,7 @@ def pdf_provider(request):


@pytest.fixture(scope="function")
def pdf_document(request, pdf_provider, layout_model, recognition_model, detection_model) -> Document:
config_mark = request.node.get_closest_marker("config")
config = config_mark.args[0] if config_mark else None

def pdf_document(request, config, pdf_provider, layout_model, recognition_model, detection_model) -> Document:
layout_builder = LayoutBuilder(layout_model, config)
ocr_builder = OcrBuilder(detection_model, recognition_model, config)
builder = DocumentBuilder(config)
Expand Down
Loading