From a6c09ec621d0c2202ad7604d937ef3e5a1b624f7 Mon Sep 17 00:00:00 2001 From: Roman Isecke <136338424+rbiseck3@users.noreply.github.com> Date: Fri, 14 Jun 2024 09:46:44 -0400 Subject: [PATCH] Roman/dry ingest pipeline step (#3203) ### Description The main goal of this was to reduce the duplicate code that was being written for each ingest pipeline step to support async and not async functionality. Additional bug fixes found and fixed: * each logger for ingest wasn't being instantiated correctly. This was fixed to instantiate in the beginning of a pipeline run as soon as the verbosity level can be determined. * The `requires_dependencies` wrapper wasn't wrapping async functions correctly. This was fixed so that `asyncio.iscoroutinefunction()` gets trigger correctly. --- CHANGELOG.md | 2 +- unstructured/__version__.py | 2 +- unstructured/ingest/v2/example.py | 4 +- unstructured/ingest/v2/logger.py | 6 ++- unstructured/ingest/v2/pipeline/interfaces.py | 22 ++++---- unstructured/ingest/v2/pipeline/pipeline.py | 4 +- .../ingest/v2/pipeline/steps/chunk.py | 30 ++++------- .../ingest/v2/pipeline/steps/download.py | 52 ++++++++----------- .../ingest/v2/pipeline/steps/embed.py | 29 ++++------- .../ingest/v2/pipeline/steps/partition.py | 33 ++++-------- .../ingest/v2/pipeline/steps/stage.py | 41 ++++++--------- .../ingest/v2/pipeline/steps/uncompress.py | 16 ++++-- .../ingest/v2/pipeline/steps/upload.py | 26 +++++----- unstructured/utils.py | 15 +++++- 14 files changed, 128 insertions(+), 154 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index da5e783c16..d0d6aaec76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## 0.14.6-dev6 +## 0.14.6-dev7 ### Enhancements diff --git a/unstructured/__version__.py b/unstructured/__version__.py index 45f2dd0f27..d1884901fb 100644 --- a/unstructured/__version__.py +++ b/unstructured/__version__.py @@ -1 +1 @@ -__version__ = "0.14.6-dev6" # pragma: no cover +__version__ = "0.14.6-dev7" # pragma: no cover diff --git a/unstructured/ingest/v2/example.py b/unstructured/ingest/v2/example.py index 02b47e135a..c4545f926d 100644 --- a/unstructured/ingest/v2/example.py +++ b/unstructured/ingest/v2/example.py @@ -24,7 +24,9 @@ if __name__ == "__main__": logger.info(f"Writing all content in: {work_dir.resolve()}") Pipeline.from_configs( - context=ProcessorConfig(work_dir=str(work_dir.resolve()), tqdm=True), + context=ProcessorConfig( + work_dir=str(work_dir.resolve()), tqdm=True, reprocess=True, verbose=True + ), indexer_config=S3IndexerConfig(remote_url="s3://utic-dev-tech-fixtures/small-pdf-set/"), downloader_config=S3DownloaderConfig(download_dir=download_path), source_connection_config=S3ConnectionConfig(anonymous=True), diff --git a/unstructured/ingest/v2/logger.py b/unstructured/ingest/v2/logger.py index b73196c9ca..34c5c1df3d 100644 --- a/unstructured/ingest/v2/logger.py +++ b/unstructured/ingest/v2/logger.py @@ -84,7 +84,8 @@ def redact_jsons(s: str) -> str: try: formatted_j = json.dumps(json.loads(j)) except json.JSONDecodeError: - formatted_j = json.dumps(ast.literal_eval(j)) + lit = ast.literal_eval(j) + formatted_j = json.dumps(lit) hidden_j = json.dumps(hide_sensitive_fields(json.loads(formatted_j))) s = s.replace(j, hidden_j) return s @@ -112,7 +113,8 @@ def make_default_logger(level: int) -> Logger: handler.name = "ingest_log_handler" formatter = SensitiveFormatter("%(asctime)s %(processName)-10s %(levelname)-8s %(message)s") handler.setFormatter(formatter) - logger.addHandler(handler) + if handler.name not in [h.name for h in logger.handlers]: + logger.addHandler(handler) logger.setLevel(level) remove_root_handlers(logger) return logger diff --git a/unstructured/ingest/v2/pipeline/interfaces.py b/unstructured/ingest/v2/pipeline/interfaces.py index 7afc347871..959799e7c9 100644 --- a/unstructured/ingest/v2/pipeline/interfaces.py +++ b/unstructured/ingest/v2/pipeline/interfaces.py @@ -6,13 +6,13 @@ from functools import wraps from pathlib import Path from time import time -from typing import Any, Optional, TypeVar +from typing import Any, Callable, Optional, TypeVar from tqdm import tqdm from tqdm.asyncio import tqdm as tqdm_asyncio from unstructured.ingest.v2.interfaces import BaseProcess, ProcessorConfig -from unstructured.ingest.v2.logger import logger +from unstructured.ingest.v2.logger import logger, make_default_logger BaseProcessT = TypeVar("BaseProcessT", bound=BaseProcess) iterable_input = list[dict[str, Any]] @@ -98,7 +98,7 @@ def _wrap_mp(self, input_kwargs: dict) -> Any: def _set_log_level(self, log_level: int): # Set the log level for each spawned process when using multiprocessing pool - logger.setLevel(log_level) + make_default_logger(log_level) @timed def __call__(self, iterable: Optional[iterable_input] = None) -> Any: @@ -113,15 +113,16 @@ def __call__(self, iterable: Optional[iterable_input] = None) -> Any: return self.process_async(iterable=iterable) return self.process_multiprocess(iterable=iterable) - def _run(self, *args, **kwargs: Any) -> Optional[Any]: - raise NotImplementedError + def _run(self, fn: Callable, **kwargs: Any) -> Optional[Any]: + return asyncio.run(self.run_async(_fn=fn, **kwargs)) - async def _run_async(self, *args, **kwargs: Any) -> Optional[Any]: + async def _run_async(self, fn: Callable, **kwargs: Any) -> Optional[Any]: raise NotImplementedError - def run(self, *args, **kwargs: Any) -> Optional[Any]: + def run(self, _fn: Optional[Callable] = None, **kwargs: Any) -> Optional[Any]: try: - return self._run(*args, **kwargs) + fn = _fn or self.process.run + return self._run(fn=fn, **kwargs) except Exception as e: logger.error(f"Exception raised while running {self.identifier}", exc_info=e) if "file_data_path" in kwargs: @@ -130,9 +131,10 @@ def run(self, *args, **kwargs: Any) -> Optional[Any]: raise e return None - async def run_async(self, *args, **kwargs: Any) -> Optional[Any]: + async def run_async(self, _fn: Optional[Callable] = None, **kwargs: Any) -> Optional[Any]: try: - return await self._run_async(*args, **kwargs) + fn = _fn or self.process.run_async + return await self._run_async(fn=fn, **kwargs) except Exception as e: logger.error(f"Exception raised while running {self.identifier}", exc_info=e) if "file_data_path" in kwargs: diff --git a/unstructured/ingest/v2/pipeline/pipeline.py b/unstructured/ingest/v2/pipeline/pipeline.py index f2a67460b3..1d089cc307 100644 --- a/unstructured/ingest/v2/pipeline/pipeline.py +++ b/unstructured/ingest/v2/pipeline/pipeline.py @@ -5,7 +5,7 @@ from typing import Any, Optional, Union from unstructured.ingest.v2.interfaces import ProcessorConfig -from unstructured.ingest.v2.logger import logger +from unstructured.ingest.v2.logger import logger, make_default_logger from unstructured.ingest.v2.pipeline.steps.chunk import Chunker, ChunkStep from unstructured.ingest.v2.pipeline.steps.download import DownloaderT, DownloadStep from unstructured.ingest.v2.pipeline.steps.embed import Embedder, EmbedStep @@ -59,7 +59,7 @@ def __post_init__( stager: UploadStager = None, uploader: Uploader = None, ): - logger.setLevel(level=logging.DEBUG if self.context.verbose else logging.INFO) + make_default_logger(level=logging.DEBUG if self.context.verbose else logging.INFO) self.indexer_step = IndexStep(process=indexer, context=self.context) self.downloader_step = DownloadStep(process=downloader, context=self.context) self.partitioner_step = PartitionStep(process=partitioner, context=self.context) diff --git a/unstructured/ingest/v2/pipeline/steps/chunk.py b/unstructured/ingest/v2/pipeline/steps/chunk.py index d8a4506b03..07eb680d7f 100644 --- a/unstructured/ingest/v2/pipeline/steps/chunk.py +++ b/unstructured/ingest/v2/pipeline/steps/chunk.py @@ -1,8 +1,9 @@ +import asyncio import hashlib import json from dataclasses import dataclass from pathlib import Path -from typing import Optional, TypedDict +from typing import Callable, Optional, TypedDict from unstructured.ingest.v2.interfaces import FileData from unstructured.ingest.v2.logger import logger @@ -53,32 +54,23 @@ def _save_output(self, output_filepath: str, chunked_content: list[dict]): logger.debug(f"Writing chunker output to: {output_filepath}") json.dump(chunked_content, f, indent=2) - def _run(self, path: str, file_data_path: str) -> ChunkStepResponse: + async def _run_async( + self, fn: Callable, path: str, file_data_path: str, **kwargs + ) -> ChunkStepResponse: path = Path(path) file_data = FileData.from_file(path=file_data_path) output_filepath = self.get_output_filepath(filename=path) if not self.should_chunk(filepath=output_filepath, file_data=file_data): logger.debug(f"Skipping chunking, output already exists: {output_filepath}") return ChunkStepResponse(file_data_path=file_data_path, path=str(output_filepath)) - chunked_content_raw = self.process.run(elements_filepath=path) - self._save_output( - output_filepath=str(output_filepath), - chunked_content=elements_to_dicts(chunked_content_raw), - ) - return ChunkStepResponse(file_data_path=file_data_path, path=str(output_filepath)) - - async def _run_async(self, path: str, file_data_path: str) -> ChunkStepResponse: - path = Path(path) - file_data = FileData.from_file(path=file_data_path) - output_filepath = self.get_output_filepath(filename=path) - if not self.should_chunk(filepath=output_filepath, file_data=file_data): - logger.debug(f"Skipping chunking, output already exists: {output_filepath}") - return ChunkStepResponse(file_data_path=file_data_path, path=str(output_filepath)) - if semaphore := self.context.semaphore: + fn_kwargs = {"elements_filepath": path} + if not asyncio.iscoroutinefunction(fn): + chunked_content_raw = fn(**fn_kwargs) + elif semaphore := self.context.semaphore: async with semaphore: - chunked_content_raw = await self.process.run_async(elements_filepath=path) + chunked_content_raw = await fn(**fn_kwargs) else: - chunked_content_raw = await self.process.run_async(elements_filepath=path) + chunked_content_raw = await fn(**fn_kwargs) self._save_output( output_filepath=str(output_filepath), chunked_content=elements_to_dicts(chunked_content_raw), diff --git a/unstructured/ingest/v2/pipeline/steps/download.py b/unstructured/ingest/v2/pipeline/steps/download.py index c5d08e1c0f..52e72fa4b4 100644 --- a/unstructured/ingest/v2/pipeline/steps/download.py +++ b/unstructured/ingest/v2/pipeline/steps/download.py @@ -1,7 +1,8 @@ +import asyncio import hashlib import json from dataclasses import dataclass -from typing import Optional, TypedDict, TypeVar +from typing import Callable, Optional, TypedDict, TypeVar from unstructured.ingest.v2.interfaces import FileData, download_responses from unstructured.ingest.v2.interfaces.downloader import Downloader @@ -55,7 +56,7 @@ def should_download(self, file_data: FileData, file_data_path: str) -> bool: if self.context.re_download: return True download_path = self.process.get_download_path(file_data=file_data) - if not download_path.exists(): + if not download_path or not download_path.exists(): return True if ( download_path.is_file() @@ -69,6 +70,24 @@ def should_download(self, file_data: FileData, file_data_path: str) -> bool: return True return False + async def _run_async(self, fn: Callable, file_data_path: str) -> list[DownloadStepResponse]: + file_data = FileData.from_file(path=file_data_path) + download_path = self.process.get_download_path(file_data=file_data) + if not self.should_download(file_data=file_data, file_data_path=file_data_path): + logger.debug(f"Skipping download, file already exists locally: {download_path}") + return [DownloadStepResponse(file_data_path=file_data_path, path=str(download_path))] + fn_kwargs = {"file_data": file_data} + if not asyncio.iscoroutinefunction(fn): + download_results = fn(**fn_kwargs) + elif semaphore := self.context.semaphore: + async with semaphore: + download_results = await fn(**fn_kwargs) + else: + download_results = await fn(**fn_kwargs) + return self.create_step_results( + current_file_data_path=file_data_path, download_results=download_results + ) + def create_step_results( self, current_file_data_path: str, download_results: download_responses ) -> list[DownloadStepResponse]: @@ -87,35 +106,6 @@ def create_step_results( ) return download_step_results - def _run(self, file_data_path: str) -> list[DownloadStepResponse]: - file_data = FileData.from_file(path=file_data_path) - download_path = self.process.get_download_path(file_data=file_data) - if not self.should_download(file_data=file_data, file_data_path=file_data_path): - logger.debug(f"Skipping download, file already exists locally: {download_path}") - return [DownloadStepResponse(file_data_path=file_data_path, path=str(download_path))] - - download_results = self.process.run(file_data=file_data) - return self.create_step_results( - current_file_data_path=file_data_path, download_results=download_results - ) - - async def _run_async(self, file_data_path: str) -> list[DownloadStepResponse]: - file_data = FileData.from_file(path=file_data_path) - download_path = self.process.get_download_path(file_data=file_data) - if download_path and not self.should_download( - file_data=file_data, file_data_path=file_data_path - ): - logger.debug(f"Skipping download, file already exists locally: {download_path}") - return [DownloadStepResponse(file_data_path=file_data_path, path=str(download_path))] - if semaphore := self.context.semaphore: - async with semaphore: - download_results = await self.process.run_async(file_data=file_data) - else: - download_results = await self.process.run_async(file_data=file_data) - return self.create_step_results( - current_file_data_path=file_data_path, download_results=download_results - ) - def persist_new_file_data(self, file_data: FileData) -> str: record_hash = self.get_hash(extras=[file_data.identifier]) filename = f"{record_hash}.json" diff --git a/unstructured/ingest/v2/pipeline/steps/embed.py b/unstructured/ingest/v2/pipeline/steps/embed.py index 7dcb94ae4e..3503a1af5d 100644 --- a/unstructured/ingest/v2/pipeline/steps/embed.py +++ b/unstructured/ingest/v2/pipeline/steps/embed.py @@ -1,8 +1,9 @@ +import asyncio import hashlib import json from dataclasses import dataclass from pathlib import Path -from typing import Optional, TypedDict +from typing import Callable, Optional, TypedDict from unstructured.ingest.v2.interfaces import FileData from unstructured.ingest.v2.logger import logger @@ -53,33 +54,21 @@ def _save_output(self, output_filepath: str, embedded_content: list[dict]): logger.debug(f"Writing embedded output to: {output_filepath}") json.dump(embedded_content, f, indent=2) - def _run(self, path: str, file_data_path: str) -> EmbedStepResponse: - path = Path(path) - file_data = FileData.from_file(path=file_data_path) - - output_filepath = self.get_output_filepath(filename=path) - if not self.should_embed(filepath=output_filepath, file_data=file_data): - logger.debug(f"Skipping embedding, output already exists: {output_filepath}") - return EmbedStepResponse(file_data_path=file_data_path, path=str(output_filepath)) - embed_content_raw = self.process.run(elements_filepath=path) - self._save_output( - output_filepath=str(output_filepath), - embedded_content=elements_to_dicts(embed_content_raw), - ) - return EmbedStepResponse(file_data_path=file_data_path, path=str(output_filepath)) - - async def _run_async(self, path: str, file_data_path: str) -> EmbedStepResponse: + async def _run_async(self, fn: Callable, path: str, file_data_path: str) -> EmbedStepResponse: path = Path(path) file_data = FileData.from_file(path=file_data_path) output_filepath = self.get_output_filepath(filename=path) if not self.should_embed(filepath=output_filepath, file_data=file_data): logger.debug(f"Skipping embedding, output already exists: {output_filepath}") return EmbedStepResponse(file_data_path=file_data_path, path=str(output_filepath)) - if semaphore := self.context.semaphore: + fn_kwargs = {"elements_filepath": path} + if not asyncio.iscoroutinefunction(fn): + embed_content_raw = fn(**fn_kwargs) + elif semaphore := self.context.semaphore: async with semaphore: - embed_content_raw = await self.process.run_async(elements_filepath=path) + embed_content_raw = await fn(**fn_kwargs) else: - embed_content_raw = await self.process.run_async(elements_filepath=path) + embed_content_raw = await fn(**fn_kwargs) self._save_output( output_filepath=str(output_filepath), diff --git a/unstructured/ingest/v2/pipeline/steps/partition.py b/unstructured/ingest/v2/pipeline/steps/partition.py index f6c2f3c9e4..bb35624515 100644 --- a/unstructured/ingest/v2/pipeline/steps/partition.py +++ b/unstructured/ingest/v2/pipeline/steps/partition.py @@ -1,8 +1,9 @@ +import asyncio import hashlib import json from dataclasses import dataclass from pathlib import Path -from typing import Optional, TypedDict +from typing import Callable, Optional, TypedDict from unstructured.ingest.v2.interfaces import FileData from unstructured.ingest.v2.logger import logger @@ -48,35 +49,23 @@ def _save_output(self, output_filepath: str, partitioned_content: list[dict]): logger.debug(f"Writing partitioned output to: {output_filepath}") json.dump(partitioned_content, f, indent=2) - def _run(self, path: str, file_data_path: str) -> PartitionStepResponse: + async def _run_async( + self, fn: Callable, path: str, file_data_path: str + ) -> PartitionStepResponse: path = Path(path) file_data = FileData.from_file(path=file_data_path) output_filepath = self.get_output_filepath(filename=Path(file_data_path)) if not self.should_partition(filepath=output_filepath, file_data=file_data): logger.debug(f"Skipping partitioning, output already exists: {output_filepath}") return PartitionStepResponse(file_data_path=file_data_path, path=str(output_filepath)) - partitioned_content = self.process.run(filename=path, metadata=file_data.metadata) - self._save_output( - output_filepath=str(output_filepath), partitioned_content=partitioned_content - ) - return PartitionStepResponse(file_data_path=file_data_path, path=str(output_filepath)) - - async def _run_async(self, path: str, file_data_path: str) -> PartitionStepResponse: - path = Path(path) - file_data = FileData.from_file(path=file_data_path) - output_filepath = self.get_output_filepath(filename=Path(file_data_path)) - if not self.should_partition(filepath=output_filepath, file_data=file_data): - logger.debug(f"Skipping partitioning, output already exists: {output_filepath}") - return PartitionStepResponse(file_data_path=file_data_path, path=str(output_filepath)) - if semaphore := self.context.semaphore: + fn_kwargs = {"filename": path, "metadata": file_data.metadata} + if not asyncio.iscoroutinefunction(fn): + partitioned_content = fn(**fn_kwargs) + elif semaphore := self.context.semaphore: async with semaphore: - partitioned_content = await self.process.run_async( - filename=path, metadata=file_data.metadata - ) + partitioned_content = await fn(**fn_kwargs) else: - partitioned_content = await self.process.run_async( - filename=path, metadata=file_data.metadata - ) + partitioned_content = await fn(**fn_kwargs) self._save_output( output_filepath=str(output_filepath), partitioned_content=partitioned_content ) diff --git a/unstructured/ingest/v2/pipeline/steps/stage.py b/unstructured/ingest/v2/pipeline/steps/stage.py index 59bbe90c16..b4c6204ad4 100644 --- a/unstructured/ingest/v2/pipeline/steps/stage.py +++ b/unstructured/ingest/v2/pipeline/steps/stage.py @@ -1,8 +1,9 @@ +import asyncio import hashlib import json from dataclasses import dataclass from pathlib import Path -from typing import Optional, TypedDict +from typing import Callable, Optional, TypedDict from unstructured.ingest.v2.interfaces.file_data import FileData from unstructured.ingest.v2.interfaces.upload_stager import UploadStager @@ -35,33 +36,23 @@ def __post_init__(self): self.cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Created {self.identifier} with configs: {config}") - def _run(self, path: str, file_data_path: str) -> UploadStageStepResponse: + async def _run_async( + self, fn: Callable, path: str, file_data_path: str + ) -> UploadStageStepResponse: path = Path(path) - staged_output_path = self.process.run( - elements_filepath=path, - file_data=FileData.from_file(path=file_data_path), - output_dir=self.cache_dir, - output_filename=self.get_hash(extras=[path.name]), - ) - return UploadStageStepResponse(file_data_path=file_data_path, path=str(staged_output_path)) - - async def _run_async(self, path: str, file_data_path: str) -> UploadStageStepResponse: - path = Path(path) - if semaphore := self.context.semaphore: + fn_kwargs = { + "elements_filepath": path, + "file_data": FileData.from_file(path=file_data_path), + "output_dir": self.cache_dir, + "output_filename": self.get_hash(extras=[path.name]), + } + if not asyncio.iscoroutinefunction(fn): + staged_output_path = fn(**fn_kwargs) + elif semaphore := self.context.semaphore: async with semaphore: - staged_output_path = await self.process.run_async( - elements_filepath=path, - file_data=FileData.from_file(path=file_data_path), - output_dir=self.cache_dir, - output_filename=self.get_hash(extras=[path.name]), - ) + staged_output_path = await fn(**fn_kwargs) else: - staged_output_path = await self.process.run_async( - elements_filepath=path, - file_data=FileData.from_file(path=file_data_path), - output_dir=self.cache_dir, - output_filename=self.get_hash(extras=[path.name]), - ) + staged_output_path = await fn(**fn_kwargs) return UploadStageStepResponse(file_data_path=file_data_path, path=str(staged_output_path)) def get_hash(self, extras: Optional[list[str]]) -> str: diff --git a/unstructured/ingest/v2/pipeline/steps/uncompress.py b/unstructured/ingest/v2/pipeline/steps/uncompress.py index 77fda8c99b..987c9d5f64 100644 --- a/unstructured/ingest/v2/pipeline/steps/uncompress.py +++ b/unstructured/ingest/v2/pipeline/steps/uncompress.py @@ -1,5 +1,6 @@ +import asyncio from pathlib import Path -from typing import TypedDict +from typing import Callable, TypedDict from unstructured.ingest.v2.interfaces.file_data import FileData from unstructured.ingest.v2.logger import logger @@ -42,13 +43,18 @@ def _run(self, path: str, file_data_path: str) -> list[UncompressStepResponse]: ) return responses - async def _run_async(self, path: str, file_data_path: str) -> list[UncompressStepResponse]: + async def _run_async( + self, fn: Callable, path: str, file_data_path: str + ) -> list[UncompressStepResponse]: file_data = FileData.from_file(path=file_data_path) - if semaphore := self.context.semaphore: + fn_kwargs = {"file_data": file_data} + if not asyncio.iscoroutinefunction(fn): + new_file_data = fn(**fn_kwargs) + elif semaphore := self.context.semaphore: async with semaphore: - new_file_data = await self.process.run_async(file_data=file_data) + new_file_data = await fn(**fn_kwargs) else: - new_file_data = await self.process.run_async(file_data=file_data) + new_file_data = await fn(**fn_kwargs) responses = [] for new_file in new_file_data: new_file_data_path = Path(file_data_path).parent / f"{new_file.identifier}.json" diff --git a/unstructured/ingest/v2/pipeline/steps/upload.py b/unstructured/ingest/v2/pipeline/steps/upload.py index dd438bb45a..25540c9524 100644 --- a/unstructured/ingest/v2/pipeline/steps/upload.py +++ b/unstructured/ingest/v2/pipeline/steps/upload.py @@ -1,7 +1,7 @@ import asyncio from dataclasses import dataclass from pathlib import Path -from typing import TypedDict +from typing import Callable, Optional, TypedDict from unstructured.ingest.v2.interfaces import FileData from unstructured.ingest.v2.interfaces.uploader import UploadContent, Uploader @@ -42,7 +42,7 @@ def __post_init__(self): ) def process_whole(self, iterable: iterable_input): - self.run(iterable) + self.run(contents=iterable) async def _process_async(self, iterable: iterable_input): return await asyncio.gather(*[self.run_async(**i) for i in iterable]) @@ -60,20 +60,20 @@ def __call__(self, iterable: iterable_input): else: self.process_whole(iterable=iterable) - def _run(self, contents: list[UploadStepContent]): + def _run(self, fn: Callable, contents: list[UploadStepContent]): upload_contents = [ UploadContent(path=Path(c["path"]), file_data=FileData.from_file(c["file_data_path"])) for c in contents ] - self.process.run(contents=upload_contents) + fn(contents=upload_contents) - async def _run_async(self, path: str, file_data_path: str): - if semaphore := self.context.semaphore: - with semaphore: - await self.process.run_async( - path=Path(path), file_data=FileData.from_file(path=file_data_path) - ) + async def _run_async(self, path: str, file_data_path: str, fn: Optional[Callable] = None): + fn = fn or self.process.run_async + fn_kwargs = {"path": Path(path), "file_data": FileData.from_file(path=file_data_path)} + if not asyncio.iscoroutinefunction(fn): + fn(**fn_kwargs) + elif semaphore := self.context.semaphore: + async with semaphore: + await fn(**fn_kwargs) else: - await self.process.run_async( - path=Path(path), file_data=FileData.from_file(path=file_data_path) - ) + await fn(**fn_kwargs) diff --git a/unstructured/utils.py b/unstructured/utils.py index 84f1c52100..55fecc319c 100644 --- a/unstructured/utils.py +++ b/unstructured/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import functools import html import importlib @@ -227,8 +228,7 @@ def requires_dependencies( dependencies = [dependencies] def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]: - @wraps(func) - def wrapper(*args: _P.args, **kwargs: _P.kwargs): + def run_check(): missing_deps: List[str] = [] for dep in dependencies: if not dependency_exists(dep): @@ -242,8 +242,19 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs): else f"Please install them using `pip install {' '.join(missing_deps)}`." ), ) + + @wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs): + run_check() return func(*args, **kwargs) + @wraps(func) + async def wrapper_async(*args: _P.args, **kwargs: _P.kwargs): + run_check() + return await func(*args, **kwargs) + + if asyncio.iscoroutinefunction(func): + return wrapper_async return wrapper return decorator