From e7b3d33d9ff7cbb5995fef3d66bdb2279ffa12fb Mon Sep 17 00:00:00 2001 From: John lee Date: Tue, 27 Aug 2024 13:07:11 +0000 Subject: [PATCH] switch stray components to bytes --- osm/pipeline/core.py | 6 +++--- osm/pipeline/parsers.py | 7 +++++-- osm/pipeline/savers.py | 4 ++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/osm/pipeline/core.py b/osm/pipeline/core.py index 5fd70f8c..7f2d43d6 100644 --- a/osm/pipeline/core.py +++ b/osm/pipeline/core.py @@ -15,7 +15,7 @@ def __init__(self, version: str = "0.0.1"): self._orm_model = None @abstractmethod - def _run(self, data: bytes, **kwargs) -> Any: + def _run(self, data: bytes|dict, **kwargs) -> Any: """Abstract method that subclasses must implement.""" pass @@ -64,7 +64,7 @@ def __iter__(self): yield self.json_saver yield self.osm_saver - def save_file(self, data: str, path: Path): + def save_file(self, data: bytes, path: Path): self.file_saver.run(data, path=path) def save_json(self, data: dict, path: Path): @@ -102,7 +102,7 @@ def __init__( def run(self): for parser in self.parsers: parsed_data = parser.run(self.file_data) - if isinstance(parsed_data, str): + if isinstance(parsed_data, bytes): self.savers.save_file(parsed_data, self.xml_path) for extractor in self.extractors: extracted_metrics = extractor.run(parsed_data, parser=parser.name) diff --git a/osm/pipeline/parsers.py b/osm/pipeline/parsers.py index fd718604..b06a3977 100644 --- a/osm/pipeline/parsers.py +++ b/osm/pipeline/parsers.py @@ -3,6 +3,7 @@ from osm.schemas.custom_fields import LongBytes from .core import Component +import io SCIENCEBEAM_URL = "http://localhost:8070/api/convert" @@ -30,8 +31,10 @@ class ScienceBeamParser(Component): def _run(self, data: bytes) -> str: self.sample = LongBytes(data) headers = {"Accept": "application/tei+xml", "Content-Type": "application/pdf"} - response = requests.post(SCIENCEBEAM_URL, data=data, headers=headers) + files = {'file': ('input.pdf', io.BytesIO(data), 'application/pdf')} + + response = requests.post(SCIENCEBEAM_URL, files=files, headers=headers) if response.status_code == 200: - return response.text + return response.content else: response.raise_for_status() diff --git a/osm/pipeline/savers.py b/osm/pipeline/savers.py index b19880b0..fc7a3b12 100644 --- a/osm/pipeline/savers.py +++ b/osm/pipeline/savers.py @@ -27,14 +27,14 @@ def format_error_message() -> str: class FileSaver(Component): """Basic saver that writes data to a file.""" - def _run(self, data: str, path: Path): + def _run(self, data: bytes, path: Path): """Write data to a file. Args: data (str): Some data. path (Path): A file path. """ - path.write_text(data) + path.write_bytes(data) logger.info(f"Data saved to {path}")