Skip to content

Commit

Permalink
Better debugging, heading detection
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Oct 16, 2024
1 parent 78acbc0 commit d807c17
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 78 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ wandb
*.dat
report.json
benchmark_data
debug

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ First, some configuration:
- Inspect the settings in `marker/settings.py`. You can override any settings with environment variables.
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`.
- By default, marker will use `surya` for OCR. Surya is slower on CPU, but more accurate than tesseract. It also doesn't require you to specify the languages in the document. If you want faster OCR, set `OCR_ENGINE` to `ocrmypdf`. This also requires external dependencies (see above). If you don't want OCR at all, set `OCR_ENGINE` to `None`.
- Some PDFs, even digital ones, have bad text in them. Set `OCR_ALL_PAGES=true` to force OCR if you find bad output from marker.

## Interactive App

Expand All @@ -107,15 +108,15 @@ marker_single /path/to/file.pdf /path/to/output/folder --batch_multiplier 2 --ma

- `--batch_multiplier` is how much to multiply default batch sizes by if you have extra VRAM. Higher numbers will take more VRAM, but process faster. Set to 2 by default. The default batch sizes will take ~3GB of VRAM.
- `--max_pages` is the maximum number of pages to process. Omit this to convert the entire document.
- `--start_page` is the page to start from (default is None, will start from the first page).
- `--langs` is an optional comma separated list of the languages in the document, for OCR. Optional by default, required if you use tesseract.
- `--ocr_all_pages` is an optional argument to force OCR on all pages of the PDF. If this or the env var `OCR_ALL_PAGES` are true, OCR will be forced.

The list of supported languages for surya OCR is [here](https://github.com/VikParuchuri/surya/blob/master/surya/languages.py). If you need more languages, you can use any language supported by [Tesseract](https://tesseract-ocr.github.io/tessdoc/Data-Files#data-files-for-version-400-november-29-2016) if you set `OCR_ENGINE` to `ocrmypdf`. If you don't need OCR, marker can work with any language.

## Convert multiple files

```shell
marker /path/to/input/folder /path/to/output/folder --workers 4 --max 10 --min_length 10000
marker /path/to/input/folder /path/to/output/folder --workers 4 --max 10
```

- `--workers` is the number of pdfs to convert at once. This is set to 1 by default, but you can increase it to increase throughput, at the cost of more CPU/GPU usage. Marker will use 5GB of VRAM per worker at the peak, and 3.5GB average.
Expand All @@ -136,7 +137,7 @@ You can use language names or codes. The exact codes depend on the OCR engine.
## Convert multiple files on multiple GPUs

```shell
MIN_LENGTH=10000 METADATA_FILE=../pdf_meta.json NUM_DEVICES=4 NUM_WORKERS=15 marker_chunk_convert ../pdf_in ../md_out
METADATA_FILE=../pdf_meta.json NUM_DEVICES=4 NUM_WORKERS=15 marker_chunk_convert ../pdf_in ../md_out
```

- `METADATA_FILE` is an optional path to a json file with metadata about the pdfs. See above for the format.
Expand All @@ -150,15 +151,18 @@ Note that the env variables above are specific to this script, and cannot be set

There are some settings that you may find useful if things aren't working the way you expect:

- `OCR_ALL_PAGES` - set this to true to force OCR all pages. This can be very useful if the table layouts aren't recognized properly by default, or if there is garbled text.
- `OCR_ALL_PAGES` - set this to true to force OCR all pages. This can be very useful if there is garbled text in the output of marker.
- `TORCH_DEVICE` - set this to force marker to use a given torch device for inference.
- `OCR_ENGINE` - can set this to `surya` or `ocrmypdf`.
- `DEBUG` - setting this to `True` shows ray logs when converting multiple pdfs
- Verify that you set the languages correctly, or passed in a metadata file.
- If you're getting out of memory errors, decrease worker count (increased the `VRAM_PER_TASK` setting). You can also try splitting up long PDFs into multiple files.

In general, if output is not what you expect, trying to OCR the PDF is a good first step. Not all PDFs have good text/bboxes embedded in them.

## Debugging

Set `DEBUG=true` to save data to the `debug` subfolder in the marker root directory. This will save images of each page with detected layout and text, as well as output a json file with additional bounding box information.

## Useful settings

These settings can improve/change output quality:
Expand Down
10 changes: 5 additions & 5 deletions convert_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import pypdfium2 # Needs to be at the top to avoid warnings
import os

from marker.settings import settings

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # For some reason, transformers decided to use .isin for a simple op, which is not supported on MPS

import argparse
Expand All @@ -22,23 +25,20 @@ def main():
parser.add_argument("--start_page", type=int, default=None, help="Page to start processing at")
parser.add_argument("--langs", type=str, help="Optional languages to use for OCR, comma separated", default=None)
parser.add_argument("--batch_multiplier", type=int, default=2, help="How much to increase batch sizes")
parser.add_argument("--debug", action="store_true", help="Enable debug logging", default=False)
parser.add_argument("--ocr_all_pages", action="store_true", help="Force OCR on all pages", default=False)
args = parser.parse_args()

langs = args.langs.split(",") if args.langs else None

fname = args.filename
model_lst = load_all_models()
start = time.time()
full_text, images, out_meta = convert_single_pdf(fname, model_lst, max_pages=args.max_pages, langs=langs, batch_multiplier=args.batch_multiplier, start_page=args.start_page, ocr_all_pages=args.ocr_all_pages)
full_text, images, out_meta = convert_single_pdf(fname, model_lst, max_pages=args.max_pages, langs=langs, batch_multiplier=args.batch_multiplier, start_page=args.start_page)

fname = os.path.basename(fname)
subfolder_path = save_markdown(args.output, fname, full_text, images, out_meta)

print(f"Saved markdown to the {subfolder_path} folder")
if args.debug:
print(f"Total time: {time.time() - start}")
print(f"Total time: {time.time() - start}")


if __name__ == "__main__":
Expand Down
13 changes: 6 additions & 7 deletions marker/cleaners/headings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def bucket_headings(line_heights, num_levels=settings.HEADING_LEVEL_COUNT):
data_labels = np.concatenate([data, labels.reshape(-1, 1)], axis=1)
data_labels = np.sort(data_labels, axis=0)

cluster_means = {label: np.mean(data[labels == label, 0]) for label in np.unique(labels)}
cluster_means = {label: np.mean(data_labels[data_labels[:, 1] == label, 0]) for label in np.unique(labels)}
label_max = None
label_min = None
heading_ranges = []
Expand All @@ -95,15 +95,14 @@ def bucket_headings(line_heights, num_levels=settings.HEADING_LEVEL_COUNT):
return heading_ranges


def infer_heading_levels(pages: List[Page]):
def infer_heading_levels(pages: List[Page], height_tol=.99):
all_line_heights = []
for page in pages:
for block in page.blocks:
if block.block_type not in ["Title", "Section-header"]:
continue

block_heights = [min(l.height, l.width) for l in block.lines] # Account for rotation
all_line_heights.extend(block_heights)
all_line_heights.extend([l.height for l in block.lines])

heading_ranges = bucket_headings(all_line_heights)

Expand All @@ -112,11 +111,11 @@ def infer_heading_levels(pages: List[Page]):
if block.block_type not in ["Title", "Section-header"]:
continue

block_heights = [min(l.height, l.width) for l in block.lines] # Account for rotation
block_heights = [l.height for l in block.lines] # Account for rotation
avg_height = sum(block_heights) / len(block_heights)
for idx, (min_height, max_height) in enumerate(heading_ranges):
if avg_height >= min_height:
block.heading_level = len(heading_ranges) - idx
if avg_height >= min_height * height_tol:
block.heading_level = idx + 1
break

if block.heading_level is None:
Expand Down
5 changes: 3 additions & 2 deletions marker/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from marker.utils import flush_cuda_memory
from marker.tables.table import format_tables
from marker.debug.data import dump_bbox_debug_data
from marker.debug.data import dump_bbox_debug_data, draw_page_debug_images
from marker.layout.layout import surya_layout, annotate_block_types
from marker.layout.order import surya_order, sort_blocks_in_reading_order
from marker.ocr.lang import replace_langs_with_codes, validate_langs
Expand Down Expand Up @@ -108,7 +108,8 @@ def convert_single_pdf(
annotate_block_types(pages)

# Dump debug data if flags are set
dump_bbox_debug_data(doc, fname, pages)
draw_page_debug_images(fname, pages)
dump_bbox_debug_data(fname, pages)

# Find reading order for blocks
# Sort blocks by reading order
Expand Down
80 changes: 34 additions & 46 deletions marker/debug/data.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,62 @@
import base64
import json
import math
import os
from typing import List

from marker.pdf.images import render_image
from marker.debug.render import render_on_image
from marker.schema.bbox import rescale_bbox
from marker.schema.page import Page
from marker.settings import settings
from PIL import Image
import io


def dump_equation_debug_data(doc, images, converted_spans):
if not settings.DEBUG_DATA_FOLDER or settings.DEBUG_LEVEL == 0:
def draw_page_debug_images(fname, pages: List[Page]):
if not settings.DEBUG:
return

if len(images) == 0:
return
# Remove extension from doc name
doc_base = os.path.basename(fname).rsplit(".", 1)[0]

# We attempted one conversion per image
assert len(converted_spans) == len(images)

data_lines = []
for idx, (pil_image, converted_span) in enumerate(zip(images, converted_spans)):
if converted_span is None:
continue
# Image is a BytesIO object
img_bytes = io.BytesIO()
pil_image.save(img_bytes, format="WEBP", lossless=True)
b64_image = base64.b64encode(img_bytes.getvalue()).decode("utf-8")
data_lines.append({
"image": b64_image,
"text": converted_span.text,
"bbox": converted_span.bbox
})
debug_folder = os.path.join(settings.DEBUG_DATA_FOLDER, doc_base)
os.makedirs(debug_folder, exist_ok=True)
for idx, page in enumerate(pages):
img_size = (int(math.ceil(page.text_lines.image_bbox[2])), int(math.ceil(page.text_lines.image_bbox[3])))
png_image = Image.new("RGB", img_size, color="white")

# Remove extension from doc name
doc_base = os.path.basename(doc.name).rsplit(".", 1)[0]
line_bboxes = []
line_text = []
for block in page.blocks:
for line in block.lines:
line_bboxes.append(rescale_bbox(page.bbox, page.text_lines.image_bbox, line.bbox))
line_text.append(line.prelim_text)

debug_file = os.path.join(settings.DEBUG_DATA_FOLDER, f"{doc_base}_equations.json")
with open(debug_file, "w+") as f:
json.dump(data_lines, f)
render_on_image(line_bboxes, png_image, labels=line_text, color="black", draw_bbox=False)

line_bboxes = [line.bbox for line in page.text_lines.bboxes]
render_on_image(line_bboxes, png_image, color="blue")

layout_boxes = [rescale_bbox(page.layout.image_bbox, page.text_lines.image_bbox, box.bbox) for box in page.layout.bboxes]
layout_labels = [box.label for box in page.layout.bboxes]

render_on_image(layout_boxes, png_image, labels=layout_labels, color="red")

def dump_bbox_debug_data(doc, fname, blocks: List[Page]):
if not settings.DEBUG_DATA_FOLDER or settings.DEBUG_LEVEL < 2:
debug_file = os.path.join(debug_folder, f"page_{idx}.png")
png_image.save(debug_file)


def dump_bbox_debug_data(fname, pages: List[Page]):
if not settings.DEBUG:
return

# Remove extension from doc name
doc_base = fname.rsplit(".", 1)[0]
doc_base = os.path.basename(fname).rsplit(".", 1)[0]

debug_file = os.path.join(settings.DEBUG_DATA_FOLDER, f"{doc_base}_bbox.json")
debug_data = []
for idx, page_blocks in enumerate(blocks):
page = doc[idx]

png_image = render_image(page, dpi=settings.TEXIFY_DPI)
width, height = png_image.size
max_dimension = 6000
if width > max_dimension or height > max_dimension:
scaling_factor = min(max_dimension / width, max_dimension / height)
png_image = png_image.resize((int(width * scaling_factor), int(height * scaling_factor)), Image.ANTIALIAS)

img_bytes = io.BytesIO()
png_image.save(img_bytes, format="WEBP", lossless=True, quality=100)
b64_image = base64.b64encode(img_bytes.getvalue()).decode("utf-8")

for idx, page_blocks in enumerate(pages):
page_data = page_blocks.model_dump(exclude=["images", "layout", "text_lines"])
page_data["layout"] = page_blocks.layout.model_dump(exclude=["segmentation_map"])
page_data["text_lines"] = page_blocks.text_lines.model_dump(exclude=["heatmap", "affinity_map"])
#page_data["image"] = b64_image
debug_data.append(page_data)

with open(debug_file, "w+") as f:
Expand Down
4 changes: 0 additions & 4 deletions marker/equations/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from copy import deepcopy
from typing import List

from marker.debug.data import dump_equation_debug_data
from marker.equations.inference import get_total_texify_tokens, get_latex_batched
from marker.pdf.images import render_bbox_image
from marker.schema.bbox import rescale_bbox
Expand Down Expand Up @@ -177,7 +176,4 @@ def replace_equations(doc, pages: List[Page], texify_model, batch_multiplier=1):
successful_ocr += success_count
unsuccessful_ocr += fail_count

# If debug mode is on, dump out conversions for comparison
dump_equation_debug_data(doc, images, converted_spans)

return pages, {"successful_ocr": successful_ocr, "unsuccessful_ocr": unsuccessful_ocr, "equations": eq_count}
10 changes: 5 additions & 5 deletions marker/postprocessors/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,22 +146,22 @@ def merge_lines(blocks: List[List[MergedBlock]]):
prev_line = None
block_text = ""
block_type = ""
block_heading_level = None
prev_heading_level = None

for idx, page in enumerate(blocks):
for block in page:
block_type = block.block_type
if block_type != prev_type and prev_type:
if (block_type != prev_type and prev_type) or (block.heading_level != prev_heading_level and prev_heading_level):
text_blocks.append(
FullyMergedBlock(
text=block_surround(block_text, prev_type, block_heading_level),
text=block_surround(block_text, prev_type, prev_heading_level),
block_type=prev_type
)
)
block_text = ""

prev_type = block_type
block_heading_level = block.heading_level
prev_heading_level = block.heading_level
# Join lines in the block together properly
for i, line in enumerate(block.lines):
line_height = line.bbox[3] - line.bbox[1]
Expand All @@ -180,7 +180,7 @@ def merge_lines(blocks: List[List[MergedBlock]]):
# Append the final block
text_blocks.append(
FullyMergedBlock(
text=block_surround(block_text, prev_type, block_heading_level),
text=block_surround(block_text, prev_type, prev_heading_level),
block_type=block_type
)
)
Expand Down
10 changes: 7 additions & 3 deletions marker/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic import computed_field
from pydantic_settings import BaseSettings
import torch
import os


class Settings(BaseSettings):
Expand All @@ -12,6 +13,7 @@ class Settings(BaseSettings):
IMAGE_DPI: int = 96 # DPI to render images pulled from pdf at
EXTRACT_IMAGES: bool = True # Extract images from pdfs and save them
PAGINATE_OUTPUT: bool = False # Paginate output markdown
BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

@computed_field
@property
Expand Down Expand Up @@ -84,9 +86,11 @@ def TORCH_DEVICE_MODEL(self) -> str:
HEADING_DEFAULT_LEVEL: int = 2

# Debug
DEBUG: bool = False # Enable debug logging
DEBUG_DATA_FOLDER: Optional[str] = None
DEBUG_LEVEL: int = 0 # 0 to 2, 2 means log everything
DEBUG_DATA_FOLDER: str = os.path.join(BASE_DIR, "debug")
DEBUG: bool = False
FONT_DIR: str = os.path.join(BASE_DIR, "static", "fonts")
DEBUG_RENDER_FONT: str = os.path.join(FONT_DIR, "GoNotoCurrent-Regular.ttf")
FONT_DL_BASE: str = "https://github.com/satbyy/go-noto-universal/releases/download/v7.0"

@computed_field
@property
Expand Down
7 changes: 6 additions & 1 deletion marker/tables/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@ def get_table_boxes(pages: List[Page], doc: PdfDocument, fname):
out_img_sizes = []
for i in range(len(table_counts)):
if i in table_idxs:
text_lines.extend([sel_text_lines.pop(0)] * table_counts[i])
page_ocred = pages[i].ocr_method is not None
if page_ocred:
# This will force re-detection of cells if the page was ocred (the text lines are not accurate)
text_lines.extend([None] * table_counts[i])
else:
text_lines.extend([sel_text_lines.pop(0)] * table_counts[i])
out_img_sizes.extend([img_sizes[i]] * table_counts[i])

assert len(table_imgs) == len(table_bboxes) == len(text_lines) == len(out_img_sizes)
Expand Down
2 changes: 2 additions & 0 deletions static/fonts/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore

0 comments on commit d807c17

Please sign in to comment.