From ead37b37631f95be166e67bea12059ca719fb641 Mon Sep 17 00:00:00 2001 From: Moses Paul R Date: Tue, 19 Nov 2024 10:28:29 +0000 Subject: [PATCH] use get_block_class everywhere, make registry mp compatible --- marker/v2/builders/document.py | 7 ++-- marker/v2/builders/layout.py | 4 +- marker/v2/builders/ocr.py | 7 ++-- marker/v2/builders/structure.py | 4 +- marker/v2/converters/pdf.py | 4 +- marker/v2/providers/pdf.py | 7 ++-- marker/v2/schema/registry.py | 74 ++++++++++++++++++--------------- tests/conftest.py | 4 +- 8 files changed, 60 insertions(+), 51 deletions(-) diff --git a/marker/v2/builders/document.py b/marker/v2/builders/document.py index 1ea0a05..af5d9f4 100644 --- a/marker/v2/builders/document.py +++ b/marker/v2/builders/document.py @@ -3,9 +3,10 @@ from marker.v2.builders.layout import LayoutBuilder from marker.v2.builders.ocr import OcrBuilder from marker.v2.providers.pdf import PdfProvider +from marker.v2.schema import BlockTypes from marker.v2.schema.document import Document from marker.v2.schema.groups.page import PageGroup -from marker.v2.schema.registry import get_block_cls +from marker.v2.schema.registry import get_block_class class DocumentBuilder(BaseBuilder): @@ -16,7 +17,7 @@ def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, ocr_bui return document def build_document(self, provider: PdfProvider): - PageGroupClass = get_block_cls(PageGroup) + PageGroupClass: PageGroup = get_block_class(BlockTypes.Page) initial_pages = [ PageGroupClass( page_id=i, @@ -25,5 +26,5 @@ def build_document(self, provider: PdfProvider): polygon=provider.get_page_bbox(i) ) for i in provider.page_range ] - DocumentClass = get_block_cls(Document) + DocumentClass: Document = get_block_class(BlockTypes.Document) return DocumentClass(filepath=provider.filepath, pages=initial_pages) diff --git a/marker/v2/builders/layout.py b/marker/v2/builders/layout.py index 5cf7842..0b4df84 100644 --- a/marker/v2/builders/layout.py +++ b/marker/v2/builders/layout.py @@ -10,7 +10,7 @@ from marker.v2.schema.document import Document from marker.v2.schema.groups.page import PageGroup from marker.v2.schema.polygon import PolygonBox -from marker.v2.schema.registry import BLOCK_REGISTRY +from marker.v2.schema.registry import get_block_class from marker.v2.schema.text.line import Line @@ -49,7 +49,7 @@ def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[Layou layout_page_size = PolygonBox.from_bbox(layout_result.image_bbox).size provider_page_size = page.polygon.size for bbox in sorted(layout_result.bboxes, key=lambda x: x.position): - block_cls = BLOCK_REGISTRY[BlockTypes[bbox.label]] + block_cls = get_block_class(BlockTypes[bbox.label]) layout_block = page.add_block(block_cls, PolygonBox(polygon=bbox.polygon)) layout_block.polygon = layout_block.polygon.rescale(layout_page_size, provider_page_size) page.add_structure(layout_block) diff --git a/marker/v2/builders/ocr.py b/marker/v2/builders/ocr.py index 1e97afb..9f40da9 100644 --- a/marker/v2/builders/ocr.py +++ b/marker/v2/builders/ocr.py @@ -5,9 +5,10 @@ from marker.settings import settings from marker.v2.builders import BaseBuilder from marker.v2.providers.pdf import PdfProvider +from marker.v2.schema import BlockTypes from marker.v2.schema.document import Document from marker.v2.schema.polygon import PolygonBox -from marker.v2.schema.registry import get_block_cls +from marker.v2.schema.registry import get_block_class from marker.v2.schema.text.line import Line from marker.v2.schema.text.span import Span @@ -62,8 +63,8 @@ def ocr_extraction(self, document: Document, provider: PdfProvider) -> Tuple[Pag page_lines = {} page_spans = {} - SpanClass = get_block_cls(Span) - LineClass = get_block_cls(Line) + SpanClass: Span = get_block_class(BlockTypes.Span) + LineClass: Line = get_block_class(BlockTypes.Line) for page_id, recognition_result in zip((page.page_id for page in page_list), recognition_results): page_spans.setdefault(page_id, {}) diff --git a/marker/v2/builders/structure.py b/marker/v2/builders/structure.py index c6ce697..b316ddb 100644 --- a/marker/v2/builders/structure.py +++ b/marker/v2/builders/structure.py @@ -3,7 +3,7 @@ from marker.v2.schema.document import Document from marker.v2.schema.groups import ListGroup from marker.v2.schema.groups.page import PageGroup -from marker.v2.schema.registry import BLOCK_REGISTRY +from marker.v2.schema.registry import get_block_class class StructureBuilder(BaseBuilder): @@ -50,7 +50,7 @@ def group_caption_blocks(self, page: PageGroup): if len(block_structure) > 1: # Create a merged block - new_block_cls = BLOCK_REGISTRY[BlockTypes[block.block_type.name + "Group"]] + new_block_cls = get_block_class(BlockTypes[block.block_type.name + "Group"]) new_polygon = block.polygon.merge(selected_polygons) group_block = page.add_block(new_block_cls, new_polygon) group_block.structure = block_structure diff --git a/marker/v2/converters/pdf.py b/marker/v2/converters/pdf.py index f149cb1..55caad9 100644 --- a/marker/v2/converters/pdf.py +++ b/marker/v2/converters/pdf.py @@ -23,7 +23,7 @@ from marker.v2.renderers.markdown import MarkdownRenderer from marker.v2.schema import BlockTypes from marker.v2.schema.blocks import Block -from marker.v2.schema.registry import BLOCK_REGISTRY +from marker.v2.schema.registry import register_block_class from marker.v2.processors.debug import DebugProcessor @@ -34,7 +34,7 @@ def __init__(self, config=None): super().__init__(config) for block_type, override_block_type in self.override_map.items(): - BLOCK_REGISTRY[block_type] = override_block_type + register_block_class(block_type, override_block_type) self.layout_model = setup_layout_model() self.texify_model = setup_texify_model() diff --git a/marker/v2/providers/pdf.py b/marker/v2/providers/pdf.py index 55aba74..77d49c2 100644 --- a/marker/v2/providers/pdf.py +++ b/marker/v2/providers/pdf.py @@ -8,7 +8,8 @@ from marker.ocr.heuristics import detect_bad_ocr from marker.v2.providers import BaseProvider from marker.v2.schema.polygon import PolygonBox -from marker.v2.schema.registry import get_block_cls +from marker.v2.schema import BlockTypes +from marker.v2.schema.registry import get_block_class from marker.v2.schema.text.line import Line from marker.v2.schema.text.span import Span @@ -105,8 +106,8 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]: workers=self.pdftext_workers, flatten_pdf=self.flatten_pdf ) - SpanClass = get_block_cls(Span) - LineClass = get_block_cls(Line) + SpanClass: Span = get_block_class(BlockTypes.Span) + LineClass: Span = get_block_class(BlockTypes.Line) for page in page_char_blocks: page_id = page["page"] lines: List[Line] = [] diff --git a/marker/v2/schema/registry.py b/marker/v2/schema/registry.py index e733746..bf17f99 100644 --- a/marker/v2/schema/registry.py +++ b/marker/v2/schema/registry.py @@ -1,4 +1,5 @@ -from typing import Dict, Type, TypeVar +from typing import Dict, Type +from importlib import import_module from marker.v2.schema import BlockTypes from marker.v2.schema.blocks import Block, Caption, Code, Equation, Figure, \ @@ -11,39 +12,44 @@ PictureGroup, TableGroup from marker.v2.schema.text import Line, Span -BLOCK_REGISTRY: Dict[str, Type[Block]] = { - BlockTypes.Line: Line, - BlockTypes.Span: Span, - BlockTypes.FigureGroup: FigureGroup, - BlockTypes.TableGroup: TableGroup, - BlockTypes.ListGroup: ListGroup, - BlockTypes.PictureGroup: PictureGroup, - BlockTypes.Page: PageGroup, - BlockTypes.Caption: Caption, - BlockTypes.Code: Code, - BlockTypes.Figure: Figure, - BlockTypes.Footnote: Footnote, - BlockTypes.Form: Form, - BlockTypes.Equation: Equation, - BlockTypes.Handwriting: Handwriting, - BlockTypes.TextInlineMath: InlineMath, - BlockTypes.ListItem: ListItem, - BlockTypes.PageFooter: PageFooter, - BlockTypes.PageHeader: PageHeader, - BlockTypes.Picture: Picture, - BlockTypes.SectionHeader: SectionHeader, - BlockTypes.Table: Table, - BlockTypes.Text: Text, - BlockTypes.TableOfContents: TableOfContents, - BlockTypes.Document: Document, -} - -T = TypeVar('T') - - -def get_block_cls(block_cls: T) -> T: - return BLOCK_REGISTRY.get(block_cls.model_fields['block_type'].default, block_cls) +BLOCK_REGISTRY: Dict[BlockTypes, str] = {} +def register_block_class(block_type: BlockTypes, block_cls: Type[Block]): + BLOCK_REGISTRY[block_type] = f"{block_cls.__module__}.{block_cls.__name__}" + + +def get_block_class(block_type: BlockTypes) -> Type[Block]: + class_path = BLOCK_REGISTRY[block_type] + module_name, class_name = class_path.rsplit('.', 1) + module = import_module(module_name) + return getattr(module, class_name) + + +register_block_class(BlockTypes.Line, Line) +register_block_class(BlockTypes.Span, Span) +register_block_class(BlockTypes.FigureGroup, FigureGroup) +register_block_class(BlockTypes.TableGroup, TableGroup) +register_block_class(BlockTypes.ListGroup, ListGroup) +register_block_class(BlockTypes.PictureGroup, PictureGroup) +register_block_class(BlockTypes.Page, PageGroup) +register_block_class(BlockTypes.Caption, Caption) +register_block_class(BlockTypes.Code, Code) +register_block_class(BlockTypes.Figure, Figure) +register_block_class(BlockTypes.Footnote, Footnote) +register_block_class(BlockTypes.Form, Form) +register_block_class(BlockTypes.Equation, Equation) +register_block_class(BlockTypes.Handwriting, Handwriting) +register_block_class(BlockTypes.TextInlineMath, InlineMath) +register_block_class(BlockTypes.ListItem, ListItem) +register_block_class(BlockTypes.PageFooter, PageFooter) +register_block_class(BlockTypes.PageHeader, PageHeader) +register_block_class(BlockTypes.Picture, Picture) +register_block_class(BlockTypes.SectionHeader, SectionHeader) +register_block_class(BlockTypes.Table, Table) +register_block_class(BlockTypes.Text, Text) +register_block_class(BlockTypes.TableOfContents, TableOfContents) +register_block_class(BlockTypes.Document, Document) + assert len(BLOCK_REGISTRY) == len(BlockTypes) -assert all([v.model_fields['block_type'].default == k for k, v in BLOCK_REGISTRY.items()]) +assert all([get_block_class(k).model_fields['block_type'].default == k for k, _ in BLOCK_REGISTRY.items()]) diff --git a/tests/conftest.py b/tests/conftest.py index 2c63a93..31335a1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ from marker.v2.builders.layout import LayoutBuilder from marker.v2.builders.ocr import OcrBuilder from marker.v2.schema.document import Document -from marker.v2.schema.registry import BLOCK_REGISTRY +from marker.v2.schema.registry import register_block_class @pytest.fixture(scope="session") @@ -58,7 +58,7 @@ def config(request): override_map: Dict[BlockTypes, Type[Block]] = config.get("override_map", {}) for block_type, override_block_type in override_map.items(): - BLOCK_REGISTRY[block_type] = override_block_type + register_block_class(block_type, override_block_type) return config