Skip to content

Commit

Permalink
fix doc extract crashes
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Feb 12, 2024
1 parent 97f94cf commit 3609a12
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 33 deletions.
10 changes: 6 additions & 4 deletions daras_ai_v2/azure_doc_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
auth_headers = {"Ocp-Apim-Subscription-Key": settings.AZURE_FORM_RECOGNIZER_KEY}


def azure_doc_extract_pages(pdf_url: str, model_id: str = "prebuilt-layout"):
result = azure_form_recognizer(pdf_url, model_id)
def azure_doc_extract_pages(
pdf_url: str, model_id: str = "prebuilt-layout", params: dict = None
):
result = azure_form_recognizer(pdf_url, model_id, params)
return [
records_to_text(extract_records(result, page["pageNumber"]))
for page in result["pages"]
Expand All @@ -39,13 +41,13 @@ def azure_form_recognizer_models() -> dict[str, str]:


@redis_cache_decorator
def azure_form_recognizer(url: str, model_id: str):
def azure_form_recognizer(url: str, model_id: str, params: dict = None):
r = requests.post(
str(
furl(settings.AZURE_FORM_RECOGNIZER_ENDPOINT)
/ f"formrecognizer/documentModels/{model_id}:analyze"
),
params={"api-version": "2023-07-31"},
params={"api-version": "2023-07-31"} | (params or {}),
headers=auth_headers,
json={"urlSource": url},
)
Expand Down
12 changes: 7 additions & 5 deletions daras_ai_v2/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,11 +429,9 @@ def pages_to_split_refs(
"title": (
doc_meta.name + (f", page {doc.end + 1}" if len(pages) > 1 else "")
),
"url": (
furl(f_url)
.set(fragment_args={"page": doc.end + 1} if len(pages) > 1 else {})
.url
),
"url": add_page_number_to_pdf(
f_url, (doc.end + 1 if len(pages) > 1 else f_url)
).url,
"snippet": doc.text,
**doc.kwargs,
"score": -1,
Expand All @@ -445,6 +443,10 @@ def pages_to_split_refs(
return refs


def add_page_number_to_pdf(url: str | furl, page_num: int) -> furl:
return furl(url).set(fragment_args={"page": page_num} if page_num else {})


sections_re = re.compile(r"(\s*[\r\n\f\v]|^)(\w+)\=", re.MULTILINE)


Expand Down
84 changes: 60 additions & 24 deletions recipes/DocExtract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import random
import subprocess
import tempfile
import threading
import typing

Expand Down Expand Up @@ -38,15 +39,15 @@
from daras_ai_v2.language_model_settings_widgets import language_model_settings
from daras_ai_v2.loom_video_widget import youtube_video
from daras_ai_v2.settings import service_account_key_path
from daras_ai_v2.vector_search import doc_url_to_metadata
from daras_ai_v2.vector_search import doc_url_to_metadata, add_page_number_to_pdf
from recipes.DocSearch import render_documents

DEFAULT_YOUTUBE_BOT_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ddc8ffac-93fb-11ee-89fb-02420a0001cb/Youtube%20transcripts.jpg.png"


class Columns(IntegerChoices):
webpage_url = 1, "Source"
title = 2, "Title"
webpage_url = 1, "url"
title = 2, "title"
description = 3, "Description"
content_url = 4, "Content"
transcript = 5, "Transcript"
Expand Down Expand Up @@ -148,7 +149,10 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
request: DocExtractPage.RequestModel = self.RequestModel.parse_obj(state)

entries = yield from flatapply_parallel(
extract_info, request.documents, message="Extracting metadata..."
extract_info,
request.documents,
message="Extracting metadata...",
max_workers=50,
)

yield "Preparing sheet..."
Expand Down Expand Up @@ -265,8 +269,6 @@ def col_i2a(col: int) -> str:


def extract_info(url: str) -> list[dict | None]:
from pypdf import PdfReader

if is_yt_url(url):
import yt_dlp

Expand All @@ -285,6 +287,9 @@ def extract_info(url: str) -> list[dict | None]:
f = furl(url)
if is_gdrive_url(f):
f_bytes, _ = gdrive_download(f, doc_meta.mime_type)
content_url = upload_file_from_bytes(
doc_meta.name, f_bytes, content_type=doc_meta.mime_type
)
else:
r = requests.get(
url,
Expand All @@ -293,18 +298,47 @@ def extract_info(url: str) -> list[dict | None]:
)
raise_for_status(r)
f_bytes = r.content
inputpdf = PdfReader(io.BytesIO(f_bytes))
content_url = url
num_pages = get_pdf_num_pages(f_bytes)
return [
{
"webpage_url": f.copy().set(fragment_args={"page": i + 1}).url,
"pdf_page": page,
"title": (doc_meta.name + f", page {i + 1}"),
"webpage_url": add_page_number_to_pdf(f, page_num).url,
"title": f"{doc_meta.name}, page {page_num}",
"doc_meta": doc_meta,
# "pdf_page": page,
"content_url": add_page_number_to_pdf(content_url, page_num).url,
"page_num": page_num,
}
for i, page in enumerate(inputpdf.pages)
for i in range(num_pages)
if (page_num := i + 1)
]
else:
return [
{
"webpage_url": url,
"title": doc_meta.name,
"doc_meta": doc_meta,
},
]

return [{"webpage_url": url, "title": doc_meta.name, "doc_meta": doc_meta}]

def get_pdf_num_pages(f_bytes: bytes) -> int:
with tempfile.NamedTemporaryFile() as infile:
infile.write(f_bytes)
args = ["pdfinfo", infile.name]
print("\t$ " + " ".join(args))
try:
output = subprocess.check_output(args, stderr=subprocess.STDOUT, text=True)
except subprocess.CalledProcessError as e:
raise ValueError(f"PDF Error: {e.output}")
output = output.lower()
for line in output.splitlines():
if not line.startswith("pages:"):
continue
try:
return int(line.split("pages:")[-1])
except ValueError:
raise ValueError(f"Unexpected PDF Info: {line}")


def process_entry(
Expand Down Expand Up @@ -338,8 +372,6 @@ def process_source(
row: int,
entry: dict,
) -> typing.Iterator[str | None]:
from pypdf import PdfWriter

webpage_url = entry["webpage_url"]
doc_meta = entry.get("doc_meta")

Expand Down Expand Up @@ -370,20 +402,15 @@ def process_source(
)
content_url, _ = audio_url_to_wav(webpage_url)
elif "application/pdf" in doc_meta.mime_type:
page = entry["pdf_page"]
outputpdf = PdfWriter()
outputpdf.add_page(page)
with io.BytesIO() as outf:
outputpdf.write(outf)
content_url = upload_file_from_bytes(
entry["title"], outf.getvalue(), content_type="application/pdf"
)
content_url = entry.get("content_url") or webpage_url
else:
raise NotImplementedError(
f"Unsupported type {doc_meta and doc_meta.mime_type} for {webpage_url}"
)
update_cell(spreadsheet_id, row, Columns.content_url.value, content_url)

usable_out_col = (Columns.transcript.value, "snippet")

transcript = existing_values[Columns.transcript.value]
if not transcript:
if (
Expand All @@ -395,7 +422,12 @@ def process_source(
transcript = run_asr(content_url, request.selected_asr_model)
elif "application/pdf" in doc_meta.mime_type:
yield "Extracting PDF"
transcript = str(azure_doc_extract_pages(content_url)[0])
if page_num := entry.get("page_num"):
params = dict(pages=str(page_num))
else:
params = None
transcript = str(azure_doc_extract_pages(content_url, params=params)[0])
usable_out_col = (Columns.transcript.value, "sections")
else:
raise NotImplementedError(
f"Unsupported type {doc_meta and doc_meta.mime_type} for {webpage_url}"
Expand All @@ -413,6 +445,7 @@ def process_source(
glossary_url=request.glossary_document,
)[0]
update_cell(spreadsheet_id, row, Columns.translation.value, translation)
usable_out_col = (Columns.translation.value, "snippet")
else:
translation = transcript
update_cell(spreadsheet_id, row, Columns.translation.value, "")
Expand All @@ -434,6 +467,9 @@ def process_source(
)
update_cell(spreadsheet_id, row, Columns.summary.value, summary)

if usable_out_col:
update_cell(spreadsheet_id, 1, *usable_out_col)


def google_api_should_retry(e: Exception) -> bool:
return isinstance(e, HttpError) and (
Expand Down

0 comments on commit 3609a12

Please sign in to comment.