diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 108b6be4..2acfdfdf 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -31,6 +31,13 @@ jobs:
python-version: ${{ matrix.python-version }}
architecture: x64
+ - name: Cache HuggingFace Models
+ uses: actions/cache@v2
+ id: cache-huggingface
+ with:
+ path: ~/.cache/huggingface/
+ key: ${{ matrix.python-version }}-huggingface
+
- name: Install hatch
run: pip install hatch
diff --git a/.gitignore b/.gitignore
index bc92f9af..60358cc5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -63,7 +63,7 @@ report.xml
*.pickle
*.joblib
*.pdf
-data/
+/data/
# MkDocs output
docs/reference
diff --git a/changelog.md b/changelog.md
index 3583237e..e6a5d2ca 100644
--- a/changelog.md
+++ b/changelog.md
@@ -1,5 +1,33 @@
# Changelog
+
+
+## v0.9.0
+
+### Added
+
+- New unified `edspdf.data` api (pdf files, pandas, parquet) and LazyCollection object
+ to efficiently read / write data from / to different formats & sources. This API is
+ has been heavily inspired by the `edsnlp.data` API.
+- New unified processing API to select the execution backend via `data.set_processing(...)`
+ to replace the old `accelerators` API (which is now deprecated, but still available).
+- `huggingface-embedding` now supports quantization and other `AutoModel.from_pretrained` kwargs
+- It is now possible to add convert a label to multiple labels in the `simple-aggregator` component :
+
+```ini
+# To build the "text" field, we will aggregate "title", "body" and "table" lines,
+# and output "title" lines in a separate field as well.
+label_map = {
+ "text" : [ "title", "body", "table" ],
+ "title": "title",
+ }
+```
+
+### Fixed
+
+- `huggingface-embedding` now resize bbox features for large PDFs, instead of making the model crash
+- `huggingface-embedding` and `sub-box-cnn-pooler` now handle empty PDFs correctly
+
## v0.8.1
### Fixed
diff --git a/docs/assets/images/multiprocessing.png b/docs/assets/images/multiprocessing.png
new file mode 100644
index 00000000..f6d54762
Binary files /dev/null and b/docs/assets/images/multiprocessing.png differ
diff --git a/docs/assets/images/multiprocessing.svg b/docs/assets/images/multiprocessing.svg
deleted file mode 100644
index 594b0d04..00000000
--- a/docs/assets/images/multiprocessing.svg
+++ /dev/null
@@ -1,3 +0,0 @@
-
-
-
diff --git a/docs/index.md b/docs/index.md
index 52b46d0d..735b6a32 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -99,12 +99,10 @@ See the [rule-based recipe](recipes/rule-based.md) for a step-by-step explanatio
If you use EDS-PDF, please cite us as below.
```bibtex
-@software{edspdf,
- author = {Dura, Basile and Wajsburt, Perceval and Calliger, Alice and GĂ©rardin, Christel and Bey, Romain},
- doi = {10.5281/zenodo.6902977},
- license = {BSD-3-Clause},
- title = {{EDS-PDF: Smart text extraction from PDF documents}},
- url = {https://github.com/aphp/edspdf}
+@article{gerardin_wajsburt_pdf,
+ title={Bridging Clinical PDFs and Downstream Natural Language Processing: An Efficient Neural Approach to Layout Segmentation},
+ author={G{\'e}rardin, Christel Ducroz and Wajsburt, Perceval and Dura, Basile and Calliger, Alice and Mouchet, Alexandre and Tannier, Xavier and Bey, Romain},
+ journal={Available at SSRN 4587624}
}
```
diff --git a/docs/inference.md b/docs/inference.md
index d849ba86..9f7c325b 100644
--- a/docs/inference.md
+++ b/docs/inference.md
@@ -1,61 +1,124 @@
# Inference
-Once you have obtained a pipeline, either by composing rule-based components, training a model or loading a model from the disk, you can use it to make predictions on documents. This is referred to as inference.
+Once you have obtained a pipeline, either by composing rule-based components, training a model or loading a model from the disk, you can use it to make predictions on documents. This is referred to as inference. This page answers the following questions :
+
+> How do we leverage computational resources run a model on many documents?
+
+> How do we connect to various data sources to retrieve documents?
+
## Inference on a single document
-In EDS-PDF, computing the prediction on a single document is done by calling the pipeline on the document. The input can be either:
+In EDS-model, computing the prediction on a single document is done by calling the pipeline on the document. The input can be either:
-- a sequence of bytes
-- or a [PDFDoc][edspdf.structures.PDFDoc] object
+- a bytes string
+- or a [PDFDoc](https://spacy.io/api/doc) object
-```python
+```{ .python .no-check }
from pathlib import Path
-pipeline = ...
-content = Path("path/to/.pdf").read_bytes()
-doc = pipeline(content)
+model = ...
+pdf_bytes = b"..."
+doc = model(pdf_bytes)
```
-If you're lucky enough to have a GPU, you can use it to speed up inference by moving the model to the GPU before calling the pipeline. To leverage multiple GPUs, refer to the [multiprocessing accelerator][edspdf.accelerators.multiprocessing.MultiprocessingAccelerator] description below.
+If you're lucky enough to have a GPU, you can use it to speed up inference by moving the model to the GPU before calling the pipeline.
-```python
-pipeline.to("cuda") # same semantics as pytorch
-doc = pipeline(content)
+```{ .python .no-check }
+model.to("cuda") # same semantics as pytorch
+doc = model(pdf_bytes)
```
-## Inference on multiple documents
+To leverage multiple GPUs when processing multiple documents, refer to the [multiprocessing backend][edspdf.processing.multiprocessing.execute_multiprocessing_backend] description below.
+
+## Inference on multiple documents {: #edspdf.lazy_collection.LazyCollection }
+
+When processing multiple documents, we can optimize the inference by parallelizing the computation on a single core, multiple cores and GPUs or even multiple machines.
+
+### Lazy collection
+
+These optimizations are enabled by performing *lazy inference* : the operations (e.g., reading a document, converting it to a PDFDoc, running the different pipes of a model or writing the result somewhere) are not executed immediately but are instead scheduled in a [LazyCollection][edspdf.lazy_collection.LazyCollection] object. It can then be executed by calling the `execute` method, iterating over it or calling a writing method (e.g., `to_pandas`). In fact, data connectors like `edspdf.data.read_files` return a lazy collection, as well as the `model.pipe` method.
+
+A lazy collection contains :
+
+- a `reader`: the source of the data (e.g., a file, a database, a list of strings, etc.)
+- the list of operations to perform under a `pipeline` attribute containing the name if any, function / pipe, keyword arguments and context for each operation
+- an optional `writer`: the destination of the data (e.g., a file, a database, a list of strings, etc.)
+- the execution `config`, containing the backend to use and its configuration such as the number of workers, the batch size, etc.
+
+All methods (`.map`, `.map_pipeline`, `.set_processing`) of the lazy collection are chainable, meaning that they return a new object (no in-place modification).
+
+For instance, the following code will load a model, read a folder of JSON files, apply the model to each document and write the result in a Parquet folder, using 4 CPUs and 2 GPUs.
-When processing multiple documents, it is usually more efficient to use the `pipeline.pipe(...)` method, especially when using deep learning components, since this allow matrix multiplications to be batched together. Depending on your computational resources and requirements, EDS-PDF comes with various "accelerators" to speed up inference (see the [Accelerators](#accelerators) section for more details). By default, the `.pipe()` method uses the [`simple` accelerator][edspdf.accelerators.simple.SimpleAccelerator] but you can switch to a different one by passing the `accelerator` argument.
+```{ .python .no-check }
+import edspdf
-```python
-pipeline = ...
-docs = pipeline.pipe(
- [content1, content2, ...],
- batch_size=16, # optional, default to the one defined in the pipeline
- accelerator=my_accelerator,
+# Load or create a model, for instance following the "Recipes"
+model = edspdf.load("path/to/model")
+
+# Read some data (this is lazy, no data will be read until the end of of this snippet)
+data = edspdf.data.read_files(
+ "/Users/perceval/Development/edspdf/tests/resources/",
+ # dict to doc converter function
+ converter=lambda x: PDFDoc(id=x["id"], content=x["content"]),
+)
+
+# Apply each pipe of the model to our documents
+data = data.map_pipeline(model)
+# or equivalently : data = model.pipe(data)
+
+# Configure the execution
+data = data.set_processing(
+ # 4 CPUs to parallelize rule-based pipes, IO and preprocessing
+ num_cpu_workers=4,
+ # 2 GPUs to accelerate deep-learning pipes
+ num_gpu_workers=2,
+)
+
+# Write the result, this will execute the lazy collection
+data.write_parquet(
+ "path/to/output_folder",
+ # doc to dict converter function
+ converter=lambda doc: {
+ "id": doc.id,
+ "text": (
+ doc.aggregated_texts["body"].text
+ if "body" in doc.aggregated_texts
+ else ""
+ ),
+ },
)
```
-The `pipe` method supports the following arguments :
+### Applying operations to a lazy collection
+
+To apply an operation to a lazy collection, you can use the `.map` method. It takes a callable as input and an optional dictionary of keyword arguments. The function will be applied to each element of the collection.
+
+To apply a model, you can use the `.map_pipeline` method. It takes a model as input and will add every pipe of the model to the scheduled operations.
+
+In both cases, the operations will not be executed immediately but will be scheduled to be executed when iterating of the collection, or calling the `.execute`, `.to_*` or `.write_*` methods.
+
+### Execution of a lazy collection {: #edspdf.lazy_collection.LazyCollection.set_processing }
+
+You can configure how the operations performed in the lazy collection is executed by calling its `set_processing(...)` method. The following options are available :
-::: edspdf.pipeline.Pipeline.pipe
+::: edspdf.lazy_collection.LazyCollection.set_processing
options:
heading_level: 3
- only_parameters: true
+ only_parameters: "no-header"
-## Accelerators
+## Backends
-### Simple accelerator {: #edspdf.accelerators.simple.SimpleAccelerator }
+### Simple backend {: #edspdf.processing.simple.execute_simple_backend }
-::: edspdf.accelerators.simple.SimpleAccelerator
+::: edspdf.processing.simple.execute_simple_backend
options:
heading_level: 3
- only_class_level: true
+ show_source: false
-### Multiprocessing accelerator {: #edspdf.accelerators.multiprocessing.MultiprocessingAccelerator }
+### Multiprocessing backend {: #edspdf.processing.multiprocessing.execute_multiprocessing_backend }
-::: edspdf.accelerators.multiprocessing.MultiprocessingAccelerator
+::: edspdf.processing.multiprocessing.execute_multiprocessing_backend
options:
heading_level: 3
- only_class_level: true
+ show_source: false
diff --git a/docs/scripts/plugin.py b/docs/scripts/plugin.py
index 22e7c1bd..b020a225 100644
--- a/docs/scripts/plugin.py
+++ b/docs/scripts/plugin.py
@@ -7,6 +7,8 @@
import mkdocs.structure.files
import mkdocs.structure.nav
import mkdocs.structure.pages
+import regex
+from mkdocs_autorefs.plugin import AutorefsPlugin
try:
from importlib.metadata import entry_points
@@ -128,9 +130,13 @@ def on_page_read_source(page, config):
return None
-HREF_REGEX = r'href=(?:"([^"]*)"|\'([^\']*)|[ ]*([^ =>]*)(?![a-z]+=))'
+HREF_REGEX = (
+ r"(?<=<\s*(?:a[^>]*href|img[^>]*src)=)"
+ r'(?:"([^"]*)"|\'([^\']*)|[ ]*([^ =>]*)(?![a-z]+=))'
+)
+
+
# Maybe find something less specific ?
-PIPE_REGEX = r"(?<=[^a-zA-Z0-9._-])eds[.][a-zA-Z0-9._-]*(?=[^a-zA-Z0-9._-])"
@mkdocs.plugins.event_priority(-1000)
@@ -155,100 +161,57 @@ def on_post_page(
"""
- autorefs = config["plugins"]["autorefs"]
- edspdf_factories_entry_points = {
- ep.name: ep.value for ep in entry_points()["edspdf_factories"]
+ autorefs: AutorefsPlugin = config["plugins"]["autorefs"]
+ factories_entry_points = {
+ ep.name: autorefs.get_item_url(ep.value.replace(":", "."))
+ for ep in entry_points()["edspdf_factories"]
+ }
+ factories_entry_points = {
+ k: "/" + v if not v.startswith("/") else v
+ for k, v in factories_entry_points.items()
}
+ factories_entry_points.update(
+ {
+ "mupdf-extractor": "https://aphp.github.io/edspdf-mupdf/latest/",
+ "poppler-extractor": "https://aphp.github.io/edspdf-poppler/latest/",
+ }
+ )
- def get_component_url(name):
- ep = edspdf_factories_entry_points.get(name)
- if ep is None:
- return None
- try:
- url = autorefs.get_item_url(ep.replace(":", "."))
- except KeyError:
- pass
- else:
- return url
- return None
+ PIPE_REGEX_BASE = "|".join(regex.escape(name) for name in factories_entry_points)
+ PIPE_REGEX = f"""(?x)
+ (?<=")({PIPE_REGEX_BASE})(?=")
+ |(?<=")({PIPE_REGEX_BASE})(?=")
+ |(?<=')({PIPE_REGEX_BASE})(?=')
+ |(?<=)({PIPE_REGEX_BASE})(?=
)
+ """
- def get_relative_link(url):
+ def replace_component(match):
+ name = match.group()
+ preceding = output[match.start(0) - 50 : match.start(0)]
+ if (
+ "DEFAULT:"
+ not in preceding
+ # and output[: match.start(0)].count("")
+ # > output[match.end(0) :].count("
")
+ ):
+ try:
+ ep_url = factories_entry_points[name]
+ except KeyError:
+ pass
+ else:
+ if ep_url.split("#")[0].strip("/") != page.file.url.strip("/"):
+ return "{name}".format(href=ep_url, name=name)
+ return name
+
+ def replace_link(match):
+ relative_url = url = match.group(1) or match.group(2) or match.group(3)
page_url = os.path.join("/", page.file.url)
if url.startswith("/"):
- url = os.path.relpath(url, page_url)
- return url
-
- def replace_component_span(span):
- content = span.text
- if content is None:
- return
- link_url = get_component_url(content.strip("\"'"))
- if link_url is None:
- return
- a = etree.Element("a", href="/" + link_url)
- a.text = content
- span.text = ""
- span.append(a)
-
- def replace_component_names(root):
- # Iterate through all span elements
- spans = list(root.iter("span", "code"))
- for i, span in enumerate(spans):
- prev = span.getprevious()
- if span.getparent().tag == "a":
- continue
- # To avoid replacing default component name in parameter tables
- if prev is None or prev.text != "DEFAULT:":
- replace_component_span(span)
- # if span.text == "add_pipe":
- # next_span = span.getnext()
- # if next_span is None:
- # continue
- # next_span = next_span.getnext()
- # if next_span is None or next_span.tag != "span":
- # continue
- # replace_component_span(next_span)
- # continue
- # tokens = ["@", "factory", "="]
- # while True:
- # if len(tokens) == 0:
- # break
- # if span.text != tokens[0]:
- # break
- # tokens = tokens[1:]
- # span = span.getnext()
- # while span is not None and (
- # span.text is None or not span.text.strip()
- # ):
- # span = span.getnext()
- # if len(tokens) == 0:
- # replace_component_span(span)
-
- # Convert the modified tree back to a string
- return root
-
- def replace_absolute_links(root):
- # Iterate through all a elements
- for a in root.iter("a"):
- href = a.get("href")
- if href is None or href.startswith("http"):
- continue
- a.set("href", get_relative_link(href))
- for img in root.iter("img"):
- href = img.get("src")
- if href is None or href.startswith("http"):
- continue
- img.set("src", get_relative_link(href))
-
- # Convert the modified tree back to a string
- return root
+ relative_url = os.path.relpath(url, page_url)
+ return f'"{relative_url}"'
# Replace absolute paths with path relative to the rendered page
- from lxml.html import etree
-
- root = etree.HTML(output)
- root = replace_component_names(root)
- root = replace_absolute_links(root)
- doctype = root.getroottree().docinfo.doctype
- res = etree.tostring(root, encoding="unicode", method="html", doctype=doctype)
- return res
+ output = regex.sub(PIPE_REGEX, replace_component, output)
+ output = regex.sub(HREF_REGEX, replace_link, output)
+
+ return output
diff --git a/docs/trainable-pipes.md b/docs/trainable-pipes.md
index b9dc02c8..ae6f37b0 100644
--- a/docs/trainable-pipes.md
+++ b/docs/trainable-pipes.md
@@ -97,12 +97,12 @@ class MyComponent(TrainablePipe):
"my-feature": ...(doc),
}
- def collate(self, batch, device: torch.device) -> Dict:
+ def collate(self, batch) -> Dict:
# Collate the features of the "embedding" subcomponent
# and the features of this component as well
return {
- "embedding": self.embedding.collate(batch["embedding"], device),
- "my-feature": torch.as_tensor(batch["my-feature"], device=device),
+ "embedding": self.embedding.collate(batch["embedding"]),
+ "my-feature": torch.as_tensor(batch["my-feature"]),
}
def forward(self, batch: Dict, supervision=False) -> Dict:
diff --git a/edspdf/__init__.py b/edspdf/__init__.py
index 9d5ed750..1ab0d6f1 100644
--- a/edspdf/__init__.py
+++ b/edspdf/__init__.py
@@ -3,7 +3,8 @@
from .pipeline import Pipeline, load
from .registry import registry
from .structures import Box, Page, PDFDoc, Text, TextBox, TextProperties
+from . import data
from . import utils # isort:skip
-__version__ = "0.8.1"
+__version__ = "0.9.0"
diff --git a/edspdf/accelerators/base.py b/edspdf/accelerators/base.py
index 07bc01b2..2360377d 100644
--- a/edspdf/accelerators/base.py
+++ b/edspdf/accelerators/base.py
@@ -1,97 +1,2 @@
-from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Union
-
-from ..structures import PDFDoc
-
-
-class FromDictFieldsToDoc:
- def __init__(self, content_field, id_field=None):
- self.content_field = content_field
- self.id_field = id_field
-
- def __call__(self, item):
- if isinstance(item, dict):
- return PDFDoc(
- content=item[self.content_field],
- id=item[self.id_field] if self.id_field else None,
- )
- return item
-
-
-class ToDoc:
- @classmethod
- def __get_validators__(cls):
- yield cls.validate
-
- @classmethod
- def validate(cls, value, config=None):
- if isinstance(value, str):
- value = {"content_field": value}
- if isinstance(value, dict):
- value = FromDictFieldsToDoc(**value)
- if callable(value):
- return value
- raise TypeError(
- f"Invalid entry {value} ({type(value)}) for ToDoc, "
- f"expected string, a dict or a callable."
- )
-
-
-FROM_DOC_TO_DICT_FIELDS_TEMPLATE = """
-def fn(doc):
- return {X}
-"""
-
-
-class FromDocToDictFields:
- def __init__(self, mapping):
- self.mapping = mapping
- dict_fields = ", ".join(f"{repr(k)}: doc.{v}" for k, v in mapping.items())
- local_vars = {}
- exec(FROM_DOC_TO_DICT_FIELDS_TEMPLATE.replace("X", dict_fields), local_vars)
- self.fn = local_vars["fn"]
-
- def __reduce__(self):
- return FromDocToDictFields, (self.mapping,)
-
- def __call__(self, doc):
- return self.fn(doc)
-
-
-class FromDoc:
- """
- A FromDoc converter (from a PDFDoc to an arbitrary type) can be either:
-
- - a dict mapping field names to doc attributes
- - a callable that takes a PDFDoc and returns an arbitrary type
- """
-
- @classmethod
- def __get_validators__(cls):
- yield cls.validate
-
- @classmethod
- def validate(cls, value, config=None):
- if isinstance(value, dict):
- value = FromDocToDictFields(value)
- if callable(value):
- return value
- raise TypeError(
- f"Invalid entry {value} ({type(value)}) for ToDoc, "
- f"expected dict or callable"
- )
-
-
class Accelerator:
- def __call__(
- self,
- inputs: Iterable[Any],
- model: Any,
- to_doc: ToDoc = FromDictFieldsToDoc("content"),
- from_doc: FromDoc = lambda doc: doc,
- ):
- raise NotImplementedError()
-
-
-if TYPE_CHECKING:
- ToDoc = Union[str, Dict[str, Any], Callable[[Any], PDFDoc]] # noqa: F811
- FromDoc = Union[Dict[str, Any], Callable[[PDFDoc], Any]] # noqa: F811
+ pass
diff --git a/edspdf/accelerators/multiprocessing.py b/edspdf/accelerators/multiprocessing.py
index 41ef2f80..b11d556e 100644
--- a/edspdf/accelerators/multiprocessing.py
+++ b/edspdf/accelerators/multiprocessing.py
@@ -1,338 +1,15 @@
-import gc
-import signal
-from multiprocessing.connection import wait
-from random import shuffle
-from typing import Any, Iterable, List, Optional, Union
+from typing import List, Optional, Union
import torch
-import torch.multiprocessing as mp
-from .. import TrainablePipe
from ..registry import registry
-from ..utils.collections import batchify
-from .base import Accelerator, FromDictFieldsToDoc, FromDoc, ToDoc
-
-DEBUG = False
-
-debug = (
- (lambda *args, flush=False, **kwargs: print(*args, **kwargs, flush=True))
- if DEBUG
- else lambda *args, **kwargs: None
-)
-
-
-class Exchanger:
- def __init__(
- self,
- num_stages,
- num_gpu_workers,
- num_cpu_workers,
- gpu_worker_devices,
- ):
- # queue for cpu input tasks
- self.gpu_worker_devices = gpu_worker_devices
- # We add prioritized queue at the end for STOP signals
- self.cpu_inputs_queues = [
- [mp.SimpleQueue()] + [mp.SimpleQueue() for _ in range(num_stages + 1)]
- # The input queue is not shared between processes, since calling `wait`
- # on a queue reader from multiple processes may lead to a deadlock
- for _ in range(num_cpu_workers)
- ]
- self.gpu_inputs_queues = [
- [mp.SimpleQueue() for _ in range(num_stages + 1)]
- for _ in range(num_gpu_workers)
- ]
- self.outputs_queue = mp.Queue()
-
- def get_cpu_tasks(self, idx):
- while True:
- queue_readers = wait(
- [queue._reader for queue in self.cpu_inputs_queues[idx]]
- )
- stage, queue = next(
- (stage, q)
- for stage, q in reversed(list(enumerate(self.cpu_inputs_queues[idx])))
- if q._reader in queue_readers
- )
- try:
- item = queue.get()
- except BaseException:
- continue
- if item is None:
- return
- yield stage, item
-
- def put_cpu(self, item, stage, idx):
- return self.cpu_inputs_queues[idx][stage].put(item)
-
- def get_gpu_tasks(self, idx):
- while True:
- queue_readers = wait(
- [queue._reader for queue in self.gpu_inputs_queues[idx]]
- )
- stage, queue = next(
- (stage, q)
- for stage, q in reversed(list(enumerate(self.gpu_inputs_queues[idx])))
- if q._reader in queue_readers
- )
- try:
- item = queue.get()
- except BaseException: # pragma: no cover
- continue
- if item is None:
- return
- yield stage, item
-
- def put_gpu(self, item, stage, idx):
- return self.gpu_inputs_queues[idx][stage].put(item)
-
- def put_results(self, items):
- self.outputs_queue.put(items)
-
- def iter_results(self):
- for out in iter(self.outputs_queue.get, None):
- yield out
-
-
-class CPUWorker(mp.Process):
- def __init__(
- self,
- cpu_idx: int,
- exchanger: Exchanger,
- gpu_pipe_names: List[str],
- model: Any,
- device: Union[str, torch.device],
- ):
- super(CPUWorker, self).__init__()
-
- self.cpu_idx = cpu_idx
- self.exchanger = exchanger
- self.gpu_pipe_names = gpu_pipe_names
- self.model = model
- self.device = device
-
- def _run(self):
- # Cannot pass torch tensor during init i think ? otherwise i get
- # ValueError: bad value(s) in fds_to_keep
- mp._prctl_pr_set_pdeathsig(signal.SIGINT)
-
- model = self.model.to(self.device)
- stages = [{"cpu_components": [], "gpu_component": None}]
- for name, component in model.pipeline:
- if name in self.gpu_pipe_names:
- stages[-1]["gpu_component"] = component
- stages.append({"cpu_components": [], "gpu_component": None})
- else:
- stages[-1]["cpu_components"].append(component)
-
- next_batch_id = 0
- active_batches = {}
- debug(
- f"CPU worker {self.cpu_idx} is ready",
- next(model.parameters()).device,
- flush=True,
- )
-
- had_error = False
- with torch.no_grad():
- for stage, task in self.exchanger.get_cpu_tasks(self.cpu_idx):
- if had_error:
- continue # pragma: no cover
- try:
- if stage == 0:
- gpu_idx = None
- batch_id = next_batch_id
- debug("preprocess start for", batch_id)
- next_batch_id += 1
- docs = task
- else:
- gpu_idx, batch_id, result = task
- debug("postprocess start for", batch_id)
- docs = active_batches.pop(batch_id)
- gpu_pipe = stages[stage - 1]["gpu_component"]
- docs = gpu_pipe.postprocess(docs, result) # type: ignore
-
- for component in stages[stage]["cpu_components"]:
- if hasattr(component, "batch_process"):
- docs = component.batch_process(docs)
- else:
- docs = [component(doc) for doc in docs]
-
- gpu_pipe = stages[stage]["gpu_component"]
- if gpu_pipe is not None:
- preprocessed = gpu_pipe.make_batch(docs) # type: ignore
- active_batches[batch_id] = docs
- if gpu_idx is None:
- gpu_idx = batch_id % len(self.exchanger.gpu_worker_devices)
- collated = gpu_pipe.collate( # type: ignore
- preprocessed,
- device=self.exchanger.gpu_worker_devices[gpu_idx],
- )
- self.exchanger.put_gpu(
- item=(self.cpu_idx, batch_id, collated),
- idx=gpu_idx,
- stage=stage,
- )
- batch_id += 1
- debug("preprocess end for", batch_id)
- else:
- self.exchanger.put_results((docs, self.cpu_idx, gpu_idx))
- debug("postprocess end for", batch_id)
- except BaseException as e:
- had_error = True
- import traceback
-
- print(traceback.format_exc(), flush=True)
- self.exchanger.put_results((e, self.cpu_idx, None))
- # We need to drain the queues of GPUWorker fed inputs (pre-moved to GPU)
- # to ensure no tensor allocated on producer processes (CPUWorker via
- # collate) are left in consumer processes
- debug("Start draining CPU worker", self.cpu_idx)
- [None for _ in self.exchanger.get_cpu_tasks(self.cpu_idx)]
- debug(f"CPU worker {self.cpu_idx} is about to stop")
-
- def run(self):
- self._run()
- self.model = None
- gc.collect()
- torch.cuda.empty_cache()
-
-
-class GPUWorker(mp.Process):
- def __init__(
- self,
- gpu_idx,
- exchanger: Exchanger,
- gpu_pipe_names: List[str],
- model: Any,
- device: Union[str, torch.device],
- ):
- super().__init__()
-
- self.device = device
- self.gpu_idx = gpu_idx
- self.exchanger = exchanger
-
- self.gpu_pipe_names = gpu_pipe_names
- self.model = model
- self.device = device
-
- def _run(self):
- debug("GPU worker", self.gpu_idx, "started")
- mp._prctl_pr_set_pdeathsig(signal.SIGINT)
- had_error = False
-
- model = self.model.to(self.device)
- stage_components = [model.get_pipe(name) for name in self.gpu_pipe_names]
- del model
- with torch.no_grad():
- for stage, task in self.exchanger.get_gpu_tasks(self.gpu_idx):
- if had_error:
- continue # pragma: no cover
- try:
- cpu_idx, batch_id, batch = task
- debug("forward start for", batch_id)
- component = stage_components[stage]
- res = component.module_forward(batch)
- del batch, task
- # TODO set non_blocking=True here
- res = {
- key: val.to("cpu") if not isinstance(val, int) else val
- for key, val in res.items()
- }
- self.exchanger.put_cpu(
- item=(self.gpu_idx, batch_id, res),
- stage=stage + 1,
- idx=cpu_idx,
- )
- debug("forward end for", batch_id)
- except BaseException as e:
- had_error = True
- self.exchanger.put_results((e, None, self.gpu_idx))
- import traceback
-
- print(traceback.format_exc(), flush=True)
- task = batch = res = None # noqa
- # We need to drain the queues of CPUWorker fed inputs (pre-moved to GPU)
- # to ensure no tensor allocated on producer processes (CPUWorker via
- # collate) are left in consumer processes
- debug("Start draining GPU worker", self.gpu_idx)
- [None for _ in self.exchanger.get_gpu_tasks(self.gpu_idx)]
- debug(f"GPU worker {self.gpu_idx} is about to stop")
-
- def run(self):
- self._run()
- self.model = None
- gc.collect()
- torch.cuda.empty_cache()
-
-
-DEFAULT_MAX_CPU_WORKERS = 4
+from .base import Accelerator
@registry.accelerator.register("multiprocessing")
class MultiprocessingAccelerator(Accelerator):
"""
- If you have multiple CPU cores, and optionally multiple GPUs, we provide a
- `multiprocessing` accelerator that allows to run the inference on multiple
- processes.
-
- This accelerator dispatches the batches between multiple workers
- (data-parallelism), and distribute the computation of a given batch on one or two
- workers (model-parallelism). This is done by creating two types of workers:
-
- - a `CPUWorker` which handles the non deep-learning components and the
- preprocessing, collating and postprocessing of deep-learning components
- - a `GPUWorker` which handles the forward call of the deep-learning components
-
- The advantage of dedicating a worker to the deep-learning components is that it
- allows to prepare multiple batches in parallel in multiple `CPUWorker`, and ensure
- that the `GPUWorker` never wait for a batch to be ready.
-
- The overall architecture described in the following figure, for 3 CPU workers and 2
- GPU workers.
-
-
-
-
-
- Here is how a small pipeline with rule-based components and deep-learning components
- is distributed between the workers:
-
-
-
-
-
- Examples
- --------
-
- ```python
- docs = list(
- pipeline.pipe(
- [content1, content2, ...],
- accelerator={
- "@accelerator": "multiprocessing",
- "num_cpu_workers": 3,
- "num_gpu_workers": 2,
- "batch_size": 8,
- },
- )
- )
- ```
-
- Parameters
- ----------
- batch_size: int
- Number of documents to process at a time in a CPU/GPU worker
- num_cpu_workers: int
- Number of CPU workers. A CPU worker handles the non deep-learning components
- and the preprocessing, collating and postprocessing of deep-learning components.
- num_gpu_workers: Optional[int]
- Number of GPU workers. A GPU worker handles the forward call of the
- deep-learning components.
- gpu_pipe_names: Optional[List[str]]
- List of pipe names to accelerate on a GPUWorker, defaults to all pipes
- that inherit from TrainablePipe
+ Deprecated: Use `docs.map_pipeline(model).set_processing(...)` instead
"""
def __init__(
@@ -350,196 +27,3 @@ def __init__(
self.gpu_pipe_names = gpu_pipe_names
self.gpu_worker_devices = gpu_worker_devices
self.cpu_worker_devices = cpu_worker_devices
-
- def __call__(
- self,
- inputs: Iterable[Any],
- model: Any,
- to_doc: ToDoc = FromDictFieldsToDoc("content"),
- from_doc: FromDoc = lambda doc: doc,
- ):
- """
- Stream of documents to process. Each document can be a string or a tuple
-
- Parameters
- ----------
- inputs
- model
-
- Yields
- ------
- Any
- Processed outputs of the pipeline
- """
- if torch.multiprocessing.get_start_method() != "spawn":
- torch.multiprocessing.set_start_method("spawn", force=True)
-
- gpu_pipe_names = (
- [
- name
- for name, component in model.pipeline
- if isinstance(component, TrainablePipe)
- ]
- if self.gpu_pipe_names is None
- else self.gpu_pipe_names
- )
-
- if not all(model.has_pipe(name) for name in gpu_pipe_names):
- raise ValueError(
- "GPU accelerated pipes {} could not be found in the model".format(
- sorted(set(model.pipe_names) - set(gpu_pipe_names))
- )
- )
-
- num_devices = torch.cuda.device_count()
- print(f"Number of available devices: {num_devices}", flush=True)
-
- num_cpu_workers = self.num_cpu_workers
- num_gpu_workers = self.num_gpu_workers
-
- if num_gpu_workers is None:
- num_gpu_workers = num_devices if len(gpu_pipe_names) > 0 else 0
-
- if num_cpu_workers is None:
- num_cpu_workers = max(
- min(mp.cpu_count() - num_gpu_workers, DEFAULT_MAX_CPU_WORKERS), 0
- )
-
- if num_gpu_workers == 0:
- gpu_pipe_names = []
-
- gpu_worker_devices = (
- [
- torch.device(f"cuda:{gpu_idx * num_devices // num_gpu_workers}")
- for gpu_idx in range(num_gpu_workers)
- ]
- if self.gpu_worker_devices is None
- else self.gpu_worker_devices
- )
- cpu_worker_devices = (
- ["cpu"] * num_cpu_workers
- if self.cpu_worker_devices is None
- else self.cpu_worker_devices
- )
- assert len(cpu_worker_devices) == num_cpu_workers
- assert len(gpu_worker_devices) == num_gpu_workers
- if num_cpu_workers == 0:
- (
- num_cpu_workers,
- num_gpu_workers,
- cpu_worker_devices,
- gpu_worker_devices,
- gpu_pipe_names,
- ) = (num_gpu_workers, 0, gpu_worker_devices, [], [])
-
- debug(f"Number of CPU workers: {num_cpu_workers}")
- debug(f"Number of GPU workers: {num_gpu_workers}")
-
- exchanger = Exchanger(
- num_stages=len(gpu_pipe_names),
- num_cpu_workers=num_cpu_workers,
- num_gpu_workers=num_gpu_workers,
- gpu_worker_devices=gpu_worker_devices,
- )
-
- cpu_workers = []
- gpu_workers = []
- model = model.to("cpu")
-
- for gpu_idx in range(num_gpu_workers):
- gpu_workers.append(
- GPUWorker(
- gpu_idx=gpu_idx,
- exchanger=exchanger,
- gpu_pipe_names=gpu_pipe_names,
- model=model,
- device=gpu_worker_devices[gpu_idx],
- )
- )
-
- for cpu_idx in range(num_cpu_workers):
- cpu_workers.append(
- CPUWorker(
- cpu_idx=cpu_idx,
- exchanger=exchanger,
- gpu_pipe_names=gpu_pipe_names,
- model=model,
- device=cpu_worker_devices[cpu_idx],
- )
- )
-
- for worker in (*cpu_workers, *gpu_workers):
- worker.start()
-
- try:
- num_max_enqueued = num_cpu_workers * 2 + 10
- # Number of input/output batch per process
- total_inputs = [0] * num_cpu_workers
- total_outputs = [0] * num_cpu_workers
- outputs_iterator = exchanger.iter_results()
-
- cpu_worker_indices = list(range(num_cpu_workers))
- inputs_iterator = (to_doc(i) for i in inputs)
- for i, pdfs_batch in enumerate(batchify(inputs_iterator, self.batch_size)):
- if sum(total_inputs) - sum(total_outputs) >= num_max_enqueued:
- outputs, cpu_idx, gpu_idx = next(outputs_iterator)
- if isinstance(outputs, BaseException):
- raise outputs # pragma: no cover
- yield from (from_doc(o) for o in outputs)
- total_outputs[cpu_idx] += 1
-
- # Shuffle to ensure the first process does not receive all the documents
- # in case of total_inputs - total_outputs equality
- shuffle(cpu_worker_indices)
- cpu_idx = min(
- cpu_worker_indices,
- key=lambda i: total_inputs[i] - total_outputs[i],
- )
- exchanger.put_cpu(pdfs_batch, stage=0, idx=cpu_idx)
- total_inputs[cpu_idx] += 1
-
- while sum(total_outputs) < sum(total_inputs):
- outputs, cpu_idx, gpu_idx = next(outputs_iterator)
- if isinstance(outputs, BaseException):
- raise outputs # pragma: no cover
- yield from (from_doc(o) for o in outputs)
- total_outputs[cpu_idx] += 1
- finally:
- # Send gpu and cpu process the order to stop processing data
- # We use the prioritized queue to ensure the stop signal is processed
- # before the next batch of data
- for i, worker in enumerate(gpu_workers):
- exchanger.gpu_inputs_queues[i][-1].put(None)
- debug("Asked gpu worker", i, "to stop processing data")
- for i, worker in enumerate(cpu_workers):
- exchanger.cpu_inputs_queues[i][-1].put(None)
- debug("Asked cpu worker", i, "to stop processing data")
-
- # Enqueue a final non prioritized STOP signal to ensure there remains no
- # data in the queues (cf drain loop in CPUWorker / GPUWorker)
- for i, worker in enumerate(gpu_workers):
- exchanger.gpu_inputs_queues[i][0].put(None)
- debug("Asked gpu", i, "to end")
- for i, worker in enumerate(gpu_workers):
- worker.join(timeout=5)
- debug("Joined gpu worker", i)
- for i, worker in enumerate(cpu_workers):
- exchanger.cpu_inputs_queues[i][0].put(None)
- debug("Asked cpu", i, "to end")
- for i, worker in enumerate(cpu_workers):
- worker.join(timeout=1)
- debug("Joined cpu worker", i)
-
- # If a worker is still alive, kill it
- # This should not happen, but for a reason I cannot explain, it does in
- # some CPU workers sometimes when we catch an error, even though each run
- # method of the workers completes cleanly. Maybe this has something to do
- # with the cleanup of these processes ?
- for i, worker in enumerate(gpu_workers): # pragma: no cover
- if worker.is_alive():
- print("Killing gpu worker", i)
- worker.kill()
- for i, worker in enumerate(cpu_workers): # pragma: no cover
- if worker.is_alive():
- print("Killing cpu worker", i)
- worker.kill()
diff --git a/edspdf/accelerators/simple.py b/edspdf/accelerators/simple.py
deleted file mode 100644
index 340dbebb..00000000
--- a/edspdf/accelerators/simple.py
+++ /dev/null
@@ -1,92 +0,0 @@
-from typing import Any, Dict, Iterable
-
-import torch
-
-from ..registry import registry
-from ..utils.collections import batchify
-from .base import Accelerator, FromDictFieldsToDoc, FromDoc, ToDoc
-
-
-@registry.accelerator.register("simple")
-class SimpleAccelerator(Accelerator):
- """
- This is the simplest accelerator which batches the documents and process each batch
- on the main process (the one calling `.pipe()`).
-
- Examples
- --------
-
- ```python
- docs = list(pipeline.pipe([content1, content2, ...]))
- ```
-
- or, if you want to override the model defined batch size
-
- ```python
- docs = list(pipeline.pipe([content1, content2, ...], batch_size=8))
- ```
-
- which is equivalent to passing a confit dict
-
- ```python
- docs = list(
- pipeline.pipe(
- [content1, content2, ...],
- accelerator={
- "@accelerator": "simple",
- "batch_size": 8,
- },
- )
- )
- ```
-
- or the instantiated accelerator directly
-
- ```python
- from edspdf.accelerators.simple import SimpleAccelerator
-
- accelerator = SimpleAccelerator(batch_size=8)
- docs = list(pipeline.pipe([content1, content2, ...], accelerator=accelerator))
- ```
-
- If you have a GPU, make sure to move the model to the appropriate device before
- calling `.pipe()`. If you have multiple GPUs, use the
- [multiprocessing][edspdf.accelerators.multiprocessing.MultiprocessingAccelerator]
- accelerator instead.
-
- ```python
- pipeline.to("cuda")
- docs = list(pipeline.pipe([content1, content2, ...]))
- ```
-
- Parameters
- ----------
- batch_size: int
- The number of documents to process in each batch.
- """
-
- def __init__(
- self,
- *,
- batch_size: int = 32,
- ):
- self.batch_size = batch_size
-
- def __call__(
- self,
- inputs: Iterable[Any],
- model: Any,
- to_doc: ToDoc = FromDictFieldsToDoc("content"),
- from_doc: FromDoc = lambda doc: doc,
- component_cfg: Dict[str, Dict[str, Any]] = None,
- ):
- docs = (to_doc(doc) for doc in inputs)
- for batch in batchify(docs, batch_size=self.batch_size):
- with torch.no_grad(), model.cache(), model.train(False):
- for name, pipe in model.pipeline:
- if name not in model._disabled:
- if hasattr(pipe, "batch_process"):
- batch = pipe.batch_process(batch)
- else:
- batch = [pipe(doc) for doc in batch] # type: ignore
- yield from (from_doc(doc) for doc in batch)
diff --git a/edspdf/data/__init__.py b/edspdf/data/__init__.py
new file mode 100644
index 00000000..14fc755e
--- /dev/null
+++ b/edspdf/data/__init__.py
@@ -0,0 +1,11 @@
+from typing import TYPE_CHECKING
+from edspdf.utils.lazy_module import lazify
+
+lazify()
+
+if TYPE_CHECKING:
+ from .base import from_iterable, to_iterable
+ from .files import read_files, write_files
+ from .parquet import read_parquet, write_parquet
+ from .pandas import from_pandas, to_pandas
+ from .converters import get_dict2doc_converter, get_doc2dict_converter
diff --git a/edspdf/data/base.py b/edspdf/data/base.py
new file mode 100644
index 00000000..cbe5917e
--- /dev/null
+++ b/edspdf/data/base.py
@@ -0,0 +1,180 @@
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ Optional,
+ Sequence,
+ Tuple,
+ TypeVar,
+ Union,
+)
+
+from edspdf.lazy_collection import LazyCollection
+
+from .converters import get_dict2doc_converter, get_doc2dict_converter
+
+
+class BaseReader:
+ """
+ The BaseReader servers as a base class for all readers. It expects two methods:
+
+ - `read_main` method which is called in the main process and should return a
+ generator of fragments (like filenames) with their estimated size (number of
+ documents)
+ - `read_worker` method which is called in the worker processes and receives
+ batches of fragments and should return a list of dictionaries (one per
+ document), ready to be converted to a Doc object by the converter.
+
+ Additionally, the subclass should define a `DATA_FIELDS` class attribute which
+ contains the names of all attributes that should not be copied when the reader is
+ copied to the worker processes. This is useful for example when the reader holds a
+ reference to a large object like a DataFrame that should not be copied to the
+ worker processes.
+ """
+
+ DATA_FIELDS = ()
+
+ def read_main(self) -> Iterable[Tuple[Any, int]]:
+ raise NotImplementedError()
+
+ def read_worker(self, fragment: Iterable[Any]) -> Iterable[Dict]:
+ raise NotImplementedError()
+
+ def worker_copy(self):
+ # new reader without data, this will not call __init__ since we use __dict__
+ # to set the data
+ reader = self.__class__.__new__(self.__class__)
+ state = {
+ k: v
+ for k, v in self.__dict__.items()
+ if k not in self.__class__.DATA_FIELDS
+ }
+ reader.__dict__ = state
+ return reader
+
+
+T = TypeVar("T")
+
+
+class BaseWriter:
+ def write_worker(self, records: Sequence[Any]) -> T:
+ raise NotImplementedError()
+
+ def write_main(self, fragments: Iterable[T]):
+ raise NotImplementedError()
+
+ def finalize(self):
+ return None, 0
+
+
+class IterableReader(BaseReader):
+ DATA_FIELDS = ("data",)
+
+ def __init__(self, data: Iterable):
+ self.data = data
+
+ super().__init__()
+
+ def read_main(self) -> Iterable[Tuple[Any, int]]:
+ return ((item, 1) for item in self.data)
+
+ def read_worker(self, fragments):
+ return [task for task in fragments]
+
+
+def from_iterable(
+ data: Iterable,
+ converter: Union[str, Callable] = None,
+ **kwargs,
+) -> LazyCollection:
+ """
+ The IterableReader (or `edsnlp.data.from_iterable`) reads a list of Python objects (
+ texts, dictionaries, ...) and yields documents by passing them through the
+ `converter` if given, or returns them as is.
+
+ Example
+ -------
+ ```{ .python .no-check }
+
+ import edsnlp
+
+ nlp = edsnlp.blank("eds")
+ nlp.add_pipe(...)
+ doc_iterator = edsnlp.data.from_iterable([{...}], nlp=nlp, converter=...)
+ annotated_docs = nlp.pipe(doc_iterator)
+ ```
+
+ !!! note "Generator vs list"
+
+ `edsnlp.data.from_iterable` returns a
+ [LazyCollection][edspdf.lazy_collection.LazyCollection].
+ To iterate over the documents multiple times efficiently or to access them by
+ index, you must convert it to a list
+
+ ```{ .python .no-check }
+ docs = list(edsnlp.data.from_iterable([{...}], converter=...)
+ ```
+
+ Parameters
+ ----------
+ data: Iterable
+ The data to read
+ converter: Optional[Union[str, Callable]]
+ Converter to use to convert the JSON rows of the data source to Doc objects
+ kwargs:
+ Additional keyword arguments to pass to the converter. These are documented
+ on the [Data schemas](/data/schemas) page.
+
+ Returns
+ -------
+ LazyCollection
+ """
+ data = LazyCollection(reader=IterableReader(data))
+ if converter:
+ converter, kwargs = get_dict2doc_converter(converter, kwargs)
+ data = data.map(converter, kwargs=kwargs)
+ return data
+
+
+def to_iterable(
+ data: Union[Any, LazyCollection],
+ converter: Optional[Union[str, Callable]] = None,
+ **kwargs,
+):
+ """
+ `edsnlp.data.to_items` returns an iterator of documents, as converted by the
+ `converter`. In comparison to just iterating over a LazyCollection, this will
+ also apply the `converter` to the documents, which can lower the data transfer
+ overhead when using multiprocessing.
+
+ Example
+ -------
+ ```{ .python .no-check }
+
+ import edsnlp
+
+ nlp = edsnlp.blank("eds")
+ nlp.add_pipe(...)
+
+ doc = nlp("My document with entities")
+
+ edsnlp.data.to_items([doc], converter="omop")
+ ```
+
+ Parameters
+ ----------
+ data: Union[Any, LazyCollection],
+ The data to write (either a list of documents or a LazyCollection).
+ converter: Optional[Union[str, Callable]]
+ Converter to use to convert the documents to dictionary objects.
+ kwargs:
+ Additional keyword arguments passed to the converter. These are documented
+ on the [Data schemas](/data/schemas) page.
+ """
+ data = LazyCollection.ensure_lazy(data)
+ if converter:
+ converter, kwargs = get_doc2dict_converter(converter, kwargs)
+ data = data.map(converter, kwargs=kwargs)
+
+ return data
diff --git a/edspdf/data/converters.py b/edspdf/data/converters.py
new file mode 100644
index 00000000..397be290
--- /dev/null
+++ b/edspdf/data/converters.py
@@ -0,0 +1,88 @@
+"""
+Converters are used to convert documents between python dictionaries and Doc objects.
+There are two types of converters: readers and writers. Readers convert dictionaries to
+Doc objects, and writers convert Doc objects to dictionaries.
+"""
+import inspect
+from copy import copy
+from types import FunctionType
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Optional,
+ Tuple,
+)
+
+from confit.registry import ValidatedFunction
+
+FILENAME = "__FILENAME__"
+CONTENT = "__CONTENT__"
+
+SCHEMA = {}
+
+
+def validate_kwargs(converter, kwargs):
+ converter: FunctionType = copy(converter)
+ spec = inspect.getfullargspec(converter)
+ first = spec.args[0]
+ converter.__annotations__[first] = Optional[Any]
+ converter.__defaults__ = (None, *(spec.defaults or ())[-len(spec.args) + 1 :])
+ vd = ValidatedFunction(converter, {"arbitrary_types_allowed": True})
+ model = vd.init_model_instance(**kwargs)
+ d = {
+ k: v
+ for k, v in model._iter()
+ if (k in model.__fields__ or model.__fields__[k].default_factory)
+ }
+ d.pop("v__duplicate_kwargs", None) # see pydantic ValidatedFunction code
+ d.pop(vd.v_args_name, None)
+ d.pop(first, None)
+ return {**(d.pop(vd.v_kwargs_name, None) or {}), **d}
+
+
+def get_dict2doc_converter(converter: Callable, kwargs) -> Tuple[Callable, Dict]:
+ # kwargs_to_init = False
+ # if not callable(converter):
+ # available = edspdf.registry.factory.get_available()
+ # try:
+ # filtered = [
+ # name
+ # for name in available
+ # if converter == name or (converter in name and "dict2doc" in name)
+ # ]
+ # converter = edspdf.registry.factory.get(filtered[0])
+ # converter = converter(**kwargs).instantiate(nlp=None)
+ # kwargs = {}
+ # return converter, kwargs
+ # except (KeyError, IndexError):
+ # available = [v for v in available if "dict2doc" in v]
+ # raise ValueError(
+ # f"Cannot find converter for format {converter}. "
+ # f"Available converters are {', '.join(available)}"
+ # )
+ # if isinstance(converter, type) or kwargs_to_init:
+ # return converter(**kwargs), {}
+ return converter, validate_kwargs(converter, kwargs)
+
+
+def get_doc2dict_converter(converter: Callable, kwargs) -> Tuple[Callable, Dict]:
+ # if not callable(converter):
+ # available = edspdf.registry.factory.get_available()
+ # try:
+ # filtered = [
+ # name
+ # for name in available
+ # if converter == name or (converter in name and "doc2dict" in name)
+ # ]
+ # converter = edspdf.registry.factory.get(filtered[0])
+ # converter = converter(**kwargs).instantiate(nlp=None)
+ # kwargs = {}
+ # return converter, kwargs
+ # except (KeyError, IndexError):
+ # available = [v for v in available if "doc2dict" in v]
+ # raise ValueError(
+ # f"Cannot find converter for format {converter}. "
+ # f"Available converters are {', '.join(available)}"
+ # )
+ return converter, validate_kwargs(converter, kwargs)
diff --git a/edspdf/data/files.py b/edspdf/data/files.py
new file mode 100644
index 00000000..94cda772
--- /dev/null
+++ b/edspdf/data/files.py
@@ -0,0 +1,337 @@
+# ruff: noqa: F401
+import json
+import os
+from collections import Counter
+from pathlib import Path
+from typing import (
+ Any,
+ Callable,
+ List,
+ Optional,
+ Union,
+)
+
+import pyarrow
+import pyarrow.fs
+from fsspec import AbstractFileSystem
+from fsspec.implementations.arrow import ArrowFSWrapper
+from loguru import logger
+
+from edspdf import registry
+from edspdf.data.base import BaseReader, BaseWriter
+from edspdf.data.converters import (
+ CONTENT,
+ FILENAME,
+ get_dict2doc_converter,
+ get_doc2dict_converter,
+)
+from edspdf.lazy_collection import LazyCollection
+from edspdf.utils.collections import flatten
+
+
+class FileReader(BaseReader):
+ DATA_FIELDS = ()
+
+ def __init__(
+ self,
+ path: Union[str, Path],
+ *,
+ keep_ipynb_checkpoints: bool = False,
+ load_annotations: bool = False,
+ filesystem: Optional[Any] = None,
+ ):
+ super().__init__()
+
+ if filesystem is None or (isinstance(path, str) and "://" in path):
+ path = (
+ path
+ if isinstance(path, Path) or "://" in path
+ else f"file://{os.path.abspath(path)}"
+ )
+ inferred_fs, fs_path = pyarrow.fs.FileSystem.from_uri(path)
+ filesystem = filesystem or inferred_fs
+ assert inferred_fs.type_name == filesystem.type_name, (
+ f"Protocol {inferred_fs.type_name} in path does not match "
+ f"filesystem {filesystem.type_name}"
+ )
+ path = fs_path
+
+ self.path = path
+ self.filesystem = (
+ ArrowFSWrapper(filesystem)
+ if isinstance(filesystem, pyarrow.fs.FileSystem)
+ else filesystem
+ )
+ self.load_annotations = load_annotations
+ if not self.filesystem.exists(path):
+ raise FileNotFoundError(f"Path {path} does not exist")
+
+ self.files: List[str] = [
+ file
+ for file in self.filesystem.glob(os.path.join(str(self.path), "*.pdf"))
+ if (keep_ipynb_checkpoints or ".ipynb_checkpoints" not in str(file))
+ and (
+ not load_annotations
+ or self.filesystem.exists(str(path).replace(".pdf", ".json"))
+ )
+ ]
+ assert len(self.files), f"No .pdf files found in the directory {path}"
+ logger.info(f"The directory contains {len(self.files)} .pdf files.")
+
+ def read_main(self):
+ return ((f, 1) for f in self.files)
+
+ def read_worker(self, fragment):
+ tasks = []
+ for path in fragment:
+ with self.filesystem.open(str(path), "rb") as f:
+ content = f.read()
+
+ json_path = str(path).replace(".pdf", ".json")
+
+ record = {"content": content}
+ if self.load_annotations and self.filesystem.exists(json_path):
+ with self.filesystem.open(json_path) as f:
+ record["annotations"] = json.load(f)
+
+ record[FILENAME] = str(os.path.relpath(path, self.path)).rsplit(".", 1)[0]
+ record["id"] = record[FILENAME]
+ tasks.append(record)
+ return tasks
+
+
+class FileWriter(BaseWriter):
+ def __init__(
+ self,
+ path: Union[str, Path],
+ *,
+ overwrite: bool = False,
+ filesystem: Optional[AbstractFileSystem] = None,
+ ):
+ fs_path = path
+ if filesystem is None or (isinstance(path, str) and "://" in path):
+ path = (
+ path
+ if isinstance(path, Path) or "://" in path
+ else f"file://{os.path.abspath(path)}"
+ )
+ inferred_fs, fs_path = pyarrow.fs.FileSystem.from_uri(path)
+ filesystem = filesystem or inferred_fs
+ assert inferred_fs.type_name == filesystem.type_name, (
+ f"Protocol {inferred_fs.type_name} in path does not match "
+ f"filesystem {filesystem.type_name}"
+ )
+ path = fs_path
+
+ self.path = path
+ self.filesystem = (
+ ArrowFSWrapper(filesystem)
+ if isinstance(filesystem, pyarrow.fs.FileSystem)
+ else filesystem
+ )
+ self.filesystem.mkdirs(fs_path, exist_ok=True)
+
+ if self.filesystem.exists(self.path):
+ suffixes = Counter(f.suffix for f in self.filesystem.listdir(self.path))
+ unsafe_suffixes = {
+ s: v for s, v in suffixes.items() if s == ".pdf" or s == ".json"
+ }
+ if unsafe_suffixes and not overwrite:
+ raise FileExistsError(
+ f"Directory {self.path} already exists and appear to contain "
+ "annotations:"
+ + "".join(f"\n -{s}: {v} files" for s, v in unsafe_suffixes.items())
+ + "\nUse overwrite=True to write files anyway."
+ )
+
+ self.filesystem.mkdirs(path, exist_ok=True)
+
+ super().__init__()
+
+ def write_worker(self, records):
+ # If write as jsonl, we will perform the actual writing in the `write` method
+ results = []
+ for rec in flatten(records):
+ filename = str(rec.pop(FILENAME))
+ path = os.path.join(self.path, f"{filename}.pdf")
+ parent_dir = filename.rsplit("/", 1)[0]
+ if parent_dir and not self.filesystem.exists(parent_dir):
+ self.filesystem.makedirs(parent_dir, exist_ok=True)
+ if CONTENT in rec:
+ content = rec.pop(CONTENT)
+ with self.filesystem.open(path, "wb") as f:
+ f.write(content)
+ ann_path = str(path).replace(".pdf", ".json")
+
+ with self.filesystem.open(ann_path, "w") as f:
+ json.dump(rec, f)
+
+ results.append(path)
+ return results, len(results)
+
+ def write_main(self, fragments):
+ return list(flatten(fragments))
+
+
+# noinspection PyIncorrectDocstring
+@registry.readers.register("files")
+def read_files(
+ path: Union[str, Path],
+ *,
+ keep_ipynb_checkpoints: bool = False,
+ load_annotations: bool = False,
+ converter: Optional[Union[str, Callable]] = None,
+ filesystem: Optional[Any] = None,
+ **kwargs,
+) -> LazyCollection:
+ """
+ The BratReader (or `edspdf.data.read_files`) reads a directory of BRAT files and
+ yields documents. At the moment, only entities and attributes are loaded. Relations
+ and events are not supported.
+
+ Example
+ -------
+ ```{ .python .no-check }
+
+ import edspdf
+
+ nlp = edspdf.blank("eds")
+ nlp.add_pipe(...)
+ doc_iterator = edspdf.data.read_files("path/to/brat/directory")
+ annotated_docs = nlp.pipe(doc_iterator)
+ ```
+
+ !!! note "Generator vs list"
+
+ `edspdf.data.read_files` returns a
+ [LazyCollection][edspdf.core.lazy_collection.LazyCollection].
+ To iterate over the documents multiple times efficiently or to access them by
+ index, you must convert it to a list :
+
+ ```{ .python .no-check }
+ docs = list(edspdf.data.read_files("path/to/brat/directory"))
+ ```
+
+ !!! warning "True/False attributes"
+
+ Boolean values are not supported by the BRAT editor, and are stored as empty
+ (key: empty value) if true, and not stored otherwise. This means that False
+ values will not be assigned to attributes by default, which can be problematic
+ when deciding if an entity is negated or not : is the entity not negated, or
+ has the negation attribute not been annotated ?
+
+ To avoid this issue, you can use the `bool_attributes` argument to specify
+ which attributes should be considered as boolean when reading a BRAT dataset.
+ These attributes will be assigned a value of `True` if they are present, and
+ `False` otherwise.
+
+ ```{ .python .no-check }
+ doc_iterator = edspdf.data.read_files(
+ "path/to/brat/directory",
+ # Mapping from 'BRAT attribute name' to 'Doc attribute name'
+ span_attributes={"Negation": "negated"},
+ bool_attributes=["negated"], # Missing values will be set to False
+ )
+ ```
+
+ Parameters
+ ----------
+ path : Union[str, Path]
+ Path to the directory containing the BRAT files (will recursively look for
+ files in subdirectories).
+ nlp : Optional[PipelineProtocol]
+ The pipeline object (optional and likely not needed, prefer to use the
+ `tokenizer` directly argument instead).
+ keep_ipynb_checkpoints : bool
+ Whether to keep files in the `.ipynb_checkpoints` directories.
+ load_annotations : bool
+ Whether to load annotations from the `.json` files that share the same name as
+ the `.pdf` files.
+ converter : Optional[Union[str, Callable]]
+ Converter to use to convert the dictionary objects to documents.
+ filesystem: Optional[AbstractFileSystem]
+ The filesystem to use to write the files. If not set, the local filesystem
+ will be used.
+
+
+ Returns
+ -------
+ LazyCollection
+ """
+ data = LazyCollection(
+ reader=FileReader(
+ path,
+ keep_ipynb_checkpoints=keep_ipynb_checkpoints,
+ load_annotations=load_annotations,
+ filesystem=filesystem,
+ )
+ )
+ if converter:
+ converter, kwargs = get_dict2doc_converter(converter, kwargs)
+ data = data.map(converter, kwargs=kwargs)
+ return data
+
+
+@registry.writers.register("files")
+def write_files(
+ data: Union[Any, LazyCollection],
+ path: Union[str, Path],
+ *,
+ overwrite: bool = False,
+ converter: Union[str, Callable],
+ filesystem: Optional[AbstractFileSystem] = None,
+ **kwargs,
+) -> None:
+ """
+ `edspdf.data.write_files` writes a list of documents using the BRAT/File
+ format in a directory. The BRAT files will be named after the `note_id` attribute of
+ the documents, and subdirectories will be created if the name contains `/`
+ characters.
+
+ Example
+ -------
+ ```{ .python .no-check }
+
+ import edspdf
+
+ nlp = edspdf.blank("eds")
+ nlp.add_pipe(...)
+
+ doc = nlp("My document with entities")
+
+ edspdf.data.write_files([doc], "path/to/brat/directory")
+ ```
+
+ !!! warning "Overwriting files"
+
+ By default, `write_files` will raise an error if the directory already exists
+ and contains files with `.json` or `.pdf` suffixes. This is to avoid overwriting
+ existing annotations. To allow overwriting existing files, use `overwrite=True`.
+
+ Parameters
+ ----------
+ data: Union[Any, LazyCollection],
+ The data to write (either a list of documents or a LazyCollection).
+ path: Union[str, Path]
+ Path to the directory containing the BRAT files (will recursively look for
+ files in subdirectories).
+ overwrite: bool
+ Whether to overwrite existing directories.
+ converter: Optional[Union[str, Callable]]
+ Converter to use to convert the documents to dictionary objects.
+ filesystem: Optional[AbstractFileSystem]
+ The filesystem to use to write the files. If not set, the local filesystem
+ will be used.
+ """
+ data = LazyCollection.ensure_lazy(data)
+ if converter:
+ converter, kwargs = get_doc2dict_converter(converter, kwargs)
+ data = data.map(converter, kwargs=kwargs)
+
+ return data.write(
+ FileWriter(
+ path=path,
+ filesystem=filesystem,
+ overwrite=overwrite,
+ )
+ )
diff --git a/edspdf/data/pandas.py b/edspdf/data/pandas.py
new file mode 100644
index 00000000..c97fedea
--- /dev/null
+++ b/edspdf/data/pandas.py
@@ -0,0 +1,155 @@
+from __future__ import annotations
+
+from typing import Any, Callable, Iterable, Optional, Tuple, Union
+
+import pandas as pd
+
+from edspdf import registry
+from edspdf.data.base import BaseReader, BaseWriter
+from edspdf.data.converters import (
+ FILENAME,
+ get_dict2doc_converter,
+ get_doc2dict_converter,
+)
+from edspdf.lazy_collection import LazyCollection
+from edspdf.utils.collections import dl_to_ld, flatten, ld_to_dl
+
+
+class PandasReader(BaseReader):
+ DATA_FIELDS = ("data",)
+
+ def __init__(
+ self,
+ data: pd.DataFrame,
+ **kwargs,
+ ):
+ assert isinstance(data, pd.DataFrame)
+ self.data = data
+
+ super().__init__(**kwargs)
+
+ def read_main(self) -> Iterable[Tuple[Any, int]]:
+ return ((item, 1) for item in dl_to_ld(dict(self.data)))
+
+ def read_worker(self, fragments):
+ return [task for task in fragments]
+
+
+@registry.readers.register("pandas")
+def from_pandas(
+ data,
+ converter: Union[str, Callable],
+ **kwargs,
+) -> LazyCollection:
+ """
+ The PandasReader (or `edspdf.data.from_pandas`) handles reading from a table and
+ yields documents. At the moment, only entities and attributes are loaded. Relations
+ and events are not supported.
+
+ Example
+ -------
+ ```{ .python .no-check }
+
+ import edspdf
+
+ nlp = edspdf.blank("eds")
+ nlp.add_pipe(...)
+ doc_iterator = edspdf.data.from_pandas(df, nlp=nlp, converter="omop")
+ annotated_docs = nlp.pipe(doc_iterator)
+ ```
+
+ !!! note "Generator vs list"
+
+ `edspdf.data.from_pandas` returns a
+ [LazyCollection][edspdf.core.lazy_collection.LazyCollection].
+ To iterate over the documents multiple times efficiently or to access them by
+ index, you must convert it to a list
+
+ ```{ .python .no-check }
+ docs = list(edspdf.data.from_pandas(df, converter="omop"))
+ ```
+
+ Parameters
+ ----------
+ data: pd.DataFrame
+ Pandas object
+ converter: Optional[Union[str, Callable]]
+ Converter to use to convert the rows of the DataFrame to Doc objects
+ kwargs:
+ Additional keyword arguments passed to the converter. These are documented
+ on the [Data schemas](/data/schemas) page.
+
+ Returns
+ -------
+ LazyCollection
+ """
+
+ data = LazyCollection(reader=PandasReader(data))
+ if converter:
+ converter, kwargs = get_dict2doc_converter(converter, kwargs)
+ data = data.map(converter, kwargs=kwargs)
+ return data
+
+
+class PandasWriter(BaseWriter):
+ def __init__(self, dtypes: Optional[dict] = None):
+ self.dtypes = dtypes
+
+ def write_worker(self, records):
+ # If write as jsonl, we will perform the actual writing in the `write` method
+ for rec in records:
+ if isinstance(rec, dict):
+ rec.pop(FILENAME, None)
+ return records, len(records)
+
+ def write_main(self, fragments):
+ import pandas as pd
+
+ columns = ld_to_dl(flatten(fragments))
+ res = pd.DataFrame(columns)
+ return res.astype(self.dtypes) if self.dtypes else res
+
+
+@registry.writers.register("pandas")
+def to_pandas(
+ data: Union[Any, LazyCollection],
+ converter: Optional[Union[str, Callable]],
+ dtypes: Optional[dict] = None,
+ **kwargs,
+) -> pd.DataFrame:
+ """
+ `edspdf.data.to_pandas` writes a list of documents as a pandas table.
+
+ Example
+ -------
+ ```{ .python .no-check }
+
+ import edspdf
+
+ nlp = edspdf.blank("eds")
+ nlp.add_pipe(...)
+
+ doc = nlp("My document with entities")
+
+ edspdf.data.to_pandas([doc], converter="omop")
+ ```
+
+ Parameters
+ ----------
+ data: Union[Any, LazyCollection],
+ The data to write (either a list of documents or a LazyCollection).
+ converter: Optional[Union[str, Callable]]
+ Converter to use to convert the documents to dictionary objects before storing
+ them in the dataframe.
+ dtypes: Optional[dict]
+ Dictionary of column names to dtypes. This is passed to `pd.DataFrame.astype`.
+ kwargs:
+ Additional keyword arguments passed to the converter. These are documented
+ on the [Data schemas](/data/schemas) page.
+ """
+ data = LazyCollection.ensure_lazy(data)
+ if converter:
+ converter, kwargs = get_doc2dict_converter(converter, kwargs)
+ data = data.map(converter, kwargs=kwargs)
+
+ return data.write(PandasWriter(dtypes))
diff --git a/edspdf/data/parquet.py b/edspdf/data/parquet.py
new file mode 100644
index 00000000..813ff984
--- /dev/null
+++ b/edspdf/data/parquet.py
@@ -0,0 +1,226 @@
+import os
+from itertools import chain
+from pathlib import Path
+from typing import Any, Callable, Dict, Iterable, List, Optional, TypeVar, Union
+
+import pyarrow.dataset
+import pyarrow.fs
+import pyarrow.parquet
+from pyarrow.dataset import ParquetFileFragment
+
+from edspdf.data.base import BaseReader, BaseWriter
+from edspdf.data.converters import (
+ FILENAME,
+ get_dict2doc_converter,
+ get_doc2dict_converter,
+)
+from edspdf.lazy_collection import LazyCollection
+from edspdf.structures import PDFDoc, registry
+from edspdf.utils.collections import dl_to_ld, flatten, ld_to_dl
+
+
+class ParquetReader(BaseReader):
+ DATA_FIELDS = ("dataset",)
+
+ def __init__(
+ self,
+ path: Union[str, Path],
+ *,
+ read_in_worker: bool,
+ filesystem: Optional[pyarrow.fs.FileSystem] = None,
+ ):
+ super().__init__()
+ # Either the filesystem has not been passed
+ # or the path is a URL (e.g. s3://) => we need to infer the filesystem
+ fs_path = path
+ if filesystem is None or (isinstance(path, str) and "://" in path):
+ path = (
+ path
+ if isinstance(path, Path) or "://" in path
+ else f"file://{os.path.abspath(path)}"
+ )
+ inferred_fs, fs_path = pyarrow.fs.FileSystem.from_uri(path)
+ filesystem = filesystem or inferred_fs
+ assert inferred_fs.type_name == filesystem.type_name, (
+ f"Protocol {inferred_fs.type_name} in path does not match "
+ f"filesystem {filesystem.type_name}"
+ )
+ self.read_in_worker = read_in_worker
+ self.dataset = pyarrow.dataset.dataset(
+ fs_path, format="parquet", filesystem=filesystem
+ )
+
+ def read_main(self):
+ fragments: List[ParquetFileFragment] = self.dataset.get_fragments()
+ if self.read_in_worker:
+ # read in worker -> each task is a file to read from
+ return ((f, f.metadata.num_rows) for f in fragments)
+ else:
+ # read in worker -> each task is a non yet parsed line
+ return (
+ (line, 1)
+ for f in fragments
+ for batch in f.to_table().to_batches(1024)
+ for line in dl_to_ld(batch.to_pydict())
+ )
+
+ def read_worker(self, tasks):
+ if self.read_in_worker:
+ tasks = list(
+ chain.from_iterable(
+ dl_to_ld(batch.to_pydict())
+ for task in tasks
+ for batch in task.to_table().to_batches(1024)
+ )
+ )
+ return tasks
+
+
+T = TypeVar("T")
+
+
+class ParquetWriter(BaseWriter):
+ def __init__(
+ self,
+ path: Union[str, Path],
+ num_rows_per_file: int,
+ overwrite: bool,
+ write_in_worker: bool,
+ accumulate: bool = True,
+ filesystem: Optional[pyarrow.fs.FileSystem] = None,
+ ):
+ super().__init__()
+ fs_path = path
+ if filesystem is None or (isinstance(path, str) and "://" in path):
+ path = (
+ path
+ if isinstance(path, Path) or "://" in path
+ else f"file://{os.path.abspath(path)}"
+ )
+ inferred_fs, fs_path = pyarrow.fs.FileSystem.from_uri(path)
+ filesystem = filesystem or inferred_fs
+ assert inferred_fs.type_name == filesystem.type_name, (
+ f"Protocol {inferred_fs.type_name} in path does not match "
+ f"filesystem {filesystem.type_name}"
+ )
+ path = fs_path
+ # Check that filesystem has the same protocol as indicated by path
+ filesystem.create_dir(fs_path, recursive=True)
+ if overwrite is False:
+ dataset = pyarrow.dataset.dataset(
+ fs_path, format="parquet", filesystem=filesystem
+ )
+ if len(list(dataset.get_fragments())):
+ raise FileExistsError(
+ f"Directory {fs_path} already exists and is not empty. "
+ "Use overwrite=True to overwrite."
+ )
+ self.filesystem = filesystem
+ self.path = path
+ self.write_in_worker = write_in_worker
+ self.batch = []
+ self.num_rows_per_file = num_rows_per_file
+ self.closed = False
+ self.finalized = False
+ self.accumulate = accumulate
+ if not self.accumulate:
+ self.finalize = super().finalize
+
+ def write_worker(self, records, last=False):
+ # Results will contain a batches of samples ready to be written (or None if
+ # write_in_worker is True) and they have already been written.
+ results = []
+ count = 0
+
+ for rec in records:
+ if isinstance(rec, dict):
+ rec.pop(FILENAME, None)
+
+ # While there is something to write
+ greedy = last or not self.accumulate
+ while len(records) or greedy and len(self.batch):
+ n_to_fill = self.num_rows_per_file - len(self.batch)
+ self.batch.extend(records[:n_to_fill])
+ records = records[n_to_fill:]
+ if greedy or len(self.batch) >= self.num_rows_per_file:
+ fragment = pyarrow.Table.from_pydict(ld_to_dl(flatten(self.batch)))
+ count += len(self.batch)
+ self.batch = []
+ if self.write_in_worker:
+ pyarrow.parquet.write_to_dataset(
+ table=fragment,
+ root_path=self.path,
+ filesystem=self.filesystem,
+ )
+ fragment = None
+ results.append(fragment)
+ return results, count
+
+ def finalize(self):
+ if not self.finalized:
+ self.finalized = True
+ return self.write_worker([], last=True)
+
+ def write_main(self, fragments: Iterable[List[Union[pyarrow.Table, Path]]]):
+ for table in flatten(fragments):
+ if not self.write_in_worker:
+ pyarrow.parquet.write_to_dataset(
+ table=table,
+ root_path=self.path,
+ filesystem=self.filesystem,
+ )
+ return pyarrow.dataset.dataset(
+ self.path, format="parquet", filesystem=self.filesystem
+ )
+
+
+@registry.readers.register("parquet")
+def read_parquet(
+ path: Union[str, Path],
+ converter: Union[str, Callable],
+ *,
+ read_in_worker: bool = False,
+ filesystem: Optional[pyarrow.fs.FileSystem] = None,
+ **kwargs,
+) -> LazyCollection:
+ data = LazyCollection(
+ reader=ParquetReader(
+ path,
+ read_in_worker=read_in_worker,
+ filesystem=filesystem,
+ )
+ )
+ if converter:
+ converter, kwargs = get_dict2doc_converter(converter, kwargs)
+ data = data.map(converter, kwargs=kwargs)
+ return data
+
+
+@registry.writers.register("parquet")
+def write_parquet(
+ data: Union[Any, LazyCollection],
+ path: Union[str, Path],
+ *,
+ write_in_worker: bool = False,
+ num_rows_per_file: int = 1024,
+ overwrite: bool = False,
+ filesystem: Optional[pyarrow.fs.FileSystem] = None,
+ accumulate: bool = True,
+ converter: Optional[Union[str, Callable[[PDFDoc], Dict]]],
+ **kwargs,
+) -> None:
+ data = LazyCollection.ensure_lazy(data)
+ if converter:
+ converter, kwargs = get_doc2dict_converter(converter, kwargs)
+ data = data.map(converter, kwargs=kwargs)
+
+ return data.write(
+ ParquetWriter(
+ path,
+ num_rows_per_file=num_rows_per_file,
+ overwrite=overwrite,
+ write_in_worker=write_in_worker,
+ accumulate=accumulate,
+ filesystem=filesystem,
+ )
+ )
diff --git a/edspdf/lazy_collection.py b/edspdf/lazy_collection.py
new file mode 100644
index 00000000..e3b60b9f
--- /dev/null
+++ b/edspdf/lazy_collection.py
@@ -0,0 +1,345 @@
+from __future__ import annotations
+
+import contextlib
+from functools import wraps
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Container,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+)
+
+from typing_extensions import Literal
+
+import edspdf.data
+
+if TYPE_CHECKING:
+ import torch
+
+ from edspdf import Pipeline
+ from edspdf.data.base import BaseReader, BaseWriter
+ from edspdf.trainable_pipe import TrainablePipe
+
+INFER = type("INFER", (), {"__repr__": lambda self: "INFER"})()
+
+
+def with_non_default_args(fn: Callable) -> Callable:
+ @wraps(fn)
+ def wrapper(self, **kwargs):
+ return fn(self, **kwargs, _non_default_args=kwargs.keys())
+
+ return wrapper
+
+
+class MetaLazyCollection(type):
+ def __getattr__(self, item):
+ if item in edspdf.data.__all__:
+ fn = getattr(edspdf.data, item)
+ setattr(self, item, fn)
+ return fn
+ raise AttributeError(item)
+
+ def __dir__(self): # pragma: no cover
+ return (*super().__dir__(), *edspdf.data.__all__)
+
+
+class LazyCollection(metaclass=MetaLazyCollection):
+ def __init__(
+ self,
+ reader: Optional[BaseReader] = None,
+ writer: Optional[BaseWriter] = None,
+ pipeline: List[Any] = [],
+ config={},
+ ):
+ self.reader = reader
+ self.writer = writer
+ self.pipeline: List[Tuple[str, Callable, Dict]] = pipeline
+ self.config = config
+
+ @property
+ def batch_size(self):
+ return self.config.get("batch_size", 1)
+
+ @property
+ def batch_by(self):
+ return self.config.get("batch_by", "docs")
+
+ @property
+ def sort_chunks(self):
+ return self.config.get("sort_chunks", False)
+
+ @property
+ def split_into_batches_after(self):
+ return self.config.get("split_into_batches_after")
+
+ @property
+ def chunk_size(self):
+ return self.config.get("chunk_size", self.config.get("batch_size", 128))
+
+ @property
+ def disable_implicit_parallelism(self):
+ return self.config.get("disable_implicit_parallelism", True)
+
+ @property
+ def num_cpu_workers(self):
+ return self.config.get("num_cpu_workers")
+
+ @property
+ def num_gpu_workers(self):
+ return self.config.get("num_gpu_workers")
+
+ @property
+ def gpu_pipe_names(self):
+ return self.config.get("gpu_pipe_names")
+
+ @property
+ def gpu_worker_devices(self):
+ return self.config.get("gpu_worker_devices")
+
+ @property
+ def cpu_worker_devices(self):
+ return self.config.get("cpu_worker_devices")
+
+ @property
+ def backend(self):
+ return self.config.get("backend")
+
+ @property
+ def show_progress(self):
+ return self.config.get("show_progress")
+
+ @property
+ def process_start_method(self):
+ return self.config.get("process_start_method")
+
+ @with_non_default_args
+ def set_processing(
+ self,
+ batch_size: int = 1,
+ batch_by: Literal["docs", "pages", "content_boxes"] = "docs",
+ chunk_size: int = INFER,
+ sort_chunks: bool = False,
+ split_into_batches_after: str = INFER,
+ num_cpu_workers: Optional[int] = INFER,
+ num_gpu_workers: Optional[int] = INFER,
+ disable_implicit_parallelism: bool = True,
+ backend: Optional[Literal["simple", "multiprocessing"]] = INFER,
+ gpu_pipe_names: Optional[List[str]] = INFER,
+ show_progress: bool = False,
+ process_start_method: Optional[Literal["fork", "spawn"]] = INFER,
+ gpu_worker_devices: Optional[List[str]] = INFER,
+ cpu_worker_devices: Optional[List[str]] = INFER,
+ _non_default_args: Iterable[str] = (),
+ ) -> "LazyCollection":
+ """
+ Parameters
+ ----------
+ batch_size: int
+ Number of documents to process at a time in a GPU worker (or in the
+ main process if no workers are used).
+ batch_by: Literal["docs", "pages", "content_boxes"]
+ How to compute the batch size:
+
+ - "docs" (default) is the number of documents.
+ - "pages" is the total number of pages in the documents.
+ - "content_boxes" is the total number of content boxes in the documents
+ chunk_size: int
+ Number of documents to build before splitting into batches. Only used
+ with "simple" and "multiprocessing" backends. This is also the number of
+ documents that will be passed through the first components of the pipeline
+ until a GPU worker is used (then the chunks will be split according to the
+ `batch_size` and `batch_by` arguments).
+
+ By default, the chunk size is equal to the batch size, or 128 if the batch
+ size is not set.
+ sort_chunks: bool
+ Whether to sort the documents by size before splitting into batches.
+ split_into_batches_after: str
+ The name of the component after which to split the documents into batches.
+ Only used with "simple" and "multiprocessing" backends.
+ By default, the documents are split into batches as soon as the input
+ are converted into Doc objects.
+ num_cpu_workers: int
+ Number of CPU workers. A CPU worker handles the non deep-learning components
+ and the preprocessing, collating and postprocessing of deep-learning
+ components. If no GPU workers are used, the CPU workers also handle the
+ forward call of the deep-learning components.
+ num_gpu_workers: Optional[int]
+ Number of GPU workers. A GPU worker handles the forward call of the
+ deep-learning components. Only used with "multiprocessing" backend.
+ disable_implicit_parallelism: bool
+ Whether to disable OpenMP and Huggingface tokenizers implicit parallelism in
+ multiprocessing mode. Defaults to True.
+ gpu_pipe_names: Optional[List[str]]
+ List of pipe names to accelerate on a GPUWorker, defaults to all pipes
+ that inherit from TrainablePipe. Only used with "multiprocessing" backend.
+ Inferred from the pipeline if not set.
+ backend: Optional[Literal["simple", "multiprocessing", "spark"]]
+ The backend to use for parallel processing. If not set, the backend is
+ automatically selected based on the input data and the number of workers.
+
+ - "simple" is the default backend and is used when `num_cpu_workers` is 1
+ and `num_gpu_workers` is 0.
+ - "multiprocessing" is used when `num_cpu_workers` is greater than 1 or
+ `num_gpu_workers` is greater than 0.
+ - "spark" is used when the input data is a Spark dataframe and the output
+ writer is a Spark writer.
+ show_progress: Optional[bool]
+ Whether to show progress bars (only applicable with "simple" and
+ "multiprocessing" backends).
+ process_start_method: Optional[Literal["fork", "spawn"]]
+ Whether to use "fork" or "spawn" as the start method for the multiprocessing
+ backend. The default is "fork" on Unix systems and "spawn" on Windows.
+
+ - "fork" is the default start method on Unix systems and is the fastest
+ start method, but it is not available on Windows, can cause issues
+ with CUDA and is not safe when using multiple threads.
+ - "spawn" is the default start method on Windows and is the safest start
+ method, but it is not available on Unix systems and is slower than
+ "fork".
+ gpu_worker_devices: Optional[List[str]]
+ List of GPU devices to use for the GPU workers. Defaults to all available
+ devices, one worker per device. Only used with "multiprocessing" backend.
+ cpu_worker_devices: Optional[List[str]]
+ List of GPU devices to use for the CPU workers. Used for debugging purposes.
+
+ Returns
+ -------
+ LazyCollection
+ """
+ kwargs = {k: v for k, v in locals().items() if k in _non_default_args}
+ return LazyCollection(
+ reader=self.reader,
+ writer=self.writer,
+ pipeline=self.pipeline,
+ config={
+ **self.config,
+ **{k: v for k, v in kwargs.items() if v is not INFER},
+ },
+ )
+
+ @classmethod
+ def ensure_lazy(cls, data):
+ from edspdf.data.base import IterableReader
+
+ if isinstance(data, cls):
+ return data
+ return cls(reader=IterableReader(data))
+
+ def map(self, pipe, name: Optional[str] = None, kwargs={}) -> "LazyCollection":
+ return LazyCollection(
+ reader=self.reader,
+ writer=self.writer,
+ pipeline=[*self.pipeline, (name, pipe, kwargs)],
+ config=self.config,
+ )
+
+ def map_pipeline(self, model: Pipeline) -> "LazyCollection":
+ new_steps = []
+ for name, pipe, kwargs in self.pipeline:
+ new_steps.append((name, pipe, kwargs))
+
+ new_steps.append((None, model.ensure_doc, {}))
+
+ for name, pipe in model.pipeline:
+ if name not in model._disabled:
+ new_steps.append((name, pipe, {}))
+ config = (
+ {**self.config, "batch_size": model.batch_size}
+ if self.batch_size is None
+ else self.config
+ )
+ return LazyCollection(
+ reader=self.reader,
+ writer=self.writer,
+ pipeline=new_steps,
+ config=config,
+ )
+
+ def write(self, writer: BaseWriter, execute: bool = True) -> Any:
+ lc = LazyCollection(
+ reader=self.reader,
+ writer=writer,
+ pipeline=self.pipeline,
+ config=self.config,
+ )
+ return lc.execute() if execute else lc
+
+ def execute(self):
+ import edspdf.processing
+
+ backend = self.backend
+ if backend is None:
+ if (
+ self.num_cpu_workers is not None
+ and self.num_cpu_workers > 1
+ or self.num_gpu_workers is not None
+ and self.num_gpu_workers > 0
+ ):
+ backend = "multiprocessing"
+ else:
+ backend = "simple"
+ execute = getattr(edspdf.processing, f"execute_{backend}_backend")
+ return execute(self)
+
+ def __iter__(self):
+ return iter(self.execute())
+
+ @contextlib.contextmanager
+ def cache(self):
+ for name, pipe, *_ in self.pipeline:
+ if hasattr(pipe, "enable_cache"):
+ pipe.enable_cache()
+ yield
+ for name, pipe, *_ in self.pipeline:
+ if hasattr(pipe, "disable_cache"):
+ pipe.disable_cache()
+
+ def torch_components(
+ self, disable: Container[str] = ()
+ ) -> Iterable[Tuple[str, "TrainablePipe"]]:
+ """
+ Yields components that are PyTorch modules.
+
+ Parameters
+ ----------
+ disable: Container[str]
+ The names of disabled components, which will be skipped.
+
+ Returns
+ -------
+ Iterable[Tuple[str, TrainablePipe]]
+ """
+ for name, pipe, *_ in self.pipeline:
+ if name not in disable and hasattr(pipe, "forward"):
+ yield name, pipe
+
+ def to(self, device: Union[str, Optional["torch.device"]] = None): # noqa F821
+ """Moves the pipeline to a given device"""
+ for name, pipe, *_ in self.torch_components():
+ pipe.to(device)
+ return self
+
+ def worker_copy(self):
+ return LazyCollection(
+ reader=self.reader.worker_copy(),
+ writer=self.writer,
+ pipeline=self.pipeline,
+ config=self.config,
+ )
+
+ def __dir__(self): # pragma: no cover
+ return (*super().__dir__(), *edspdf.data.__all__)
+
+ def __getattr__(self, item):
+ return getattr(LazyCollection, item).__get__(self)
+
+
+if TYPE_CHECKING:
+ # just to add read/from_* and write/to_* methods to the static type hints
+ LazyCollection = edspdf.data # noqa: F811
diff --git a/edspdf/pipeline.py b/edspdf/pipeline.py
index 05e1d226..83ecec1f 100644
--- a/edspdf/pipeline.py
+++ b/edspdf/pipeline.py
@@ -1,13 +1,15 @@
import functools
+import importlib
+import inspect
import json
import os
import shutil
import warnings
-from collections import defaultdict
from contextlib import contextmanager
from enum import Enum
from pathlib import Path
from typing import (
+ TYPE_CHECKING,
Any,
Callable,
Dict,
@@ -24,14 +26,13 @@
from confit import Config
from confit.errors import ConfitValidationError, patch_errors
-from confit.utils.collections import join_path, split_path
from confit.utils.xjson import Reference
-from pydantic import parse_obj_as
from typing_extensions import Literal
import edspdf
-from .accelerators.base import Accelerator, FromDoc, ToDoc
+from .data.converters import FILENAME
+from .lazy_collection import LazyCollection
from .registry import CurriedFactory, registry
from .structures import PDFDoc
from .utils.collections import (
@@ -41,6 +42,9 @@
multi_tee,
)
+if TYPE_CHECKING:
+ import torch
+
EMPTY_LIST = FrozenList()
@@ -210,17 +214,16 @@ def add_pipe(
pipe = factory
if hasattr(pipe, "name"):
if name is not None and name != pipe.name:
- raise ValueError(
+ warnings.warn(
"The provided name does not match the name of the component."
)
+ pipe.name = name
else:
name = pipe.name
- else:
- if name is None:
- raise ValueError(
- "The component does not have a name, so you must provide one",
- )
- pipe.name = name
+ if name is None:
+ raise ValueError(
+ "The component does not have a name, so you must provide one",
+ )
assert sum([before is not None, after is not None, first]) <= 1, (
"You can only use one of before, after, or first",
)
@@ -236,6 +239,18 @@ def add_pipe(
self._components.insert(insertion_idx, (name, pipe))
return pipe
+ def ensure_doc(self, doc):
+ return (
+ doc
+ if isinstance(doc, PDFDoc)
+ else PDFDoc(content=doc)
+ if isinstance(doc, bytes)
+ else PDFDoc(
+ content=doc["content"],
+ id=doc.get("id") or doc.get(FILENAME),
+ )
+ )
+
def __call__(self, doc: Any) -> PDFDoc:
"""
Apply each component successively on a document.
@@ -265,20 +280,20 @@ def __call__(self, doc: Any) -> PDFDoc:
def pipe(
self,
- inputs: Any,
+ inputs: Union[LazyCollection, Iterable],
batch_size: Optional[int] = None,
*,
- accelerator: Optional[Union[str, Accelerator]] = None,
- to_doc: Optional[ToDoc] = None,
- from_doc: FromDoc = lambda doc: doc,
- ) -> Iterable[PDFDoc]:
+ accelerator: Any = None,
+ to_doc: Any = None,
+ from_doc: Any = None,
+ ) -> LazyCollection:
"""
Process a stream of documents by applying each component successively on
batches of documents.
Parameters
----------
- inputs: Iterable[Union[str, PDFDoc]]
+ inputs: Union[LazyCollection, Iterable]
The inputs to create the PDFDocs from, or the PDFDocs directly.
batch_size: Optional[int]
The batch size to use. If not provided, the batch size of the pipeline
@@ -296,43 +311,102 @@ def pipe(
Returns
-------
- Iterable[PDFDoc]
+ LazyCollection
"""
if batch_size is None:
batch_size = self.batch_size
- if accelerator is None:
- accelerator = "simple"
- if isinstance(accelerator, str):
- accelerator = {"@accelerator": accelerator, "batch_size": batch_size}
- if isinstance(accelerator, dict):
- accelerator = Config(accelerator).resolve(registry=registry)
-
- kwargs = {
- "inputs": inputs,
- "model": self,
- "to_doc": parse_obj_as(Optional[ToDoc], to_doc),
- "from_doc": parse_obj_as(Optional[FromDoc], from_doc),
- }
- for k, v in list(kwargs.items()):
- if v is None:
- del kwargs[k]
-
- with self.train(False):
- return accelerator(**kwargs)
+ lazy_collection = LazyCollection.ensure_lazy(inputs)
+
+ if to_doc is not None:
+ warnings.warn(
+ "The `to_doc` argument is deprecated. "
+ "Please use the returned value's `map` method or the read/from_{} "
+ "method's converter argument instead.",
+ DeprecationWarning,
+ )
+ if isinstance(to_doc, str):
+ to_doc = {"content_field": to_doc}
+ if isinstance(to_doc, dict):
+ to_doc_dict = to_doc
+
+ def to_doc(doc):
+ return PDFDoc(
+ content=doc[to_doc_dict["content_field"]],
+ id=doc[to_doc_dict["id_field"]]
+ if "id_field" in to_doc_dict
+ else None,
+ )
+
+ if not callable(to_doc):
+ raise ValueError(
+ "The `to_doc` argument must be a callable or a dictionary",
+ )
+ lazy_collection = lazy_collection.map(to_doc)
+
+ lazy_collection = lazy_collection.map_pipeline(self).set_processing(
+ batch_size=batch_size
+ )
+
+ if accelerator is not None:
+ warnings.warn(
+ "The `accelerator` argument is deprecated. "
+ "Please use the returned value's `set_processing` method instead.",
+ DeprecationWarning,
+ )
+ if isinstance(accelerator, str):
+ kwargs = {}
+ backend = accelerator
+ elif isinstance(accelerator, dict):
+ kwargs = dict(accelerator)
+ backend = kwargs.pop("@accelerator", "simple")
+ elif "Accelerator" in type(accelerator).__name__:
+ backend = (
+ "multiprocessing"
+ if "Multiprocessing" in type(accelerator).__name__
+ else "simple"
+ )
+ kwargs = accelerator.__dict__
+ lazy_collection.set_processing(
+ backend=backend,
+ **kwargs,
+ )
+ if from_doc is not None:
+ warnings.warn(
+ "The `from_doc` argument is deprecated. "
+ "Please use the returned value's `map` method or the write/to_{} "
+ "method's converter argument instead.",
+ DeprecationWarning,
+ )
+ if isinstance(from_doc, dict):
+ from_doc_dict = from_doc
+
+ def from_doc(doc):
+ return {k: getattr(doc, v) for k, v in from_doc_dict.items()}
+
+ if not callable(from_doc):
+ raise ValueError(
+ "The `from_doc` argument must be a callable or a dictionary",
+ )
+ lazy_collection = lazy_collection.map(from_doc)
+
+ return lazy_collection
@contextmanager
def cache(self):
"""
Enable caching for all (trainable) components in the pipeline
"""
- was_not_cached = self._cache is None
- if was_not_cached:
- self._cache = {}
+ to_disable = set()
+ for name, pipe in self.trainable_pipes():
+ if getattr(pipe, "_current_cache_id", None) is None:
+ pipe.enable_cache()
+ to_disable.add(name)
yield
- if was_not_cached:
- self._cache = None
+ for name, pipe in self.trainable_pipes():
+ if name in to_disable:
+ pipe.disable_cache()
def trainable_pipes(
self, disable: Sequence[str] = ()
@@ -353,7 +427,7 @@ def trainable_pipes(
if name not in disable and hasattr(pipe, "batch_process"):
yield name, pipe
- def post_init(self, gold_data: Iterable[PDFDoc], exclude: Optional[set] = None):
+ def post_init(self, gold_data: Iterable[PDFDoc], exclude: Optional[Set] = None):
"""
Completes the initialization of the pipeline by calling the post_init
method of all components that have one.
@@ -365,7 +439,7 @@ def post_init(self, gold_data: Iterable[PDFDoc], exclude: Optional[set] = None):
gold_data: Iterable[PDFDoc]
The documents to use for initialization.
Each component will not necessarily see all the data.
- exclude: Optional[set]
+ exclude: Optional[Set]
The names of components to exclude from initialization.
This argument will be gradually updated with the names of initialized
components
@@ -580,7 +654,7 @@ def collate(
for name, component in self.pipeline:
if name in batch:
component_inputs = batch[name]
- batch[name] = component.collate(component_inputs, device)
+ batch[name] = component.collate(component_inputs)
return batch
def parameters(self):
@@ -600,7 +674,7 @@ def named_parameters(self):
seen.add(param)
yield f"{name}.{param_name}", param
- def to(self, device: Optional["torch.device"] = None): # noqa F821
+ def to(self, device: Union[str, Optional["torch.device"]] = None): # noqa F821
"""Moves the pipeline to a given device"""
for name, component in self.trainable_pipes():
component.to(device)
@@ -630,7 +704,7 @@ def __exit__(ctx_self, type, value, traceback):
return context()
- def save(
+ def to_disk(
self, path: Union[str, Path], *, exclude: Optional[Set[str]] = None
) -> None:
"""
@@ -648,30 +722,6 @@ def save(
process. This list will be gradually filled in place as components are
saved
"""
-
- def save_tensors(path: Path):
- import safetensors.torch
-
- shutil.rmtree(path, ignore_errors=True)
- os.makedirs(path, exist_ok=True)
- tensors = defaultdict(list)
- tensor_to_group = defaultdict(list)
- for pipe_name, pipe in self.trainable_pipes(disable=exclude):
- for key, tensor in pipe.state_dict(keep_vars=True).items():
- full_key = join_path((pipe_name, *split_path(key)))
- tensors[tensor].append(full_key)
- tensor_to_group[tensor].append(pipe_name)
- group_to_tensors = defaultdict(set)
- for tensor, group in tensor_to_group.items():
- group_to_tensors["+".join(sorted(set(group)))].add(tensor)
- for group, group_tensors in group_to_tensors.items():
- sub_path = path / f"{group}.safetensors"
- tensor_dict = {
- "+".join(tensors[p]): p
- for p in {p.data_ptr(): p for p in group_tensors}.values()
- }
- safetensors.torch.save_file(tensor_dict, sub_path)
-
exclude = set() if exclude is None else exclude
path = Path(path) if isinstance(path, str) else path
@@ -690,18 +740,29 @@ def save_tensors(path: Path):
if "config" not in exclude:
self.config.to_disk(path / "config.cfg")
- extra_exclude = set(exclude)
- for pipe_name, pipe in self._components:
- if hasattr(pipe, "save_extra_data") and pipe_name not in extra_exclude:
- pipe.save_extra_data(path / pipe_name, exclude=extra_exclude)
- if "tensors" not in exclude:
- save_tensors(path / "tensors")
+ pwd = os.getcwd()
+ overrides = {"components": {}}
+ try:
+ os.chdir(path)
+ for pipe_name, pipe in self._components:
+ if hasattr(pipe, "to_disk") and pipe_name not in exclude:
+ pipe_overrides = pipe.to_disk(Path(pipe_name), exclude=exclude)
+ overrides["components"][pipe_name] = pipe_overrides
+ finally:
+ os.chdir(pwd)
+
+ config = self.config.merge(overrides)
+
+ if "config" not in exclude:
+ config.to_disk(path / "config.cfg")
+
+ save = to_disk
- def load_state_from_disk(
+ def from_disk(
self,
path: Union[str, Path],
*,
- exclude: Set[str] = None,
+ exclude: Optional[Union[str, Sequence[str]]] = None,
device: Optional[Union[str, "torch.device"]] = "cpu", # noqa F821
) -> "Pipeline":
"""
@@ -711,77 +772,47 @@ def load_state_from_disk(
----------
path: Union[str, Path]
The path to the directory to load the pipeline from
- exclude: Set[str]
+ exclude: Optional[Union[str, Sequence[str]]]
The names of the components, or attributes to exclude from the loading
- process. This list will be gradually filled in place as components are
- loaded
+ process.
+ device: Optional[Union[str, "torch.device"]]
+ Device to use when loading the tensors
"""
def deserialize_meta(path: Path) -> None:
if path.exists():
- data = json.loads(path.read_text())
+ with open(path, "r") as f:
+ data = json.load(f)
self.meta.update(data)
+ # self.meta always overrides meta["vectors"] with the metadata
+ # from self.vocab.vectors, so set the name directly
+
+ exclude = (
+ set()
+ if exclude is None
+ else {exclude}
+ if isinstance(exclude, str)
+ else set(exclude)
+ )
- def deserialize_tensors(path: Path):
- import safetensors.torch
-
- trainable_components = dict(self.trainable_pipes())
- for file_name in path.iterdir():
- pipe_names = file_name.stem.split("+")
- if any(pipe_name in trainable_components for pipe_name in pipe_names):
- # We only load tensors in one of the pipes since parameters
- # are expected to be shared
- pipe = trainable_components[pipe_names[0]]
- tensor_dict = {}
- for keys, tensor in safetensors.torch.load_file(
- file_name, device=device
- ).items():
- split_keys = [split_path(key) for key in keys.split("+")]
- key = next(key for key in split_keys if key[0] == pipe_names[0])
- tensor_dict[join_path(key[1:])] = tensor
- # Non-strict because tensors of a given pipeline can be shared
- # between multiple files
- print(f"Loading tensors of {pipe_names[0]} from {file_name}")
- extra_tensors = set(tensor_dict) - set(
- pipe.state_dict(keep_vars=True).keys()
- )
- if extra_tensors:
- warnings.warn(
- f"{file_name} contains tensors that are not in the state"
- f"dict of {pipe_names[0]}: {sorted(extra_tensors)}"
- )
- pipe.load_state_dict(tensor_dict, strict=False)
-
- exclude = set() if exclude is None else exclude
-
+ path = (Path(path) if isinstance(path, str) else path).absolute()
if "meta" not in exclude:
deserialize_meta(path / "meta.json")
- extra_exclude = set(exclude)
- for name, proc in self._components:
- if hasattr(proc, "load_extra_data") and name not in extra_exclude:
- proc.load_extra_data(path / name, extra_exclude)
-
- if "tensors" not in exclude:
- deserialize_tensors(path / "tensors")
+ pwd = os.getcwd()
+ try:
+ os.chdir(path)
+ for name, proc in self._components:
+ if hasattr(proc, "from_disk") and name not in exclude:
+ proc.from_disk(Path(name), exclude=exclude)
+ # Convert to list here in case exclude is (default) tuple
+ exclude.add(name)
+ finally:
+ os.chdir(pwd)
self._path = path # type: ignore[assignment]
return self
- @classmethod
- def load(
- cls,
- path: Union[str, Path],
- *,
- exclude: Optional[Set[str]] = None,
- device: Optional[Union[str, "torch.device"]] = "cpu", # noqa F821
- ):
- path = Path(path) if isinstance(path, str) else path
- config = Config.from_disk(path / "config.cfg")
- self = Pipeline.from_config(config)
- self.load_state_from_disk(path, exclude=exclude, device=device)
- return self
-
# override config property getter to remove "factory" key from components
@property
def cfg(self) -> Config:
@@ -830,12 +861,10 @@ def __exit__(ctx_self, type, value, traceback):
if enable is None and disable is None:
raise ValueError("Expected either `enable` or `disable`")
- if isinstance(disable, str):
- disable = [disable]
+ disable = [disable] if isinstance(disable, str) else disable
pipe_names = set(self.pipe_names)
if enable is not None:
- if isinstance(enable, str):
- enable = [enable]
+ enable = [enable] if isinstance(enable, str) else enable
if set(enable) - pipe_names:
raise ValueError(
"Enabled pipes {} not found in pipeline.".format(
@@ -863,24 +892,26 @@ def package(
self,
name: Optional[str] = None,
root_dir: Union[str, Path] = ".",
+ build_dir: Union[str, Path] = "build",
+ dist_dir: Union[str, Path] = "dist",
artifacts_name: str = "artifacts",
- check_dependencies: bool = False,
project_type: Optional[Literal["poetry", "setuptools"]] = None,
- version: str = "0.1.0",
+ version: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = {},
distributions: Optional[Sequence[Literal["wheel", "sdist"]]] = ["wheel"],
config_settings: Optional[Mapping[str, Union[str, Sequence[str]]]] = None,
isolation: bool = True,
skip_build_dependency_check: bool = False,
):
- from .utils.package import package
+ from edspdf.utils.package import package
return package(
pipeline=self,
name=name,
root_dir=root_dir,
+ build_dir=build_dir,
+ dist_dir=dist_dir,
artifacts_name=artifacts_name,
- check_dependencies=check_dependencies,
project_type=project_type,
version=version,
metadata=metadata,
@@ -892,19 +923,99 @@ def package(
def load(
- config: Union[Path, str, Config],
- device: Optional[Union[str, "torch.device"]] = "cpu", # noqa F821
-) -> Pipeline:
- error = "The load function expects a Config or a path to a config file"
- if isinstance(config, (Path, str)):
- path = Path(config)
- if path.is_dir():
- return Pipeline.load(path, device=device)
- elif path.is_file():
- config = Config.from_disk(path)
- else:
- raise ValueError(error)
- elif not isinstance(config, Config):
+ model: Union[Path, str, Config],
+ overrides: Optional[Dict[str, Any]] = None,
+ *,
+ exclude: Optional[Union[str, Iterable[str]]] = None,
+ device: Optional[Union[str, "torch.device"]] = "cpu",
+):
+ """
+ Load a pipeline from a config file or a directory.
+
+ Examples
+ --------
+
+ ```{ .python .no-check }
+ import edspdf
+
+ nlp = edspdf.load(
+ "path/to/config.cfg",
+ overrides={"components": {"my_component": {"arg": "value"}}},
+ )
+ ```
+
+ Parameters
+ ----------
+ model: Union[Path, str, Config]
+ The config to use for the pipeline, or the path to a config file or a directory.
+ overrides: Optional[Dict[str, Any]]
+ Overrides to apply to the config when loading the pipeline. These are the
+ same parameters as the ones used when initializing the pipeline.
+ exclude: Optional[Union[str, Iterable[str]]]
+ The names of the components, or attributes to exclude from the loading
+ process. :warning: The `exclude` argument will be mutated in place.
+ device: Optional[Union[str, "torch.device"]]
+ Device to use when loading the tensors
+
+ Returns
+ -------
+ Pipeline
+ """
+ error = (
+ "The load function expects either :\n"
+ "- a confit Config object\n"
+ "- the path of a config file (.cfg file)\n"
+ "- the path of a trained model\n"
+ "- the name of an installed pipeline package\n"
+ f"but got {model!r} which is neither"
+ )
+ if isinstance(model, (Path, str)):
+ path = Path(model)
+ is_dir = path.is_dir()
+ is_config = path.is_file() and path.suffix == ".cfg"
+ try:
+ module = importlib.import_module(model)
+ is_package = True
+ except (ImportError, AttributeError, TypeError):
+ module = None
+ is_package = False
+ if is_dir and is_package:
+ warnings.warn(
+ "The path provided is both a directory and a package : edspdf will "
+ "load the package. To load from the directory instead, please pass the "
+ f'path as "./{path}" instead.'
+ )
+ if is_dir:
+ path = (Path(path) if isinstance(path, str) else path).absolute()
+ config = Config.from_disk(path / "config.cfg")
+ if overrides:
+ config = config.merge(overrides)
+ pwd = os.getcwd()
+ try:
+ os.chdir(path)
+ nlp = Pipeline.from_config(config)
+ nlp.from_disk(path, exclude=exclude, device=device)
+ finally:
+ os.chdir(pwd)
+ return nlp
+ elif is_config:
+ model = Config.from_disk(path)
+ elif is_package:
+ # Load as package
+ available_kwargs = {
+ "overrides": overrides,
+ "exclude": exclude,
+ "device": device,
+ }
+ signature_kwargs = inspect.signature(module.load).parameters
+ kwargs = {
+ name: available_kwargs[name]
+ for name in signature_kwargs
+ if name in available_kwargs
+ }
+ return module.load(**kwargs)
+
+ if not isinstance(model, Config):
raise ValueError(error)
- return Pipeline.from_config(config)
+ return Pipeline.from_config(model)
diff --git a/edspdf/pipes/aggregators/simple.py b/edspdf/pipes/aggregators/simple.py
index 2e46ae57..f5682aac 100644
--- a/edspdf/pipes/aggregators/simple.py
+++ b/edspdf/pipes/aggregators/simple.py
@@ -1,5 +1,5 @@
-from itertools import groupby
-from typing import Dict, List
+from collections import defaultdict
+from typing import Dict, List, Union
import numpy as np
@@ -45,8 +45,13 @@ class SimpleAggregator:
@factory = "simple-aggregator"
new_line_threshold = 0.2
new_paragraph_threshold = 1.5
- label_map = { body = "text", table = "text" }
-
+ # To build the "text" label, we will aggregate lines from
+ # "title", "body" and "table" and output "title" lines in a
+ # separate field "title" as well.
+ label_map = {
+ "text" : [ "title", "body", "table" ],
+ "title" : "title",
+ }
...
```
@@ -77,8 +82,9 @@ class SimpleAggregator:
lines to consider them as being on separate paragraphs and thus add a
newline character between them.
label_map: Dict
- A dictionary mapping labels to new labels. This is useful to group labels
- together, for instance, to output both "body" and "table" as "text".
+ A dictionary mapping from new labels to old labels.
+ This is useful to group labels together, for instance, to output both "body"
+ and "table" as "text".
"""
def __init__(
@@ -88,11 +94,14 @@ def __init__(
sort: bool = False,
new_line_threshold: float = 0.2,
new_paragraph_threshold: float = 1.5,
- label_map: Dict = {},
+ label_map: Dict[str, Union[str, List[str]]] = {},
) -> None:
self.name = name
self.sort = sort
- self.label_map = dict(label_map)
+ self.label_map = {
+ label: [old_labels] if not isinstance(old_labels, list) else old_labels
+ for label, old_labels in label_map.items()
+ }
self.new_line_threshold = new_line_threshold
self.new_paragraph_threshold = new_paragraph_threshold
@@ -107,12 +116,22 @@ def __call__(self, doc: PDFDoc) -> PDFDoc:
all_lines,
key=lambda b: (b.label, b.page_num, b.y1 // row_height, b.x0),
)
- else:
- all_lines = sorted(all_lines, key=lambda b: b.label)
texts = {}
styles = {}
- for label, lines in groupby(all_lines, key=lambda b: b.label):
+
+ inv_label_map = defaultdict(list)
+ for new_label, old_labels in self.label_map.items():
+ for old_label in old_labels:
+ inv_label_map[old_label].append(new_label)
+
+ lines_per_label = defaultdict(list)
+ lines_per_label.update({k: [] for k in self.label_map})
+ for line in all_lines:
+ for new_label in inv_label_map.get(line.label, [line.label]):
+ lines_per_label[new_label].append(line)
+
+ for label, lines in lines_per_label.items():
styles[label] = []
text = ""
lines: List[TextBox] = list(lines)
diff --git a/edspdf/pipes/classifiers/trainable.py b/edspdf/pipes/classifiers/trainable.py
index f879af48..1f71e79a 100644
--- a/edspdf/pipes/classifiers/trainable.py
+++ b/edspdf/pipes/classifiers/trainable.py
@@ -155,9 +155,9 @@ def preprocess_supervised(self, doc: PDFDoc) -> Dict[str, Any]:
],
}
- def collate(self, batch, device: torch.device) -> Dict:
+ def collate(self, batch) -> Dict:
collated = {
- "embedding": self.embedding.collate(batch["embedding"], device),
+ "embedding": self.embedding.collate(batch["embedding"]),
"doc_id": batch["doc_id"],
}
if "labels" in batch:
@@ -167,7 +167,6 @@ def collate(self, batch, device: torch.device) -> Dict:
batch["labels"],
data_dims=("line",),
full_names=("sample", "page", "line"),
- device=device,
dtype=torch.long,
),
}
@@ -182,7 +181,9 @@ def forward(self, batch: Dict) -> Dict:
output = {"loss": 0, "mask": embeddings.mask}
# Label prediction / learning
- logits = self.classifier(embeddings).refold("line")
+ logits = self.classifier(embeddings.to(self.classifier.weight.dtype)).refold(
+ "line"
+ )
if "labels" in batch:
targets = batch["labels"].refold(logits.data_dims)
output["label_loss"] = (
@@ -208,28 +209,28 @@ def postprocess(self, docs: Sequence[PDFDoc], output: Dict) -> Sequence[PDFDoc]:
b.label = self.label_voc.decode(label) if b.text != "" else None
return docs
- def save_extra_data(self, path: Path, exclude: Set):
+ def to_disk(self, path: Path, exclude: Set):
if self.name in exclude:
return
exclude.add(self.name)
- self.embedding.save_extra_data(path / "embedding", exclude)
-
os.makedirs(path, exist_ok=True)
with (path / "label_voc.json").open("w") as f:
json.dump(self.label_voc.indices, f)
- def load_extra_data(self, path: Path, exclude: Set):
+ return super().to_disk(path, exclude)
+
+ def from_disk(self, path: Path, exclude: Set):
if self.name in exclude:
return
exclude.add(self.name)
- self.embedding.load_extra_data(path / "embedding", exclude)
-
label_voc_indices = dict(self.label_voc.indices)
with (path / "label_voc.json").open("r") as f:
self.label_voc.indices = json.load(f)
self.update_weights_from_vocab_(label_voc_indices)
+
+ super().from_disk(path, exclude)
diff --git a/edspdf/pipes/embeddings/box_layout_embedding.py b/edspdf/pipes/embeddings/box_layout_embedding.py
index 0bd8c29f..e27661a4 100644
--- a/edspdf/pipes/embeddings/box_layout_embedding.py
+++ b/edspdf/pipes/embeddings/box_layout_embedding.py
@@ -9,7 +9,7 @@
BoxLayoutPreprocessor,
)
from edspdf.registry import registry
-from edspdf.trainable_pipe import NestedSequences, TrainablePipe
+from edspdf.trainable_pipe import TrainablePipe
@registry.factory.register("box-layout-embedding")
@@ -71,8 +71,8 @@ def __init__(
def preprocess(self, doc):
return self.box_preprocessor.preprocess(doc)
- def collate(self, batch: NestedSequences, device: torch.device) -> BoxLayoutBatch:
- return self.box_preprocessor.collate(batch, device)
+ def collate(self, batch) -> BoxLayoutBatch:
+ return self.box_preprocessor.collate(batch)
@classmethod
def _make_embed(cls, n_positions, size, mode):
diff --git a/edspdf/pipes/embeddings/box_layout_preprocessor.py b/edspdf/pipes/embeddings/box_layout_preprocessor.py
index ae9715f3..10559f6f 100644
--- a/edspdf/pipes/embeddings/box_layout_preprocessor.py
+++ b/edspdf/pipes/embeddings/box_layout_preprocessor.py
@@ -74,11 +74,10 @@ def preprocess(self, doc: PDFDoc, supervision: bool = False):
"last_page": [[b.page_num == last_p for b in p.text_boxes] for p in pages],
}
- def collate(self, batch, device: torch.device) -> BoxLayoutBatch:
+ def collate(self, batch) -> BoxLayoutBatch:
kw = {
"full_names": ["sample", "page", "line"],
"data_dims": ["line"],
- "device": device,
}
return {
diff --git a/edspdf/pipes/embeddings/huggingface_embedding.py b/edspdf/pipes/embeddings/huggingface_embedding.py
index 16ccf076..58a3e3ad 100644
--- a/edspdf/pipes/embeddings/huggingface_embedding.py
+++ b/edspdf/pipes/embeddings/huggingface_embedding.py
@@ -1,8 +1,12 @@
import math
+import sys
+from typing import Optional, Set
import torch
+from confit import validate_arguments
from foldedtensor import as_folded_tensor
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig_
from typing_extensions import Literal
from edspdf import TrainablePipe, registry
@@ -10,6 +14,8 @@
from edspdf.pipes.embeddings import EmbeddingOutput
from edspdf.structures import PDFDoc
+BitsAndBytesConfig = validate_arguments(BitsAndBytesConfig_)
+
def compute_contextualization_scores(windows):
ramp = torch.arange(0, windows.shape[1], 1)
@@ -108,6 +114,11 @@ class HuggingfaceEmbedding(TrainablePipe[EmbeddingOutput]):
The maximum number of tokens that can be processed by the model on a single
device. This does not affect the results but can be used to reduce the memory
usage of the model, at the cost of a longer processing time.
+ quantization_config: Optional[BitsAndBytesConfig]
+ The quantization configuration to use when loading the model
+ kwargs:
+ Additional keyword arguments to pass to the Huggingface
+ `AutoModel.from_pretrained` method
"""
def __init__(
@@ -119,7 +130,9 @@ def __init__(
window: int = 510,
stride: int = 255,
line_pooling: Literal["mean", "max", "sum"] = "mean",
- max_tokens_per_device: int = 128 * 128,
+ max_tokens_per_device: int = sys.maxsize,
+ quantization_config: Optional[BitsAndBytesConfig] = None,
+ **kwargs,
):
super().__init__(pipeline, name)
self.use_image = use_image
@@ -129,7 +142,11 @@ def __init__(
else None
)
self.tokenizer = AutoTokenizer.from_pretrained(model)
- self.hf_model = AutoModel.from_pretrained(model)
+ self.hf_model = AutoModel.from_pretrained(
+ model,
+ quantization_config=quantization_config,
+ **kwargs,
+ )
self.output_size = self.hf_model.config.hidden_size
self.window = window
self.stride = stride
@@ -148,14 +165,21 @@ def preprocess(self, doc: PDFDoc):
for page in doc.pages:
# Preprocess it using LayoutLMv3
+ width = page.width
+ height = page.height
+
+ ratio = width / height
+ width, height = (1000, 1000 / ratio) if width > 1000 else (width, height)
+ width, height = (1000 * ratio, 1000) if height > 1000 else (width, height)
+
prep = self.tokenizer(
text=[line.text for line in page.text_boxes],
boxes=[
(
- int(line.x0 * line.page.width),
- int(line.y0 * line.page.height),
- int(line.x1 * line.page.width),
- int(line.y1 * line.page.height),
+ int(line.x0 * width),
+ int(line.y0 * height),
+ int(line.x1 * width),
+ int(line.y1 * height),
)
for line in page.text_boxes
],
@@ -181,7 +205,7 @@ def preprocess(self, doc: PDFDoc):
return res
- def collate(self, batch, device):
+ def collate(self, batch):
# Flatten most of these arrays to process batches page per page and
# not sample per sample
@@ -214,7 +238,9 @@ def collate(self, batch, device):
data_dims=("window", "token"),
dtype=torch.long,
)
- indexer = torch.zeros(windows.max() + 1, dtype=torch.long)
+ indexer = torch.zeros(
+ (windows.max() + 1) if windows.numel() else 0, dtype=torch.long
+ )
# Sort each occurrence of an initial token by its contextualization score:
# We can only use the amax reduction, so to retrieve the best occurrence, we
@@ -262,7 +288,7 @@ def collate(self, batch, device):
data_dims=("token",),
dtype=torch.long,
)
- last_after_one = max(1, len(line_window_offsets_flat) - 1)
+ last_after_one = max(0, len(line_window_offsets_flat) - 1)
line_window_offsets_flat = as_folded_tensor(
# discard the last offset, since we start from 0 and add each line length
data=torch.as_tensor(line_window_offsets_flat[:last_after_one]),
@@ -270,42 +296,48 @@ def collate(self, batch, device):
full_names=("sample", "page", "line"),
lengths=line_window_indices.lengths[:-1],
)
-
kw = dict(
full_names=("sample", "page", "subword"),
data_dims=("subword",),
- device=device,
)
collated = {
"input_ids": as_folded_tensor(batch["input_ids"], **kw, dtype=torch.long),
"bbox": as_folded_tensor(batch["bbox"], **kw, dtype=torch.long),
- "windows": windows.to(device),
- "indexer": indexer[line_window_indices].to(device),
- "line_window_indices": indexer[line_window_indices].as_tensor().to(device),
- "line_window_offsets_flat": line_window_offsets_flat.to(device),
+ "windows": windows,
+ "indexer": indexer[line_window_indices],
+ "line_window_indices": indexer[line_window_indices].as_tensor(),
+ "line_window_offsets_flat": line_window_offsets_flat,
}
if self.use_image:
- collated["pixel_values"] = (
- torch.stack(
- [
- torch.from_numpy(page_pixels)
- for sample_pages in batch["pixel_values"]
- for page_pixels in sample_pages
- ],
- dim=0,
- )
- .repeat_interleave(torch.as_tensor(windows_count_per_page), dim=0)
- .to(device)
+ collated["pixel_values"] = torch.as_tensor(
+ [
+ page_pixels
+ for sample_pages in batch["pixel_values"]
+ for page_pixels in sample_pages
+ ],
+ ).repeat_interleave(
+ torch.as_tensor(windows_count_per_page, dtype=torch.long), dim=0
)
return collated
def forward(self, batch):
+ if 0 in batch["input_ids"].shape:
+ return {
+ "embeddings": batch["line_window_offsets_flat"].view(
+ *batch["line_window_offsets_flat"].shape, self.output_size
+ ),
+ }
+
windows = batch["windows"]
kwargs = dict(
input_ids=batch["input_ids"].as_tensor()[windows],
bbox=batch["bbox"].as_tensor()[windows],
attention_mask=windows.mask,
- pixel_values=batch.get("pixel_values"),
+ pixel_values=(
+ batch.get("pixel_values").to(next(self.parameters()).dtype)
+ if self.use_image
+ else None
+ ),
)
num_windows_per_batch = self.max_tokens_per_device // (
windows.shape[1]
@@ -340,3 +372,16 @@ def forward(self, batch):
mode=self.line_pooling,
)
return {"embeddings": line_embedding}
+
+ def to_disk(self, path, *, exclude: Optional[Set[str]]):
+ repr_id = object.__repr__(self)
+ if repr_id in exclude:
+ return
+ for obj in (self.tokenizer, self.image_processor, self.hf_model):
+ if obj is not None:
+ obj.save_pretrained(path)
+ for param in self.hf_model.parameters():
+ exclude.add(object.__repr__(param))
+ cfg = super().to_disk(path, exclude=exclude) or {}
+ cfg["model"] = f"./{path.as_posix()}"
+ return cfg
diff --git a/edspdf/pipes/embeddings/simple_text_embedding.py b/edspdf/pipes/embeddings/simple_text_embedding.py
index cc2d8691..f7674102 100644
--- a/edspdf/pipes/embeddings/simple_text_embedding.py
+++ b/edspdf/pipes/embeddings/simple_text_embedding.py
@@ -154,7 +154,7 @@ def post_init(self, gold_data, exclude: set):
self.update_weights_from_vocab_(vocab_items_before)
- def save_extra_data(self, path: Path, exclude: Set):
+ def to_disk(self, path: Path, exclude: Set):
if self.name in exclude:
return
@@ -169,7 +169,9 @@ def save_extra_data(self, path: Path, exclude: Set):
with (path / "norm_voc.json").open("w") as f:
json.dump(self.norm_voc.indices, f)
- def load_extra_data(self, path: Path, exclude: Set):
+ return super().to_disk(path, exclude)
+
+ def from_disk(self, path: Path, exclude: Set):
if self.name in exclude:
return
@@ -191,6 +193,8 @@ def load_extra_data(self, path: Path, exclude: Set):
self.update_weights_from_vocab_(vocab_items_before)
+ super().from_disk(path, exclude)
+
def preprocess(self, doc: PDFDoc):
tokens_shape = []
tokens_prefix = []
@@ -228,10 +232,9 @@ def preprocess(self, doc: PDFDoc):
"tokens_norm": tokens_norm,
}
- def collate(self, batch, device: torch.device) -> BoxTextEmbeddingInputBatch:
+ def collate(self, batch) -> BoxTextEmbeddingInputBatch:
kwargs = dict(
dtype=torch.long,
- device=device,
data_dims=("word",),
full_names=(
"sample",
diff --git a/edspdf/pipes/embeddings/sub_box_cnn_pooler.py b/edspdf/pipes/embeddings/sub_box_cnn_pooler.py
index a917281e..03edbbb9 100644
--- a/edspdf/pipes/embeddings/sub_box_cnn_pooler.py
+++ b/edspdf/pipes/embeddings/sub_box_cnn_pooler.py
@@ -74,6 +74,15 @@ def forward(self, batch: Any) -> EmbeddingOutput:
embeddings = self.embedding.module_forward(batch["embedding"])[
"embeddings"
].refold("line", "word")
+ if 0 in embeddings.shape:
+ return {
+ "embeddings": as_folded_tensor(
+ data=torch.zeros(0, self.output_size, device=embeddings.device),
+ lengths=embeddings.lengths[:-1], # pooled on the last dim
+ data_dims=["line"], # fully flattened
+ full_names=["sample", "page", "line"],
+ )
+ }
# sample word dim -> sample dim word
box_token_embeddings = embeddings.as_tensor().permute(0, 2, 1)
diff --git a/edspdf/pipes/extractors/pdfminer.py b/edspdf/pipes/extractors/pdfminer.py
index 3e645398..b822adf7 100644
--- a/edspdf/pipes/extractors/pdfminer.py
+++ b/edspdf/pipes/extractors/pdfminer.py
@@ -185,12 +185,10 @@ def __call__(self, doc: Union[PDFDoc, bytes]) -> PDFDoc:
if self.render_pages:
# See https://pypdfium2.readthedocs.io/en/stable/python_api.html#user-unit
- images = pypdfium2.PdfDocument(content).render_topil(
- scale=self.render_dpi / 72
- )
- for page, image in zip(pages, images):
+ pdfium_doc = pypdfium2.PdfDocument(content)
+ for page, pdfium_page in zip(pages, pdfium_doc):
+ image = pdfium_page.render(scale=self.render_dpi / 72).to_pil()
np_img = np.array(image)
- print("NP IMG", np_img.shape)
page.image = np_img
return doc
diff --git a/edspdf/processing/__init__.py b/edspdf/processing/__init__.py
new file mode 100644
index 00000000..65a80d6e
--- /dev/null
+++ b/edspdf/processing/__init__.py
@@ -0,0 +1,9 @@
+from typing import TYPE_CHECKING
+
+from edspdf.utils.lazy_module import lazify
+
+lazify()
+
+if TYPE_CHECKING:
+ from .simple import execute_simple_backend
+ from .multiprocessing import execute_multiprocessing_backend
diff --git a/edspdf/processing/multiprocessing.py b/edspdf/processing/multiprocessing.py
new file mode 100644
index 00000000..0fbbe129
--- /dev/null
+++ b/edspdf/processing/multiprocessing.py
@@ -0,0 +1,933 @@
+from __future__ import annotations
+
+import copyreg
+import gc
+import io
+import logging
+import multiprocessing
+import multiprocessing.reduction
+import os
+import sys
+import tempfile
+import warnings
+from contextlib import nullcontext
+from multiprocessing.connection import wait
+from random import shuffle
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
+
+import dill
+from typing_extensions import TypedDict
+
+from edspdf.lazy_collection import LazyCollection
+from edspdf.utils.collections import batchify, flatten
+
+batch_size_fns = {
+ "content_boxes": lambda batch: sum(len(doc.content_boxes) for doc in batch),
+ "pages": lambda batch: sum(len(doc.pages) for doc in batch),
+ "docs": len,
+}
+
+doc_size_fns = {
+ "content_boxes": lambda doc: len(doc.content_boxes),
+}
+
+if TYPE_CHECKING:
+ import torch
+
+ from edspdf.trainable_pipe import TrainablePipe
+
+Stage = TypedDict(
+ "Stage",
+ {
+ "cpu_components": List[Tuple[str, Callable, Dict]],
+ "gpu_component": Optional[Any],
+ },
+)
+
+
+def apply_basic_pipes(docs, pipes):
+ for name, pipe, kwargs in pipes:
+ if hasattr(pipe, "batch_process"):
+ docs = pipe.batch_process(docs)
+ else:
+ docs = [pipe(doc, **kwargs) for doc in docs]
+ return docs
+
+
+class ForkingPickler(dill.Pickler):
+ """
+ ForkingPickler that uses dill instead of pickle to transfer objects between
+ processes.
+ """
+
+ _extra_reducers = {}
+ _copyreg_dispatch_table = copyreg.dispatch_table
+
+ def __new__(cls, *args, **kwargs):
+ result = dill.Pickler.__new__(ForkingPickler)
+ # Python would not call __init__ if the original
+ # multiprocessing.reduction.ForkingPickler called, leading to a call to this
+ # monkey-patched __new__ method, because [original cls] != [type of result]
+ # (see https://docs.python.org/3/reference/datamodel.html#basic-customization)
+ # so we force the call to __init__ here
+ if not isinstance(result, cls):
+ result.__init__(*args, **kwargs)
+ return result
+
+ def __init__(self, *args, **kwds):
+ super().__init__(*args, **kwds)
+ self.dispatch_table = self._copyreg_dispatch_table.copy()
+ self.dispatch_table.update(self._extra_reducers)
+
+ @classmethod
+ def register(cls, type, reduce):
+ """Register a reduce function for a type."""
+ cls._extra_reducers[type] = reduce
+
+ @classmethod
+ def dumps(cls, obj, protocol=None, *args, **kwds):
+ buf = io.BytesIO()
+ cls(buf, protocol, *args, **kwds).dump(obj)
+ return buf.getbuffer()
+
+ loads = dill.loads
+
+
+def replace_pickler():
+ """
+ Replace the default pickler used by multiprocessing with dill.
+ "multiprocess" didn't work for obscure reasons (maybe the reducers / dispatchers
+ are not propagated between multiprocessing and multiprocess => torch specific
+ reducers might be missing ?), so this patches multiprocessing directly.
+ directly.
+
+ For some reason I do not explain, this has a massive impact on the performance of
+ the multiprocessing backend. With the original pickler, the performance can be
+ up to 2x slower than with our custom one.
+ """
+ old_pickler = multiprocessing.reduction.ForkingPickler
+
+ before = (
+ dict(ForkingPickler._extra_reducers),
+ old_pickler.__new__,
+ old_pickler.dumps,
+ old_pickler.loads,
+ old_pickler.register,
+ )
+
+ old_pickler.__new__ = ForkingPickler.__new__
+ old_pickler.dumps = ForkingPickler.dumps
+ old_pickler.loads = ForkingPickler.loads
+ old_pickler.register = ForkingPickler.register
+ ForkingPickler._extra_reducers.update(
+ multiprocessing.reduction.ForkingPickler._extra_reducers
+ )
+
+ def revert():
+ (
+ ForkingPickler._extra_reducers,
+ old_pickler.__new__,
+ old_pickler.dumps,
+ old_pickler.loads,
+ old_pickler.register,
+ ) = before
+
+ return revert
+
+
+# Should we check if the multiprocessing module of edspdf
+# is responsible for this child process before replacing the pickler ?
+if (
+ multiprocessing.current_process() != "MainProcess"
+ or hasattr(multiprocessing, "parent_process")
+ and multiprocessing.parent_process() is not None
+):
+ replace_pickler()
+
+DEBUG = True
+
+debug = (
+ (lambda *args, flush=False, **kwargs: print(*args, **kwargs, flush=True))
+ if DEBUG
+ else lambda *args, **kwargs: None
+)
+
+try: # pragma: no cover
+ import torch
+
+ # Torch may still be imported as a namespace package, so we can access the
+ # torch.save and torch.load functions
+ torch_save = torch.save
+ torch_load = torch.load
+
+ MAP_LOCATION = None
+
+ try:
+ from accelerate.hooks import AlignDevicesHook
+
+ # We need to replace the "execution_device" attribute of the AlignDevicesHook
+ # using map_location when unpickling the lazy collection
+
+ def save_align_devices_hook(pickler: Any, obj: Any):
+ pickler.save_reduce(load_align_devices_hook, (obj.__dict__,), obj=obj)
+
+ def load_align_devices_hook(state):
+ state["execution_device"] = MAP_LOCATION
+ new_obj = AlignDevicesHook.__new__(AlignDevicesHook)
+ new_obj.__dict__.update(state)
+ return new_obj
+
+ except ImportError:
+ AlignDevicesHook = None
+
+ def dump(*args, **kwargs):
+ # We need to replace the "execution_device" attribute of the AlignDevicesHook
+ # using map_location when pickling the lazy collection
+ old = None
+ try:
+ if AlignDevicesHook is not None:
+ old = dill.Pickler.dispatch.get(AlignDevicesHook)
+ dill.Pickler.dispatch[AlignDevicesHook] = save_align_devices_hook
+ dill.settings["recurse"] = True
+ return torch_save(*args, pickle_module=dill, **kwargs)
+ finally:
+ dill.settings["recurse"] = False
+ if AlignDevicesHook is not None:
+ del dill.Pickler.dispatch[AlignDevicesHook]
+ if old is not None: # pragma: no cover
+ dill.Pickler.dispatch[AlignDevicesHook] = old
+
+ def load(*args, map_location=None, **kwargs):
+ global MAP_LOCATION
+ MAP_LOCATION = map_location
+ if torch.__version__ >= "2.1" and isinstance(args[0], str):
+ kwargs["mmap"] = True
+ # with open(args[0], "rb") as f:
+ # result = dill.load(f, **kwargs)
+ try:
+ if torch.__version__ < "2.0.0":
+ pickle = torch_load.__globals__["pickle"]
+ torch_load.__globals__["pickle"] = dill
+ result = torch_load(
+ *args,
+ pickle_module=dill,
+ map_location=map_location,
+ **kwargs,
+ )
+ finally:
+ import pickle
+
+ torch_load.__globals__["pickle"] = pickle
+ MAP_LOCATION = None
+ return result
+
+except (ImportError, AttributeError): # pragma: no cover
+
+ def load(file, *args, map_location=None, **kwargs):
+ # check if path
+ if isinstance(file, str):
+ with open(file, "rb") as f:
+ return dill.load(f, *args, **kwargs)
+ return dill.load(file, *args, **kwargs)
+
+ dump = dill.dump
+
+
+class Exchanger:
+ def __init__(
+ self,
+ mp: multiprocessing.context.BaseContext,
+ num_stages: int,
+ num_gpu_workers: int,
+ num_cpu_workers: int,
+ gpu_worker_devices: List[Any],
+ ):
+ # queue for cpu input tasks
+ self.gpu_worker_devices = gpu_worker_devices
+ self.num_cpu_workers = num_cpu_workers
+ self.num_gpu_workers = num_gpu_workers
+ # We add prioritized queue at the end for STOP signals
+ self.cpu_inputs_queues = [
+ [mp.Queue()] + [mp.SimpleQueue() for _ in range(num_stages + 1)]
+ # The input queue is not shared between processes, since calling `wait`
+ # on a queue reader from multiple processes may lead to a deadlock
+ for _ in range(num_cpu_workers)
+ ]
+ self.gpu_inputs_queues = [
+ [mp.Queue() for _ in range(num_stages + 1)] for _ in range(num_gpu_workers)
+ ]
+ self.outputs_queue = mp.Queue()
+ self.num_stages = num_stages
+
+ # noinspection PyUnresolvedReferences
+ def get_cpu_task(self, idx):
+ queue_readers = wait([queue._reader for queue in self.cpu_inputs_queues[idx]])
+ stage, queue = next(
+ (stage, q)
+ for stage, q in reversed(list(enumerate(self.cpu_inputs_queues[idx])))
+ if q._reader in queue_readers
+ )
+ item = queue.get()
+ return stage, item
+
+ def put_cpu(self, item, stage, idx):
+ return self.cpu_inputs_queues[idx][stage].put(item)
+
+ def get_gpu_task(self, idx):
+ queue_readers = wait([queue._reader for queue in self.gpu_inputs_queues[idx]])
+ stage, queue = next(
+ (stage, q)
+ for stage, q in reversed(list(enumerate(self.gpu_inputs_queues[idx])))
+ if q._reader in queue_readers
+ )
+ item = queue.get()
+ return stage, item
+
+ def put_gpu(self, item, stage, idx):
+ return self.gpu_inputs_queues[idx][stage].put(item)
+
+ def put_results(self, items):
+ self.outputs_queue.put(items)
+
+ def iter_results(self):
+ for out in iter(self.outputs_queue.get, None):
+ yield out
+
+
+class CPUWorker:
+ def __init__(
+ self,
+ cpu_idx: int,
+ exchanger: Exchanger,
+ gpu_pipe_names: List[str],
+ lazy_collection_path: str,
+ device: Union[str, "torch.device"],
+ ):
+ super(CPUWorker, self).__init__()
+
+ self.cpu_idx = cpu_idx
+ self.exchanger = exchanger
+ self.gpu_pipe_names = gpu_pipe_names
+ self.lazy_collection_path = lazy_collection_path
+ self.device = device
+
+ def run(self):
+ # Cannot pass torch tensor during init i think ? otherwise i get
+ # ValueError: bad value(s) in fds_to_keep
+ # mp._prctl_pr_set_pdeathsig(signal.SIGINT)
+
+ def read_tasks():
+ next_batch_id = self.cpu_idx
+ expect_new_tasks = True
+
+ while expect_new_tasks or len(active_batches) > 0:
+ stage, task = self.exchanger.get_cpu_task(
+ idx=self.cpu_idx,
+ )
+ # stage, task = next(iterator)
+ # Prioritized STOP signal: something bad happened in another process
+ # -> stop listening to input queues and raise StopIteration (return)
+ if task is None and stage == self.exchanger.num_stages + 1:
+ return
+ # Non prioritized STOP signal: there are no more tasks to process
+ # and we should smoothly stop (wait that there are no more active
+ # tasks, and finalize the writer)
+ if stage == 0 and task is None:
+ expect_new_tasks = False
+ continue
+
+ # If first stage, we receive tasks that may require batching
+ # again => we split them into chunks
+ if stage == 0:
+ task_id, fragments = task
+ chunks = list(
+ batchify(lc.reader.read_worker(fragments), lc.chunk_size)
+ )
+ for chunk_idx, docs in enumerate(chunks):
+ # If we sort by size, we must first create the documents
+ # to have features against which we will sort
+ docs = apply_basic_pipes(docs, preprocess_pipes)
+
+ if lc.sort_chunks:
+ docs.sort(
+ key=doc_size_fns.get(
+ lc.sort_chunks, doc_size_fns["content_boxes"]
+ )
+ )
+
+ batches = [
+ batch
+ for batch in batchify(
+ docs,
+ batch_size=lc.batch_size,
+ formula=batch_size_fns[lc.batch_by],
+ )
+ ]
+
+ for batch_idx, batch in enumerate(batches):
+ assert len(batch) > 0
+ batch_id = next_batch_id
+
+ # We mark the task id only for the last batch of a task
+ # since the purpose of storing the task id is to know
+ # when the worker has finished processing the task,
+ # which is true only when the last batch has been
+ # processed
+ active_batches[batch_id] = (
+ batch,
+ task_id
+ if (batch_idx == len(batches) - 1)
+ and (chunk_idx == len(chunks) - 1)
+ else None,
+ )
+ next_batch_id += num_cpu
+ # gpu_idx = None
+ # batch_id = we have just created a new batch
+ # result from the last stage = None
+ yield stage, (None, batch_id, None)
+ else:
+ yield stage, task
+
+ try:
+ lc: LazyCollection = load(
+ self.lazy_collection_path, map_location=self.device
+ )
+ preprocess_pipes = []
+ num_cpu = self.exchanger.num_cpu_workers
+ split_into_batches_after = lc.split_into_batches_after
+ if (
+ split_into_batches_after is None
+ or lc.batch_by != "docs"
+ or lc.sort_chunks
+ ):
+ split_into_batches_after = next(
+ (p[0] for p in lc.pipeline if p[0] is not None), None
+ )
+ is_before_split = split_into_batches_after is not None
+
+ stages: List[Stage] = [{"cpu_components": [], "gpu_component": None}]
+ for name, pipe, *rest in lc.pipeline:
+ if name in self.gpu_pipe_names:
+ is_before_split = False
+ stages[-1]["gpu_component"] = pipe
+ stages.append({"cpu_components": [], "gpu_component": None})
+ else:
+ if is_before_split:
+ preprocess_pipes.append((name, pipe, *rest))
+ else:
+ stages[-1]["cpu_components"].append((name, pipe, *rest))
+ if name is split_into_batches_after:
+ is_before_split = False
+
+ # Start at cpu_idx to avoid having all workers sending their
+ # first batch (0 % num_device, cf below) to the same gpu
+ active_batches = {}
+
+ logging.info(f"Starting {self} on {os.getpid()}")
+
+ # Inform the main process that we are ready
+ self.exchanger.put_results((None, 0, None, None))
+
+ for stage, (gpu_idx, batch_id, result) in read_tasks():
+ docs, task_id = active_batches.pop(batch_id)
+ for name, pipe, *rest in lc.pipeline:
+ if hasattr(pipe, "enable_cache"):
+ pipe.enable_cache(batch_id)
+ if stage > 0:
+ gpu_pipe = stages[stage - 1]["gpu_component"]
+ docs = gpu_pipe.postprocess(docs, result) # type: ignore
+
+ docs = apply_basic_pipes(docs, stages[stage]["cpu_components"])
+
+ gpu_pipe: "TrainablePipe" = stages[stage]["gpu_component"]
+ if gpu_pipe is not None:
+ preprocessed = gpu_pipe.make_batch(docs) # type: ignore
+ active_batches[batch_id] = (docs, task_id)
+ if gpu_idx is None:
+ gpu_idx = batch_id % len(self.exchanger.gpu_worker_devices)
+ collated = gpu_pipe.collate(preprocessed)
+ collated = gpu_pipe.batch_to_device(
+ collated,
+ device=self.exchanger.gpu_worker_devices[gpu_idx],
+ )
+ self.exchanger.put_gpu(
+ item=(self.cpu_idx, batch_id, collated),
+ idx=gpu_idx,
+ stage=stage,
+ )
+ else:
+ for name, pipe, *rest in lc.pipeline:
+ if hasattr(pipe, "disable_cache"):
+ pipe.disable_cache(batch_id)
+ results, count = (
+ lc.writer.write_worker(docs)
+ if lc.writer is not None
+ else (docs, len(docs))
+ )
+ self.exchanger.put_results(
+ (
+ results,
+ count,
+ self.cpu_idx,
+ task_id,
+ )
+ )
+
+ results, count = lc.writer.finalize() if lc.writer is not None else ([], 0)
+ self.exchanger.put_results((results, count, self.cpu_idx, "finalize"))
+
+ except BaseException as e: # pragma: no cover
+ import traceback
+
+ print(f"Error in {self}:\n{traceback.format_exc()}", flush=True)
+ self.exchanger.put_results((e, 0, self.cpu_idx, None))
+ # We need to drain the queues of GPUWorker fed inputs (pre-moved to GPU)
+ # to ensure no tensor allocated on producer processes (CPUWorker via
+ # collate) are left in consumer processes
+ task = True # anything but None
+ stage = None
+ while (stage, task) != (0, None):
+ try:
+ stage, task = self.exchanger.get_cpu_task(self.cpu_idx)
+ finally:
+ pass
+
+ def __repr__(self):
+ return f""
+
+
+class GPUWorker:
+ def __init__(
+ self,
+ gpu_idx,
+ exchanger: Exchanger,
+ gpu_pipe_names: List[str],
+ lazy_collection_path: str,
+ device: Union[str, "torch.device"],
+ ):
+ super().__init__()
+
+ self.device = device
+ self.gpu_idx = gpu_idx
+ self.exchanger = exchanger
+
+ self.gpu_pipe_names = gpu_pipe_names
+ self.lazy_collection_path = lazy_collection_path
+
+ def run(self):
+ import torch
+
+ # mp._prctl_pr_set_pdeathsig(signal.SIGINT)
+ try:
+ lc = load(self.lazy_collection_path, map_location=self.device)
+ stage_components = [
+ pipe
+ # move_to_device(pipe, self.device)
+ for name, pipe, *_ in lc.pipeline
+ if name in self.gpu_pipe_names
+ ]
+
+ del lc
+ logging.info(f"Starting {self} on {os.getpid()}")
+
+ # Inform the main process that we are ready
+ self.exchanger.put_results((None, 0, None, None))
+
+ with torch.no_grad():
+ while True:
+ stage, task = self.exchanger.get_gpu_task(self.gpu_idx)
+ if task is None:
+ break
+
+ cpu_idx, batch_id, batch = task
+ pipe = stage_components[stage]
+ pipe.enable_cache(batch_id)
+ res = pipe.module_forward(batch)
+ self.exchanger.put_cpu(
+ item=(
+ self.gpu_idx,
+ batch_id,
+ {
+ k: v.to("cpu") if hasattr(v, "to") else v
+ for k, v in res.items()
+ },
+ ),
+ stage=stage + 1,
+ idx=cpu_idx,
+ )
+ if stage == len(stage_components) - 1:
+ pipe.disable_cache(batch_id)
+ del batch, task
+
+ task = batch = res = None # noqa
+ except BaseException as e: # pragma: no cover
+ import traceback
+
+ print(f"Error in {self}:\n{traceback.format_exc()}", flush=True)
+ self.exchanger.put_results((e, 0, None, None))
+
+ from edspdf.trainable_pipe import _caches
+
+ task = batch = res = None # noqa
+ _caches.clear()
+ gc.collect()
+ sys.modules["torch"].cuda.empty_cache()
+
+ # We need to drain the queues of CPUWorker fed inputs (pre-moved to GPU)
+ # to ensure no tensor allocated on producer processes (CPUWorker via
+ # collate) are left in consumer processes
+ stage = None
+ task = None
+ while (stage, task) != (0, None):
+ try:
+ stage, task = self.exchanger.get_gpu_task(self.gpu_idx)
+ finally:
+ pass
+
+ def __repr__(self):
+ return f""
+
+
+DEFAULT_MAX_CPU_WORKERS = 4
+
+
+def execute_multiprocessing_backend(
+ lc: LazyCollection,
+):
+ """
+ If you have multiple CPU cores, and optionally multiple GPUs, we provide the
+ `multiprocessing` backend that allows to run the inference on multiple
+ processes.
+
+ This accelerator dispatches the batches between multiple workers
+ (data-parallelism), and distribute the computation of a given batch on one or two
+ workers (model-parallelism):
+
+ - a `CPUWorker` which handles the non deep-learning components and the
+ preprocessing, collating and postprocessing of deep-learning components
+ - a `GPUWorker` which handles the forward call of the deep-learning components
+
+ If no GPU is available, no `GPUWorker` is started, and the `CPUWorkers` handle
+ the forward call of the deep-learning components as well.
+
+ The advantage of dedicating a worker to the deep-learning components is that it
+ allows to prepare multiple batches in parallel in multiple `CPUWorker`, and ensure
+ that the `GPUWorker` never wait for a batch to be ready.
+
+ The overall architecture described in the following figure, for 3 CPU workers and 2
+ GPU workers.
+
+
+
+
+
+ Here is how a small pipeline with rule-based components and deep-learning components
+ is distributed between the workers:
+
+
+
+
+
+ """
+ try:
+ TrainablePipe = sys.modules["edspdf.trainable_pipe"].TrainablePipe
+ except (KeyError, AttributeError): # pragma: no cover
+ TrainablePipe = None
+
+ steps = lc.pipeline
+ num_cpu_workers = lc.num_cpu_workers
+ num_gpu_workers = lc.num_gpu_workers
+ show_progress = lc.show_progress
+ process_start_method = lc.process_start_method
+
+ # Infer which pipes should be accelerated on GPU
+ gpu_steps_candidates = (
+ [name for name, component, *_ in steps if isinstance(component, TrainablePipe)]
+ if TrainablePipe is not None
+ else []
+ )
+ gpu_pipe_names = (
+ gpu_steps_candidates if lc.gpu_pipe_names is None else lc.gpu_pipe_names
+ )
+ if set(gpu_pipe_names) - set(gpu_steps_candidates):
+ raise ValueError(
+ "GPU accelerated pipes {} could not be found in the model".format(
+ sorted(set(gpu_pipe_names) - set(gpu_steps_candidates))
+ )
+ )
+
+ old_environ = {
+ k: os.environ.get(k) for k in ("TOKENIZERS_PARALLELISM", "OMP_NUM_THREADS")
+ }
+ if lc.disable_implicit_parallelism:
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+ os.environ["OMP_NUM_THREADS"] = "1"
+
+ requires_gpu = (
+ num_gpu_workers is None
+ and len(gpu_pipe_names)
+ or num_gpu_workers is not None
+ and num_gpu_workers > 0
+ )
+
+ num_devices = 0
+ if requires_gpu:
+ import torch
+
+ num_devices = torch.cuda.device_count()
+ logging.info(f"Number of available devices: {num_devices}")
+
+ if num_gpu_workers is None:
+ num_gpu_workers = num_devices
+ else:
+ num_gpu_workers = 0
+
+ if any(gpu_steps_candidates):
+ if process_start_method == "fork":
+ warnings.warn(
+ "Using fork start method with GPU workers may lead to deadlocks. "
+ "Consider using process_start_method='spawn' instead."
+ )
+
+ process_start_method = process_start_method or "spawn"
+
+ default_method = multiprocessing.get_start_method()
+ if process_start_method is not None and default_method != process_start_method:
+ logging.info(f"Switching process start method to {process_start_method}")
+
+ mp = multiprocessing.get_context(process_start_method)
+ max_workers = max(min(mp.cpu_count() - num_gpu_workers, DEFAULT_MAX_CPU_WORKERS), 0)
+ num_cpu_workers = (
+ (num_gpu_workers or max_workers)
+ if num_cpu_workers is None
+ else max_workers + num_cpu_workers + 1
+ if num_cpu_workers < 0
+ else num_cpu_workers
+ )
+
+ if num_gpu_workers == 0:
+ gpu_pipe_names = []
+
+ gpu_worker_devices = (
+ [
+ f"cuda:{gpu_idx * num_devices // num_gpu_workers}"
+ for gpu_idx in range(num_gpu_workers)
+ ]
+ if requires_gpu and lc.gpu_worker_devices is None
+ else []
+ if lc.gpu_worker_devices is None
+ else lc.gpu_worker_devices
+ )
+ cpu_worker_devices = (
+ ["cpu"] * num_cpu_workers
+ if lc.cpu_worker_devices is None
+ else lc.cpu_worker_devices
+ )
+ assert len(cpu_worker_devices) == num_cpu_workers
+ assert len(gpu_worker_devices) == num_gpu_workers
+ if num_cpu_workers == 0: # pragma: no cover
+ (
+ num_cpu_workers,
+ num_gpu_workers,
+ cpu_worker_devices,
+ gpu_worker_devices,
+ gpu_pipe_names,
+ ) = (num_gpu_workers, 0, gpu_worker_devices, [], [])
+
+ exchanger = Exchanger(
+ mp,
+ num_stages=len(gpu_pipe_names),
+ num_cpu_workers=num_cpu_workers,
+ num_gpu_workers=num_gpu_workers,
+ gpu_worker_devices=gpu_worker_devices,
+ )
+
+ lc = lc.to("cpu")
+
+ cpu_workers = []
+ gpu_workers = []
+
+ with tempfile.NamedTemporaryFile(delete=False) as fp:
+ dump(lc.worker_copy(), fp)
+ fp.close()
+
+ revert_pickler = replace_pickler()
+
+ for gpu_idx in range(num_gpu_workers):
+ gpu_workers.append(
+ mp.Process(
+ target=GPUWorker.run,
+ args=(
+ GPUWorker(
+ gpu_idx=gpu_idx,
+ exchanger=exchanger,
+ gpu_pipe_names=gpu_pipe_names,
+ lazy_collection_path=fp.name,
+ device=gpu_worker_devices[gpu_idx],
+ ),
+ ),
+ )
+ )
+
+ for cpu_idx in range(num_cpu_workers):
+ cpu_workers.append(
+ mp.Process(
+ target=CPUWorker.run,
+ args=(
+ CPUWorker(
+ cpu_idx=cpu_idx,
+ exchanger=exchanger,
+ gpu_pipe_names=gpu_pipe_names,
+ lazy_collection_path=fp.name,
+ device=cpu_worker_devices[cpu_idx],
+ ),
+ ),
+ )
+ )
+
+ logging.info(f"Main PID {os.getpid()}")
+
+ logging.info(
+ f"Starting {num_cpu_workers} cpu workers and {num_gpu_workers} gpu workers on "
+ f"{gpu_worker_devices}, with accelerated pipes: {gpu_pipe_names}",
+ )
+
+ for worker in (*cpu_workers, *gpu_workers):
+ worker.start()
+
+ logging.info("Workers are ready")
+
+ for i in range(len((*cpu_workers, *gpu_workers))):
+ outputs, count, cpu_idx, output_task_id = exchanger.outputs_queue.get()
+ if isinstance(outputs, BaseException): # pragma: no cover
+ raise outputs
+
+ os.unlink(fp.name)
+
+ num_max_enqueued = 1
+ # Number of input/output batch per process
+ outputs_iterator = exchanger.iter_results()
+
+ cpu_worker_indices = list(range(num_cpu_workers))
+ inputs_iterator = lc.reader.read_main()
+ active_chunks = [{} for i in cpu_worker_indices]
+ non_finalized = {i for i in cpu_worker_indices}
+ max_workload = lc.chunk_size * num_max_enqueued
+
+ bar = nullcontext()
+ if show_progress:
+ from tqdm import tqdm
+
+ bar = tqdm(smoothing=0.1, mininterval=5.0)
+
+ def get_and_process_output():
+ outputs, count, cpu_idx, output_task_id = next(outputs_iterator)
+ if output_task_id == "finalize":
+ non_finalized.discard(cpu_idx)
+ if isinstance(outputs, BaseException): # pragma: no cover
+ raise outputs
+ if show_progress:
+ bar.update(count)
+ if count > 0:
+ yield outputs
+ if output_task_id is not None:
+ active_chunks[cpu_idx].pop(output_task_id, None)
+
+ def process():
+ try:
+ with bar:
+ for input_task_id, items in enumerate(
+ batchify(
+ iterable=inputs_iterator,
+ batch_size=lc.chunk_size,
+ drop_last=False,
+ formula=lambda x: sum(item[1] for item in x),
+ )
+ ):
+ batch = [item[0] for item in items]
+ batch_size = sum(item[1] for item in items)
+
+ while all(sum(wl.values()) >= max_workload for wl in active_chunks):
+ yield from get_and_process_output()
+
+ # Shuffle to ensure the first process does not receive all the
+ # documents in case of workload equality
+ shuffle(cpu_worker_indices)
+ cpu_idx = min(
+ cpu_worker_indices,
+ key=lambda i: sum(active_chunks[i].values()),
+ )
+ exchanger.put_cpu((input_task_id, batch), stage=0, idx=cpu_idx)
+ active_chunks[cpu_idx][input_task_id] = batch_size
+
+ # Inform the CPU workers that there are no more tasks to process
+ for i, worker in enumerate(cpu_workers):
+ exchanger.cpu_inputs_queues[i][0].put(None)
+
+ while any(active_chunks):
+ yield from get_and_process_output()
+
+ while len(non_finalized):
+ yield from get_and_process_output()
+
+ finally:
+ revert_pickler()
+
+ for k, v in old_environ.items():
+ os.environ.pop(k, None)
+ if v is not None:
+ os.environ[k] = v
+
+ # Send gpu and cpu process the order to stop processing data
+ # We use the prioritized queue to ensure the stop signal is processed
+ # before the next batch of data
+ for i, worker in enumerate(gpu_workers):
+ exchanger.gpu_inputs_queues[i][-1].put(None)
+ for i, worker in enumerate(cpu_workers):
+ exchanger.cpu_inputs_queues[i][-1].put(None)
+
+ # Enqueue a final non prioritized STOP signal to ensure there remains no
+ # data in the queues (cf drain loop in CPUWorker / GPUWorker)
+ for i, worker in enumerate(gpu_workers):
+ exchanger.gpu_inputs_queues[i][0].put(None)
+ for i, worker in enumerate(gpu_workers):
+ worker.join(timeout=5)
+ for i, worker in enumerate(cpu_workers):
+ exchanger.cpu_inputs_queues[i][0].put(None)
+ for i, worker in enumerate(cpu_workers):
+ worker.join(timeout=1)
+
+ # If a worker is still alive, kill it
+ # This should not happen, but for a reason I cannot explain, it does in
+ # some CPU workers sometimes when we catch an error, even though each run
+ # method of the workers completes cleanly. Maybe this has something to do
+ # with the cleanup of these processes ?
+ for i, worker in enumerate(gpu_workers): # pragma: no cover
+ if worker.is_alive():
+ logging.error(f"Killing ")
+ worker.kill()
+ for i, worker in enumerate(cpu_workers): # pragma: no cover
+ if worker.is_alive():
+ logging.error(f"Killing ")
+ worker.kill()
+
+ for queue_group in (
+ *exchanger.cpu_inputs_queues,
+ *exchanger.gpu_inputs_queues,
+ [exchanger.outputs_queue],
+ ):
+ for queue in queue_group:
+ if hasattr(queue, "cancel_join_thread"):
+ queue.cancel_join_thread()
+
+ gen = process()
+ return lc.writer.write_main(gen) if lc.writer is not None else flatten(gen)
diff --git a/edspdf/processing/simple.py b/edspdf/processing/simple.py
new file mode 100644
index 00000000..71372413
--- /dev/null
+++ b/edspdf/processing/simple.py
@@ -0,0 +1,111 @@
+from __future__ import annotations
+
+import sys
+from contextlib import nullcontext
+from typing import TYPE_CHECKING
+
+from edspdf.utils.collections import batchify, flatten
+
+if TYPE_CHECKING:
+ from edspdf.lazy_collection import LazyCollection
+
+batch_size_fns = {
+ "content_boxes": lambda batch: sum(len(doc.content_boxes) for doc in batch),
+ "pages": lambda batch: sum(len(doc.pages) for doc in batch),
+ "docs": len,
+}
+
+doc_size_fns = {
+ "content_boxes": lambda doc: len(doc.content_boxes),
+}
+
+
+def apply_basic_pipes(docs, pipes):
+ for name, pipe, kwargs in pipes:
+ if hasattr(pipe, "batch_process"):
+ docs = pipe.batch_process(docs)
+ else:
+ docs = [pipe(doc, **kwargs) for doc in docs]
+ return docs
+
+
+def execute_simple_backend(
+ lc: LazyCollection,
+):
+ """
+ This is the default execution mode which batches the documents and processes each
+ batch on the current process in a sequential manner.
+ """
+ try:
+ no_grad = sys.modules["torch"].no_grad
+ except (KeyError, AttributeError):
+ no_grad = nullcontext
+ reader = lc.reader
+ writer = lc.writer
+ show_progress = lc.show_progress
+
+ split_into_batches_after = lc.split_into_batches_after
+ if split_into_batches_after is None or lc.batch_by != "docs" or lc.sort_chunks:
+ split_into_batches_after = next(
+ (p[0] for p in lc.pipeline if p[0] is not None), None
+ )
+ names = [step[0] for step in lc.pipeline] + [None]
+ chunk_components = lc.pipeline[: names.index(split_into_batches_after)]
+ batch_components = lc.pipeline[names.index(split_into_batches_after) :]
+
+ def process():
+ bar = nullcontext()
+ if show_progress:
+ from tqdm import tqdm
+
+ bar = tqdm(smoothing=0.1, mininterval=5.0)
+
+ with bar:
+ for docs in batchify(
+ (
+ subtask
+ for task, count in reader.read_main()
+ for subtask in reader.read_worker([task])
+ ),
+ batch_size=lc.chunk_size,
+ ):
+ docs = apply_basic_pipes(docs, chunk_components)
+
+ if lc.sort_chunks:
+ docs.sort(
+ key=doc_size_fns.get(
+ lc.sort_chunks, doc_size_fns["content_boxes"]
+ )
+ )
+
+ batches = [
+ batch
+ for batch in batchify(
+ docs,
+ batch_size=lc.batch_size,
+ formula=batch_size_fns.get(lc.batch_by, len),
+ )
+ ]
+
+ for batch in batches:
+ with no_grad(), lc.cache():
+ batch = apply_basic_pipes(batch, batch_components)
+
+ if writer is not None:
+ result, count = writer.write_worker(batch)
+ if show_progress:
+ bar.update(count)
+ yield result
+ else:
+ if show_progress:
+ bar.update(len(batch))
+ yield batch
+ if writer is not None:
+ result, count = writer.finalize()
+ if show_progress:
+ bar.update(count)
+ if count:
+ yield result
+
+ gen = process()
+ return flatten(gen) if writer is None else writer.write_main(gen)
diff --git a/edspdf/registry.py b/edspdf/registry.py
index 641b0eb2..5219b4cd 100644
--- a/edspdf/registry.py
+++ b/edspdf/registry.py
@@ -220,3 +220,5 @@ class registry(RegistryCollection):
misc = Registry(("edspdf", "misc"), entry_points=True)
adapter = Registry(("edspdf", "adapter"), entry_points=True)
accelerator = Registry(("edspdf", "accelerator"), entry_points=True)
+ readers = Registry(("edspdf", "readers"), entry_points=True)
+ writers = Registry(("edspdf", "writers"), entry_points=True)
diff --git a/edspdf/trainable_pipe.py b/edspdf/trainable_pipe.py
index e87ae84a..e1979baa 100644
--- a/edspdf/trainable_pipe.py
+++ b/edspdf/trainable_pipe.py
@@ -1,28 +1,34 @@
+import os
from abc import ABCMeta
from enum import Enum
from functools import wraps
-from pathlib import Path
from typing import (
Any,
+ Callable,
Dict,
Generic,
Iterable,
Optional,
Sequence,
+ Set,
+ Tuple,
TypeVar,
Union,
)
+import safetensors.torch
import torch
from edspdf.pipeline import Pipeline
from edspdf.structures import PDFDoc
from edspdf.utils.collections import batch_compress_dict, decompress_dict
-NestedSequences = Dict[str, Union["NestedSequences", Sequence]]
-NestedTensors = Dict[str, Union["NestedSequences", torch.Tensor]]
-InputBatch = TypeVar("InputBatch", bound=NestedTensors)
-OutputBatch = TypeVar("OutputBatch", bound=NestedTensors)
+BatchInput = TypeVar("BatchInput", bound=Dict[str, Any])
+BatchOutput = TypeVar("BatchOutput", bound=Dict[str, Any])
+Scorer = Callable[[Sequence[Tuple[PDFDoc, PDFDoc]]], Union[float, Dict[str, Any]]]
+
+ALL_CACHES = object()
+_caches = {}
class CacheEnum(str, Enum):
@@ -36,19 +42,24 @@ def hash_batch(batch):
return hash(tuple(id(item) for item in batch))
elif not isinstance(batch, dict):
return id(batch)
- return hash((tuple(batch.keys()), tuple(map(hash_batch, batch.values()))))
+ if "__batch_hash__" in batch:
+ return batch["__batch_hash__"]
+ batch_hash = hash((tuple(batch.keys()), tuple(map(hash_batch, batch.values()))))
+ batch["__batch_hash__"] = batch_hash
+ return batch_hash
def cached_preprocess(fn):
@wraps(fn)
def wrapped(self: "TrainablePipe", doc: PDFDoc):
- if self.pipeline is None or self.pipeline._cache is None:
+ if self._current_cache_id is None:
return fn(self, doc)
- cache_id = (id(self), "preprocess", id(doc))
- if cache_id in self.pipeline._cache:
- return self.pipeline._cache[cache_id]
+ cache_key = ("preprocess", f"{type(self)}<{id(self)}>: {id(doc)}")
+ cache = _caches[self._current_cache_id]
+ if cache_key in cache:
+ return cache[cache_key]
res = fn(self, doc)
- self.pipeline._cache[cache_id] = res
+ cache[cache_key] = res
return res
return wrapped
@@ -57,31 +68,30 @@ def wrapped(self: "TrainablePipe", doc: PDFDoc):
def cached_preprocess_supervised(fn):
@wraps(fn)
def wrapped(self: "TrainablePipe", doc: PDFDoc):
- if self.pipeline is None or self.pipeline._cache is None:
+ if self._current_cache_id is None:
return fn(self, doc)
- cache_id = (id(self), "preprocess_supervised", id(doc))
- if cache_id in self.pipeline._cache:
- return self.pipeline._cache[cache_id]
+ cache_key = ("preprocess_supervised", f"{type(self)}<{id(self)}>: {id(doc)}")
+ cache = _caches[self._current_cache_id]
+ if cache_key in cache:
+ return cache[cache_key]
res = fn(self, doc)
- self.pipeline._cache[cache_id] = res
+ cache[cache_key] = res
return res
return wrapped
def cached_collate(fn):
- import torch
-
@wraps(fn)
- def wrapped(self: "TrainablePipe", batch: Dict, device: torch.device):
- if self.pipeline is None or self.pipeline._cache is None:
- return fn(self, batch, device)
- cache_id = (id(self), "collate", hash_batch(batch))
- if cache_id in self.pipeline._cache:
- return self.pipeline._cache[cache_id]
- res = fn(self, batch, device)
- self.pipeline._cache[cache_id] = res
- res["cache_id"] = cache_id
+ def wrapped(self: "TrainablePipe", batch: Dict):
+ if self._current_cache_id is None:
+ return fn(self, batch)
+ cache_key = ("collate", f"{type(self)}<{id(self)}>: {hash_batch(batch)}")
+ cache = _caches[self._current_cache_id]
+ if cache_key in cache:
+ return cache[cache_key]
+ res = fn(self, batch)
+ cache[cache_key] = res
return res
return wrapped
@@ -90,13 +100,35 @@ def wrapped(self: "TrainablePipe", batch: Dict, device: torch.device):
def cached_forward(fn):
@wraps(fn)
def wrapped(self: "TrainablePipe", batch):
- if self.pipeline is None or self.pipeline._cache is None:
+ # Convert args and kwargs to a dictionary matching fn signature
+ if self._current_cache_id is None:
return fn(self, batch)
- cache_id = (id(self), "collate", hash_batch(batch))
- if cache_id in self.pipeline._cache:
- return self.pipeline._cache[cache_id]
+ cache_key = ("forward", f"{type(self)}<{id(self)}>: {hash_batch(batch)}")
+ cache = _caches[self._current_cache_id]
+ if cache_key in cache:
+ return cache[cache_key]
res = fn(self, batch)
- self.pipeline._cache[cache_id] = res
+ cache[cache_key] = res
+ return res
+
+ return wrapped
+
+
+def cached_batch_to_device(fn):
+ @wraps(fn)
+ def wrapped(self: "TrainablePipe", batch, device):
+ # Convert args and kwargs to a dictionary matching fn signature
+ if self._current_cache_id is None:
+ return fn(self, batch, device)
+ cache_key = (
+ "batch_to_device",
+ f"{type(self)}<{id(self)}>: {hash_batch(batch)}",
+ )
+ cache = _caches[self._current_cache_id]
+ if cache_key in cache:
+ return cache[cache_key]
+ res = fn(self, batch, device)
+ cache[cache_key] = res
return res
return wrapped
@@ -112,6 +144,10 @@ def __new__(mcs, name, bases, class_dict):
)
if "collate" in class_dict:
class_dict["collate"] = cached_collate(class_dict["collate"])
+ if "batch_to_device" in class_dict:
+ class_dict["batch_to_device"] = cached_batch_to_device(
+ class_dict["batch_to_device"]
+ )
if "forward" in class_dict:
class_dict["forward"] = cached_forward(class_dict["forward"])
@@ -120,7 +156,7 @@ def __new__(mcs, name, bases, class_dict):
class TrainablePipe(
torch.nn.Module,
- Generic[OutputBatch],
+ Generic[BatchOutput],
metaclass=TrainablePipeMeta,
):
"""
@@ -133,15 +169,30 @@ class TrainablePipe(
for components that share a common subcomponent.
"""
+ call_super_init = True
+
def __init__(self, pipeline: Optional[Pipeline], name: Optional[str]):
super().__init__()
- self.pipeline = pipeline
self.name = name
- self.cfg = {}
- self._preprocess_cache = {}
- self._preprocess_supervised_cache = {}
- self._collate_cache = {}
- self._forward_cache = {}
+ self._current_cache_id = None
+
+ def enable_cache(self, cache_id="default"):
+ self._current_cache_id = cache_id
+ _caches.setdefault(cache_id, {})
+ for name, component in self.named_component_children():
+ if hasattr(component, "enable_cache"):
+ component.enable_cache(cache_id)
+
+ def disable_cache(self, cache_id=ALL_CACHES):
+ if cache_id is ALL_CACHES:
+ _caches.clear()
+ else:
+ if cache_id in _caches:
+ del _caches[cache_id]
+ self._current_cache_id = None
+ for name, component in self.named_component_children():
+ if hasattr(component, "disable_cache"):
+ component.disable_cache(cache_id)
@property
def device(self):
@@ -152,45 +203,7 @@ def named_component_children(self):
if isinstance(module, TrainablePipe):
yield name, module
- def save_extra_data(self, path: Path, exclude: set):
- """
- Dumps vocabularies indices to json files
-
- Parameters
- ----------
- path: Path
- Path to the directory where the files will be saved
- exclude: Set
- The set of component names to exclude from saving
- This is useful when components are repeated in the pipeline.
- """
- if self.name in exclude:
- return
- exclude.add(self.name)
- for name, component in self.named_component_children():
- if hasattr(component, "save_extra_data"):
- component.save_extra_data(path / name, exclude)
-
- def load_extra_data(self, path: Path, exclude: set):
- """
- Loads vocabularies indices from json files
-
- Parameters
- ----------
- path: Path
- Path to the directory where the files will be loaded
- exclude: Set
- The set of component names to exclude from loading
- This is useful when components are repeated in the pipeline.
- """
- if self.name in exclude:
- return
- exclude.add(self.name)
- for name, component in self.named_component_children():
- if hasattr(component, "load_extra_data"):
- component.load_extra_data(path / name, exclude)
-
- def post_init(self, gold_data: Iterable[PDFDoc], exclude: set):
+ def post_init(self, gold_data: Iterable[PDFDoc], exclude: Set[str]):
"""
This method completes the attributes of the component, by looking at some
documents. It is especially useful to build vocabularies or detect the labels
@@ -205,9 +218,10 @@ def post_init(self, gold_data: Iterable[PDFDoc], exclude: set):
This argument will be gradually updated with the names of initialized
components
"""
- if self.name in exclude:
+ repr_id = object.__repr__(self)
+ if repr_id in exclude:
return
- exclude.add(self.name)
+ exclude.add(repr_id)
for name, component in self.named_component_children():
if hasattr(component, "post_init"):
component.post_init(gold_data, exclude=exclude)
@@ -233,52 +247,79 @@ def preprocess(self, doc: PDFDoc) -> Dict[str, Any]:
for name, component in self.named_component_children()
}
- def collate(self, batch: NestedSequences, device: torch.device) -> InputBatch:
+ def collate(self, batch: Dict[str, Any]) -> BatchInput:
"""
Collate the batch of features into a single batch of tensors that can be
used by the forward method of the component.
Parameters
----------
- batch: NestedSequences
+ batch: Dict[str, Any]
Batch of features
- device: torch.device
- Device on which the tensors should be moved
Returns
-------
- InputBatch
+ BatchInput
Dictionary (optionally nested) containing the collated tensors
"""
return {
- name: component.collate(batch[name], device)
+ name: component.collate(batch[name])
for name, component in self.named_component_children()
}
- def forward(self, batch: InputBatch) -> OutputBatch:
+ def batch_to_device(
+ self,
+ batch: BatchInput,
+ device: Optional[Union[str, torch.device]],
+ ) -> BatchInput:
"""
- Perform the forward pass of the neural network, i.e, apply transformations
- over the collated features to compute new embeddings, probabilities, losses, etc
+ Move the batch of tensors to the specified device.
Parameters
----------
- batch: InputBatch
+ batch: BatchInput
+ Batch of tensors
+ device: Optional[Union[str, torch.device]]
+ Device to move the tensors to
+
+ Returns
+ -------
+ BatchInput
+ """
+ return {
+ name: (
+ value.to(device)
+ if hasattr(value, "to")
+ else getattr(self, name).batch_to_device(value, device=device)
+ if hasattr(self, name)
+ else value
+ )
+ for name, value in batch.items()
+ }
+
+ def forward(self, batch: BatchInput) -> BatchOutput:
+ """
+ Perform the forward pass of the neural network.
+
+ Parameters
+ ----------
+ batch: BatchInput
Batch of tensors (nested dictionary) computed by the collate method
Returns
-------
- OutputBatch
+ BatchOutput
"""
raise NotImplementedError()
- def module_forward(self, batch: InputBatch) -> OutputBatch:
+ def module_forward(self, *args, **kwargs):
"""
This is a wrapper around `torch.nn.Module.__call__` to avoid conflict
with the
[`TrainablePipe.__call__`][edspdf.trainable_pipe.TrainablePipe.__call__]
method.
"""
- return torch.nn.Module.__call__(self, batch)
+ return torch.nn.Module.__call__(self, *args, **kwargs)
def make_batch(
self,
@@ -323,9 +364,11 @@ def batch_process(self, docs: Sequence[PDFDoc]) -> Sequence[PDFDoc]:
Sequence[PDFDoc]
Batch of updated documents
"""
+ device = self.device
with torch.no_grad():
batch = self.make_batch(docs)
- inputs = self.collate(batch, device=self.device)
+ inputs = self.collate(batch)
+ inputs = self.batch_to_device(inputs, device=device)
if hasattr(self, "compiled"):
res = self.compiled(inputs)
else:
@@ -334,7 +377,7 @@ def batch_process(self, docs: Sequence[PDFDoc]) -> Sequence[PDFDoc]:
return docs
def postprocess(
- self, docs: Sequence[PDFDoc], batch: OutputBatch
+ self, docs: Sequence[PDFDoc], batch: BatchOutput
) -> Sequence[PDFDoc]:
"""
Update the documents with the predictions of the neural network, for instance
@@ -346,7 +389,7 @@ def postprocess(
----------
docs: Sequence[PDFDoc]
Batch of documents
- batch: OutputBatch
+ batch: BatchOutput
Batch of predictions, as returned by the forward method
Returns
@@ -391,3 +434,41 @@ def __call__(self, doc: PDFDoc) -> PDFDoc:
PDFDoc
"""
return self.batch_process([doc])[0]
+
+ def to_disk(self, path, exclude: Optional[Set[str]]):
+ if object.__repr__(self) in exclude:
+ return
+ exclude.add(object.__repr__(self))
+ overrides = {}
+ for name, component in self.named_component_children():
+ if hasattr(component, "to_disk"):
+ pipe_overrides = component.to_disk(path / name, exclude=exclude)
+ if pipe_overrides:
+ overrides[name] = pipe_overrides
+ tensor_dict = {
+ n: p
+ for n, p in self.named_parameters()
+ if object.__repr__(p) not in exclude
+ }
+ os.makedirs(path, exist_ok=True)
+ safetensors.torch.save_file(tensor_dict, path / "parameters.safetensors")
+ exclude.update(object.__repr__(p) for p in tensor_dict.values())
+ return overrides
+
+ def from_disk(self, path, exclude: Optional[Set[str]]):
+ if object.__repr__(self) in exclude:
+ return
+ exclude.add(object.__repr__(self))
+ for name, component in self.named_component_children():
+ if hasattr(component, "from_disk"):
+ component.from_disk(path / name, exclude=exclude)
+ tensor_dict = safetensors.torch.load_file(path / "parameters.safetensors")
+ self.load_state_dict(tensor_dict, strict=False)
+
+ @property
+ def load_extra_data(self):
+ return self.from_disk
+
+ @property
+ def save_extra_data(self):
+ return self.to_disk
diff --git a/edspdf/utils/collections.py b/edspdf/utils/collections.py
index f0c37db9..7ffe7861 100644
--- a/edspdf/utils/collections.py
+++ b/edspdf/utils/collections.py
@@ -1,8 +1,18 @@
import copy
import itertools
-import math
from collections import defaultdict
-from typing import Any, Dict, Iterable, List, Mapping, Sequence, TypeVar
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Mapping,
+ Sequence,
+ TypeVar,
+ Union,
+)
It = TypeVar("It", bound=Iterable)
T = TypeVar("T")
@@ -37,16 +47,46 @@ def nest_dict(flat: Dict[str, Any]) -> Dict[str, Any]:
def ld_to_dl(ld: Iterable[Mapping[str, T]]) -> Dict[str, List[T]]:
+ """
+ Convert a list of dictionaries to a dictionary of lists
+
+ Parameters
+ ----------
+ ld: Iterable[Mapping[str, T]]
+ The list of dictionaries
+
+ Returns
+ -------
+ Dict[str, List[T]]
+ The dictionary of lists
+ """
ld = list(ld)
- return {k: [dic[k] for dic in ld] for k in ld[0]}
+ return {k: [dic.get(k) for dic in ld] for k in (ld[0] if len(ld) else ())}
+
+
+def dl_to_ld(dl: Mapping[str, Sequence[Any]]) -> Iterator[Dict[str, Any]]:
+ """
+ Convert a dictionary of lists to a list of dictionaries
+ Parameters
+ ----------
+ dl: Mapping[str, Sequence[Any]]
+ The dictionary of lists
-def dl_to_ld(dl: Mapping[str, Sequence[T]]) -> List[Dict[str, T]]:
- return [dict(zip(dl, t)) for t in zip(*dl.values())]
+ Returns
+ -------
+ List[Dict[str, Any]]
+ The list of dictionaries
+ """
+ return (dict(zip(dl, t)) for t in zip(*dl.values()))
-def flatten(seq: Sequence[Sequence["T"]]) -> List["T"]:
- return list(itertools.chain.from_iterable(seq))
+def flatten(items):
+ for item in items:
+ if isinstance(item, list):
+ yield from flatten(item)
+ else:
+ yield item
FLATTEN_TEMPLATE = """\
@@ -64,16 +104,17 @@ def rec(current, path):
keys[id(current)].append(path)
return
for key, value in current.items():
- rec(value, (*path, key))
+ if not key.startswith("$"):
+ rec(value, (*path, key))
rec(obj, ())
code = FLATTEN_TEMPLATE.format(
"{"
+ "\n".join(
- "'{}': root{},".format(
- "|".join(map("/".join, key_list)),
- "".join(f"['{k}']" for k in key_list[0]),
+ "{}: root{},".format(
+ repr("|".join(map("/".join, key_list))),
+ "".join(f"[{repr(k)}]" for k in key_list[0]),
)
for key_list in keys.values()
)
@@ -83,6 +124,19 @@ def rec(current, path):
class batch_compress_dict:
+ """
+ Compress a sequence of dictionaries in which values that occur multiple times are
+ deduplicated. The corresponding keys will be merged into a single string using
+ the "|" character as a separator.
+ This is useful to preserve referential identities when decompressing the dictionary
+ after it has been serialized and deserialized.
+
+ Parameters
+ ----------
+ seq: Iterable[Dict[str, Any]]
+ Sequence of dictionaries to compress
+ """
+
__slots__ = ("flatten", "seq")
def __init__(self, seq: Iterable[Dict[str, Any]]):
@@ -109,7 +163,23 @@ def __next__(self) -> Dict[str, List]:
return self.flatten(item)
-def decompress_dict(seq):
+def decompress_dict(seq: Union[Iterable[Dict[str, Any]], Dict[str, Any]]):
+ """
+ Decompress a dictionary of lists into a sequence of dictionaries.
+ This function assumes that the dictionary structure was obtained using the
+ `batch_compress_dict` class.
+ Keys that were merged into a single string using the "|" character as a separator
+ will be split into a nested dictionary structure.
+
+ Parameters
+ ----------
+ seq: Union[Iterable[Dict[str, Any]], Dict[str, Any]]
+ The dictionary to decompress or a sequence of dictionaries to decompress
+
+ Returns
+ -------
+
+ """
obj = ld_to_dl(seq) if isinstance(seq, Sequence) else seq
res = {}
for key, value in obj.items():
@@ -122,26 +192,33 @@ def decompress_dict(seq):
return res
-class batchify(Iterable[List[T]]):
- def __init__(self, iterable: Iterable[T], batch_size: int):
- self.iterable = iter(iterable)
- self.batch_size = batch_size
- try:
- self.length = math.ceil(len(iterable) / batch_size)
- except (AttributeError, TypeError):
- pass
-
- def __len__(self):
- return self.length
-
- def __iter__(self):
- return self
-
- def __next__(self):
- batch = list(itertools.islice(self.iterable, self.batch_size))
- if len(batch) == 0:
- raise StopIteration()
- return batch
+def batchify(
+ iterable: Iterable[T],
+ batch_size: int,
+ drop_last: bool = False,
+ formula: Callable = len,
+) -> Iterable[List[T]]:
+ """
+ Yields batch that contain at most `batch_size` elements.
+ If an item contains more than `batch_size` elements, it will be yielded as a single
+ batch.
+
+ Parameters
+ ----------
+ iterable: Iterable[T]
+ batch_size: int
+ drop_last: bool
+ formula: Callable
+ """
+ batch = []
+ for item in iterable:
+ next_size = formula(batch + [item])
+ if next_size > batch_size and len(batch) > 0:
+ yield batch
+ batch = []
+ batch.append(item)
+ if len(batch) > 0 and not drop_last:
+ yield batch
def get_attr_item(base, attr):
diff --git a/edspdf/utils/lazy_module.py b/edspdf/utils/lazy_module.py
new file mode 100644
index 00000000..ddc75d61
--- /dev/null
+++ b/edspdf/utils/lazy_module.py
@@ -0,0 +1,108 @@
+# flake8: noqa: F811
+import ast
+import importlib
+import inspect
+import os
+
+
+def lazify():
+ def _get_module_paths(file):
+ """
+ Reads the content of the current file, parses it with ast and store the
+ import path for future potential imports. This is useful to only import
+ the module that is requested and avoid loading all the modules at once, since
+ some of them are quite heavy, or contain dependencies that are not always
+ available.
+
+ For instance:
+ > from .trainable.span_qualifier.factory import create_component as
+ span_qualifier is stored in the cache as:
+ > module_paths["span_qualifier"] = "trainable.span_qualifier.factory"
+
+ Returns
+ -------
+ Dict[str, Tuple[str, str]]
+ The absolute path of the current file.
+ """
+ module_path = os.path.abspath(file)
+ with open(module_path, "r") as f:
+ module_content = f.read()
+ module_ast = ast.parse(module_content)
+ module_paths = {}
+ for node in module_ast.body:
+ # Lookup TYPE_CHECKING
+ if not (
+ isinstance(node, ast.If)
+ and (
+ (
+ isinstance(node.test, ast.Name)
+ and node.test.id == "TYPE_CHECKING"
+ )
+ or (
+ isinstance(node.test, ast.Attribute)
+ and node.test.attr == "TYPE_CHECKING"
+ )
+ )
+ ):
+ continue
+ for import_node in node.body:
+ if isinstance(import_node, ast.ImportFrom):
+ for name in import_node.names:
+ module_paths[name.asname or name.name] = (
+ import_node.module,
+ name.name,
+ )
+
+ return module_paths
+
+ def __getattr__(name):
+ """
+ Imports the actual module if it is in the module_paths dict.
+
+ Parameters
+ ----------
+ name
+
+ Returns
+ -------
+
+ """
+ if name in module_paths:
+ module_path, module_name = module_paths[name]
+ result = getattr(
+ importlib.__import__(
+ module_path,
+ fromlist=[module_name],
+ globals=module_globals,
+ level=1,
+ ),
+ module_name,
+ )
+ module_globals[name] = result
+ return result
+ raise AttributeError(f"module {__name__} has no attribute {name}")
+
+ def __dir__():
+ """
+ Returns the list of available modules.
+
+ Returns
+ -------
+ List[str]
+ """
+ return __all__
+
+ # Access upper frame
+ module_globals = inspect.currentframe().f_back.f_globals
+
+ module_paths = _get_module_paths(module_globals["__file__"])
+
+ __all__ = list(module_paths.keys())
+
+ module_globals.update(
+ {
+ "__getattr__": __getattr__,
+ "__dir__": __dir__,
+ "__all__": __all__,
+ }
+ )
diff --git a/edspdf/utils/package.py b/edspdf/utils/package.py
index 8440fa65..31a7c132 100644
--- a/edspdf/utils/package.py
+++ b/edspdf/utils/package.py
@@ -1,4 +1,3 @@
-import io
import os
import re
import shutil
@@ -6,7 +5,6 @@
import sys
from contextlib import contextmanager
from pathlib import Path
-from types import FunctionType
from typing import (
TYPE_CHECKING,
Any,
@@ -14,20 +12,13 @@
Mapping,
Optional,
Sequence,
- Tuple,
- Type,
Union,
)
import build
-import dill
import toml
from build.__main__ import build_package, build_package_via_sdist
from confit import Cli
-from dill._dill import save_function as dill_save_function
-from dill._dill import save_type as dill_save_type
-from importlib_metadata import PackageNotFoundError
-from importlib_metadata import version as get_version
from loguru import logger
from typing_extensions import Literal
@@ -35,57 +26,6 @@
py_version = f"{sys.version_info.major}.{sys.version_info.minor}"
-
-def get_package(obj_type: Type):
- # Retrieve the __package__ attribute of the module of a type, if possible.
- # And returns the package version as well
- try:
- module_name = obj_type.__module__
- if module_name == "__main__":
- raise Exception(f"Could not find package of type {obj_type}")
- module = __import__(module_name, fromlist=["__package__"])
- package = module.__package__
- try:
- version = get_version(package)
- except (PackageNotFoundError, ValueError):
- return None
- return package, version
- except (ImportError, AttributeError):
- raise Exception(f"Cound not find package of type {obj_type}")
-
-
-def save_type(pickler, obj, *args, **kwargs):
- package_name = get_package(obj)
- if package_name is not None:
- pickler.packages.add(package_name)
- dill_save_type(pickler, obj, *args, **kwargs)
-
-
-def save_function(pickler, obj, *args, **kwargs):
- package_name = get_package(obj)
- if package_name is not None:
- pickler.packages.add(package_name)
- return dill_save_function(pickler, obj, *args, **kwargs)
-
-
-class PackagingPickler(dill.Pickler):
- dispatch = dill.Pickler.dispatch.copy()
-
- dispatch[FunctionType] = save_function
- dispatch[type] = save_type
-
- def __init__(self, *args, **kwargs):
- self.file = io.BytesIO()
- super().__init__(self.file, *args, **kwargs)
- self.packages = set()
-
-
-def get_deep_dependencies(obj):
- pickler = PackagingPickler()
- pickler.dump(obj)
- return sorted(pickler.packages)
-
-
app = Cli(pretty_exceptions_show_locals=False, pretty_exceptions_enable=False)
@@ -132,6 +72,8 @@ def validate(cls, value, config=None):
# Initialize the builder
try:
builder = SdistBuilder(poetry, None, None)
+ # Get the list of files to include
+ files = builder.find_files_to_add()
except ModuleOrPackageNotFound:
if not poetry.package.packages:
print([])
@@ -139,15 +81,13 @@ def validate(cls, value, config=None):
print([
{k: v for k, v in {
- "include": include._include,
- "from": include.source,
- "formats": include.formats,
- }.items()}
+ "include": getattr(include, '_include'),
+ "from": getattr(include, 'source', None),
+ "formats": getattr(include, 'formats', None),
+ }.items() if v}
for include in builder._module.includes
])
-# Get the list of files to include
-files = builder.find_files_to_add()
# Print the list of files
for file in files:
@@ -161,12 +101,16 @@ def validate(cls, value, config=None):
import edspdf
from pathlib import Path
+from typing import Optional, Dict, Any
__version__ = {__version__}
-def load(device: "torch.device" = "cpu") -> edspdf.Pipeline:
+def load(
+ overrides: Optional[Dict[str, Any]] = None,
+ device: "torch.device" = "cpu"
+) -> edspdf.Pipeline:
artifacts_path = Path(__file__).parent / "{artifacts_dir}"
- model = edspdf.load(artifacts_path, device=device)
+ model = edspdf.load(artifacts_path, overrides=overrides, device=device)
return model
"""
@@ -194,13 +138,12 @@ def __init__(
self,
pyproject: Optional[Dict[str, Any]],
pipeline: Union[Path, "edspdf.Pipeline"],
- version: str,
+ version: Optional[str],
name: Optional[ModuleName],
root_dir: Path = ".",
- build_name: Path = "build",
- out_dir: Path = "dist",
+ build_dir: Path = "build",
+ dist_dir: Path = "dist",
artifacts_name: ModuleName = "artifacts",
- dependencies: Optional[Sequence[Tuple[str, str]]] = None,
metadata: Optional[Dict[str, Any]] = {},
):
self.poetry_bin_path = (
@@ -212,13 +155,13 @@ def __init__(
self.name = name
self.pyproject = pyproject
self.root_dir = root_dir.resolve()
- self.dependencies = dependencies
self.pipeline = pipeline
self.artifacts_name = artifacts_name
- self.out_dir = self.root_dir / out_dir
+ self.dist_dir = (
+ dist_dir if Path(dist_dir).is_absolute() else self.root_dir / dist_dir
+ )
with self.ensure_pyproject(metadata):
-
python_executable = (
Path(self.poetry_bin_path).read_text().split("\n")[0][2:]
)
@@ -236,7 +179,9 @@ def __init__(
out = result.stdout.decode().strip().split("\n")
self.poetry_packages = eval(out[0])
- self.build_dir = root_dir / build_name / self.name
+ self.build_dir = (
+ build_dir if Path(build_dir).is_absolute() else root_dir / build_dir
+ ) / self.name
self.file_paths = [self.root_dir / file_path for file_path in out[1:]]
logger.info(f"root_dir: {self.root_dir}")
@@ -262,13 +207,9 @@ def ensure_pyproject(self, metadata):
"poetry": {
**metadata,
"name": self.name,
- "version": self.version,
+ "version": self.version or "0.1.0",
"dependencies": {
"python": f">={py_version},<4.0",
- **{
- dep_name: f"^{dep_version}"
- for dep_name, dep_version in self.dependencies
- },
},
},
},
@@ -319,7 +260,7 @@ def build(
distributions = ["wheel"]
build_call(
srcdir=self.build_dir,
- outdir=self.out_dir,
+ outdir=self.dist_dir,
distributions=distributions,
config_settings=config_settings,
isolation=isolation,
@@ -335,12 +276,13 @@ def update_pyproject(self):
f"project"
)
- old_version = self.pyproject["tool"]["poetry"]["version"]
- self.pyproject["tool"]["poetry"]["version"] = self.version
- logger.info(
- f"Replaced project version {old_version!r} with {self.version!r} in poetry "
- f"based project"
- )
+ if self.version is not None:
+ old_version = self.pyproject["tool"]["poetry"]["version"]
+ self.pyproject["tool"]["poetry"]["version"] = self.version
+ logger.info(
+ f"Replaced project version {old_version!r} with {self.version!r} in "
+ f"poetry based project"
+ )
# Adding artifacts to include in pyproject.toml
snake_name = snake_case(self.name.lower())
@@ -381,7 +323,7 @@ def make_src_dir(self):
build_artifacts_dir,
)
else:
- self.pipeline.save(build_artifacts_dir)
+ self.pipeline.to_disk(build_artifacts_dir)
os.makedirs(package_dir, exist_ok=True)
with open(package_dir / "__init__.py", mode="a") as f:
f.write(
@@ -397,10 +339,11 @@ def package(
pipeline: Union[Path, "edspdf.Pipeline"],
name: Optional[ModuleName] = None,
root_dir: Path = ".",
+ build_dir: Path = "build",
+ dist_dir: Path = "dist",
artifacts_name: ModuleName = "artifacts",
- check_dependencies: bool = False,
project_type: Optional[Literal["poetry", "setuptools"]] = None,
- version: str = "0.1.0",
+ version: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = {},
distributions: Optional[Sequence[Literal["wheel", "sdist"]]] = ["wheel"],
config_settings: Optional[Mapping[str, Union[str, Sequence[str]]]] = None,
@@ -411,21 +354,12 @@ def package(
pyproject_path = root_dir / "pyproject.toml"
if not pyproject_path.exists():
- check_dependencies = True
if name is None:
raise ValueError(
f"No pyproject.toml could be found in the root directory {root_dir}, "
f"you need to create one, or fill the name parameter."
)
- dependencies = None
- if check_dependencies:
- if isinstance(pipeline, Path):
- pipeline = edspdf.load(pipeline)
- dependencies = get_deep_dependencies(pipeline)
- for dep in dependencies:
- print("DEPENDENCY", dep[0].ljust(30), dep[1])
-
root_dir = root_dir.resolve()
pyproject = None
@@ -442,8 +376,9 @@ def package(
name=name,
version=version,
root_dir=root_dir,
+ build_dir=build_dir,
+ dist_dir=dist_dir,
artifacts_name=artifacts_name,
- dependencies=dependencies,
metadata=metadata,
)
else:
diff --git a/edspdf/visualization/annotations.py b/edspdf/visualization/annotations.py
index cd5d0a8d..cf1770cc 100644
--- a/edspdf/visualization/annotations.py
+++ b/edspdf/visualization/annotations.py
@@ -59,7 +59,7 @@ def show_annotations(
"""
pdf_doc = pdfium.PdfDocument(pdf)
- pages = list(pdf_doc.render_topil(scale=2))
+ pages = list([page.render(scale=2).to_pil() for page in pdf_doc])
unique_labels = list(dict.fromkeys([box.label for box in annotations]))
if colors is None:
diff --git a/pyproject.toml b/pyproject.toml
index ed6024d6..037daa39 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -18,22 +18,28 @@ dependencies = [
"anyascii>=0.3.2",
"scikit-learn>=1.0.2,<2.0.0",
"pydantic>=1.2,<2.0.0",
- "catalogue~=2.0",
- "networkx~=2.6",
- "confit>=0.4.3,<1.0.0",
- "foldedtensor>=0.3.1,<1.0.0",
+ "catalogue>=2.0",
+ "networkx>=2.6",
+ "confit>=0.5.3,<1.0.0",
+ "fsspec<2023.1.0 ; python_version<'3.8'",
+ "fsspec ; python_version>='3.8'",
+ "foldedtensor>=0.3.3",
"torch>1.0.0",
"accelerate>=0.12.0,<1.0.0",
- "tqdm~=4.64.1",
+ "tqdm>=4.64",
"regex",
- "pdfminer.six>=20220319",
- "pypdfium2~=2.7",
- "rich-logger>=0.3.0,<1.0.0",
- "safetensors~=0.3.1",
+ "pdfminer.six>=20220319,<20231228 ; python_version<'3.8'",
+ "pdfminer.six ; python_version>='3.8'",
+ "pypdfium2>=4.0",
+ "rich-logger>=0.3",
+ "safetensors>=0.3",
"anyascii>=0.3.2",
- "attrs~=23.1",
+ "attrs>=23.1",
"build>=0.10.0",
+ "pyarrow",
"loguru",
+ "toml",
+ "dill",
]
@@ -50,9 +56,9 @@ dependencies = [
"mypy>=1.0.0",
"streamlit>=1.19",
"coverage>=6.5.0",
- "datasets~=2.10",
+ "datasets>=2.10",
"huggingface_hub>=0.8.1",
- "transformers~=4.30",
+ "transformers>=4.30",
]
[tool.hatch.envs.default.scripts]
@@ -139,19 +145,17 @@ omit-covered-files = false
concurrency = ["multiprocessing"]
[tool.coverage.report]
-omit = [
- "tests/*",
+include = [
+ "edspdf/*",
]
-# omit = [
-# "edspdf/accelerators/multiprocessing.py",
-# ]
-exclude_also = [
+exclude_lines = [
"def __repr__",
"if __name__ == .__main__.:",
"@overload",
"pragma: no cover",
"raise .*Error",
"raise .*Exception",
+ "warn\\(",
"if __name__ == .__main__.:",
"if (self[.])?name in exclude:",
"if TYPE_CHECKING:",
diff --git a/tests/conftest.py b/tests/conftest.py
index 289e7831..dc274392 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,6 +1,4 @@
-import copy
import os
-from functools import lru_cache
from pathlib import Path
import pytest
@@ -53,18 +51,6 @@ def error_pdf():
return path.read_bytes()
-@lru_cache(maxsize=1)
-def make_pdfdoc(pdf):
- from edspdf.pipes.extractors.pdfminer import PdfMinerExtractor
-
- return PdfMinerExtractor(render_pages=True)(pdf)
-
-
-@fixture()
-def pdfdoc(pdf):
- return copy.deepcopy(make_pdfdoc(pdf))
-
-
@fixture(scope="session")
def dummy_dataset(tmpdir_factory, pdf):
tmp_path = tmpdir_factory.mktemp("datasets")
diff --git a/tests/core/test_data.py b/tests/core/test_data.py
new file mode 100644
index 00000000..a695f66d
--- /dev/null
+++ b/tests/core/test_data.py
@@ -0,0 +1,160 @@
+import json
+import os
+from pathlib import Path
+
+import pandas as pd
+import pytest
+
+import edspdf
+import edspdf.accelerators.multiprocessing
+from edspdf import PDFDoc
+from edspdf.data.converters import CONTENT, FILENAME
+from edspdf.utils.collections import flatten
+
+
+def box_converter(x):
+ return [
+ {
+ "id": x.id,
+ "page_num": b.page_num,
+ "x0": b.x0,
+ "x1": b.x1,
+ "y0": b.y0,
+ "y1": b.y1,
+ }
+ for b in x.content_boxes
+ ]
+
+
+def full_file_converter(x):
+ return {
+ FILENAME: x.id,
+ CONTENT: x.content,
+ "annotations": [
+ {
+ "page_num": b.page_num,
+ "x0": b.x0,
+ "x1": b.x1,
+ "y0": b.y0,
+ "y1": b.y1,
+ }
+ for b in x.content_boxes
+ ],
+ }
+
+
+@pytest.mark.parametrize("write_mode", ["parquet", "pandas", "iterable", "files"])
+@pytest.mark.parametrize("num_cpu_workers", [1, 2])
+@pytest.mark.parametrize("write_in_worker", [False, True])
+def test_write_data(
+ frozen_pipeline,
+ tmp_path,
+ change_test_dir,
+ write_mode,
+ num_cpu_workers,
+ write_in_worker,
+):
+ docs = edspdf.data.read_files("file://" + os.path.abspath("../resources"))
+ docs = docs.map_pipeline(frozen_pipeline)
+ docs = docs.set_processing(
+ num_cpu_workers=num_cpu_workers,
+ gpu_pipe_names=[],
+ batch_by="content_boxes",
+ chunk_size=3,
+ sort_chunks=True,
+ )
+ if write_mode == "parquet":
+ docs.write_parquet(
+ "file://" + str(tmp_path / "parquet" / "test.parquet"),
+ converter=box_converter,
+ write_in_worker=write_in_worker,
+ )
+ df = pd.read_parquet("file://" + str(tmp_path / "parquet" / "test.parquet"))
+ elif write_mode == "pandas":
+ if write_in_worker:
+ pytest.skip()
+ df = docs.to_pandas(converter=box_converter)
+ elif write_mode == "iterable":
+ if write_in_worker:
+ pytest.skip()
+ df = pd.DataFrame(flatten(docs.to_iterable(converter=box_converter)))
+ else:
+ if write_in_worker:
+ pytest.skip()
+ docs.write_files(
+ tmp_path / "files",
+ converter=full_file_converter,
+ )
+ records = []
+ for f in (tmp_path / "files").rglob("*.json"):
+ records.extend(json.loads(f.read_text())["annotations"])
+ df = pd.DataFrame(records)
+ assert len(df) == 91
+
+
+@pytest.fixture(scope="module")
+def parquet_file(tmp_path_factory, request):
+ os.chdir(request.fspath.dirname)
+ tmp_path = tmp_path_factory.mktemp("test_input_parquet")
+ path = tmp_path / "input_test.pq"
+ docs = edspdf.data.read_files("file://" + os.path.abspath("../resources"))
+ docs.write_parquet(
+ path,
+ converter=lambda x: {
+ "content": x["content"],
+ "id": x["id"],
+ },
+ )
+ os.chdir(request.config.invocation_dir)
+ return path
+
+
+@pytest.mark.parametrize("read_mode", ["parquet", "pandas", "iterable", "files"])
+@pytest.mark.parametrize("num_cpu_workers", [1, 2])
+@pytest.mark.parametrize("read_in_worker", [False, True])
+def test_read_data(
+ frozen_pipeline,
+ tmp_path,
+ parquet_file,
+ change_test_dir,
+ read_mode,
+ num_cpu_workers,
+ read_in_worker,
+):
+ if read_mode == "files":
+ docs = edspdf.data.read_files(
+ "file://" + os.path.abspath("../resources"),
+ converter=lambda x: PDFDoc(id=x["id"], content=x["content"]),
+ # read_in_worker=True,
+ )
+ if read_in_worker:
+ pytest.skip()
+ elif read_mode == "parquet":
+ docs = edspdf.data.read_parquet(
+ parquet_file,
+ converter=lambda x: x["content"],
+ read_in_worker=True,
+ )
+ elif read_mode == "pandas":
+ if read_in_worker:
+ pytest.skip()
+ docs = edspdf.data.from_pandas(
+ pd.read_parquet(parquet_file),
+ converter=lambda x: x["content"],
+ )
+ else:
+ if read_in_worker:
+ pytest.skip()
+ docs = edspdf.data.from_iterable(
+ f.read_bytes() for f in Path("../resources").rglob("*.pdf")
+ )
+ docs = docs.map_pipeline(frozen_pipeline)
+ docs = docs.set_processing(
+ num_cpu_workers=num_cpu_workers,
+ show_progress=True,
+ batch_by="content_boxes",
+ chunk_size=3,
+ sort_chunks=True,
+ )
+ df = docs.to_pandas(converter=box_converter)
+ assert len(df) == 91
diff --git a/tests/core/test_pipeline.py b/tests/core/test_pipeline.py
index 554ef205..637a76a4 100644
--- a/tests/core/test_pipeline.py
+++ b/tests/core/test_pipeline.py
@@ -1,6 +1,7 @@
import copy
from itertools import chain
from pathlib import Path
+from time import sleep
import datasets
import pytest
@@ -214,7 +215,7 @@ def score(golds, preds):
)
assert type(pipeline(pdf)) == PDFDoc
- assert len(list(pipeline.pipe([pdf] * 4))) == 4
+ assert len(list(pipeline.pipe([pdf] * 4).set_processing(show_progress=True))) == 4
data = list(make_segmentation_adapter(dummy_dataset)(pipeline))
with pipeline.select_pipes(enable=["classifier"]):
@@ -223,11 +224,13 @@ def score(golds, preds):
def test_cache(pipeline: Pipeline, dummy_dataset: Path, pdf: bytes):
+ from edspdf.trainable_pipe import _caches
+
pipeline(pdf)
with pipeline.cache():
pipeline(pdf)
- assert len(pipeline._cache) > 0
+ assert len(_caches["default"]) > 0
assert pipeline._cache is None
@@ -249,11 +252,11 @@ def test_different_names(pipeline: Pipeline):
extractor = PdfMinerExtractor(pipeline=pipeline, name="custom_name")
- with pytest.raises(ValueError) as exc_info:
+ with pytest.warns() as record:
pipeline.add_pipe(extractor, name="extractor")
assert "The provided name does not match the name of the component." in str(
- exc_info.value
+ record[0].message
)
@@ -335,19 +338,21 @@ def test_multiprocessing_accelerator(frozen_pipeline, pdf, letter_pdf):
def error_pipe(doc: PDFDoc):
+ sleep(0.1)
if doc.id == "pdf-3":
raise ValueError("error")
return doc
-def test_multiprocessing_gpu_stub(frozen_pipeline, pdf, letter_pdf):
+def test_deprecated_multiprocessing_gpu_stub(frozen_pipeline, pdf, letter_pdf):
edspdf.accelerators.multiprocessing.MAX_NUM_PROCESSES = 2
- accelerator = edspdf.accelerators.multiprocessing.MultiprocessingAccelerator(
- batch_size=2,
- num_gpu_workers=1,
- num_cpu_workers=1,
- gpu_worker_devices=["cpu"],
- )
+ accelerator = {
+ "@accelerator": "multiprocessing",
+ "batch_size": 2,
+ "num_gpu_workers": 1,
+ "num_cpu_workers": 1,
+ "gpu_worker_devices": ["cpu"],
+ }
list(
frozen_pipeline.pipe(
chain.from_iterable(
@@ -364,6 +369,29 @@ def test_multiprocessing_gpu_stub(frozen_pipeline, pdf, letter_pdf):
)
+def test_multiprocessing_gpu_stub(frozen_pipeline, pdf, letter_pdf):
+ edspdf.accelerators.multiprocessing.MAX_NUM_PROCESSES = 2
+ iterator = chain.from_iterable(
+ [
+ {"content": pdf},
+ {"content": letter_pdf},
+ ]
+ for i in range(5)
+ )
+ docs = edspdf.data.from_iterable(
+ iterator, converter=lambda x: PDFDoc(content=x["content"])
+ )
+ docs = docs.map_pipeline(frozen_pipeline)
+ docs = docs.set_processing(
+ batch_size=2,
+ num_gpu_workers=1,
+ num_cpu_workers=1,
+ gpu_worker_devices=["cpu"],
+ batch_by="content_boxes",
+ )
+ docs = list(docs.to_iterable(converter=lambda x: {"text": x.aggregated_texts}))
+
+
def test_multiprocessing_rb_error(pipeline, pdf, letter_pdf):
edspdf.accelerators.multiprocessing.MAX_NUM_PROCESSES = 2
pipeline.add_pipe(error_pipe, name="error", after="extractor")
@@ -375,7 +403,7 @@ def test_multiprocessing_rb_error(pipeline, pdf, letter_pdf):
{"content": pdf, "id": f"pdf-{i}"},
{"content": letter_pdf, "id": f"letter-{i}"},
]
- for i in range(5)
+ for i in range(200)
),
accelerator="multiprocessing",
batch_size=2,
@@ -391,13 +419,14 @@ def __init__(self, *args, **kwargs):
def preprocess(self, doc):
return {"num_boxes": len(doc.content_boxes), "doc_id": doc.id}
- def collate(self, batch, device):
+ def collate(self, batch):
return {
- "num_boxes": torch.tensor(batch["num_boxes"], device=device),
+ "num_boxes": torch.tensor(batch["num_boxes"]),
"doc_id": batch["doc_id"],
}
def forward(self, batch):
+ sleep(0.1)
if "pdf-1" in batch["doc_id"]:
raise RuntimeError("Deep learning error")
return {}
@@ -426,10 +455,14 @@ def test_multiprocessing_ml_error(pipeline, pdf, letter_pdf):
{"content": pdf, "id": f"pdf-{i}"},
{"content": letter_pdf, "id": f"letter-{i}"},
]
- for i in range(5)
+ for i in range(200)
),
accelerator=accelerator,
to_doc={"content_field": "content", "id_field": "id"},
)
)
assert "Deep learning error" in str(e.value)
+
+
+def test_apply_on_empty_pdf(error_pdf, frozen_pipeline):
+ assert len(frozen_pipeline(error_pdf).content_boxes) == 0
diff --git a/tests/pipes/aggregators/test_simple.py b/tests/pipes/aggregators/test_simple.py
index f6f0eb12..cda90f65 100644
--- a/tests/pipes/aggregators/test_simple.py
+++ b/tests/pipes/aggregators/test_simple.py
@@ -63,7 +63,13 @@ def test_no_style():
def test_styled_pdfminer_aggregation(styles_pdf):
extractor = PdfMinerExtractor(extract_style=True)
- aggregator = SimpleAggregator()
+ aggregator = SimpleAggregator(
+ sort=True,
+ label_map={
+ "header": ["header"],
+ "body": "body",
+ },
+ )
doc = extractor(styles_pdf)
for b, label in zip(doc.text_boxes, cycle(["header", "body"])):
diff --git a/tests/pipes/embeddings/test_custom.py b/tests/pipes/embeddings/test_custom.py
index 31f5a6f3..b0571177 100644
--- a/tests/pipes/embeddings/test_custom.py
+++ b/tests/pipes/embeddings/test_custom.py
@@ -3,9 +3,10 @@
from edspdf.pipes.embeddings.embedding_combiner import EmbeddingCombiner
from edspdf.pipes.embeddings.simple_text_embedding import SimpleTextEmbedding
from edspdf.pipes.embeddings.sub_box_cnn_pooler import SubBoxCNNPooler
+from edspdf.pipes.extractors.pdfminer import PdfMinerExtractor
-def test_custom_embedding(pdfdoc, tmp_path):
+def test_custom_embedding(pdf, error_pdf, tmp_path):
embedding = BoxTransformer(
num_heads=4,
dropout_p=0.1,
@@ -35,8 +36,14 @@ def test_custom_embedding(pdfdoc, tmp_path):
),
)
str(embedding)
+
+ extractor = PdfMinerExtractor(render_pages=True)
+ pdfdoc = extractor(pdf)
pdfdoc.text_boxes[0].text = "Very long word of 150 letters : " + "x" * 150
embedding.post_init([pdfdoc], set())
embedding(pdfdoc)
embedding.save_extra_data(tmp_path, set())
embedding.load_extra_data(tmp_path, set())
+
+ # Test empty document
+ embedding(extractor(error_pdf))
diff --git a/tests/pipes/embeddings/test_huggingface.py b/tests/pipes/embeddings/test_huggingface.py
index e82f386d..3e619b4e 100644
--- a/tests/pipes/embeddings/test_huggingface.py
+++ b/tests/pipes/embeddings/test_huggingface.py
@@ -1,7 +1,8 @@
from edspdf.pipes.embeddings.huggingface_embedding import HuggingfaceEmbedding
+from edspdf.pipes.extractors.pdfminer import PdfMinerExtractor
-def test_huggingface_embedding(pdfdoc):
+def test_huggingface_embedding(pdf, error_pdf):
embedding = HuggingfaceEmbedding(
pipeline=None,
name="huggingface",
@@ -15,4 +16,7 @@ def test_huggingface_embedding(pdfdoc):
"height": embedding.hf_model.config.input_size,
"width": embedding.hf_model.config.input_size,
}
- embedding(pdfdoc)
+
+ extractor = PdfMinerExtractor(render_pages=True)
+ embedding(extractor(pdf))
+ embedding(extractor(error_pdf))
diff --git a/tests/utils/test_package.py b/tests/utils/test_package.py
index 28168b12..1743db9f 100644
--- a/tests/utils/test_package.py
+++ b/tests/utils/test_package.py
@@ -5,6 +5,7 @@
import pytest
import torch
+import edspdf
from edspdf.utils.package import package
@@ -82,7 +83,6 @@ def test_package_with_files(frozen_pipeline, tmp_path, package_name):
name=package_name,
pipeline=tmp_path / "model",
root_dir=tmp_path,
- check_dependencies=True,
version="0.1.0",
distributions=None,
metadata={
@@ -133,16 +133,23 @@ def test_package_with_files(frozen_pipeline, tmp_path, package_name):
import edspdf
from pathlib import Path
+from typing import Optional, Dict, Any
__version__ = '0.1.0'
-def load(device: "torch.device" = "cpu") -> edspdf.Pipeline:
+def load(
+ overrides: Optional[Dict[str, Any]] = None,
+ device: "torch.device" = "cpu"
+) -> edspdf.Pipeline:
artifacts_path = Path(__file__).parent / "artifacts"
- model = edspdf.load(artifacts_path, device=device)
+ model = edspdf.load(artifacts_path, overrides=overrides, device=device)
return model
"""
)
+ module.load()
+ edspdf.load(module_name)
+
@pytest.fixture(scope="session", autouse=True)
def clean_after():