Skip to content

Commit

Permalink
Fix issues with spans/lines and providers
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Nov 19, 2024
1 parent 7b2b3d8 commit cc08b4b
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 90 deletions.
17 changes: 9 additions & 8 deletions marker/v2/builders/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from marker.settings import settings
from marker.v2.builders import BaseBuilder
from marker.v2.providers.pdf import PageLines, PageSpans, PdfProvider
from marker.v2.providers import ProviderOutput, ProviderPageLines
from marker.v2.providers.pdf import PdfProvider
from marker.v2.schema import BlockTypes
from marker.v2.schema.document import Document
from marker.v2.schema.groups.page import PageGroup
Expand All @@ -25,7 +26,7 @@ def __init__(self, layout_model, config=None):
def __call__(self, document: Document, provider: PdfProvider):
layout_results = self.surya_layout(document.pages)
self.add_blocks_to_pages(document.pages, layout_results)
self.merge_blocks(document.pages, provider.page_lines, provider.page_spans)
self.merge_blocks(document.pages, provider.page_lines)

def get_batch_size(self):
if self.batch_size is not None:
Expand Down Expand Up @@ -54,18 +55,18 @@ def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[Layou
layout_block.polygon = layout_block.polygon.rescale(layout_page_size, provider_page_size)
page.add_structure(layout_block)

def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: PageLines, provider_page_spans: PageSpans):
for document_page, provider_lines in zip(document_pages, provider_page_lines.values()):
def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: ProviderPageLines):
for document_page in document_pages:
provider_lines = provider_page_lines[document_page.page_id]
if not self.check_layout_coverage(document_page, provider_lines):
document_page.text_extraction_method = "surya"
continue
line_spans = provider_page_spans[document_page.page_id]
document_page.merge_blocks(provider_lines, line_spans, text_extraction_method="pdftext")
document_page.merge_blocks(provider_lines, text_extraction_method="pdftext")

def check_layout_coverage(
self,
document_page: PageGroup,
provider_lines: List[Line],
provider_lines: List[ProviderOutput],
coverage_threshold=0.5
):
layout_area = 0
Expand All @@ -76,6 +77,6 @@ def check_layout_coverage(
continue
layout_area += layout_block.polygon.area
for provider_line in provider_lines:
provider_area += layout_block.polygon.intersection_area(provider_line.polygon)
provider_area += layout_block.polygon.intersection_area(provider_line.line.polygon)
coverage_ratio = provider_area / layout_area if layout_area > 0 else 0
return coverage_ratio >= coverage_threshold
65 changes: 30 additions & 35 deletions marker/v2/builders/ocr.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Dict, List, Tuple

from surya.ocr import run_ocr

from marker.settings import settings
from marker.v2.builders import BaseBuilder
from marker.v2.providers import ProviderOutput, ProviderPageLines
from marker.v2.providers.pdf import PdfProvider
from marker.v2.schema import BlockTypes
from marker.v2.schema.document import Document
Expand All @@ -12,10 +11,6 @@
from marker.v2.schema.text.line import Line
from marker.v2.schema.text.span import Span

PageLines = Dict[int, List[Line]]
LineSpans = Dict[int, List[Span]]
PageSpans = Dict[int, LineSpans]


class OcrBuilder(BaseBuilder):
recognition_batch_size = None
Expand All @@ -28,8 +23,8 @@ def __init__(self, detection_model, recognition_model, config=None):
self.recognition_model = recognition_model

def __call__(self, document: Document, provider: PdfProvider):
page_lines, page_spans = self.ocr_extraction(document, provider)
self.merge_blocks(document, page_lines, page_spans)
page_lines = self.ocr_extraction(document, provider)
self.merge_blocks(document, page_lines)

def get_recognition_batch_size(self):
if self.recognition_batch_size is not None:
Expand All @@ -47,7 +42,7 @@ def get_detection_batch_size(self):
return 4
return 4

def ocr_extraction(self, document: Document, provider: PdfProvider) -> Tuple[PageLines, PageSpans]:
def ocr_extraction(self, document: Document, provider: PdfProvider) -> ProviderPageLines:
page_list = [page for page in document.pages if page.text_extraction_method == "surya"]
recognition_results = run_ocr(
images=[page.lowres_image for page in page_list],
Expand All @@ -61,43 +56,43 @@ def ocr_extraction(self, document: Document, provider: PdfProvider) -> Tuple[Pag
)

page_lines = {}
page_spans = {}

SpanClass: Span = get_block_class(BlockTypes.Span)
LineClass: Line = get_block_class(BlockTypes.Line)

for page_id, recognition_result in zip((page.page_id for page in page_list), recognition_results):
page_spans.setdefault(page_id, {})
page_lines.setdefault(page_id, [])

page_size = provider.get_page_bbox(page_id).size
line_spans = page_spans[page_id]

for ocr_line_idx, ocr_line in enumerate(recognition_result.text_lines):
image_polygon = PolygonBox.from_bbox(recognition_result.image_bbox)
polygon = PolygonBox.from_bbox(ocr_line.bbox).rescale(image_polygon.size, page_size)

page_lines[page_id].append(LineClass(
polygon=polygon,
page_id=page_id,
))

line_spans.setdefault(ocr_line_idx, [])
line_spans[ocr_line_idx].append(SpanClass(
text=ocr_line.text,
formats=['plain'],
page_id=page_id,
polygon=polygon,
minimum_position=0,
maximum_position=0,
font='',
font_weight=0,
font_size=0,
))

return page_lines, page_spans

def merge_blocks(self, document: Document, page_lines: PageLines, page_spans: PageSpans):
line = LineClass(
polygon=polygon,
page_id=page_id,
)
spans = [
SpanClass(
text=ocr_line.text + "\n",
formats=['plain'],
page_id=page_id,
polygon=polygon,
minimum_position=0,
maximum_position=0,
font='',
font_weight=0,
font_size=0,
)
]

page_lines[page_id].append(ProviderOutput(line=line, spans=spans))

return page_lines

def merge_blocks(self, document: Document, page_lines: ProviderPageLines):
ocred_pages = [page for page in document.pages if page.text_extraction_method == "surya"]
for document_page, lines, line_spans in zip(ocred_pages, page_lines.values(), page_spans.values()):
document_page.merge_blocks(lines, line_spans, text_extraction_method="surya")
for document_page in ocred_pages:
lines = page_lines[document_page.page_id]
document_page.merge_blocks(lines, text_extraction_method="surya")
5 changes: 4 additions & 1 deletion marker/v2/converters/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def __call__(self, filepath: str):
@click.option("--debug", is_flag=True)
@click.option("--output_format", type=click.Choice(["markdown", "json"]), default="markdown")
@click.option("--pages", type=str, default=None)
def main(fpath: str, output_dir: str, debug: bool, output_format: str, pages: str):
@click.option("--force_ocr", is_flag=True)
def main(fpath: str, output_dir: str, debug: bool, output_format: str, pages: str, force_ocr: bool):
if pages is not None:
pages = list(map(int, pages.split(",")))

Expand All @@ -96,6 +97,8 @@ def main(fpath: str, output_dir: str, debug: bool, output_format: str, pages: st
config["debug_pdf_images"] = True
config["debug_layout_images"] = True
config["debug_json"] = True
if force_ocr:
config["force_ocr"] = True


converter = PdfConverter(config=config, output_format=output_format)
Expand Down
9 changes: 8 additions & 1 deletion marker/v2/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from typing import List, Optional
from typing import List, Optional, Dict

from pydantic import BaseModel

from marker.v2.schema.text import Span
from marker.v2.schema.text.line import Line
from marker.v2.util import assign_config


class ProviderOutput(BaseModel):
line: Line
spans: List[Span]

ProviderPageLines = Dict[int, List[ProviderOutput]]

class BaseProvider:
def __init__(self, filepath: str, config: Optional[BaseModel | dict] = None):
assign_config(self, config)
Expand Down
52 changes: 23 additions & 29 deletions marker/v2/providers/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,13 @@
from PIL import Image

from marker.ocr.heuristics import detect_bad_ocr
from marker.v2.providers import BaseProvider
from marker.v2.providers import BaseProvider, ProviderOutput, ProviderPageLines
from marker.v2.schema.polygon import PolygonBox
from marker.v2.schema import BlockTypes
from marker.v2.schema.registry import get_block_class
from marker.v2.schema.text.line import Line
from marker.v2.schema.text.span import Span

PageLines = Dict[int, List[Line]]
LineSpans = Dict[int, List[Span]]
PageSpans = Dict[int, LineSpans]


class PdfProvider(BaseProvider):
page_range: List[int] | None = None
Expand All @@ -29,15 +25,14 @@ def __init__(self, filepath: str, config=None):
super().__init__(filepath, config)

self.doc: pdfium.PdfDocument = pdfium.PdfDocument(self.filepath)
self.page_lines: PageLines = {i: [] for i in range(len(self.doc))}
self.page_spans: PageSpans = {i: {} for i in range(len(self.doc))}
self.page_lines: ProviderPageLines = {i: [] for i in range(len(self.doc))}

if self.page_range is None:
self.page_range = range(len(self.doc))
assert max(self.page_range) < len(self.doc) and min(self.page_range) >= 0, "Invalid page range"

if not self.force_ocr:
self.page_lines, self.page_spans = self.pdftext_extraction()
self.page_lines = self.pdftext_extraction()

atexit.register(self.cleanup_pdf_doc)

Expand Down Expand Up @@ -99,9 +94,8 @@ def font_names_to_format(self, font_name: str | None) -> Set[str]:
formats.add("italic")
return formats

def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]:
page_lines: PageLines = {}
page_spans: PageSpans = {}
def pdftext_extraction(self) -> ProviderPageLines:
page_lines: ProviderPageLines = {}
page_char_blocks = dictionary_output(
self.filepath,
page_range=self.page_range,
Expand All @@ -110,11 +104,10 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]:
flatten_pdf=self.flatten_pdf
)
SpanClass: Span = get_block_class(BlockTypes.Span)
LineClass: Span = get_block_class(BlockTypes.Line)
LineClass: Line = get_block_class(BlockTypes.Line)
for page in page_char_blocks:
page_id = page["page"]
lines: List[Line] = []
line_spans: LineSpans = {}
lines: List[ProviderOutput] = []
for block in page["blocks"]:
for line in block["lines"]:
spans: List[Span] = []
Expand All @@ -139,20 +132,24 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]:
text_extraction_method="pdftext"
)
)
lines.append(LineClass(polygon=PolygonBox.from_bbox(line["bbox"]), page_id=page_id))
line_spans[len(lines) - 1] = spans
if self.check_line_spans(line_spans):
lines.append(
ProviderOutput(
line=LineClass(polygon=PolygonBox.from_bbox(line["bbox"]), page_id=page_id),
spans=spans
)
)
if self.check_line_spans(lines):
page_lines[page_id] = lines
page_spans[page_id] = line_spans
return page_lines, page_spans
return page_lines

def check_line_spans(self, page_spans: LineSpans) -> bool:
if not len(sum(list(page_spans.values()), [])):
def check_line_spans(self, page_lines: List[ProviderOutput]) -> bool:
page_spans = [span for line in page_lines for span in line.spans]
if len(page_spans) == 0:
return False

text = ""
for line_spans in page_spans.values():
for span in line_spans:
text = text + " " + span.text
for span in page_spans:
text = text + " " + span.text
text = text + "\n"
if len(text.strip()) == 0:
return False
Expand All @@ -171,8 +168,5 @@ def get_page_bbox(self, idx: int) -> PolygonBox:
page = self.doc[idx]
return PolygonBox.from_bbox(page.get_bbox())

def get_page_lines(self, idx: int) -> PageLines:
return self.page_lines[idx]

def get_page_spans(self, idx: int) -> PageSpans:
return self.page_spans[idx]
def get_page_lines(self, idx: int) -> List[ProviderOutput]:
return self.page_lines[idx]
36 changes: 22 additions & 14 deletions marker/v2/schema/groups/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from PIL import Image

from marker.v2.providers import ProviderOutput
from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import Block, BlockId
from marker.v2.schema.groups.base import Group
Expand Down Expand Up @@ -55,44 +56,51 @@ def assemble_html(self, child_blocks, parent_structure=None):
template += f"<content-ref src='{c.id}'></content-ref>"
return template

def merge_blocks(
self,
page_lines: List[Line],
line_spans: Dict[int, List[Span]],
text_extraction_method: str,
excluded_block_types=[BlockTypes.Line, BlockTypes.Span]
):
provider_line_idxs = set(range(len(page_lines)))
def compute_line_block_intersections(self, provider_outputs: List[ProviderOutput], excluded_block_types):
max_intersections = {}

for line_idx, line in enumerate(page_lines):
for line_idx, line in enumerate(provider_outputs):
for block_idx, block in enumerate(self.children):
if block.block_type in excluded_block_types:
continue
intersection_pct = line.polygon.intersection_pct(block.polygon)
intersection_pct = line.line.polygon.intersection_pct(block.polygon)
if line_idx not in max_intersections:
max_intersections[line_idx] = (intersection_pct, block_idx)
elif intersection_pct > max_intersections[line_idx][0]:
max_intersections[line_idx] = (intersection_pct, block_idx)
return max_intersections

def merge_blocks(
self,
provider_outputs: List[ProviderOutput],
text_extraction_method: str,
excluded_block_types=(BlockTypes.Line, BlockTypes.Span)
):
provider_line_idxs = set(range(len(provider_outputs)))
max_intersections = self.compute_line_block_intersections(provider_outputs, excluded_block_types)

assigned_line_idxs = set()
for line_idx, line in enumerate(page_lines):
for line_idx, provider_output in enumerate(provider_outputs):
if line_idx in max_intersections and max_intersections[line_idx][0] > 0.0:
line = provider_output.line
spans = provider_output.spans
self.add_full_block(line)
block_idx = max_intersections[line_idx][1]
block: Block = self.children[block_idx]
block.add_structure(line)
block.polygon = block.polygon.merge([line.polygon])
block.text_extraction_method = text_extraction_method
assigned_line_idxs.add(line_idx)
for span in line_spans[line_idx]:
for span in spans:
self.add_full_block(span)
line.add_structure(span)

for line_idx in provider_line_idxs.difference(assigned_line_idxs):
min_dist = None
min_dist_idx = None
line = page_lines[line_idx]
provider_output: ProviderOutput = provider_outputs[line_idx]
line = provider_output.line
spans = provider_output.spans
for block_idx, block in enumerate(self.children):
if block.block_type in excluded_block_types:
continue
Expand All @@ -108,6 +116,6 @@ def merge_blocks(
nearest_block.polygon = nearest_block.polygon.merge([line.polygon])
nearest_block.text_extraction_method = text_extraction_method
assigned_line_idxs.add(line_idx)
for span in line_spans[line_idx]:
for span in spans:
self.add_full_block(span)
line.add_structure(span)
2 changes: 1 addition & 1 deletion tests/builders/test_overriding.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ def test_overriding_mp():

with mp.Pool(processes=2) as pool:
results = pool.starmap(get_lines, [(pdf, config) for pdf in pdf_list])
assert all([r[0].__class__ == NewLine for r in results])
assert all([r[0].line.__class__ == NewLine for r in results])
Loading

0 comments on commit cc08b4b

Please sign in to comment.