Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow for bytes or buffer as input #259

Merged
merged 3 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 40 additions & 31 deletions llama_parse/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
from pathlib import Path
from typing import List, Optional, Union
from io import BufferedIOBase

from llama_index.core.async_utils import run_jobs
from llama_index.core.bridge.pydantic import Field, validator
Expand All @@ -20,6 +21,10 @@
)
from copy import deepcopy

# can put in a path to the file or the file bytes itself
# if passing as bytes or a buffer, must provide the file_name in extra_info
FileInput = Union[str, bytes, BufferedIOBase]


def _get_sub_docs(docs: List[Document]) -> List[Document]:
"""Split docs into pages, by separator."""
Expand Down Expand Up @@ -142,29 +147,32 @@ def validate_base_url(cls, v: str) -> str:
return url or v or DEFAULT_BASE_URL

# upload a document and get back a job_id
async def _create_job(
self, file_path: str, extra_info: Optional[dict] = None
) -> str:
file_path = str(file_path)
file_ext = os.path.splitext(file_path)[1]
if file_ext not in SUPPORTED_FILE_TYPES:
raise Exception(
f"Currently, only the following file types are supported: {SUPPORTED_FILE_TYPES}\n"
f"Current file type: {file_ext}"
)

extra_info = extra_info or {}
extra_info["file_path"] = file_path

async def _create_job(self, file_input: FileInput, extra_info: Optional[dict] = None) -> str:
headers = {"Authorization": f"Bearer {self.api_key}"}

# load data, set the mime type
with open(file_path, "rb") as f:
url = f"{self.base_url}/api/parsing/upload"
files = None
file_handle = None

if isinstance(file_input, (bytes, BufferedIOBase)):
if not extra_info or "file_name" not in extra_info:
raise ValueError("file_name must be provided in extra_info when passing bytes")
file_name = extra_info["file_name"]
mime_type = mimetypes.guess_type(file_name)[0]
files = {"file": (file_name, file_input, mime_type)}
elif isinstance(file_input, str):
file_path = str(file_input)
file_ext = os.path.splitext(file_path)[1]
if file_ext not in SUPPORTED_FILE_TYPES:
raise Exception(f"Currently, only the following file types are supported: {SUPPORTED_FILE_TYPES}\n"
f"Current file type: {file_ext}")
mime_type = mimetypes.guess_type(file_path)[0]
files = {"file": (f.name, f, mime_type)}
# Open the file here for the duration of the async context
file_handle = open(file_path, 'rb')
files = {"file": (os.path.basename(file_path), file_handle, mime_type)}
else:
raise ValueError("file_input must be either a file path string, file bytes, or buffer object")

# send the request, start job
url = f"{self.base_url}/api/parsing/upload"
try:
async with httpx.AsyncClient(timeout=self.max_timeout) as client:
response = await client.post(
url,
Expand All @@ -187,10 +195,11 @@ async def _create_job(
)
if not response.is_success:
raise Exception(f"Failed to parse the file: {response.text}")

# check the status of the job, return when done
job_id = response.json()["id"]
return job_id
job_id = response.json()["id"]
return job_id
finally:
if file_handle is not None:
file_handle.close()

async def _get_job_result(
self, job_id: str, result_type: str, verbose: bool = False
Expand Down Expand Up @@ -240,7 +249,7 @@ async def _get_job_result(
)

async def _aload_data(
self, file_path: str, extra_info: Optional[dict] = None, verbose: bool = False
self, file_path: FileInput, extra_info: Optional[dict] = None, verbose: bool = False
) -> List[Document]:
"""Load data from the input path."""
try:
Expand Down Expand Up @@ -271,10 +280,10 @@ async def _aload_data(
raise e

async def aload_data(
self, file_path: Union[List[str], str], extra_info: Optional[dict] = None
self, file_path: Union[List[FileInput], FileInput], extra_info: Optional[dict] = None
) -> List[Document]:
"""Load data from the input path."""
if isinstance(file_path, (str, Path)):
if isinstance(file_path, (str, Path, bytes, BufferedIOBase)):
return await self._aload_data(
file_path, extra_info=extra_info, verbose=self.verbose
)
Expand Down Expand Up @@ -308,7 +317,7 @@ async def aload_data(
)

def load_data(
self, file_path: Union[List[str], str], extra_info: Optional[dict] = None
self, file_path: Union[List[FileInput], FileInput], extra_info: Optional[dict] = None
) -> List[Document]:
"""Load data from the input path."""
try:
Expand All @@ -320,7 +329,7 @@ def load_data(
raise e

async def _aget_json(
self, file_path: str, extra_info: Optional[dict] = None
self, file_path: FileInput, extra_info: Optional[dict] = None
) -> List[dict]:
"""Load data from the input path."""
try:
Expand All @@ -341,7 +350,7 @@ async def _aget_json(
raise e

async def aget_json(
self, file_path: Union[List[str], str], extra_info: Optional[dict] = None
self, file_path: Union[List[FileInput], FileInput], extra_info: Optional[dict] = None
) -> List[dict]:
"""Load data from the input path."""
if isinstance(file_path, (str, Path)):
Expand Down Expand Up @@ -369,7 +378,7 @@ async def aget_json(
)

def get_json_result(
self, file_path: Union[List[str], str], extra_info: Optional[dict] = None
self, file_path: Union[List[FileInput], FileInput], extra_info: Optional[dict] = None
) -> List[dict]:
"""Parse the input path."""
try:
Expand Down
43 changes: 36 additions & 7 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,50 @@ def test_simple_page_text() -> None:
assert len(result) == 1
assert len(result[0].text) > 0

@pytest.fixture
def markdown_parser() -> LlamaParse:
if os.environ.get("LLAMA_CLOUD_API_KEY", "") == "":
pytest.skip("LLAMA_CLOUD_API_KEY not set")
return LlamaParse(result_type="markdown", ignore_errors=False)

@pytest.mark.skipif(
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
reason="LLAMA_CLOUD_API_KEY not set",
)
def test_simple_page_markdown() -> None:
parser = LlamaParse(result_type="markdown")

def test_simple_page_markdown(markdown_parser: LlamaParse) -> None:
filepath = os.path.join(
os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf"
)
result = parser.load_data(filepath)
result = markdown_parser.load_data(filepath)
assert len(result) == 1
assert len(result[0].text) > 0

def test_simple_page_markdown_bytes(markdown_parser: LlamaParse) -> None:
markdown_parser = LlamaParse(result_type="markdown", ignore_errors=False)

filepath = os.path.join(
os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf"
)
with open(filepath, "rb") as f:
file_bytes = f.read()
# client must provide extra_info with file_name
with pytest.raises(ValueError):
result = markdown_parser.load_data(file_bytes)
result = markdown_parser.load_data(file_bytes, extra_info={"file_name": "attention_is_all_you_need.pdf"})
assert len(result) == 1
assert len(result[0].text) > 0

def test_simple_page_markdown_buffer(markdown_parser: LlamaParse) -> None:
markdown_parser = LlamaParse(result_type="markdown", ignore_errors=False)

filepath = os.path.join(
os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf"
)
with open(filepath, "rb") as f:
# client must provide extra_info with file_name
with pytest.raises(ValueError):
result = markdown_parser.load_data(f)
result = markdown_parser.load_data(f, extra_info={"file_name": "attention_is_all_you_need.pdf"})
assert len(result) == 1
assert len(result[0].text) > 0


@pytest.mark.skipif(
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
Expand Down
Loading