Skip to content

Commit

Permalink
google drive support for bulk runner
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Oct 12, 2023
1 parent 73cf54f commit a8e4501
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 62 deletions.
121 changes: 81 additions & 40 deletions daras_ai_v2/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def doc_url_to_text_pages(
doc_meta: DocMetadata,
google_translate_target: str | None,
selected_asr_model: str | None,
) -> list[str]:
) -> typing.Union[list[str], "pd.DataFrame"]:
"""
Download document from url and convert to text pages.
Expand All @@ -437,33 +437,61 @@ def doc_url_to_text_pages(
Returns:
list of text pages
"""
f_bytes, ext = download_content_bytes(f_url=f_url, mime_type=doc_meta.mime_type)
if not f_bytes:
return []
pages = bytes_to_text_pages_or_df(
f_url=f_url,
f_name=doc_meta.name,
f_bytes=f_bytes,
ext=ext,
mime_type=doc_meta.mime_type,
selected_asr_model=selected_asr_model,
)
# optionally, translate text
if google_translate_target and isinstance(pages, list):
pages = run_google_translate(pages, google_translate_target)
return pages


def download_content_bytes(*, f_url: str, mime_type: str) -> tuple[bytes, str]:
f = furl(f_url)
f_name = doc_meta.name
if is_gdrive_url(f):
# download from google drive
f_bytes, ext = gdrive_download(f, doc_meta.mime_type)
else:
return gdrive_download(f, mime_type)
try:
# download from url
r = requests.get(
f_url,
headers={"User-Agent": random.choice(FAKE_USER_AGENTS)},
timeout=settings.EXTERNAL_REQUEST_TIMEOUT_SEC,
)
r.raise_for_status()
except requests.RequestException as e:
print(f"ignore error while downloading {f_url}: {e}")
return b"", ""
f_bytes = r.content
# if it's a known encoding, standardize to utf-8
if r.encoding:
try:
r = requests.get(
f_url,
headers={"User-Agent": random.choice(FAKE_USER_AGENTS)},
timeout=settings.EXTERNAL_REQUEST_TIMEOUT_SEC,
)
r.raise_for_status()
except requests.RequestException as e:
print(f"ignore error while downloading {f_url}: {e}")
return []
f_bytes = r.content
# if it's a known encoding, standardize to utf-8
if r.encoding:
try:
codec = codecs.lookup(r.encoding)
except LookupError:
pass
else:
f_bytes = codec.decode(f_bytes)[0].encode()
ext = guess_ext_from_response(r)
codec = codecs.lookup(r.encoding)
except LookupError:
pass
else:
f_bytes = codec.decode(f_bytes)[0].encode()
ext = guess_ext_from_response(r)
return f_bytes, ext


def bytes_to_text_pages_or_df(
*,
f_url: str,
f_name: str,
f_bytes: bytes,
ext: str,
mime_type: str,
selected_asr_model: str | None,
) -> typing.Union[list[str], "pd.DataFrame"]:
# convert document to text pages
match ext:
case ".pdf":
Expand All @@ -477,25 +505,42 @@ def doc_url_to_text_pages(
raise ValueError(
"For transcribing audio/video, please choose an ASR model from the settings!"
)
if is_gdrive_url(f):
f_url = upload_file_from_bytes(
f_name, f_bytes, content_type=doc_meta.mime_type
)
if is_gdrive_url(furl(f_url)):
f_url = upload_file_from_bytes(f_name, f_bytes, content_type=mime_type)
pages = [run_asr(f_url, selected_model=selected_asr_model, language="en")]
case ".csv" | ".xlsx" | ".tsv" | ".ods":
import pandas as pd

df = pd.read_csv(io.BytesIO(f_bytes), dtype=str).fillna("")
case _:
df = bytes_to_df(f_name=f_name, f_bytes=f_bytes, ext=ext)
assert (
"snippet" in df.columns or "sections" in df.columns
), f'uploaded spreadsheet must contain a "snippet" or "sections" column - {f_name !r}'
pages = df
return df

return pages


def bytes_to_df(
*,
f_name: str,
f_bytes: bytes,
ext: str,
) -> "pd.DataFrame":
import pandas as pd

f = io.BytesIO(f_bytes)
match ext:
case ".csv":
df = pd.read_csv(f, dtype=str)
case ".tsv":
df = pd.read_csv(f, sep="\t", dtype=str)
case ".xls" | ".xlsx":
df = pd.read_excel(f, dtype=str)
case ".json":
df = pd.read_json(f, dtype=str)
case ".xml":
df = pd.read_xml(f, dtype=str)
case _:
raise ValueError(f"Unsupported document format {ext!r} ({f_name})")
# optionally, translate text
if google_translate_target:
pages = run_google_translate(pages, google_translate_target)
return pages
return df.fillna("")


def pdf_to_text_pages(f: typing.BinaryIO) -> list[str]:
Expand Down Expand Up @@ -534,10 +579,6 @@ def pandoc_to_text(f_name: str, f_bytes: bytes, to="plain") -> str:
subprocess.check_call(args)
return outfile.read()

refs = st.session_state.get("references", [])
if not refs:
return


def render_sources_widget(refs: list[SearchReference]):
if not refs:
Expand Down
68 changes: 46 additions & 22 deletions recipes/BulkRunner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import io
import typing

import pandas as pd
import requests
from fastapi import HTTPException
from furl import furl
from pydantic import BaseModel, Field
Expand All @@ -14,6 +12,10 @@
from daras_ai_v2.doc_search_settings_widgets import document_uploader
from daras_ai_v2.functional import map_parallel
from daras_ai_v2.query_params_util import extract_query_params
from daras_ai_v2.vector_search import (
doc_url_to_metadata,
download_content_bytes,
)
from recipes.DocSearch import render_documents

CACHED_COLUMNS = "__cached_columns"
Expand Down Expand Up @@ -71,15 +73,17 @@ def render_form_v2(self):
)

if files:
dfs = map_parallel(_read_df, files)
st.session_state[CACHED_COLUMNS] = list(
{
col: None
for df in map_parallel(_read_df, files)
for df in dfs
for col in df.columns
if not col.startswith("Unnamed:")
}
)
else:
dfs = []
st.session_state.pop(CACHED_COLUMNS, None)

required_input_fields = {}
Expand Down Expand Up @@ -145,12 +149,26 @@ def render_form_v2(self):
st.write(
"""
##### Input Data Preview
Here's how we've parsed your data.
Here's how we've parsed your data.
"""
)

for file in files:
st.data_table(file)
for df in dfs:
st.text_area(
"",
value=df.to_string(
max_cols=10, max_rows=10, max_colwidth=40, show_dimensions=True
),
label_visibility="collapsed",
disabled=True,
style={
"white-space": "pre",
"overflow": "scroll",
"font-family": "monospace",
"font-size": "0.9rem",
},
height=250,
)

if not (required_input_fields or optional_input_fields):
return
Expand Down Expand Up @@ -218,6 +236,8 @@ def run_v2(
request: "BulkRunnerPage.RequestModel",
response: "BulkRunnerPage.ResponseModel",
) -> typing.Iterator[str | None]:
import pandas as pd

response.output_documents = []

for doc_ix, doc in enumerate(request.documents):
Expand Down Expand Up @@ -408,21 +428,25 @@ def is_arr(field_props: dict) -> bool:
return False


def _read_df(f: str) -> "pd.DataFrame":
def _read_df(f_url: str) -> "pd.DataFrame":
import pandas as pd

r = requests.get(f)
r.raise_for_status()
if f.endswith(".csv"):
df = pd.read_csv(io.StringIO(r.text))
elif f.endswith(".xlsx") or f.endswith(".xls"):
df = pd.read_excel(io.BytesIO(r.content))
elif f.endswith(".json"):
df = pd.read_json(io.StringIO(r.text))
elif f.endswith(".tsv"):
df = pd.read_csv(io.StringIO(r.text), sep="\t")
elif f.endswith(".xml"):
df = pd.read_xml(io.StringIO(r.text))
else:
raise ValueError(f"Unsupported file type: {f}")
return df.dropna(how="all", axis=1).dropna(how="all", axis=0)
doc_meta = doc_url_to_metadata(f_url)
f_bytes, ext = download_content_bytes(f_url=f_url, mime_type=doc_meta.mime_type)

f = io.BytesIO(f_bytes)
match ext:
case ".csv":
df = pd.read_csv(f)
case ".tsv":
df = pd.read_csv(f, sep="\t")
case ".xls" | ".xlsx":
df = pd.read_excel(f)
case ".json":
df = pd.read_json(f)
case ".xml":
df = pd.read_xml(f)
case _:
raise ValueError(f"Unsupported file type: {f_url}")

return df.dropna(how="all", axis=1).dropna(how="all", axis=0).fillna("")

0 comments on commit a8e4501

Please sign in to comment.