Skip to content

Commit

Permalink
add multiprocessing tests for block overriding
Browse files Browse the repository at this point in the history
  • Loading branch information
iammosespaulr committed Nov 19, 2024
1 parent ead37b3 commit 52e4d3f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
33 changes: 32 additions & 1 deletion tests/test_overriding.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,49 @@
import multiprocessing as mp

import pytest

from marker.v2.providers.pdf import PdfProvider
from marker.v2.schema import BlockTypes
from marker.v2.schema.document import Document
from marker.v2.schema.blocks import SectionHeader
from marker.v2.schema.document import Document
from marker.v2.schema.registry import register_block_class
from marker.v2.schema.text import Line
from tests.utils import setup_pdf_provider


class NewSectionHeader(SectionHeader):
pass


class NewLine(Line):
pass


@pytest.mark.config({
"page_range": [0],
"override_map": {BlockTypes.SectionHeader: NewSectionHeader}
})
def test_overriding(pdf_document: Document):
assert pdf_document.pages[0]\
.get_block(pdf_document.pages[0].structure[0]).__class__ == NewSectionHeader


def get_lines(pdf: str, config=None):
provider: PdfProvider = setup_pdf_provider(pdf, config)
return provider.get_page_lines(0)


def test_overriding_mp():
config = {
"page_range": [0],
"override_map": {BlockTypes.Line: NewLine}
}

for block_type, block_cls in config["override_map"].items():
register_block_class(block_type, block_cls)

pdf_list = ["adversarial.pdf", "adversarial_rot.pdf"]

with mp.Pool(processes=2) as pool:
results = pool.starmap(get_lines, [(pdf, config) for pdf in pdf_list])
assert all([r[0].__class__ == NewLine for r in results])
14 changes: 11 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,30 @@
from marker.v2.schema.document import Document


def setup_pdf_document(
def setup_pdf_provider(
filename='adversarial.pdf',
config=None,
) -> Document:
) -> PdfProvider:
dataset = datasets.load_dataset("datalab-to/pdfs", split="train")
idx = dataset['filename'].index(filename)

temp_pdf = tempfile.NamedTemporaryFile(suffix=".pdf")
temp_pdf.write(dataset['pdf'][idx])
temp_pdf.flush()

provider = PdfProvider(temp_pdf.name, config)
return provider


def setup_pdf_document(
filename='adversarial.pdf',
config=None,
) -> Document:
layout_model = setup_layout_model()
recognition_model = setup_recognition_model()
detection_model = setup_detection_model()

provider = PdfProvider(temp_pdf.name, config)
provider = setup_pdf_provider(filename, config)
layout_builder = LayoutBuilder(layout_model, config)
ocr_builder = OcrBuilder(detection_model, recognition_model, config)
builder = DocumentBuilder(config)
Expand Down

0 comments on commit 52e4d3f

Please sign in to comment.