From afe629561e26854b64d228e16eeaa343d91225dd Mon Sep 17 00:00:00 2001 From: Moses Paul R Date: Fri, 15 Nov 2024 11:28:21 +0000 Subject: [PATCH] add OCR builder and tests --- marker/v2/builders/document.py | 11 ++++- marker/v2/builders/layout.py | 10 ++-- marker/v2/builders/ocr.py | 86 +++++++++++++++++++++++++++++++++ marker/v2/converters/pdf.py | 4 +- marker/v2/processors/table.py | 6 +-- marker/v2/schema/blocks/base.py | 8 +-- tests/conftest.py | 6 ++- tests/test_document_builder.py | 3 +- tests/test_ocr_pipeline.py | 27 +++++++++++ tests/utils.py | 26 ++++++---- 10 files changed, 162 insertions(+), 25 deletions(-) create mode 100644 marker/v2/builders/ocr.py create mode 100644 tests/test_ocr_pipeline.py diff --git a/marker/v2/builders/document.py b/marker/v2/builders/document.py index 1493119d..0726194b 100644 --- a/marker/v2/builders/document.py +++ b/marker/v2/builders/document.py @@ -1,6 +1,7 @@ from marker.settings import settings from marker.v2.builders import BaseBuilder from marker.v2.builders.layout import LayoutBuilder +from marker.v2.builders.ocr import OcrBuilder from marker.v2.providers.pdf import PdfProvider from marker.v2.schema.document import Document from marker.v2.schema.groups.page import PageGroup @@ -8,9 +9,15 @@ class DocumentBuilder(BaseBuilder): - def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder): + force_ocr: bool = False + + def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, ocr_builder: OcrBuilder): document = self.build_document(provider) - layout_builder(document, provider) + if self.force_ocr: + layout_builder(document) + else: + layout_builder(document, provider) + ocr_builder(document) return document def build_document(self, provider: PdfProvider): diff --git a/marker/v2/builders/layout.py b/marker/v2/builders/layout.py index 8f614c7a..2d119d6b 100644 --- a/marker/v2/builders/layout.py +++ b/marker/v2/builders/layout.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from surya.layout import batch_layout_detection from surya.schema import LayoutResult @@ -20,10 +20,11 @@ def __init__(self, layout_model, config=None): super().__init__(config) - def __call__(self, document: Document, provider: PdfProvider): + def __call__(self, document: Document, provider: Optional[PdfProvider] = None): layout_results = self.surya_layout(document.pages) self.add_blocks_to_pages(document.pages, layout_results) - self.merge_blocks(document.pages, provider.page_lines) + if provider is not None: + self.merge_blocks(document.pages, provider.page_lines) def get_batch_size(self): if self.batch_size is not None: @@ -70,6 +71,7 @@ def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: Pdf document_page.add_full_block(line) block_idx = max_intersections[line_idx][1] block: Block = document_page.children[block_idx] + block.text_extraction_method = "pdftext" block.add_structure(line) block.polygon = block.polygon.merge([line.polygon]) assigned_line_idxs.add(line_idx) @@ -90,6 +92,7 @@ def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: Pdf if min_dist_idx is not None: document_page.add_full_block(line) nearest_block = document_page.children[min_dist_idx] + nearest_block.text_extraction_method = "pdftext" nearest_block.add_structure(line) nearest_block.polygon = nearest_block.polygon.merge([line.polygon]) assigned_line_idxs.add(line_idx) @@ -101,6 +104,7 @@ def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: Pdf line, spans = provider_lines[line_idx] document_page.add_full_block(line) text_block = document_page.add_block(Text, polygon=line.polygon) + text_block.text_extraction_method = "pdftext" text_block.add_structure(line) for span in spans: document_page.add_full_block(span) diff --git a/marker/v2/builders/ocr.py b/marker/v2/builders/ocr.py new file mode 100644 index 00000000..de942204 --- /dev/null +++ b/marker/v2/builders/ocr.py @@ -0,0 +1,86 @@ +from typing import List + +from surya.ocr import run_recognition + +from marker.settings import settings +from marker.v2.builders import BaseBuilder +from marker.v2.schema.blocks import BlockId +from marker.v2.schema.text.line import Line, Span +from marker.v2.schema.document import Document +from marker.v2.schema.groups.page import PageGroup +from marker.v2.schema.polygon import PolygonBox + + +class OcrBuilder(BaseBuilder): + batch_size = None + + def __init__(self, ocr_model, config=None): + self.ocr_model = ocr_model + + super().__init__(config) + + def __call__(self, document: Document): + self.surya_recognition(document.pages) + + def get_batch_size(self): + if self.batch_size is not None: + return self.batch_size + elif settings.TORCH_DEVICE_MODEL == "cuda": + return 32 + elif settings.TORCH_DEVICE_MODEL == "mps": + return 32 + return 32 + + def surya_recognition(self, pages: List[PageGroup]) -> List[List[str]]: + ocr_bbox_list = [] + ocr_block_id_list: List[List[BlockId]] = [] + for page in pages: + ocr_page_bbox_list = [] + ocr_page_block_id_list = [] + for block in page.children: + if block.block_type in [ + "Caption", "Code", "Footnote", + "Form", "Handwriting", "List-item", + "Page-footer", "Page-header", + "Section-header", "Text" + ]: + if block.structure is None: + block.text_extraction_method = "surya" + block_polygon = block.polygon.rescale(page.polygon.size, page.highres_image.size) + bbox = list(map(round, block_polygon.bbox)) + ocr_page_bbox_list.append(bbox) + ocr_page_block_id_list.append(block._id) + ocr_bbox_list.append(ocr_page_bbox_list) + ocr_block_id_list.append(ocr_page_block_id_list) + + recognition_results = run_recognition( + images=[p.highres_image for p in pages], + langs=[None] * len(pages), + rec_model=self.ocr_model, + rec_processor=self.ocr_model.processor, + bboxes=ocr_bbox_list, + batch_size=int(self.get_batch_size()) + ) + + for ocr_block_ids, recognition_result in zip(ocr_block_id_list, recognition_results): + for ocr_block_id, recognition in zip(ocr_block_ids, recognition_result.text_lines): + page_id = ocr_block_id.page_id + polygon = PolygonBox.from_bbox(recognition.bbox) + page = pages[page_id] + block = page.get_block(ocr_block_id) + line_block = page.add_block(Line, polygon=polygon) + block.add_structure(line_block) + span_block = page.add_full_block( + Span( + text=recognition.text, + formats=['plain'], + page_id=page_id, + polygon=polygon, + minimum_position=0, + maximum_position=0, + font='', + font_weight=0, + font_size=0 + ) + ) + line_block.add_structure(span_block) diff --git a/marker/v2/converters/pdf.py b/marker/v2/converters/pdf.py index 1b307b9c..8759b951 100644 --- a/marker/v2/converters/pdf.py +++ b/marker/v2/converters/pdf.py @@ -8,6 +8,7 @@ from marker.v2.builders.document import DocumentBuilder from marker.v2.builders.layout import LayoutBuilder +from marker.v2.builders.ocr import OcrBuilder from marker.v2.builders.structure import StructureBuilder from marker.v2.converters import BaseConverter from marker.v2.processors.equation import EquationProcessor @@ -30,7 +31,8 @@ def __call__(self, filepath: str, page_range: List[int] | None = None): pdf_provider = PdfProvider(filepath, {"page_range": page_range}) layout_builder = LayoutBuilder(self.layout_model) - document = DocumentBuilder()(pdf_provider, layout_builder) + ocr_builder = OcrBuilder(self.recognition_model) + document = DocumentBuilder()(pdf_provider, layout_builder, ocr_builder) StructureBuilder()(document) equation_processor = EquationProcessor(self.texify_model) diff --git a/marker/v2/processors/table.py b/marker/v2/processors/table.py index e2b1b08b..af765dab 100644 --- a/marker/v2/processors/table.py +++ b/marker/v2/processors/table.py @@ -25,7 +25,7 @@ def __init__(self, detection_model, ocr_model, table_rec_model, config: Optional self.table_rec_model = table_rec_model def __call__(self, document: Document): - filepath = document.filepath # Path to original pdf file + filepath = document.filepath # Path to original pdf file table_data = [] for page in document.pages: @@ -35,7 +35,7 @@ def __call__(self, document: Document): image_poly = block.polygon.rescale((page.polygon.width, page.polygon.height), page.highres_image.size) image = page.highres_image.crop(image_poly.bbox).convert("RGB") - if block.text_extraction_method == "ocr": + if block.text_extraction_method == "surya": text_lines = None else: text_lines = get_page_text_lines( @@ -102,4 +102,4 @@ def get_ocr_batch_size(self): return 32 elif settings.TORCH_DEVICE_MODEL == "cuda": return 128 - return 32 \ No newline at end of file + return 32 diff --git a/marker/v2/schema/blocks/base.py b/marker/v2/schema/blocks/base.py index 6a5b7d5a..d24be9eb 100644 --- a/marker/v2/schema/blocks/base.py +++ b/marker/v2/schema/blocks/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Optional +from typing import Any, List, Literal, Optional from pydantic import BaseModel, ConfigDict @@ -22,7 +22,9 @@ def __repr__(self): return str(self) def __eq__(self, other): - if not isinstance(other, BlockId): + if isinstance(other, str): + return str(self) == other + elif not isinstance(other, BlockId): return NotImplemented return self.page_id == other.page_id and self.block_id == other.block_id and self.block_type == other.block_type @@ -32,7 +34,7 @@ class Block(BaseModel): block_type: Optional[str] = None block_id: Optional[int] = None page_id: Optional[int] = None - text_extraction_method: Optional[str] = None + text_extraction_method: Optional[Literal['pdftext', 'surya']] = None structure: List[BlockId] | None = None # The top-level page structure, which is the block ids in order rendered: Any | None = None # The rendered output of the block diff --git a/tests/conftest.py b/tests/conftest.py index 1a1a607d..f8948cdf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ setup_detection_model from marker.v2.builders.document import DocumentBuilder from marker.v2.builders.layout import LayoutBuilder +from marker.v2.builders.ocr import OcrBuilder from marker.v2.schema.document import Document @@ -47,7 +48,7 @@ def table_rec_model(): @pytest.fixture(scope="function") -def pdf_document(request, layout_model) -> Document: +def pdf_document(request, layout_model, recognition_model) -> Document: mark = request.node.get_closest_marker("filename") if mark is None: filename = "adversarial.pdf" @@ -63,6 +64,7 @@ def pdf_document(request, layout_model) -> Document: provider = PdfProvider(temp_pdf.name) layout_builder = LayoutBuilder(layout_model) + ocr_builder = OcrBuilder(recognition_model) builder = DocumentBuilder() - document = builder(provider, layout_builder) + document = builder(provider, layout_builder, ocr_builder) return document diff --git a/tests/test_document_builder.py b/tests/test_document_builder.py index 1c5444a0..085a1bdb 100644 --- a/tests/test_document_builder.py +++ b/tests/test_document_builder.py @@ -7,6 +7,7 @@ def test_document_builder(pdf_document): first_block = first_page.get_block(first_page.structure[0]) assert first_block.block_type == 'Section-header' + assert first_block.text_extraction_method == 'pdftext' first_text_block: Line = first_page.get_block(first_block.structure[0]) assert first_text_block.block_type == 'Line' first_span = first_page.get_block(first_text_block.structure[0]) @@ -16,7 +17,7 @@ def test_document_builder(pdf_document): assert first_span.formats == ['plain'] last_block = first_page.get_block(first_page.structure[-1]) - assert last_block.block_type == 'Text' + assert last_block.block_type == 'Text-inline-math' last_text_block: Line = first_page.get_block(last_block.structure[-1]) assert last_text_block.block_type == 'Line' last_span = first_page.get_block(last_text_block.structure[-1]) diff --git a/tests/test_ocr_pipeline.py b/tests/test_ocr_pipeline.py new file mode 100644 index 00000000..5b2d756b --- /dev/null +++ b/tests/test_ocr_pipeline.py @@ -0,0 +1,27 @@ +from marker.v2.schema.text.line import Line +from tests.utils import setup_pdf_document + + +def test_document_builder(): + pdf_document = setup_pdf_document( + "adversarial.pdf", + document_builder_config={ + "force_ocr": False + } + ) + + first_page = pdf_document.pages[0] + assert first_page.structure[0] == '/page/0/Section-header/0' + + first_block = first_page.get_block(first_page.structure[0]) + assert first_block.text_extraction_method == 'surya' + assert first_block.block_type == 'Section-header' + first_text_block: Line = first_page.get_block(first_block.structure[0]) + assert first_text_block.block_type == 'Line' + first_span = first_page.get_block(first_text_block.structure[0]) + assert first_span.block_type == 'Span' + assert first_span.text == 'Subspace Adversarial Training' + + +if __name__ == "__main__": + test_document_builder() diff --git a/tests/utils.py b/tests/utils.py index 50af92b9..c8f326ec 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,15 +2,20 @@ import tempfile import datasets -from surya.model.layout.model import load_model -from surya.model.layout.processor import load_processor - +from marker.v2.models import setup_layout_model, setup_recognition_model from marker.v2.builders.document import DocumentBuilder from marker.v2.builders.layout import LayoutBuilder +from marker.v2.builders.ocr import OcrBuilder from marker.v2.schema.document import Document -def setup_pdf_document(filename: str) -> Document: +def setup_pdf_document( + filename='adversarial.pdf', + pdf_provider_config=None, + layout_builder_config=None, + ocr_builder_config=None, + document_builder_config=None +) -> Document: dataset = datasets.load_dataset("datalab-to/pdfs", split="train") idx = dataset['filename'].index(filename) @@ -18,11 +23,12 @@ def setup_pdf_document(filename: str) -> Document: temp_pdf.write(dataset['pdf'][idx]) temp_pdf.flush() - layout_model = load_model() - layout_model.processor = load_processor() + layout_model = setup_layout_model() + recognition_model = setup_recognition_model() - provider = PdfProvider(temp_pdf.name) - layout_builder = LayoutBuilder(layout_model) - builder = DocumentBuilder() - document = builder(provider, layout_builder) + provider = PdfProvider(temp_pdf.name, pdf_provider_config) + layout_builder = LayoutBuilder(layout_model, layout_builder_config) + ocr_builder = OcrBuilder(recognition_model, ocr_builder_config) + builder = DocumentBuilder(document_builder_config) + document = builder(provider, layout_builder, ocr_builder) return document