From 8b96176d8a3246f66480a1e6d26af1e68f5a29d6 Mon Sep 17 00:00:00 2001 From: Sourabh Desai Date: Fri, 28 Jun 2024 23:31:35 -0700 Subject: [PATCH] allow for bytes or buffer as input (#259) * allow for bytes or buffer as input * format & readme update * lint --- README.md | 28 ++++++++++++++ llama_parse/base.py | 92 +++++++++++++++++++++++++++++--------------- tests/test_reader.py | 50 ++++++++++++++++++++---- 3 files changed, 132 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 9cb86f2..a94993b 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,34 @@ documents = await parser.aload_data("./my_file.pdf") documents = await parser.aload_data(["./my_file1.pdf", "./my_file2.pdf"]) ``` +## Using with file object + +You can parse a file object directly: + +```python +import nest_asyncio + +nest_asyncio.apply() + +from llama_parse import LlamaParse + +parser = LlamaParse( + api_key="llx-...", # can also be set in your env as LLAMA_CLOUD_API_KEY + result_type="markdown", # "markdown" and "text" are available + num_workers=4, # if multiple files passed, split in `num_workers` API calls + verbose=True, + language="en", # Optionally you can define a language, default=en +) + +with open("./my_file1.pdf", "rb") as f: + documents = parser.load_data(f) + +# you can also pass file bytes directly +with open("./my_file1.pdf", "rb") as f: + file_bytes = f.read() + documents = parser.load_data(file_bytes) +``` + ## Using with `SimpleDirectoryReader` You can also integrate the parser as the default PDF loader in `SimpleDirectoryReader`: diff --git a/llama_parse/base.py b/llama_parse/base.py index e53a7db..27b43c8 100644 --- a/llama_parse/base.py +++ b/llama_parse/base.py @@ -5,6 +5,7 @@ import time from pathlib import Path from typing import List, Optional, Union +from io import BufferedIOBase from llama_index.core.async_utils import run_jobs from llama_index.core.bridge.pydantic import Field, validator @@ -20,6 +21,10 @@ ) from copy import deepcopy +# can put in a path to the file or the file bytes itself +# if passing as bytes or a buffer, must provide the file_name in extra_info +FileInput = Union[str, bytes, BufferedIOBase] + def _get_sub_docs(docs: List[Document]) -> List[Document]: """Split docs into pages, by separator.""" @@ -143,28 +148,39 @@ def validate_base_url(cls, v: str) -> str: # upload a document and get back a job_id async def _create_job( - self, file_path: str, extra_info: Optional[dict] = None + self, file_input: FileInput, extra_info: Optional[dict] = None ) -> str: - file_path = str(file_path) - file_ext = os.path.splitext(file_path)[1] - if file_ext not in SUPPORTED_FILE_TYPES: - raise Exception( - f"Currently, only the following file types are supported: {SUPPORTED_FILE_TYPES}\n" - f"Current file type: {file_ext}" - ) - - extra_info = extra_info or {} - extra_info["file_path"] = file_path - headers = {"Authorization": f"Bearer {self.api_key}"} - - # load data, set the mime type - with open(file_path, "rb") as f: + url = f"{self.base_url}/api/parsing/upload" + files = None + file_handle = None + + if isinstance(file_input, (bytes, BufferedIOBase)): + if not extra_info or "file_name" not in extra_info: + raise ValueError( + "file_name must be provided in extra_info when passing bytes" + ) + file_name = extra_info["file_name"] + mime_type = mimetypes.guess_type(file_name)[0] + files = {"file": (file_name, file_input, mime_type)} + elif isinstance(file_input, str): + file_path = str(file_input) + file_ext = os.path.splitext(file_path)[1] + if file_ext not in SUPPORTED_FILE_TYPES: + raise Exception( + f"Currently, only the following file types are supported: {SUPPORTED_FILE_TYPES}\n" + f"Current file type: {file_ext}" + ) mime_type = mimetypes.guess_type(file_path)[0] - files = {"file": (f.name, f, mime_type)} + # Open the file here for the duration of the async context + file_handle = open(file_path, "rb") + files = {"file": (os.path.basename(file_path), file_handle, mime_type)} + else: + raise ValueError( + "file_input must be either a file path string, file bytes, or buffer object" + ) - # send the request, start job - url = f"{self.base_url}/api/parsing/upload" + try: async with httpx.AsyncClient(timeout=self.max_timeout) as client: response = await client.post( url, @@ -187,10 +203,11 @@ async def _create_job( ) if not response.is_success: raise Exception(f"Failed to parse the file: {response.text}") - - # check the status of the job, return when done - job_id = response.json()["id"] - return job_id + job_id = response.json()["id"] + return job_id + finally: + if file_handle is not None: + file_handle.close() async def _get_job_result( self, job_id: str, result_type: str, verbose: bool = False @@ -240,7 +257,10 @@ async def _get_job_result( ) async def _aload_data( - self, file_path: str, extra_info: Optional[dict] = None, verbose: bool = False + self, + file_path: FileInput, + extra_info: Optional[dict] = None, + verbose: bool = False, ) -> List[Document]: """Load data from the input path.""" try: @@ -264,17 +284,20 @@ async def _aload_data( return docs except Exception as e: - print(f"Error while parsing the file '{file_path}':", e) + file_repr = file_path if isinstance(file_path, str) else "" + print(f"Error while parsing the file '{file_repr}':", e) if self.ignore_errors: return [] else: raise e async def aload_data( - self, file_path: Union[List[str], str], extra_info: Optional[dict] = None + self, + file_path: Union[List[FileInput], FileInput], + extra_info: Optional[dict] = None, ) -> List[Document]: """Load data from the input path.""" - if isinstance(file_path, (str, Path)): + if isinstance(file_path, (str, Path, bytes, BufferedIOBase)): return await self._aload_data( file_path, extra_info=extra_info, verbose=self.verbose ) @@ -308,7 +331,9 @@ async def aload_data( ) def load_data( - self, file_path: Union[List[str], str], extra_info: Optional[dict] = None + self, + file_path: Union[List[FileInput], FileInput], + extra_info: Optional[dict] = None, ) -> List[Document]: """Load data from the input path.""" try: @@ -320,7 +345,7 @@ def load_data( raise e async def _aget_json( - self, file_path: str, extra_info: Optional[dict] = None + self, file_path: FileInput, extra_info: Optional[dict] = None ) -> List[dict]: """Load data from the input path.""" try: @@ -334,14 +359,17 @@ async def _aget_json( return [result] except Exception as e: - print(f"Error while parsing the file '{file_path}':", e) + file_repr = file_path if isinstance(file_path, str) else "" + print(f"Error while parsing the file '{file_repr}':", e) if self.ignore_errors: return [] else: raise e async def aget_json( - self, file_path: Union[List[str], str], extra_info: Optional[dict] = None + self, + file_path: Union[List[FileInput], FileInput], + extra_info: Optional[dict] = None, ) -> List[dict]: """Load data from the input path.""" if isinstance(file_path, (str, Path)): @@ -369,7 +397,9 @@ async def aget_json( ) def get_json_result( - self, file_path: Union[List[str], str], extra_info: Optional[dict] = None + self, + file_path: Union[List[FileInput], FileInput], + extra_info: Optional[dict] = None, ) -> List[dict]: """Parse the input path.""" try: diff --git a/tests/test_reader.py b/tests/test_reader.py index 2736e99..5ebc2ed 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -18,21 +18,57 @@ def test_simple_page_text() -> None: assert len(result[0].text) > 0 -@pytest.mark.skipif( - os.environ.get("LLAMA_CLOUD_API_KEY", "") == "", - reason="LLAMA_CLOUD_API_KEY not set", -) -def test_simple_page_markdown() -> None: - parser = LlamaParse(result_type="markdown") +@pytest.fixture +def markdown_parser() -> LlamaParse: + if os.environ.get("LLAMA_CLOUD_API_KEY", "") == "": + pytest.skip("LLAMA_CLOUD_API_KEY not set") + return LlamaParse(result_type="markdown", ignore_errors=False) + +def test_simple_page_markdown(markdown_parser: LlamaParse) -> None: filepath = os.path.join( os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf" ) - result = parser.load_data(filepath) + result = markdown_parser.load_data(filepath) assert len(result) == 1 assert len(result[0].text) > 0 +def test_simple_page_markdown_bytes(markdown_parser: LlamaParse) -> None: + markdown_parser = LlamaParse(result_type="markdown", ignore_errors=False) + + filepath = os.path.join( + os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf" + ) + with open(filepath, "rb") as f: + file_bytes = f.read() + # client must provide extra_info with file_name + with pytest.raises(ValueError): + result = markdown_parser.load_data(file_bytes) + result = markdown_parser.load_data( + file_bytes, extra_info={"file_name": "attention_is_all_you_need.pdf"} + ) + assert len(result) == 1 + assert len(result[0].text) > 0 + + +def test_simple_page_markdown_buffer(markdown_parser: LlamaParse) -> None: + markdown_parser = LlamaParse(result_type="markdown", ignore_errors=False) + + filepath = os.path.join( + os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf" + ) + with open(filepath, "rb") as f: + # client must provide extra_info with file_name + with pytest.raises(ValueError): + result = markdown_parser.load_data(f) + result = markdown_parser.load_data( + f, extra_info={"file_name": "attention_is_all_you_need.pdf"} + ) + assert len(result) == 1 + assert len(result[0].text) > 0 + + @pytest.mark.skipif( os.environ.get("LLAMA_CLOUD_API_KEY", "") == "", reason="LLAMA_CLOUD_API_KEY not set",