Skip to content

Commit

Permalink
add OCR builder and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
iammosespaulr committed Nov 15, 2024
1 parent c564341 commit afe6295
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 25 deletions.
11 changes: 9 additions & 2 deletions marker/v2/builders/document.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from marker.settings import settings
from marker.v2.builders import BaseBuilder
from marker.v2.builders.layout import LayoutBuilder
from marker.v2.builders.ocr import OcrBuilder
from marker.v2.providers.pdf import PdfProvider
from marker.v2.schema.document import Document
from marker.v2.schema.groups.page import PageGroup
from marker.v2.schema.polygon import PolygonBox


class DocumentBuilder(BaseBuilder):
def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder):
force_ocr: bool = False

def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, ocr_builder: OcrBuilder):
document = self.build_document(provider)
layout_builder(document, provider)
if self.force_ocr:
layout_builder(document)
else:
layout_builder(document, provider)
ocr_builder(document)
return document

def build_document(self, provider: PdfProvider):
Expand Down
10 changes: 7 additions & 3 deletions marker/v2/builders/layout.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional

from surya.layout import batch_layout_detection
from surya.schema import LayoutResult
Expand All @@ -20,10 +20,11 @@ def __init__(self, layout_model, config=None):

super().__init__(config)

def __call__(self, document: Document, provider: PdfProvider):
def __call__(self, document: Document, provider: Optional[PdfProvider] = None):
layout_results = self.surya_layout(document.pages)
self.add_blocks_to_pages(document.pages, layout_results)
self.merge_blocks(document.pages, provider.page_lines)
if provider is not None:
self.merge_blocks(document.pages, provider.page_lines)

def get_batch_size(self):
if self.batch_size is not None:
Expand Down Expand Up @@ -70,6 +71,7 @@ def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: Pdf
document_page.add_full_block(line)
block_idx = max_intersections[line_idx][1]
block: Block = document_page.children[block_idx]
block.text_extraction_method = "pdftext"
block.add_structure(line)
block.polygon = block.polygon.merge([line.polygon])
assigned_line_idxs.add(line_idx)
Expand All @@ -90,6 +92,7 @@ def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: Pdf
if min_dist_idx is not None:
document_page.add_full_block(line)
nearest_block = document_page.children[min_dist_idx]
nearest_block.text_extraction_method = "pdftext"
nearest_block.add_structure(line)
nearest_block.polygon = nearest_block.polygon.merge([line.polygon])
assigned_line_idxs.add(line_idx)
Expand All @@ -101,6 +104,7 @@ def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: Pdf
line, spans = provider_lines[line_idx]
document_page.add_full_block(line)
text_block = document_page.add_block(Text, polygon=line.polygon)
text_block.text_extraction_method = "pdftext"
text_block.add_structure(line)
for span in spans:
document_page.add_full_block(span)
Expand Down
86 changes: 86 additions & 0 deletions marker/v2/builders/ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import List

from surya.ocr import run_recognition

from marker.settings import settings
from marker.v2.builders import BaseBuilder
from marker.v2.schema.blocks import BlockId
from marker.v2.schema.text.line import Line, Span
from marker.v2.schema.document import Document
from marker.v2.schema.groups.page import PageGroup
from marker.v2.schema.polygon import PolygonBox


class OcrBuilder(BaseBuilder):
batch_size = None

def __init__(self, ocr_model, config=None):
self.ocr_model = ocr_model

super().__init__(config)

def __call__(self, document: Document):
self.surya_recognition(document.pages)

def get_batch_size(self):
if self.batch_size is not None:
return self.batch_size
elif settings.TORCH_DEVICE_MODEL == "cuda":
return 32
elif settings.TORCH_DEVICE_MODEL == "mps":
return 32
return 32

def surya_recognition(self, pages: List[PageGroup]) -> List[List[str]]:
ocr_bbox_list = []
ocr_block_id_list: List[List[BlockId]] = []
for page in pages:
ocr_page_bbox_list = []
ocr_page_block_id_list = []
for block in page.children:
if block.block_type in [
"Caption", "Code", "Footnote",
"Form", "Handwriting", "List-item",
"Page-footer", "Page-header",
"Section-header", "Text"
]:
if block.structure is None:
block.text_extraction_method = "surya"
block_polygon = block.polygon.rescale(page.polygon.size, page.highres_image.size)
bbox = list(map(round, block_polygon.bbox))
ocr_page_bbox_list.append(bbox)
ocr_page_block_id_list.append(block._id)
ocr_bbox_list.append(ocr_page_bbox_list)
ocr_block_id_list.append(ocr_page_block_id_list)

recognition_results = run_recognition(
images=[p.highres_image for p in pages],
langs=[None] * len(pages),
rec_model=self.ocr_model,
rec_processor=self.ocr_model.processor,
bboxes=ocr_bbox_list,
batch_size=int(self.get_batch_size())
)

for ocr_block_ids, recognition_result in zip(ocr_block_id_list, recognition_results):
for ocr_block_id, recognition in zip(ocr_block_ids, recognition_result.text_lines):
page_id = ocr_block_id.page_id
polygon = PolygonBox.from_bbox(recognition.bbox)
page = pages[page_id]
block = page.get_block(ocr_block_id)
line_block = page.add_block(Line, polygon=polygon)
block.add_structure(line_block)
span_block = page.add_full_block(
Span(
text=recognition.text,
formats=['plain'],
page_id=page_id,
polygon=polygon,
minimum_position=0,
maximum_position=0,
font='',
font_weight=0,
font_size=0
)
)
line_block.add_structure(span_block)
4 changes: 3 additions & 1 deletion marker/v2/converters/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from marker.v2.builders.document import DocumentBuilder
from marker.v2.builders.layout import LayoutBuilder
from marker.v2.builders.ocr import OcrBuilder
from marker.v2.builders.structure import StructureBuilder
from marker.v2.converters import BaseConverter
from marker.v2.processors.equation import EquationProcessor
Expand All @@ -30,7 +31,8 @@ def __call__(self, filepath: str, page_range: List[int] | None = None):
pdf_provider = PdfProvider(filepath, {"page_range": page_range})

layout_builder = LayoutBuilder(self.layout_model)
document = DocumentBuilder()(pdf_provider, layout_builder)
ocr_builder = OcrBuilder(self.recognition_model)
document = DocumentBuilder()(pdf_provider, layout_builder, ocr_builder)
StructureBuilder()(document)

equation_processor = EquationProcessor(self.texify_model)
Expand Down
6 changes: 3 additions & 3 deletions marker/v2/processors/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, detection_model, ocr_model, table_rec_model, config: Optional
self.table_rec_model = table_rec_model

def __call__(self, document: Document):
filepath = document.filepath # Path to original pdf file
filepath = document.filepath # Path to original pdf file

table_data = []
for page in document.pages:
Expand All @@ -35,7 +35,7 @@ def __call__(self, document: Document):
image_poly = block.polygon.rescale((page.polygon.width, page.polygon.height), page.highres_image.size)
image = page.highres_image.crop(image_poly.bbox).convert("RGB")

if block.text_extraction_method == "ocr":
if block.text_extraction_method == "surya":
text_lines = None
else:
text_lines = get_page_text_lines(
Expand Down Expand Up @@ -102,4 +102,4 @@ def get_ocr_batch_size(self):
return 32
elif settings.TORCH_DEVICE_MODEL == "cuda":
return 128
return 32
return 32
8 changes: 5 additions & 3 deletions marker/v2/schema/blocks/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, List, Optional
from typing import Any, List, Literal, Optional

from pydantic import BaseModel, ConfigDict

Expand All @@ -22,7 +22,9 @@ def __repr__(self):
return str(self)

def __eq__(self, other):
if not isinstance(other, BlockId):
if isinstance(other, str):
return str(self) == other
elif not isinstance(other, BlockId):
return NotImplemented
return self.page_id == other.page_id and self.block_id == other.block_id and self.block_type == other.block_type

Expand All @@ -32,7 +34,7 @@ class Block(BaseModel):
block_type: Optional[str] = None
block_id: Optional[int] = None
page_id: Optional[int] = None
text_extraction_method: Optional[str] = None
text_extraction_method: Optional[Literal['pdftext', 'surya']] = None
structure: List[BlockId] | None = None # The top-level page structure, which is the block ids in order
rendered: Any | None = None # The rendered output of the block

Expand Down
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
setup_detection_model
from marker.v2.builders.document import DocumentBuilder
from marker.v2.builders.layout import LayoutBuilder
from marker.v2.builders.ocr import OcrBuilder
from marker.v2.schema.document import Document


Expand Down Expand Up @@ -47,7 +48,7 @@ def table_rec_model():


@pytest.fixture(scope="function")
def pdf_document(request, layout_model) -> Document:
def pdf_document(request, layout_model, recognition_model) -> Document:
mark = request.node.get_closest_marker("filename")
if mark is None:
filename = "adversarial.pdf"
Expand All @@ -63,6 +64,7 @@ def pdf_document(request, layout_model) -> Document:

provider = PdfProvider(temp_pdf.name)
layout_builder = LayoutBuilder(layout_model)
ocr_builder = OcrBuilder(recognition_model)
builder = DocumentBuilder()
document = builder(provider, layout_builder)
document = builder(provider, layout_builder, ocr_builder)
return document
3 changes: 2 additions & 1 deletion tests/test_document_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def test_document_builder(pdf_document):

first_block = first_page.get_block(first_page.structure[0])
assert first_block.block_type == 'Section-header'
assert first_block.text_extraction_method == 'pdftext'
first_text_block: Line = first_page.get_block(first_block.structure[0])
assert first_text_block.block_type == 'Line'
first_span = first_page.get_block(first_text_block.structure[0])
Expand All @@ -16,7 +17,7 @@ def test_document_builder(pdf_document):
assert first_span.formats == ['plain']

last_block = first_page.get_block(first_page.structure[-1])
assert last_block.block_type == 'Text'
assert last_block.block_type == 'Text-inline-math'
last_text_block: Line = first_page.get_block(last_block.structure[-1])
assert last_text_block.block_type == 'Line'
last_span = first_page.get_block(last_text_block.structure[-1])
Expand Down
27 changes: 27 additions & 0 deletions tests/test_ocr_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from marker.v2.schema.text.line import Line
from tests.utils import setup_pdf_document


def test_document_builder():
pdf_document = setup_pdf_document(
"adversarial.pdf",
document_builder_config={
"force_ocr": False
}
)

first_page = pdf_document.pages[0]
assert first_page.structure[0] == '/page/0/Section-header/0'

first_block = first_page.get_block(first_page.structure[0])
assert first_block.text_extraction_method == 'surya'
assert first_block.block_type == 'Section-header'
first_text_block: Line = first_page.get_block(first_block.structure[0])
assert first_text_block.block_type == 'Line'
first_span = first_page.get_block(first_text_block.structure[0])
assert first_span.block_type == 'Span'
assert first_span.text == 'Subspace Adversarial Training'


if __name__ == "__main__":
test_document_builder()
26 changes: 16 additions & 10 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,33 @@
import tempfile

import datasets
from surya.model.layout.model import load_model
from surya.model.layout.processor import load_processor

from marker.v2.models import setup_layout_model, setup_recognition_model
from marker.v2.builders.document import DocumentBuilder
from marker.v2.builders.layout import LayoutBuilder
from marker.v2.builders.ocr import OcrBuilder
from marker.v2.schema.document import Document


def setup_pdf_document(filename: str) -> Document:
def setup_pdf_document(
filename='adversarial.pdf',
pdf_provider_config=None,
layout_builder_config=None,
ocr_builder_config=None,
document_builder_config=None
) -> Document:
dataset = datasets.load_dataset("datalab-to/pdfs", split="train")
idx = dataset['filename'].index(filename)

temp_pdf = tempfile.NamedTemporaryFile(suffix=".pdf")
temp_pdf.write(dataset['pdf'][idx])
temp_pdf.flush()

layout_model = load_model()
layout_model.processor = load_processor()
layout_model = setup_layout_model()
recognition_model = setup_recognition_model()

provider = PdfProvider(temp_pdf.name)
layout_builder = LayoutBuilder(layout_model)
builder = DocumentBuilder()
document = builder(provider, layout_builder)
provider = PdfProvider(temp_pdf.name, pdf_provider_config)
layout_builder = LayoutBuilder(layout_model, layout_builder_config)
ocr_builder = OcrBuilder(recognition_model, ocr_builder_config)
builder = DocumentBuilder(document_builder_config)
document = builder(provider, layout_builder, ocr_builder)
return document

0 comments on commit afe6295

Please sign in to comment.