From cc08b4bb546e3d445d26817bcaddf42188670a56 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Tue, 19 Nov 2024 14:32:21 -0500 Subject: [PATCH] Fix issues with spans/lines and providers --- marker/v2/builders/layout.py | 17 ++++---- marker/v2/builders/ocr.py | 65 +++++++++++++--------------- marker/v2/converters/pdf.py | 5 ++- marker/v2/providers/__init__.py | 9 +++- marker/v2/providers/pdf.py | 52 ++++++++++------------ marker/v2/schema/groups/page.py | 36 +++++++++------ tests/builders/test_overriding.py | 2 +- tests/providers/test_pdf_provider.py | 3 +- 8 files changed, 99 insertions(+), 90 deletions(-) diff --git a/marker/v2/builders/layout.py b/marker/v2/builders/layout.py index 0b4df842..25c2249b 100644 --- a/marker/v2/builders/layout.py +++ b/marker/v2/builders/layout.py @@ -5,7 +5,8 @@ from marker.settings import settings from marker.v2.builders import BaseBuilder -from marker.v2.providers.pdf import PageLines, PageSpans, PdfProvider +from marker.v2.providers import ProviderOutput, ProviderPageLines +from marker.v2.providers.pdf import PdfProvider from marker.v2.schema import BlockTypes from marker.v2.schema.document import Document from marker.v2.schema.groups.page import PageGroup @@ -25,7 +26,7 @@ def __init__(self, layout_model, config=None): def __call__(self, document: Document, provider: PdfProvider): layout_results = self.surya_layout(document.pages) self.add_blocks_to_pages(document.pages, layout_results) - self.merge_blocks(document.pages, provider.page_lines, provider.page_spans) + self.merge_blocks(document.pages, provider.page_lines) def get_batch_size(self): if self.batch_size is not None: @@ -54,18 +55,18 @@ def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[Layou layout_block.polygon = layout_block.polygon.rescale(layout_page_size, provider_page_size) page.add_structure(layout_block) - def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: PageLines, provider_page_spans: PageSpans): - for document_page, provider_lines in zip(document_pages, provider_page_lines.values()): + def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: ProviderPageLines): + for document_page in document_pages: + provider_lines = provider_page_lines[document_page.page_id] if not self.check_layout_coverage(document_page, provider_lines): document_page.text_extraction_method = "surya" continue - line_spans = provider_page_spans[document_page.page_id] - document_page.merge_blocks(provider_lines, line_spans, text_extraction_method="pdftext") + document_page.merge_blocks(provider_lines, text_extraction_method="pdftext") def check_layout_coverage( self, document_page: PageGroup, - provider_lines: List[Line], + provider_lines: List[ProviderOutput], coverage_threshold=0.5 ): layout_area = 0 @@ -76,6 +77,6 @@ def check_layout_coverage( continue layout_area += layout_block.polygon.area for provider_line in provider_lines: - provider_area += layout_block.polygon.intersection_area(provider_line.polygon) + provider_area += layout_block.polygon.intersection_area(provider_line.line.polygon) coverage_ratio = provider_area / layout_area if layout_area > 0 else 0 return coverage_ratio >= coverage_threshold diff --git a/marker/v2/builders/ocr.py b/marker/v2/builders/ocr.py index 64e1dac5..114ed55e 100644 --- a/marker/v2/builders/ocr.py +++ b/marker/v2/builders/ocr.py @@ -1,9 +1,8 @@ -from typing import Dict, List, Tuple - from surya.ocr import run_ocr from marker.settings import settings from marker.v2.builders import BaseBuilder +from marker.v2.providers import ProviderOutput, ProviderPageLines from marker.v2.providers.pdf import PdfProvider from marker.v2.schema import BlockTypes from marker.v2.schema.document import Document @@ -12,10 +11,6 @@ from marker.v2.schema.text.line import Line from marker.v2.schema.text.span import Span -PageLines = Dict[int, List[Line]] -LineSpans = Dict[int, List[Span]] -PageSpans = Dict[int, LineSpans] - class OcrBuilder(BaseBuilder): recognition_batch_size = None @@ -28,8 +23,8 @@ def __init__(self, detection_model, recognition_model, config=None): self.recognition_model = recognition_model def __call__(self, document: Document, provider: PdfProvider): - page_lines, page_spans = self.ocr_extraction(document, provider) - self.merge_blocks(document, page_lines, page_spans) + page_lines = self.ocr_extraction(document, provider) + self.merge_blocks(document, page_lines) def get_recognition_batch_size(self): if self.recognition_batch_size is not None: @@ -47,7 +42,7 @@ def get_detection_batch_size(self): return 4 return 4 - def ocr_extraction(self, document: Document, provider: PdfProvider) -> Tuple[PageLines, PageSpans]: + def ocr_extraction(self, document: Document, provider: PdfProvider) -> ProviderPageLines: page_list = [page for page in document.pages if page.text_extraction_method == "surya"] recognition_results = run_ocr( images=[page.lowres_image for page in page_list], @@ -61,43 +56,43 @@ def ocr_extraction(self, document: Document, provider: PdfProvider) -> Tuple[Pag ) page_lines = {} - page_spans = {} SpanClass: Span = get_block_class(BlockTypes.Span) LineClass: Line = get_block_class(BlockTypes.Line) for page_id, recognition_result in zip((page.page_id for page in page_list), recognition_results): - page_spans.setdefault(page_id, {}) page_lines.setdefault(page_id, []) page_size = provider.get_page_bbox(page_id).size - line_spans = page_spans[page_id] for ocr_line_idx, ocr_line in enumerate(recognition_result.text_lines): image_polygon = PolygonBox.from_bbox(recognition_result.image_bbox) polygon = PolygonBox.from_bbox(ocr_line.bbox).rescale(image_polygon.size, page_size) - page_lines[page_id].append(LineClass( - polygon=polygon, - page_id=page_id, - )) - - line_spans.setdefault(ocr_line_idx, []) - line_spans[ocr_line_idx].append(SpanClass( - text=ocr_line.text, - formats=['plain'], - page_id=page_id, - polygon=polygon, - minimum_position=0, - maximum_position=0, - font='', - font_weight=0, - font_size=0, - )) - - return page_lines, page_spans - - def merge_blocks(self, document: Document, page_lines: PageLines, page_spans: PageSpans): + line = LineClass( + polygon=polygon, + page_id=page_id, + ) + spans = [ + SpanClass( + text=ocr_line.text + "\n", + formats=['plain'], + page_id=page_id, + polygon=polygon, + minimum_position=0, + maximum_position=0, + font='', + font_weight=0, + font_size=0, + ) + ] + + page_lines[page_id].append(ProviderOutput(line=line, spans=spans)) + + return page_lines + + def merge_blocks(self, document: Document, page_lines: ProviderPageLines): ocred_pages = [page for page in document.pages if page.text_extraction_method == "surya"] - for document_page, lines, line_spans in zip(ocred_pages, page_lines.values(), page_spans.values()): - document_page.merge_blocks(lines, line_spans, text_extraction_method="surya") + for document_page in ocred_pages: + lines = page_lines[document_page.page_id] + document_page.merge_blocks(lines, text_extraction_method="surya") diff --git a/marker/v2/converters/pdf.py b/marker/v2/converters/pdf.py index 5f69df69..5d1dee78 100644 --- a/marker/v2/converters/pdf.py +++ b/marker/v2/converters/pdf.py @@ -81,7 +81,8 @@ def __call__(self, filepath: str): @click.option("--debug", is_flag=True) @click.option("--output_format", type=click.Choice(["markdown", "json"]), default="markdown") @click.option("--pages", type=str, default=None) -def main(fpath: str, output_dir: str, debug: bool, output_format: str, pages: str): +@click.option("--force_ocr", is_flag=True) +def main(fpath: str, output_dir: str, debug: bool, output_format: str, pages: str, force_ocr: bool): if pages is not None: pages = list(map(int, pages.split(","))) @@ -96,6 +97,8 @@ def main(fpath: str, output_dir: str, debug: bool, output_format: str, pages: st config["debug_pdf_images"] = True config["debug_layout_images"] = True config["debug_json"] = True + if force_ocr: + config["force_ocr"] = True converter = PdfConverter(config=config, output_format=output_format) diff --git a/marker/v2/providers/__init__.py b/marker/v2/providers/__init__.py index 04c9826a..e1d635e4 100644 --- a/marker/v2/providers/__init__.py +++ b/marker/v2/providers/__init__.py @@ -1,11 +1,18 @@ -from typing import List, Optional +from typing import List, Optional, Dict from pydantic import BaseModel +from marker.v2.schema.text import Span from marker.v2.schema.text.line import Line from marker.v2.util import assign_config +class ProviderOutput(BaseModel): + line: Line + spans: List[Span] + +ProviderPageLines = Dict[int, List[ProviderOutput]] + class BaseProvider: def __init__(self, filepath: str, config: Optional[BaseModel | dict] = None): assign_config(self, config) diff --git a/marker/v2/providers/pdf.py b/marker/v2/providers/pdf.py index a924f304..4595f26c 100644 --- a/marker/v2/providers/pdf.py +++ b/marker/v2/providers/pdf.py @@ -7,17 +7,13 @@ from PIL import Image from marker.ocr.heuristics import detect_bad_ocr -from marker.v2.providers import BaseProvider +from marker.v2.providers import BaseProvider, ProviderOutput, ProviderPageLines from marker.v2.schema.polygon import PolygonBox from marker.v2.schema import BlockTypes from marker.v2.schema.registry import get_block_class from marker.v2.schema.text.line import Line from marker.v2.schema.text.span import Span -PageLines = Dict[int, List[Line]] -LineSpans = Dict[int, List[Span]] -PageSpans = Dict[int, LineSpans] - class PdfProvider(BaseProvider): page_range: List[int] | None = None @@ -29,15 +25,14 @@ def __init__(self, filepath: str, config=None): super().__init__(filepath, config) self.doc: pdfium.PdfDocument = pdfium.PdfDocument(self.filepath) - self.page_lines: PageLines = {i: [] for i in range(len(self.doc))} - self.page_spans: PageSpans = {i: {} for i in range(len(self.doc))} + self.page_lines: ProviderPageLines = {i: [] for i in range(len(self.doc))} if self.page_range is None: self.page_range = range(len(self.doc)) assert max(self.page_range) < len(self.doc) and min(self.page_range) >= 0, "Invalid page range" if not self.force_ocr: - self.page_lines, self.page_spans = self.pdftext_extraction() + self.page_lines = self.pdftext_extraction() atexit.register(self.cleanup_pdf_doc) @@ -99,9 +94,8 @@ def font_names_to_format(self, font_name: str | None) -> Set[str]: formats.add("italic") return formats - def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]: - page_lines: PageLines = {} - page_spans: PageSpans = {} + def pdftext_extraction(self) -> ProviderPageLines: + page_lines: ProviderPageLines = {} page_char_blocks = dictionary_output( self.filepath, page_range=self.page_range, @@ -110,11 +104,10 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]: flatten_pdf=self.flatten_pdf ) SpanClass: Span = get_block_class(BlockTypes.Span) - LineClass: Span = get_block_class(BlockTypes.Line) + LineClass: Line = get_block_class(BlockTypes.Line) for page in page_char_blocks: page_id = page["page"] - lines: List[Line] = [] - line_spans: LineSpans = {} + lines: List[ProviderOutput] = [] for block in page["blocks"]: for line in block["lines"]: spans: List[Span] = [] @@ -139,20 +132,24 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]: text_extraction_method="pdftext" ) ) - lines.append(LineClass(polygon=PolygonBox.from_bbox(line["bbox"]), page_id=page_id)) - line_spans[len(lines) - 1] = spans - if self.check_line_spans(line_spans): + lines.append( + ProviderOutput( + line=LineClass(polygon=PolygonBox.from_bbox(line["bbox"]), page_id=page_id), + spans=spans + ) + ) + if self.check_line_spans(lines): page_lines[page_id] = lines - page_spans[page_id] = line_spans - return page_lines, page_spans + return page_lines - def check_line_spans(self, page_spans: LineSpans) -> bool: - if not len(sum(list(page_spans.values()), [])): + def check_line_spans(self, page_lines: List[ProviderOutput]) -> bool: + page_spans = [span for line in page_lines for span in line.spans] + if len(page_spans) == 0: return False + text = "" - for line_spans in page_spans.values(): - for span in line_spans: - text = text + " " + span.text + for span in page_spans: + text = text + " " + span.text text = text + "\n" if len(text.strip()) == 0: return False @@ -171,8 +168,5 @@ def get_page_bbox(self, idx: int) -> PolygonBox: page = self.doc[idx] return PolygonBox.from_bbox(page.get_bbox()) - def get_page_lines(self, idx: int) -> PageLines: - return self.page_lines[idx] - - def get_page_spans(self, idx: int) -> PageSpans: - return self.page_spans[idx] + def get_page_lines(self, idx: int) -> List[ProviderOutput]: + return self.page_lines[idx] \ No newline at end of file diff --git a/marker/v2/schema/groups/page.py b/marker/v2/schema/groups/page.py index 1a2c3ec4..4046cbc5 100644 --- a/marker/v2/schema/groups/page.py +++ b/marker/v2/schema/groups/page.py @@ -2,6 +2,7 @@ from PIL import Image +from marker.v2.providers import ProviderOutput from marker.v2.schema import BlockTypes from marker.v2.schema.blocks import Block, BlockId from marker.v2.schema.groups.base import Group @@ -55,29 +56,34 @@ def assemble_html(self, child_blocks, parent_structure=None): template += f"" return template - def merge_blocks( - self, - page_lines: List[Line], - line_spans: Dict[int, List[Span]], - text_extraction_method: str, - excluded_block_types=[BlockTypes.Line, BlockTypes.Span] - ): - provider_line_idxs = set(range(len(page_lines))) + def compute_line_block_intersections(self, provider_outputs: List[ProviderOutput], excluded_block_types): max_intersections = {} - for line_idx, line in enumerate(page_lines): + for line_idx, line in enumerate(provider_outputs): for block_idx, block in enumerate(self.children): if block.block_type in excluded_block_types: continue - intersection_pct = line.polygon.intersection_pct(block.polygon) + intersection_pct = line.line.polygon.intersection_pct(block.polygon) if line_idx not in max_intersections: max_intersections[line_idx] = (intersection_pct, block_idx) elif intersection_pct > max_intersections[line_idx][0]: max_intersections[line_idx] = (intersection_pct, block_idx) + return max_intersections + + def merge_blocks( + self, + provider_outputs: List[ProviderOutput], + text_extraction_method: str, + excluded_block_types=(BlockTypes.Line, BlockTypes.Span) + ): + provider_line_idxs = set(range(len(provider_outputs))) + max_intersections = self.compute_line_block_intersections(provider_outputs, excluded_block_types) assigned_line_idxs = set() - for line_idx, line in enumerate(page_lines): + for line_idx, provider_output in enumerate(provider_outputs): if line_idx in max_intersections and max_intersections[line_idx][0] > 0.0: + line = provider_output.line + spans = provider_output.spans self.add_full_block(line) block_idx = max_intersections[line_idx][1] block: Block = self.children[block_idx] @@ -85,14 +91,16 @@ def merge_blocks( block.polygon = block.polygon.merge([line.polygon]) block.text_extraction_method = text_extraction_method assigned_line_idxs.add(line_idx) - for span in line_spans[line_idx]: + for span in spans: self.add_full_block(span) line.add_structure(span) for line_idx in provider_line_idxs.difference(assigned_line_idxs): min_dist = None min_dist_idx = None - line = page_lines[line_idx] + provider_output: ProviderOutput = provider_outputs[line_idx] + line = provider_output.line + spans = provider_output.spans for block_idx, block in enumerate(self.children): if block.block_type in excluded_block_types: continue @@ -108,6 +116,6 @@ def merge_blocks( nearest_block.polygon = nearest_block.polygon.merge([line.polygon]) nearest_block.text_extraction_method = text_extraction_method assigned_line_idxs.add(line_idx) - for span in line_spans[line_idx]: + for span in spans: self.add_full_block(span) line.add_structure(span) diff --git a/tests/builders/test_overriding.py b/tests/builders/test_overriding.py index 01c4e588..65b28f81 100644 --- a/tests/builders/test_overriding.py +++ b/tests/builders/test_overriding.py @@ -46,4 +46,4 @@ def test_overriding_mp(): with mp.Pool(processes=2) as pool: results = pool.starmap(get_lines, [(pdf, config) for pdf in pdf_list]) - assert all([r[0].__class__ == NewLine for r in results]) + assert all([r[0].line.__class__ == NewLine for r in results]) diff --git a/tests/providers/test_pdf_provider.py b/tests/providers/test_pdf_provider.py index 0e765215..1f52007b 100644 --- a/tests/providers/test_pdf_provider.py +++ b/tests/providers/test_pdf_provider.py @@ -7,7 +7,8 @@ def test_pdf_provider(pdf_provider): assert pdf_provider.get_image(0, 72).size == (612, 792) assert pdf_provider.get_image(0, 96).size == (816, 1056) - spans_list = pdf_provider.get_page_spans(0) + page_lines = pdf_provider.get_page_spans(0) + spans_list = [span for line in page_lines for span in line.spans] assert len(spans_list) == 93 spans = spans_list[0]