diff --git a/changelog.md b/changelog.md index 200ce78a..f541b7ab 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,7 @@ - Add multi-modal transformers (`huggingface-embedding`) with windowing options - Add `render_page` option to `pdfminer` extractor, for multi-modal PDF features +- Add inference utilities (`accelerators`), with simple mono process support and multi gpu / cpu support ### Changed diff --git a/docs/assets/images/multiprocessing.svg b/docs/assets/images/multiprocessing.svg new file mode 100644 index 00000000..594b0d04 --- /dev/null +++ b/docs/assets/images/multiprocessing.svg @@ -0,0 +1,3 @@ + + +
CPU Worker 1
CPU Worker 1
CPU Worker 2
CPU Worker 2
CPU Worker 3
CPU Worker 3
GPU Worker 1
GPU Worker 1
GPU Worker 2
GPU Worker 2

batch_id = 8
cpu_id = 1
gpu_id = 0
stage = 2
forward out = ...

batch_id = 8...

batch_id = 28
cpu_id = 2
gpu_id = 1
stage = 0
collate out = ...

batch_id = 28...
Inputs
Inputs
Non deep-learning ops:
- extractors
- aggregators
- feature preprocessing
- feature collating
- forward output postproc.
Non deep-learning ops:...
Deep-learning ops:
- forward
Deep-learning ops:...

batch_id = 28
cpu_id = 2
gpu_id = ?
stage = 0
input doc = ...

batch_id = 28...
Outputs
Outputs
Text is not SVG - cannot display
diff --git a/docs/assets/stylesheets/extra.css b/docs/assets/stylesheets/extra.css index 97c90e53..e2b927df 100644 --- a/docs/assets/stylesheets/extra.css +++ b/docs/assets/stylesheets/extra.css @@ -155,6 +155,6 @@ body, input { margin-top: 1.5rem; } -.references { - +.doc td > code { + word-break: normal; } diff --git a/docs/inference.md b/docs/inference.md new file mode 100644 index 00000000..d849ba86 --- /dev/null +++ b/docs/inference.md @@ -0,0 +1,61 @@ +# Inference + +Once you have obtained a pipeline, either by composing rule-based components, training a model or loading a model from the disk, you can use it to make predictions on documents. This is referred to as inference. + +## Inference on a single document + +In EDS-PDF, computing the prediction on a single document is done by calling the pipeline on the document. The input can be either: + +- a sequence of bytes +- or a [PDFDoc][edspdf.structures.PDFDoc] object + +```python +from pathlib import Path + +pipeline = ... +content = Path("path/to/.pdf").read_bytes() +doc = pipeline(content) +``` + +If you're lucky enough to have a GPU, you can use it to speed up inference by moving the model to the GPU before calling the pipeline. To leverage multiple GPUs, refer to the [multiprocessing accelerator][edspdf.accelerators.multiprocessing.MultiprocessingAccelerator] description below. + +```python +pipeline.to("cuda") # same semantics as pytorch +doc = pipeline(content) +``` + +## Inference on multiple documents + +When processing multiple documents, it is usually more efficient to use the `pipeline.pipe(...)` method, especially when using deep learning components, since this allow matrix multiplications to be batched together. Depending on your computational resources and requirements, EDS-PDF comes with various "accelerators" to speed up inference (see the [Accelerators](#accelerators) section for more details). By default, the `.pipe()` method uses the [`simple` accelerator][edspdf.accelerators.simple.SimpleAccelerator] but you can switch to a different one by passing the `accelerator` argument. + +```python +pipeline = ... +docs = pipeline.pipe( + [content1, content2, ...], + batch_size=16, # optional, default to the one defined in the pipeline + accelerator=my_accelerator, +) +``` + +The `pipe` method supports the following arguments : + +::: edspdf.pipeline.Pipeline.pipe + options: + heading_level: 3 + only_parameters: true + +## Accelerators + +### Simple accelerator {: #edspdf.accelerators.simple.SimpleAccelerator } + +::: edspdf.accelerators.simple.SimpleAccelerator + options: + heading_level: 3 + only_class_level: true + +### Multiprocessing accelerator {: #edspdf.accelerators.multiprocessing.MultiprocessingAccelerator } + +::: edspdf.accelerators.multiprocessing.MultiprocessingAccelerator + options: + heading_level: 3 + only_class_level: true diff --git a/docs/pipeline.md b/docs/pipeline.md index 23b0042c..8d1d1b00 100644 --- a/docs/pipeline.md +++ b/docs/pipeline.md @@ -57,6 +57,8 @@ model(pdf_bytes) model.pipe([pdf_bytes, ...]) ``` +For more information on how to use the pipeline, refer to the [Inference](../inference) page. + ## Hybrid models EDS-PDF was designed to facilitate the training and inference of hybrid models that diff --git a/edspdf/accelerators/base.py b/edspdf/accelerators/base.py index a72650a2..07bc01b2 100644 --- a/edspdf/accelerators/base.py +++ b/edspdf/accelerators/base.py @@ -25,23 +25,18 @@ def __get_validators__(cls): @classmethod def validate(cls, value, config=None): if isinstance(value, str): - return FromDictFieldsToDoc(value) - elif isinstance(value, dict): - return FromDictFieldsToDoc(**value) - elif callable(value): + value = {"content_field": value} + if isinstance(value, dict): + value = FromDictFieldsToDoc(**value) + if callable(value): return value - else: - raise TypeError( - f"Invalid entry {value} ({type(value)}) for ToDoc, " - f"expected string, a dict or a callable." - ) - + raise TypeError( + f"Invalid entry {value} ({type(value)}) for ToDoc, " + f"expected string, a dict or a callable." + ) -def identity(x): - return x - -FROM_DOC_TO_DICT_FIELDS_TEMPLATE = """\ +FROM_DOC_TO_DICT_FIELDS_TEMPLATE = """ def fn(doc): return {X} """ @@ -50,8 +45,10 @@ def fn(doc): class FromDocToDictFields: def __init__(self, mapping): self.mapping = mapping - dict_fields = ", ".join(f"{k}: doc.{v}" for k, v in mapping.items()) - self.fn = eval(FROM_DOC_TO_DICT_FIELDS_TEMPLATE.replace("X", dict_fields)) + dict_fields = ", ".join(f"{repr(k)}: doc.{v}" for k, v in mapping.items()) + local_vars = {} + exec(FROM_DOC_TO_DICT_FIELDS_TEMPLATE.replace("X", dict_fields), local_vars) + self.fn = local_vars["fn"] def __reduce__(self): return FromDocToDictFields, (self.mapping,) @@ -75,14 +72,13 @@ def __get_validators__(cls): @classmethod def validate(cls, value, config=None): if isinstance(value, dict): - return FromDocToDictFields(value) - elif callable(value): + value = FromDocToDictFields(value) + if callable(value): return value - else: - raise TypeError( - f"Invalid entry {value} ({type(value)}) for ToDoc, " - f"expected dict or callable" - ) + raise TypeError( + f"Invalid entry {value} ({type(value)}) for ToDoc, " + f"expected dict or callable" + ) class Accelerator: @@ -92,7 +88,6 @@ def __call__( model: Any, to_doc: ToDoc = FromDictFieldsToDoc("content"), from_doc: FromDoc = lambda doc: doc, - component_cfg: Dict[str, Dict[str, Any]] = None, ): raise NotImplementedError() diff --git a/edspdf/accelerators/multiprocessing.py b/edspdf/accelerators/multiprocessing.py new file mode 100644 index 00000000..6802a2ee --- /dev/null +++ b/edspdf/accelerators/multiprocessing.py @@ -0,0 +1,545 @@ +import gc +import signal +from multiprocessing.connection import wait +from random import shuffle +from typing import Any, Iterable, List, Optional, Union + +import torch +import torch.multiprocessing as mp + +from .. import TrainablePipe +from ..registry import registry +from ..utils.collections import batchify +from .base import Accelerator, FromDictFieldsToDoc, FromDoc, ToDoc + +DEBUG = True + +debug = ( + (lambda *args, flush=False, **kwargs: print(*args, **kwargs, flush=True)) + if DEBUG + else lambda *args, **kwargs: None +) + + +class Exchanger: + def __init__( + self, + num_stages, + num_gpu_workers, + num_cpu_workers, + gpu_worker_devices, + ): + # queue for cpu input tasks + self.gpu_worker_devices = gpu_worker_devices + # We add prioritized queue at the end for STOP signals + self.cpu_inputs_queues = [ + [mp.SimpleQueue()] + [mp.SimpleQueue() for _ in range(num_stages + 1)] + # The input queue is not shared between processes, since calling `wait` + # on a queue reader from multiple processes may lead to a deadlock + for _ in range(num_cpu_workers) + ] + self.gpu_inputs_queues = [ + [mp.SimpleQueue() for _ in range(num_stages + 1)] + for _ in range(num_gpu_workers) + ] + self.outputs_queue = mp.Queue() + + def get_cpu_tasks(self, idx): + while True: + queue_readers = wait( + [queue._reader for queue in self.cpu_inputs_queues[idx]] + ) + stage, queue = next( + (stage, q) + for stage, q in reversed(list(enumerate(self.cpu_inputs_queues[idx]))) + if q._reader in queue_readers + ) + try: + item = queue.get() + except BaseException: + continue + if item is None: + return + yield stage, item + + def put_cpu(self, item, stage, idx): + return self.cpu_inputs_queues[idx][stage].put(item) + + def get_gpu_tasks(self, idx): + while True: + queue_readers = wait( + [queue._reader for queue in self.gpu_inputs_queues[idx]] + ) + stage, queue = next( + (stage, q) + for stage, q in reversed(list(enumerate(self.gpu_inputs_queues[idx]))) + if q._reader in queue_readers + ) + try: + item = queue.get() + except BaseException: # pragma: no cover + continue + if item is None: + return + yield stage, item + + def put_gpu(self, item, stage, idx): + return self.gpu_inputs_queues[idx][stage].put(item) + + def put_results(self, items): + self.outputs_queue.put(items) + + def iter_results(self): + for out in iter(self.outputs_queue.get, None): + yield out + + +class CPUWorker(mp.Process): + def __init__( + self, + cpu_idx: int, + exchanger: Exchanger, + gpu_pipe_names: List[str], + model: Any, + device: Union[str, torch.device], + ): + super(CPUWorker, self).__init__() + + self.cpu_idx = cpu_idx + self.exchanger = exchanger + self.gpu_pipe_names = gpu_pipe_names + self.model = model + self.device = device + + def _run(self): + # Cannot pass torch tensor during init i think ? otherwise i get + # ValueError: bad value(s) in fds_to_keep + mp._prctl_pr_set_pdeathsig(signal.SIGINT) + + model = self.model.to(self.device) + stages = [{"cpu_components": [], "gpu_component": None}] + for name, component in model.pipeline: + if name in self.gpu_pipe_names: + stages[-1]["gpu_component"] = component + stages.append({"cpu_components": [], "gpu_component": None}) + else: + stages[-1]["cpu_components"].append(component) + + next_batch_id = 0 + active_batches = {} + debug( + f"CPU worker {self.cpu_idx} is ready", + next(model.parameters()).device, + flush=True, + ) + + had_error = False + with torch.no_grad(): + for stage, task in self.exchanger.get_cpu_tasks(self.cpu_idx): + if had_error: + continue # pragma: no cover + try: + if stage == 0: + gpu_idx = None + batch_id = next_batch_id + debug("preprocess start for", batch_id) + next_batch_id += 1 + docs = task + else: + gpu_idx, batch_id, result = task + debug("postprocess start for", batch_id) + docs = active_batches.pop(batch_id) + gpu_pipe = stages[stage - 1]["gpu_component"] + docs = gpu_pipe.postprocess(docs, result) # type: ignore + + for component in stages[stage]["cpu_components"]: + if hasattr(component, "batch_process"): + docs = component.batch_process(docs) + else: + docs = [component(doc) for doc in docs] + + gpu_pipe = stages[stage]["gpu_component"] + if gpu_pipe is not None: + preprocessed = gpu_pipe.make_batch(docs) # type: ignore + active_batches[batch_id] = docs + if gpu_idx is None: + gpu_idx = batch_id % len(self.exchanger.gpu_worker_devices) + collated = gpu_pipe.collate( # type: ignore + preprocessed, + device=self.exchanger.gpu_worker_devices[gpu_idx], + ) + self.exchanger.put_gpu( + item=(self.cpu_idx, batch_id, collated), + idx=gpu_idx, + stage=stage, + ) + batch_id += 1 + debug("preprocess end for", batch_id) + else: + self.exchanger.put_results((docs, self.cpu_idx, gpu_idx)) + debug("postprocess end for", batch_id) + except BaseException as e: + had_error = True + import traceback + + print(traceback.format_exc(), flush=True) + self.exchanger.put_results((e, self.cpu_idx, None)) + # We need to drain the queues of GPUWorker fed inputs (pre-moved to GPU) + # to ensure no tensor allocated on producer processes (CPUWorker via + # collate) are left in consumer processes + debug("Start draining CPU worker", self.cpu_idx) + [None for _ in self.exchanger.get_cpu_tasks(self.cpu_idx)] + debug(f"CPU worker {self.cpu_idx} is about to stop") + + def run(self): + self._run() + self.model = None + gc.collect() + torch.cuda.empty_cache() + + +class GPUWorker(mp.Process): + def __init__( + self, + gpu_idx, + exchanger: Exchanger, + gpu_pipe_names: List[str], + model: Any, + device: Union[str, torch.device], + ): + super().__init__() + + self.device = device + self.gpu_idx = gpu_idx + self.exchanger = exchanger + + self.gpu_pipe_names = gpu_pipe_names + self.model = model + self.device = device + + def _run(self): + debug("GPU worker", self.gpu_idx, "started") + mp._prctl_pr_set_pdeathsig(signal.SIGINT) + had_error = False + + model = self.model.to(self.device) + stage_components = [model.get_pipe(name) for name in self.gpu_pipe_names] + del model + with torch.no_grad(): + for stage, task in self.exchanger.get_gpu_tasks(self.gpu_idx): + if had_error: + continue # pragma: no cover + try: + cpu_idx, batch_id, batch = task + debug("forward start for", batch_id) + component = stage_components[stage] + res = component.module_forward(batch) + del batch, task + # TODO set non_blocking=True here + res = { + key: val.to("cpu") if not isinstance(val, int) else val + for key, val in res.items() + } + self.exchanger.put_cpu( + item=(self.gpu_idx, batch_id, res), + stage=stage + 1, + idx=cpu_idx, + ) + debug("forward end for", batch_id) + except BaseException as e: + had_error = True + self.exchanger.put_results((e, None, self.gpu_idx)) + import traceback + + print(traceback.format_exc(), flush=True) + task = batch = res = None # noqa + # We need to drain the queues of CPUWorker fed inputs (pre-moved to GPU) + # to ensure no tensor allocated on producer processes (CPUWorker via + # collate) are left in consumer processes + debug("Start draining GPU worker", self.gpu_idx) + [None for _ in self.exchanger.get_gpu_tasks(self.gpu_idx)] + debug(f"GPU worker {self.gpu_idx} is about to stop") + + def run(self): + self._run() + self.model = None + gc.collect() + torch.cuda.empty_cache() + + +DEFAULT_MAX_CPU_WORKERS = 4 + + +@registry.accelerator.register("multiprocessing") +class MultiprocessingAccelerator(Accelerator): + """ + If you have multiple CPU cores, and optionally multiple GPUs, we provide a + `multiprocessing` accelerator that allows to run the inference on multiple + processes. + + This accelerator dispatches the batches between multiple workers + (data-parallelism), and distribute the computation of a given batch on one or two + workers (model-parallelism). This is done by creating two types of workers: + + - a `CPUWorker` which handles the non deep-learning components and the + preprocessing, collating and postprocessing of deep-learning components + - a `GPUWorker` which handles the forward call of the deep-learning components + + The advantage of dedicating a worker to the deep-learning components is that it + allows to prepare multiple batches in parallel in multiple `CPUWorker`, and ensure + that the `GPUWorker` never wait for a batch to be ready. + + The overall architecture described in the following figure, for 3 CPU workers and 2 + GPU workers. + +
+ +
+ + Here is how a small pipeline with rule-based components and deep-learning components + is distributed between the workers: + +
+ +
+ + Examples + -------- + + ```python + docs = list( + pipeline.pipe( + [content1, content2, ...], + accelerator={ + "@accelerator": "multiprocessing", + "num_cpu_workers": 3, + "num_gpu_workers": 2, + "batch_size": 8, + }, + ) + ) + ``` + + Parameters + ---------- + batch_size: int + Number of documents to process at a time in a CPU/GPU worker + num_cpu_workers: int + Number of CPU workers. A CPU worker handles the non deep-learning components + and the preprocessing, collating and postprocessing of deep-learning components. + num_gpu_workers: Optional[int] + Number of GPU workers. A GPU worker handles the forward call of the + deep-learning components. + gpu_pipe_names: Optional[List[str]] + List of pipe names to accelerate on a GPUWorker, defaults to all pipes + that inherit from TrainablePipe + """ + + def __init__( + self, + batch_size: int, + num_cpu_workers: Optional[int] = None, + num_gpu_workers: Optional[int] = None, + gpu_pipe_names: Optional[List[str]] = None, + gpu_worker_devices: Optional[List[Union[torch.device, str]]] = None, + cpu_worker_devices: Optional[List[Union[torch.device, str]]] = None, + ): + self.batch_size = batch_size + self.num_gpu_workers: Optional[int] = num_gpu_workers + self.num_cpu_workers = num_cpu_workers + self.gpu_pipe_names = gpu_pipe_names + self.gpu_worker_devices = gpu_worker_devices + self.cpu_worker_devices = cpu_worker_devices + + def __call__( + self, + inputs: Iterable[Any], + model: Any, + to_doc: ToDoc = FromDictFieldsToDoc("content"), + from_doc: FromDoc = lambda doc: doc, + ): + """ + Stream of documents to process. Each document can be a string or a tuple + + Parameters + ---------- + inputs + model + + Yields + ------ + Any + Processed outputs of the pipeline + """ + if torch.multiprocessing.get_start_method() != "spawn": + torch.multiprocessing.set_start_method("spawn", force=True) + + gpu_pipe_names = ( + [ + name + for name, component in model.pipeline + if isinstance(component, TrainablePipe) + ] + if self.gpu_pipe_names is None + else self.gpu_pipe_names + ) + + if not all(model.has_pipe(name) for name in gpu_pipe_names): + raise ValueError( + "GPU accelerated pipes {} could not be found in the model".format( + sorted(set(model.pipe_names) - set(gpu_pipe_names)) + ) + ) + + num_devices = torch.cuda.device_count() + print(f"Number of available devices: {num_devices}", flush=True) + + num_cpu_workers = self.num_cpu_workers + num_gpu_workers = self.num_gpu_workers + + if num_gpu_workers is None: + num_gpu_workers = num_devices if len(gpu_pipe_names) > 0 else 0 + + if num_cpu_workers is None: + num_cpu_workers = max( + min(mp.cpu_count() - num_gpu_workers, DEFAULT_MAX_CPU_WORKERS), 0 + ) + + if num_gpu_workers == 0: + gpu_pipe_names = [] + + gpu_worker_devices = ( + [ + torch.device(f"cuda:{gpu_idx * num_devices // num_gpu_workers}") + for gpu_idx in range(num_gpu_workers) + ] + if self.gpu_worker_devices is None + else self.gpu_worker_devices + ) + cpu_worker_devices = ( + ["cpu"] * num_cpu_workers + if self.cpu_worker_devices is None + else self.cpu_worker_devices + ) + assert len(cpu_worker_devices) == num_cpu_workers + assert len(gpu_worker_devices) == num_gpu_workers + if num_cpu_workers == 0: + ( + num_cpu_workers, + num_gpu_workers, + cpu_worker_devices, + gpu_worker_devices, + gpu_pipe_names, + ) = (num_gpu_workers, 0, gpu_worker_devices, [], []) + + debug(f"Number of CPU workers: {num_cpu_workers}") + debug(f"Number of GPU workers: {num_gpu_workers}") + + exchanger = Exchanger( + num_stages=len(gpu_pipe_names), + num_cpu_workers=num_cpu_workers, + num_gpu_workers=num_gpu_workers, + gpu_worker_devices=gpu_worker_devices, + ) + + cpu_workers = [] + gpu_workers = [] + model = model.to("cpu") + + for gpu_idx in range(num_gpu_workers): + gpu_workers.append( + GPUWorker( + gpu_idx=gpu_idx, + exchanger=exchanger, + gpu_pipe_names=gpu_pipe_names, + model=model, + device=gpu_worker_devices[gpu_idx], + ) + ) + + for cpu_idx in range(num_cpu_workers): + cpu_workers.append( + CPUWorker( + cpu_idx=cpu_idx, + exchanger=exchanger, + gpu_pipe_names=gpu_pipe_names, + model=model, + device=cpu_worker_devices[cpu_idx], + ) + ) + + for worker in (*cpu_workers, *gpu_workers): + worker.start() + + try: + num_max_enqueued = num_cpu_workers * 2 + 10 + # Number of input/output batch per process + total_inputs = [0] * num_cpu_workers + total_outputs = [0] * num_cpu_workers + outputs_iterator = exchanger.iter_results() + + cpu_worker_indices = list(range(num_cpu_workers)) + inputs_iterator = (to_doc(i) for i in inputs) + for i, pdfs_batch in enumerate(batchify(inputs_iterator, self.batch_size)): + if sum(total_inputs) - sum(total_outputs) >= num_max_enqueued: + outputs, cpu_idx, gpu_idx = next(outputs_iterator) + if isinstance(outputs, BaseException): + raise outputs # pragma: no cover + yield from (from_doc(o) for o in outputs) + total_outputs[cpu_idx] += 1 + + # Shuffle to ensure the first process does not receive all the documents + # in case of total_inputs - total_outputs equality + shuffle(cpu_worker_indices) + cpu_idx = min( + cpu_worker_indices, + key=lambda i: total_inputs[i] - total_outputs[i], + ) + exchanger.put_cpu(pdfs_batch, stage=0, idx=cpu_idx) + total_inputs[cpu_idx] += 1 + + while sum(total_outputs) < sum(total_inputs): + outputs, cpu_idx, gpu_idx = next(outputs_iterator) + if isinstance(outputs, BaseException): + raise outputs # pragma: no cover + yield from (from_doc(o) for o in outputs) + total_outputs[cpu_idx] += 1 + finally: + # Send gpu and cpu process the order to stop processing data + # We use the prioritized queue to ensure the stop signal is processed + # before the next batch of data + for i, worker in enumerate(gpu_workers): + exchanger.gpu_inputs_queues[i][-1].put(None) + debug("Asked gpu worker", i, "to stop processing data") + for i, worker in enumerate(cpu_workers): + exchanger.cpu_inputs_queues[i][-1].put(None) + debug("Asked cpu worker", i, "to stop processing data") + + # Enqueue a final non prioritized STOP signal to ensure there remains no + # data in the queues (cf drain loop in CPUWorker / GPUWorker) + for i, worker in enumerate(gpu_workers): + exchanger.gpu_inputs_queues[i][0].put(None) + debug("Asked gpu", i, "to end") + for i, worker in enumerate(gpu_workers): + worker.join(timeout=5) + debug("Joined gpu worker", i) + for i, worker in enumerate(cpu_workers): + exchanger.cpu_inputs_queues[i][0].put(None) + debug("Asked cpu", i, "to end") + for i, worker in enumerate(cpu_workers): + worker.join(timeout=1) + debug("Joined cpu worker", i) + + # If a worker is still alive, kill it + # This should not happen, but for a reason I cannot explain, it does in + # some CPU workers sometimes when we catch an error, even though each run + # method of the workers completes cleanly. Maybe this has something to do + # with the cleanup of these processes ? + for i, worker in enumerate(gpu_workers): # pragma: no cover + if worker.is_alive(): + print("Killing gpu worker", i) + worker.kill() + for i, worker in enumerate(cpu_workers): # pragma: no cover + if worker.is_alive(): + print("Killing cpu worker", i) + worker.kill() diff --git a/edspdf/accelerators/simple.py b/edspdf/accelerators/simple.py index 219af4c5..340dbebb 100644 --- a/edspdf/accelerators/simple.py +++ b/edspdf/accelerators/simple.py @@ -80,11 +80,6 @@ def __call__( from_doc: FromDoc = lambda doc: doc, component_cfg: Dict[str, Dict[str, Any]] = None, ): - if from_doc is None: - - def from_doc(doc): - return doc - docs = (to_doc(doc) for doc in inputs) for batch in batchify(docs, batch_size=self.batch_size): with torch.no_grad(), model.cache(), model.train(False): diff --git a/mkdocs.yml b/mkdocs.yml index d0d2b362..b15ef621 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -29,6 +29,7 @@ nav: - configuration.md - data-structures.md - trainable-pipes.md + - inference.md - Recipes: - recipes/index.md - recipes/rule-based.md diff --git a/pyproject.toml b/pyproject.toml index c5e81df5..c218fc6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,10 +122,16 @@ whitelist-regex = [] color = true omit-covered-files = false +[tool.coverage.run] +concurrency = ["multiprocessing"] + [tool.coverage.report] omit = [ - "edspdf/accelerators/multi_gpu.py", + "tests/*", ] +# omit = [ +# "edspdf/accelerators/multiprocessing.py", +# ] exclude_also = [ "def __repr__", "if __name__ == .__main__.:", diff --git a/tests/core/test_pipeline.py b/tests/core/test_pipeline.py index 80a88199..6b0fb6b4 100644 --- a/tests/core/test_pipeline.py +++ b/tests/core/test_pipeline.py @@ -1,12 +1,16 @@ +from itertools import chain from pathlib import Path import datasets import pytest +import torch from confit import Config from confit.errors import ConfitValidationError from confit.registry import validate_arguments import edspdf +import edspdf.accelerators.multiprocessing +from edspdf import TrainablePipe from edspdf.pipeline import Pipeline from edspdf.pipes.aggregators.simple import SimpleAggregator from edspdf.pipes.extractors.pdfminer import PdfMinerExtractor @@ -319,3 +323,116 @@ def test_add_pipe_validation_error(): "-> extractor.foo\n" " unexpected keyword argument" ) + + +def test_multiprocessing_accelerator(frozen_pipeline, pdf, letter_pdf): + edspdf.accelerators.multiprocessing.MAX_NUM_PROCESSES = 2 + docs = list( + frozen_pipeline.pipe( + [pdf, letter_pdf] * 20, + accelerator="multiprocessing", + batch_size=2, + ) + ) + assert len(docs) == 40 + + +def error_pipe(doc: PDFDoc): + if doc.id == "pdf-3": + raise ValueError("error") + return doc + + +def test_multiprocessing_gpu_stub(frozen_pipeline, pdf, letter_pdf): + edspdf.accelerators.multiprocessing.MAX_NUM_PROCESSES = 2 + accelerator = edspdf.accelerators.multiprocessing.MultiprocessingAccelerator( + batch_size=2, + num_gpu_workers=1, + num_cpu_workers=1, + gpu_worker_devices=["cpu"], + ) + list( + frozen_pipeline.pipe( + chain.from_iterable( + [ + {"content": pdf}, + {"content": letter_pdf}, + ] + for i in range(5) + ), + accelerator=accelerator, + to_doc="content", + from_doc={"text": "aggregated_texts"}, + ) + ) + + +def test_multiprocessing_rb_error(pipeline, pdf, letter_pdf): + edspdf.accelerators.multiprocessing.MAX_NUM_PROCESSES = 2 + pipeline.add_pipe(error_pipe, name="error", after="extractor") + with pytest.raises(ValueError): + list( + pipeline.pipe( + chain.from_iterable( + [ + {"content": pdf, "id": f"pdf-{i}"}, + {"content": letter_pdf, "id": f"letter-{i}"}, + ] + for i in range(5) + ), + accelerator="multiprocessing", + batch_size=2, + to_doc={"content_field": "content", "id_field": "id"}, + ) + ) + + +class DeepLearningError(TrainablePipe): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def preprocess(self, doc): + return {"num_boxes": len(doc.content_boxes), "doc_id": doc.id} + + def collate(self, batch, device): + return { + "num_boxes": torch.tensor(batch["num_boxes"], device=device), + "doc_id": batch["doc_id"], + } + + def forward(self, batch): + if "pdf-1" in batch["doc_id"]: + raise RuntimeError("Deep learning error") + return {} + + +def test_multiprocessing_ml_error(pipeline, pdf, letter_pdf): + edspdf.accelerators.multiprocessing.MAX_NUM_PROCESSES = 2 + pipeline.add_pipe( + DeepLearningError( + pipeline=pipeline, + name="error", + ), + after="extractor", + ) + accelerator = edspdf.accelerators.multiprocessing.MultiprocessingAccelerator( + batch_size=2, + num_gpu_workers=1, + num_cpu_workers=1, + gpu_worker_devices=["cpu"], + ) + with pytest.raises(RuntimeError) as e: + list( + pipeline.pipe( + chain.from_iterable( + [ + {"content": pdf, "id": f"pdf-{i}"}, + {"content": letter_pdf, "id": f"letter-{i}"}, + ] + for i in range(5) + ), + accelerator=accelerator, + to_doc={"content_field": "content", "id_field": "id"}, + ) + ) + assert "Deep learning error" in str(e.value)