Skip to content

Commit

Permalink
use get_block_class everywhere, make registry mp compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
iammosespaulr committed Nov 19, 2024
1 parent 8c71b35 commit ead37b3
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 51 deletions.
7 changes: 4 additions & 3 deletions marker/v2/builders/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from marker.v2.builders.layout import LayoutBuilder
from marker.v2.builders.ocr import OcrBuilder
from marker.v2.providers.pdf import PdfProvider
from marker.v2.schema import BlockTypes
from marker.v2.schema.document import Document
from marker.v2.schema.groups.page import PageGroup
from marker.v2.schema.registry import get_block_cls
from marker.v2.schema.registry import get_block_class


class DocumentBuilder(BaseBuilder):
Expand All @@ -16,7 +17,7 @@ def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, ocr_bui
return document

def build_document(self, provider: PdfProvider):
PageGroupClass = get_block_cls(PageGroup)
PageGroupClass: PageGroup = get_block_class(BlockTypes.Page)
initial_pages = [
PageGroupClass(
page_id=i,
Expand All @@ -25,5 +26,5 @@ def build_document(self, provider: PdfProvider):
polygon=provider.get_page_bbox(i)
) for i in provider.page_range
]
DocumentClass = get_block_cls(Document)
DocumentClass: Document = get_block_class(BlockTypes.Document)
return DocumentClass(filepath=provider.filepath, pages=initial_pages)
4 changes: 2 additions & 2 deletions marker/v2/builders/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from marker.v2.schema.document import Document
from marker.v2.schema.groups.page import PageGroup
from marker.v2.schema.polygon import PolygonBox
from marker.v2.schema.registry import BLOCK_REGISTRY
from marker.v2.schema.registry import get_block_class
from marker.v2.schema.text.line import Line


Expand Down Expand Up @@ -49,7 +49,7 @@ def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[Layou
layout_page_size = PolygonBox.from_bbox(layout_result.image_bbox).size
provider_page_size = page.polygon.size
for bbox in sorted(layout_result.bboxes, key=lambda x: x.position):
block_cls = BLOCK_REGISTRY[BlockTypes[bbox.label]]
block_cls = get_block_class(BlockTypes[bbox.label])
layout_block = page.add_block(block_cls, PolygonBox(polygon=bbox.polygon))
layout_block.polygon = layout_block.polygon.rescale(layout_page_size, provider_page_size)
page.add_structure(layout_block)
Expand Down
7 changes: 4 additions & 3 deletions marker/v2/builders/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from marker.settings import settings
from marker.v2.builders import BaseBuilder
from marker.v2.providers.pdf import PdfProvider
from marker.v2.schema import BlockTypes
from marker.v2.schema.document import Document
from marker.v2.schema.polygon import PolygonBox
from marker.v2.schema.registry import get_block_cls
from marker.v2.schema.registry import get_block_class
from marker.v2.schema.text.line import Line
from marker.v2.schema.text.span import Span

Expand Down Expand Up @@ -62,8 +63,8 @@ def ocr_extraction(self, document: Document, provider: PdfProvider) -> Tuple[Pag
page_lines = {}
page_spans = {}

SpanClass = get_block_cls(Span)
LineClass = get_block_cls(Line)
SpanClass: Span = get_block_class(BlockTypes.Span)
LineClass: Line = get_block_class(BlockTypes.Line)

for page_id, recognition_result in zip((page.page_id for page in page_list), recognition_results):
page_spans.setdefault(page_id, {})
Expand Down
4 changes: 2 additions & 2 deletions marker/v2/builders/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from marker.v2.schema.document import Document
from marker.v2.schema.groups import ListGroup
from marker.v2.schema.groups.page import PageGroup
from marker.v2.schema.registry import BLOCK_REGISTRY
from marker.v2.schema.registry import get_block_class


class StructureBuilder(BaseBuilder):
Expand Down Expand Up @@ -50,7 +50,7 @@ def group_caption_blocks(self, page: PageGroup):

if len(block_structure) > 1:
# Create a merged block
new_block_cls = BLOCK_REGISTRY[BlockTypes[block.block_type.name + "Group"]]
new_block_cls = get_block_class(BlockTypes[block.block_type.name + "Group"])
new_polygon = block.polygon.merge(selected_polygons)
group_block = page.add_block(new_block_cls, new_polygon)
group_block.structure = block_structure
Expand Down
4 changes: 2 additions & 2 deletions marker/v2/converters/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from marker.v2.renderers.markdown import MarkdownRenderer
from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import Block
from marker.v2.schema.registry import BLOCK_REGISTRY
from marker.v2.schema.registry import register_block_class
from marker.v2.processors.debug import DebugProcessor


Expand All @@ -34,7 +34,7 @@ def __init__(self, config=None):
super().__init__(config)

for block_type, override_block_type in self.override_map.items():
BLOCK_REGISTRY[block_type] = override_block_type
register_block_class(block_type, override_block_type)

self.layout_model = setup_layout_model()
self.texify_model = setup_texify_model()
Expand Down
7 changes: 4 additions & 3 deletions marker/v2/providers/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from marker.ocr.heuristics import detect_bad_ocr
from marker.v2.providers import BaseProvider
from marker.v2.schema.polygon import PolygonBox
from marker.v2.schema.registry import get_block_cls
from marker.v2.schema import BlockTypes
from marker.v2.schema.registry import get_block_class
from marker.v2.schema.text.line import Line
from marker.v2.schema.text.span import Span

Expand Down Expand Up @@ -105,8 +106,8 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]:
workers=self.pdftext_workers,
flatten_pdf=self.flatten_pdf
)
SpanClass = get_block_cls(Span)
LineClass = get_block_cls(Line)
SpanClass: Span = get_block_class(BlockTypes.Span)
LineClass: Span = get_block_class(BlockTypes.Line)
for page in page_char_blocks:
page_id = page["page"]
lines: List[Line] = []
Expand Down
74 changes: 40 additions & 34 deletions marker/v2/schema/registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Type, TypeVar
from typing import Dict, Type
from importlib import import_module

from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import Block, Caption, Code, Equation, Figure, \
Expand All @@ -11,39 +12,44 @@
PictureGroup, TableGroup
from marker.v2.schema.text import Line, Span

BLOCK_REGISTRY: Dict[str, Type[Block]] = {
BlockTypes.Line: Line,
BlockTypes.Span: Span,
BlockTypes.FigureGroup: FigureGroup,
BlockTypes.TableGroup: TableGroup,
BlockTypes.ListGroup: ListGroup,
BlockTypes.PictureGroup: PictureGroup,
BlockTypes.Page: PageGroup,
BlockTypes.Caption: Caption,
BlockTypes.Code: Code,
BlockTypes.Figure: Figure,
BlockTypes.Footnote: Footnote,
BlockTypes.Form: Form,
BlockTypes.Equation: Equation,
BlockTypes.Handwriting: Handwriting,
BlockTypes.TextInlineMath: InlineMath,
BlockTypes.ListItem: ListItem,
BlockTypes.PageFooter: PageFooter,
BlockTypes.PageHeader: PageHeader,
BlockTypes.Picture: Picture,
BlockTypes.SectionHeader: SectionHeader,
BlockTypes.Table: Table,
BlockTypes.Text: Text,
BlockTypes.TableOfContents: TableOfContents,
BlockTypes.Document: Document,
}

T = TypeVar('T')


def get_block_cls(block_cls: T) -> T:
return BLOCK_REGISTRY.get(block_cls.model_fields['block_type'].default, block_cls)
BLOCK_REGISTRY: Dict[BlockTypes, str] = {}


def register_block_class(block_type: BlockTypes, block_cls: Type[Block]):
BLOCK_REGISTRY[block_type] = f"{block_cls.__module__}.{block_cls.__name__}"


def get_block_class(block_type: BlockTypes) -> Type[Block]:
class_path = BLOCK_REGISTRY[block_type]
module_name, class_name = class_path.rsplit('.', 1)
module = import_module(module_name)
return getattr(module, class_name)


register_block_class(BlockTypes.Line, Line)
register_block_class(BlockTypes.Span, Span)
register_block_class(BlockTypes.FigureGroup, FigureGroup)
register_block_class(BlockTypes.TableGroup, TableGroup)
register_block_class(BlockTypes.ListGroup, ListGroup)
register_block_class(BlockTypes.PictureGroup, PictureGroup)
register_block_class(BlockTypes.Page, PageGroup)
register_block_class(BlockTypes.Caption, Caption)
register_block_class(BlockTypes.Code, Code)
register_block_class(BlockTypes.Figure, Figure)
register_block_class(BlockTypes.Footnote, Footnote)
register_block_class(BlockTypes.Form, Form)
register_block_class(BlockTypes.Equation, Equation)
register_block_class(BlockTypes.Handwriting, Handwriting)
register_block_class(BlockTypes.TextInlineMath, InlineMath)
register_block_class(BlockTypes.ListItem, ListItem)
register_block_class(BlockTypes.PageFooter, PageFooter)
register_block_class(BlockTypes.PageHeader, PageHeader)
register_block_class(BlockTypes.Picture, Picture)
register_block_class(BlockTypes.SectionHeader, SectionHeader)
register_block_class(BlockTypes.Table, Table)
register_block_class(BlockTypes.Text, Text)
register_block_class(BlockTypes.TableOfContents, TableOfContents)
register_block_class(BlockTypes.Document, Document)

assert len(BLOCK_REGISTRY) == len(BlockTypes)
assert all([v.model_fields['block_type'].default == k for k, v in BLOCK_REGISTRY.items()])
assert all([get_block_class(k).model_fields['block_type'].default == k for k, _ in BLOCK_REGISTRY.items()])
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from marker.v2.builders.layout import LayoutBuilder
from marker.v2.builders.ocr import OcrBuilder
from marker.v2.schema.document import Document
from marker.v2.schema.registry import BLOCK_REGISTRY
from marker.v2.schema.registry import register_block_class


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -58,7 +58,7 @@ def config(request):

override_map: Dict[BlockTypes, Type[Block]] = config.get("override_map", {})
for block_type, override_block_type in override_map.items():
BLOCK_REGISTRY[block_type] = override_block_type
register_block_class(block_type, override_block_type)

return config

Expand Down

0 comments on commit ead37b3

Please sign in to comment.