Skip to content

Commit

Permalink
allow for bytes or buffer as input (#259)
Browse files Browse the repository at this point in the history
* allow for bytes or buffer as input

* format & readme update

* lint
  • Loading branch information
sourabhdesai authored Jun 29, 2024
1 parent 1bbf5f4 commit 8b96176
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 38 deletions.
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,34 @@ documents = await parser.aload_data("./my_file.pdf")
documents = await parser.aload_data(["./my_file1.pdf", "./my_file2.pdf"])
```

## Using with file object

You can parse a file object directly:

```python
import nest_asyncio

nest_asyncio.apply()

from llama_parse import LlamaParse

parser = LlamaParse(
api_key="llx-...", # can also be set in your env as LLAMA_CLOUD_API_KEY
result_type="markdown", # "markdown" and "text" are available
num_workers=4, # if multiple files passed, split in `num_workers` API calls
verbose=True,
language="en", # Optionally you can define a language, default=en
)

with open("./my_file1.pdf", "rb") as f:
documents = parser.load_data(f)

# you can also pass file bytes directly
with open("./my_file1.pdf", "rb") as f:
file_bytes = f.read()
documents = parser.load_data(file_bytes)
```

## Using with `SimpleDirectoryReader`

You can also integrate the parser as the default PDF loader in `SimpleDirectoryReader`:
Expand Down
92 changes: 61 additions & 31 deletions llama_parse/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
from pathlib import Path
from typing import List, Optional, Union
from io import BufferedIOBase

from llama_index.core.async_utils import run_jobs
from llama_index.core.bridge.pydantic import Field, validator
Expand All @@ -20,6 +21,10 @@
)
from copy import deepcopy

# can put in a path to the file or the file bytes itself
# if passing as bytes or a buffer, must provide the file_name in extra_info
FileInput = Union[str, bytes, BufferedIOBase]


def _get_sub_docs(docs: List[Document]) -> List[Document]:
"""Split docs into pages, by separator."""
Expand Down Expand Up @@ -143,28 +148,39 @@ def validate_base_url(cls, v: str) -> str:

# upload a document and get back a job_id
async def _create_job(
self, file_path: str, extra_info: Optional[dict] = None
self, file_input: FileInput, extra_info: Optional[dict] = None
) -> str:
file_path = str(file_path)
file_ext = os.path.splitext(file_path)[1]
if file_ext not in SUPPORTED_FILE_TYPES:
raise Exception(
f"Currently, only the following file types are supported: {SUPPORTED_FILE_TYPES}\n"
f"Current file type: {file_ext}"
)

extra_info = extra_info or {}
extra_info["file_path"] = file_path

headers = {"Authorization": f"Bearer {self.api_key}"}

# load data, set the mime type
with open(file_path, "rb") as f:
url = f"{self.base_url}/api/parsing/upload"
files = None
file_handle = None

if isinstance(file_input, (bytes, BufferedIOBase)):
if not extra_info or "file_name" not in extra_info:
raise ValueError(
"file_name must be provided in extra_info when passing bytes"
)
file_name = extra_info["file_name"]
mime_type = mimetypes.guess_type(file_name)[0]
files = {"file": (file_name, file_input, mime_type)}
elif isinstance(file_input, str):
file_path = str(file_input)
file_ext = os.path.splitext(file_path)[1]
if file_ext not in SUPPORTED_FILE_TYPES:
raise Exception(
f"Currently, only the following file types are supported: {SUPPORTED_FILE_TYPES}\n"
f"Current file type: {file_ext}"
)
mime_type = mimetypes.guess_type(file_path)[0]
files = {"file": (f.name, f, mime_type)}
# Open the file here for the duration of the async context
file_handle = open(file_path, "rb")
files = {"file": (os.path.basename(file_path), file_handle, mime_type)}
else:
raise ValueError(
"file_input must be either a file path string, file bytes, or buffer object"
)

# send the request, start job
url = f"{self.base_url}/api/parsing/upload"
try:
async with httpx.AsyncClient(timeout=self.max_timeout) as client:
response = await client.post(
url,
Expand All @@ -187,10 +203,11 @@ async def _create_job(
)
if not response.is_success:
raise Exception(f"Failed to parse the file: {response.text}")

# check the status of the job, return when done
job_id = response.json()["id"]
return job_id
job_id = response.json()["id"]
return job_id
finally:
if file_handle is not None:
file_handle.close()

async def _get_job_result(
self, job_id: str, result_type: str, verbose: bool = False
Expand Down Expand Up @@ -240,7 +257,10 @@ async def _get_job_result(
)

async def _aload_data(
self, file_path: str, extra_info: Optional[dict] = None, verbose: bool = False
self,
file_path: FileInput,
extra_info: Optional[dict] = None,
verbose: bool = False,
) -> List[Document]:
"""Load data from the input path."""
try:
Expand All @@ -264,17 +284,20 @@ async def _aload_data(
return docs

except Exception as e:
print(f"Error while parsing the file '{file_path}':", e)
file_repr = file_path if isinstance(file_path, str) else "<bytes/buffer>"
print(f"Error while parsing the file '{file_repr}':", e)
if self.ignore_errors:
return []
else:
raise e

async def aload_data(
self, file_path: Union[List[str], str], extra_info: Optional[dict] = None
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""Load data from the input path."""
if isinstance(file_path, (str, Path)):
if isinstance(file_path, (str, Path, bytes, BufferedIOBase)):
return await self._aload_data(
file_path, extra_info=extra_info, verbose=self.verbose
)
Expand Down Expand Up @@ -308,7 +331,9 @@ async def aload_data(
)

def load_data(
self, file_path: Union[List[str], str], extra_info: Optional[dict] = None
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
) -> List[Document]:
"""Load data from the input path."""
try:
Expand All @@ -320,7 +345,7 @@ def load_data(
raise e

async def _aget_json(
self, file_path: str, extra_info: Optional[dict] = None
self, file_path: FileInput, extra_info: Optional[dict] = None
) -> List[dict]:
"""Load data from the input path."""
try:
Expand All @@ -334,14 +359,17 @@ async def _aget_json(
return [result]

except Exception as e:
print(f"Error while parsing the file '{file_path}':", e)
file_repr = file_path if isinstance(file_path, str) else "<bytes/buffer>"
print(f"Error while parsing the file '{file_repr}':", e)
if self.ignore_errors:
return []
else:
raise e

async def aget_json(
self, file_path: Union[List[str], str], extra_info: Optional[dict] = None
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
) -> List[dict]:
"""Load data from the input path."""
if isinstance(file_path, (str, Path)):
Expand Down Expand Up @@ -369,7 +397,9 @@ async def aget_json(
)

def get_json_result(
self, file_path: Union[List[str], str], extra_info: Optional[dict] = None
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
) -> List[dict]:
"""Parse the input path."""
try:
Expand Down
50 changes: 43 additions & 7 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,57 @@ def test_simple_page_text() -> None:
assert len(result[0].text) > 0


@pytest.mark.skipif(
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
reason="LLAMA_CLOUD_API_KEY not set",
)
def test_simple_page_markdown() -> None:
parser = LlamaParse(result_type="markdown")
@pytest.fixture
def markdown_parser() -> LlamaParse:
if os.environ.get("LLAMA_CLOUD_API_KEY", "") == "":
pytest.skip("LLAMA_CLOUD_API_KEY not set")
return LlamaParse(result_type="markdown", ignore_errors=False)


def test_simple_page_markdown(markdown_parser: LlamaParse) -> None:
filepath = os.path.join(
os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf"
)
result = parser.load_data(filepath)
result = markdown_parser.load_data(filepath)
assert len(result) == 1
assert len(result[0].text) > 0


def test_simple_page_markdown_bytes(markdown_parser: LlamaParse) -> None:
markdown_parser = LlamaParse(result_type="markdown", ignore_errors=False)

filepath = os.path.join(
os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf"
)
with open(filepath, "rb") as f:
file_bytes = f.read()
# client must provide extra_info with file_name
with pytest.raises(ValueError):
result = markdown_parser.load_data(file_bytes)
result = markdown_parser.load_data(
file_bytes, extra_info={"file_name": "attention_is_all_you_need.pdf"}
)
assert len(result) == 1
assert len(result[0].text) > 0


def test_simple_page_markdown_buffer(markdown_parser: LlamaParse) -> None:
markdown_parser = LlamaParse(result_type="markdown", ignore_errors=False)

filepath = os.path.join(
os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf"
)
with open(filepath, "rb") as f:
# client must provide extra_info with file_name
with pytest.raises(ValueError):
result = markdown_parser.load_data(f)
result = markdown_parser.load_data(
f, extra_info={"file_name": "attention_is_all_you_need.pdf"}
)
assert len(result) == 1
assert len(result[0].text) > 0


@pytest.mark.skipif(
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
reason="LLAMA_CLOUD_API_KEY not set",
Expand Down

0 comments on commit 8b96176

Please sign in to comment.